From 2021f6caeb7dee363d1176d99388ca1c71f56bc9 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 10 Aug 2025 18:10:14 +0000 Subject: [PATCH 01/24] datadeps: Fix views and implement remainder copies This commit fixes two major issues in Datadeps that caused incorrect results with views and ChunkView. First, when presented with arguments that alias (such as an `Array` and a view of that array), `generate_slot` would separately `move` these values onto the destination processor (without considering how they alias with each other), which could break the aliasing that they previously had on their originating processor. This meant that certain algorithms which used views together with the underlying arrays as arguments would get incorrect results in a distributed setting, because `generate_slot` would break aliasing and cause data to not be updated correctly during copies. This commit adds helpers (specifically `aliased_object!`) which allows objects like views and `ChunkView` to declare the underlying parent array as an object that may need to be tracked separately from the surrounding structure; this helper keeps track of other such declared objects that have been allocated on the destination processor, and replaces the source object with the destination object during `move`. By default, all arguments are now provided directly to `aliased_object!` to perform this replacement, but this can be customized by overloading `move_rewrap` (which `SubArray` and `ChunkView` now overload). Secondly, even with objects now properly aliasing on remote processors, Datadeps did not have a clear way to copy only the changed portions of an argument. For example, when only a view of an array is updated on a remote processor, and the next task will then need the full parent array on the same remote processor, how does Datadeps copy over only the portions of the parent array that aren't yet up-to-date on the remote? The answer is that it didn't; it would do a full copy of the parent array to the remote, which would then destroy the changes made to the underlying view. This commit overhauls the copying machinery to properly calculate this difference (termed the "remainder"), based on the target ainfo and all previously-updated ainfos, and schedules a "remainder copy" to copy only the exact bytes that are not yet updated on the remote. Additionally, it may schedule copies from multiple other remote processors to the "target" remote processor as necessary, in case portions of an aliased object exist on multiple distinct processors. This machinery is driven by a new interval tree implementation, which allows efficient calculation of differences between sets of memory spans, and uses `unsafe_copyto!` to handle arbitrary data. --- src/Dagger.jl | 8 +- src/datadeps.jl | 1082 --------------------------------- src/datadeps/aliasing.jl | 700 +++++++++++++++++++++ src/datadeps/chunkview.jl | 64 ++ src/datadeps/interval_tree.jl | 349 +++++++++++ src/datadeps/queue.jl | 500 +++++++++++++++ src/datadeps/remainders.jl | 443 ++++++++++++++ src/memory-spaces.jl | 96 +-- src/utils/dagdebug.jl | 25 + 9 files changed, 2121 insertions(+), 1146 deletions(-) delete mode 100644 src/datadeps.jl create mode 100644 src/datadeps/aliasing.jl create mode 100644 src/datadeps/chunkview.jl create mode 100644 src/datadeps/interval_tree.jl create mode 100644 src/datadeps/queue.jl create mode 100644 src/datadeps/remainders.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index fa30c7c1a..987963b34 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -83,7 +83,13 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Data dependency task queue -include("datadeps.jl") +include("datadeps/aliasing.jl") +include("datadeps/chunkview.jl") +include("datadeps/interval_tree.jl") +include("datadeps/remainders.jl") +include("datadeps/queue.jl") + +# Stencils include("utils/haloarray.jl") include("stencil.jl") diff --git a/src/datadeps.jl b/src/datadeps.jl deleted file mode 100644 index d20bda647..000000000 --- a/src/datadeps.jl +++ /dev/null @@ -1,1082 +0,0 @@ -import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv - -export In, Out, InOut, Deps, spawn_datadeps - -"Specifies a read-only dependency." -struct In{T} - x::T -end -"Specifies a write-only dependency." -struct Out{T} - x::T -end -"Specifies a read-write dependency." -struct InOut{T} - x::T -end -"Specifies one or more dependencies." -struct Deps{T,DT<:Tuple} - x::T - deps::DT -end -Deps(x, deps...) = Deps(x, deps) - -struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue - # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} - # The data-dependency graph of all tasks - g::Union{SimpleDiGraph{Int},Nothing} - # The mapping from task to graph ID - task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol - # Which scheduler to use to assign tasks to processors - scheduler::Symbol - - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] - g = SimpleDiGraph() - task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) - end -end - -function unwrap_inout(arg) - readdep = false - writedep = false - if arg isa In - readdep = true - arg = arg.x - elseif arg isa Out - writedep = true - arg = arg.x - elseif arg isa InOut - readdep = true - writedep = true - arg = arg.x - elseif arg isa Deps - alldeps = Tuple[] - for dep in arg.deps - dep_mod, inner_deps = unwrap_inout(dep) - for (_, readdep, writedep) in inner_deps - push!(alldeps, (dep_mod, readdep, writedep)) - end - end - arg = arg.x - return arg, alldeps - else - readdep = true - end - return arg, Tuple[(identity, readdep, writedep)] -end - -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) -end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) -end - -_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) -_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) -_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) - -struct ArgumentWrapper - arg - dep_mod - hash::UInt - - function ArgumentWrapper(arg, dep_mod) - h = hash(dep_mod) - h = _identity_hash(arg, h) - return new(arg, dep_mod, h) - end -end -Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) -Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = - aw1.hash == aw2.hash - -struct DataDepsAliasingState - # Track original and current data locations - # We track data => space - data_origin::Dict{AliasingWrapper,MemorySpace} - data_locality::Dict{AliasingWrapper,MemorySpace} - - # Track writers ("owners") and readers - ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} - ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} - - # Cache ainfo lookups - ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - - function DataDepsAliasingState() - data_origin = Dict{AliasingWrapper,MemorySpace}() - data_locality = Dict{AliasingWrapper,MemorySpace}() - - ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() - ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() - - ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - - return new(data_origin, data_locality, - ainfos_owner, ainfos_readers, ainfos_overlaps, - ainfo_cache) - end -end -struct DataDepsNonAliasingState - # Track original and current data locations - # We track data => space - data_origin::IdDict{Any,MemorySpace} - data_locality::IdDict{Any,MemorySpace} - - # Track writers ("owners") and readers - args_owner::IdDict{Any,Union{Pair{DTask,Int},Nothing}} - args_readers::IdDict{Any,Vector{Pair{DTask,Int}}} - - function DataDepsNonAliasingState() - data_origin = IdDict{Any,MemorySpace}() - data_locality = IdDict{Any,MemorySpace}() - - args_owner = IdDict{Any,Union{Pair{DTask,Int},Nothing}}() - args_readers = IdDict{Any,Vector{Pair{DTask,Int}}}() - - return new(data_origin, data_locality, - args_owner, args_readers) - end -end -struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} - # Whether aliasing is being analyzed - aliasing::Bool - - # The ordered list of tasks and their read/write dependencies - dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}} - - # The mapping of memory space to remote argument copies - remote_args::Dict{MemorySpace,IdDict{Any,Any}} - - # Cache of whether arguments supports in-place move - supports_inplace_cache::IdDict{Any,Bool} - - # The aliasing analysis state - alias_state::State - - function DataDepsState(aliasing::Bool) - dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[] - remote_args = Dict{MemorySpace,IdDict{Any,Any}}() - supports_inplace_cache = IdDict{Any,Bool}() - if aliasing - state = DataDepsAliasingState() - else - state = DataDepsNonAliasingState() - end - return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) - end -end - -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) - aw = ArgumentWrapper(arg, dep_mod) - get!(astate.ainfo_cache, aw) do - return AliasingWrapper(aliasing(arg, dep_mod)) - end -end - -function supports_inplace_move(state::DataDepsState, arg) - return get!(state.supports_inplace_cache, arg) do - return supports_inplace_move(arg) - end -end - -# Determine which arguments could be written to, and thus need tracking - -"Whether `arg` has any writedep in this datadeps region." -function has_writedep(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - # Check if we are writing to this memory - writedep = any(dep->dep[3], deps) - if writedep - arg_has_writedep[arg] = true - return true - end - - # Check if another task is writing to this memory - for (_, taskdeps) in state.dependencies - for (_, other_arg_writedep, _, _, other_arg) in taskdeps - other_arg_writedep || continue - if arg === other_arg - return true - end - end - end - - return false -end -""" -Whether `arg` has any writedep at or before executing `task` in this -datadeps region. -""" -function has_writedep(state::DataDepsState, arg, deps, task::DTask) - is_writedep(arg, deps, task) && return true - if state.aliasing - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, other_ainfo, _, _) in other_taskdeps - writedep || continue - for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) - if will_alias(ainfo, other_ainfo) - return true - end - end - end - if task === other_task - return false - end - end - else - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, _, _, other_arg) in other_taskdeps - writedep || continue - if arg === other_arg - return true - end - end - if task === other_task - return false - end - end - end - error("Task isn't in argdeps set") -end -"Whether `arg` is written to by `task`." -function is_writedep(arg, deps, task::DTask) - return any(dep->dep[3], deps) -end - -# Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Populate task dependencies - dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() - - # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(value(_arg)) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Add all aliasing dependencies - for (dep_mod, readdep, writedep) in deps - if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) - else - ainfo = AliasingWrapper(UnknownAliasing()) - end - push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg)) - end - - # Populate argument write info - populate_argument_info!(state, arg, deps) - end - - # Track the task result too - # N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this - push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task)) - - # Record argument/result dependencies - push!(state.dependencies, task => dependencies_to_add) -end -function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) - astate = state.alias_state - for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - - # Initialize owner and readers - if !haskey(astate.ainfos_owner, ainfo) - overlaps = Set{AliasingWrapper}() - push!(overlaps, ainfo) - for other_ainfo in keys(astate.ainfos_owner) - ainfo == other_ainfo && continue - if will_alias(ainfo, other_ainfo) - push!(overlaps, other_ainfo) - push!(astate.ainfos_overlaps[other_ainfo], ainfo) - end - end - astate.ainfos_overlaps[ainfo] = overlaps - astate.ainfos_owner[ainfo] = nothing - astate.ainfos_readers[ainfo] = Pair{DTask,Int}[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, ainfo) - astate.data_locality[ainfo] = memory_space(arg) - astate.data_origin[ainfo] = memory_space(arg) - end - end -end -function populate_argument_info!(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - astate = state.alias_state - # Initialize owner and readers - if !haskey(astate.args_owner, arg) - astate.args_owner[arg] = nothing - astate.args_readers[arg] = DTask[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, arg) - astate.data_locality[arg] = memory_space(arg) - astate.data_origin[arg] = memory_space(arg) - end -end -function populate_return_info!(state::DataDepsState{DataDepsAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - # FIXME: We don't yet know about ainfos for this task -end -function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - astate.data_locality[task] = space - astate.data_origin[task] = space -end - -""" - supports_inplace_move(x) -> Bool - -Returns `false` if `x` doesn't support being copied into from another object -like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting -to copy between values which don't support mutation or otherwise don't have an -implemented `move!` and want to skip in-place copies. When this returns -`false`, datadeps will instead perform out-of-place copies for each non-local -use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` -region returns. -""" -supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) -function supports_inplace_move(c::Chunk) - # FIXME: Use MemPool.access_ref - pid = root_worker_id(c.processor) - if pid == myid() - return supports_inplace_move(poolget(c.handle)) - else - return remotecall_fetch(supports_inplace_move, pid, c) - end -end -supports_inplace_move(::Function) = false - -# Read/write dependency management -function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) - _get_read_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end -function get_read_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end - -function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - other_task_write_num = astate.ainfos_owner[other_ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo" - other_task_write_num === nothing && continue - other_task, other_write_num = other_task_write_num - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end -end -function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo" - other_tasks = astate.ainfos_readers[other_ainfo] - for (other_task, other_write_num) in other_tasks - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - state.alias_state.ainfos_owner[ainfo] = task=>write_num - empty!(state.alias_state.ainfos_readers[ainfo]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, ainfo, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - push!(state.alias_state.ainfos_readers[ainfo], task=>write_num) -end - -function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - other_task_write_num = state.alias_state.args_owner[arg] - if other_task_write_num !== nothing - other_task, other_write_num = other_task_write_num - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - for (other_task, other_write_num) in state.alias_state.args_readers[arg] - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - state.alias_state.args_owner[arg] = task=>write_num - empty!(state.alias_state.args_readers[arg]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, arg, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - push!(state.alias_state.args_readers[arg], task=>write_num) -end - -# Make a copy of each piece of data on each worker -# memory_space => {arg => copy_of_arg} -isremotehandle(x) = false -isremotehandle(x::DTask) = true -isremotehandle(x::Chunk) = true -function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end - orig_space = memory_space(data) - to_proc = first(processors(dest_space)) - from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data or data already in a Chunk - data_chunk = tochunk(data, from_proc) - dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc - @assert memory_space(data_chunk) == orig_space - else - to_w = root_worker_id(dest_space) - ctx = Sch.eager_context() - id = rand(Int) - dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - if orig_space != dest_space - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - end - return data_chunk - end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) - end - return dest_space_args[data] -end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) - end -end - -function distribute_tasks!(queue::DataDepsTaskQueue) - #= TODO: Improvements to be made: - # - Support for copying non-AbstractArray arguments - # - Parallelize read copies - # - Unreference unused slots - # - Reuse memory when possible - # - Account for differently-sized data - =# - - # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) - end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) - if isempty(all_procs) - throw(Sch.SchedulingException("No processors available, try widening scope")) - end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end - - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - astate = state.alias_state - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - - # Start launching tasks and necessary copies - write_num = 1 - proc_idx = 1 - pressures = Dict{Processor,Int}() - proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(last(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(f)) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(f)) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = get(Set{Any}, spec.options, :syncdeps) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break - end - - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end - end - else - error("Invalid scheduler: $sched") - end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - - # Find the scope for this task (and its copies) - if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end - else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) - end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) - end - - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) - - # Copy args from local to remote - for (idx, _arg) in enumerate(task_args) - # Is the data writeable? - arg, deps = unwrap_inout(value(_arg)) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" - spec.fargs[idx].value = arg - continue - end - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end - - # Is the source of truth elsewhere? - arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) - end - if queue.aliasing - for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) - data_space = astate.data_locality[ainfo] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) - add_writer!(state, ainfo, copy_to, write_num) - - astate.data_locality[ainfo] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" - end - end - else - data_space = astate.data_locality[arg] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) - add_writer!(state, arg, copy_to, write_num) - - astate.data_locality[arg] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - arg = value(_arg) - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end - end - - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() - end - syncdeps = spec.options.syncdeps - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" - get_write_deps!(state, ainfo, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" - get_read_deps!(state, ainfo, task, write_num, syncdeps) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" - get_write_deps!(state, arg, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" - get_read_deps!(state, arg, task, write_num, syncdeps) - end - end - end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" - add_writer!(state, ainfo, task, write_num) - else - add_reader!(state, ainfo, task, write_num) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" - add_writer!(state, arg, task, write_num) - else - add_reader!(state, arg, task, write_num) - end - end - end - - # Update tracking for return value - populate_return_info!(state, task, our_space) - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - end - - # Copy args from remote to local - if queue.aliasing - # We need to replay the writes from all tasks in-order (skipping any - # outdated write owners), to ensure that overlapping writes are applied - # in the correct order - - # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}() - for (task, taskdeps) in state.dependencies - for (_, writedep, ainfo, dep_mod, arg) in taskdeps - writedep || continue - haskey(astate.data_locality, ainfo) || continue - @assert haskey(astate.ainfos_owner, ainfo) "Missing ainfo: $ainfo ($dep_mod($(typeof(arg))))" - - # Skip virtual writes from task result aliasing - # FIXME: Make this less bad - if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing - continue - end - - # Skip non-writeable arguments - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - continue - end - - # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg) - - #= FIXME: If we fully overlap any writer, evict them - idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) - deleteat!(ainfo_writes, idxs) - =# - - # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) - end - end - - # Then, replay the writes from each owner in-order - # FIXME: write_num should advance across overlapping ainfo's, as - # writes must be ordered sequentially - for (arg, ainfo_writes) in arg_writes - if length(ainfo_writes) > 1 - # FIXME: Remove me - deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) - end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes - # Is the source of truth elsewhere? - data_local_space = astate.data_origin[ainfo] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) - end - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" - end - end - end - else - for arg in keys(astate.data_origin) - # Is the data previously written? - arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" - end - - # Can the data be written back to? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - end - - # Is the source of truth elsewhere? - data_remote_space = astate.data_locality[arg] - data_local_space = astate.data_origin[arg] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = state.remote_args[data_local_space][arg] - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" - end - end - end -end - -""" - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) - -Constructs a "datadeps" (data dependencies) region and calls `f` within it. -Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or -`InOut` to indicate whether the task will read, write, or read+write that -argument, respectively. These argument dependencies will be used to specify -which tasks depend on each other based on the following rules: - -- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other -- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects -- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel -- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies -- An `In` dependency synchronizes with any previous `Out` dependencies -- If unspecified, an `In` dependency is assumed - -In general, the result of executing tasks following the above rules will be -equivalent to simply executing tasks sequentially and in order of submission. -Of course, if dependencies are incorrectly specified, undefined behavior (and -unexpected results) may occur. - -Unlike other Dagger tasks, tasks executed within a datadeps region are allowed -to write to their arguments when annotated with `Out` or `InOut` -appropriately. - -At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks -to complete, rethrowing the first error, if any. The result of `f` will be -returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. -""" -function spawn_datadeps(f::Base.Callable; static::Bool=true, - traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, - aliasing::Bool=true, - launch_wait::Union{Bool,Nothing}=nothing) - if !static - throw(ArgumentError("Dynamic scheduling is no longer available")) - end - wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol - launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool - if launch_wait - result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - result = with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - return result - end -end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) -const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl new file mode 100644 index 000000000..1d2a15a34 --- /dev/null +++ b/src/datadeps/aliasing.jl @@ -0,0 +1,700 @@ +import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv + +export In, Out, InOut, Deps, spawn_datadeps + +#= +============================================================================== + DATADEPS ALIASING AND DATA MOVEMENT SYSTEM +============================================================================== + +This file implements the data dependencies system for Dagger tasks, which allows +tasks to write to their arguments in a controlled manner. The system maintains +data coherency across distributed workers by tracking aliasing relationships +and orchestrating data movement operations. + +OVERVIEW: +--------- +The datadeps system enables parallel execution of tasks that modify shared data +by analyzing memory aliasing relationships and scheduling appropriate data +transfers. The core challenge is maintaining coherency when aliased data (e.g., +an array and its views) needs to be accessed by tasks running on different workers. + +KEY CONCEPTS: +------------- + +1. ALIASING ANALYSIS: + - Every mutable argument is analyzed for its memory access pattern + - Memory spans are computed to determine which bytes in memory are accessed + - Objects that access overlapping memory spans are considered "aliasing" + - Examples: An array A and view(A, 2:3, 2:3) alias each other + +2. DATA LOCALITY TRACKING: + - The system tracks where the "source of truth" for each piece of data lives + - As tasks execute and modify data, the source of truth may move between workers + - Each aliasing region can have its own independent source of truth location + +3. ALIASED OBJECT MANAGEMENT: + - When copying arguments between workers, the system tracks "aliased objects" + - This ensures that if both an array and its view need to be copied to a worker, + only one copy of the underlying array is made, with the view pointing to it + - The aliased_object!() functions manage this sharing + +THE DISTRIBUTED ALIASING PROBLEM: +--------------------------------- + +In a multithreaded environment, aliasing "just works" because all tasks operate +on the same memory. However, in a distributed environment, arguments must be +copied between workers, which breaks aliasing relationships. + +Consider this scenario: +```julia +A = rand(4, 4) +vA = view(A, 2:3, 2:3) + +Dagger.spawn_datadeps() do + Dagger.@spawn inc!(InOut(A), 1) # Task 1: increment all of A + Dagger.@spawn inc!(InOut(vA), 2) # Task 2: increment view of A +end +``` + +MULTITHREADED BEHAVIOR (WORKS): +- Both tasks run on the same worker +- They operate on the same memory, with proper dependency tracking +- Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) + +DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Tasks may be scheduled on different workers +- Each argument must be copied to the destination worker +- Without special handling, we would copy A to worker1 and vA to worker2 +- This creates two separate arrays, breaking the aliasing relationship +- Updates to the view on worker2 don't affect the array on worker1 + +THE SOLUTION - PARTIAL DATA MOVEMENT: +------------------------------------- + +The datadeps system solves this by: + +1. UNIFIED ALLOCATION: + - When copying aliased objects, ensure only one underlying array exists per worker + - Use aliased_object!() to detect and reuse existing allocations + - Views on the destination worker point to the shared underlying array + +2. PARTIAL DATA TRANSFER: + - Instead of copying entire objects, only transfer the "dirty" regions + - This minimizes network traffic and maximizes parallelism + - Uses the move!(dep_mod, ...) function with dependency modifiers + +3. REMAINDER TRACKING: + - When a partial region is updated, track what parts still need updating + - Before a task needs the full object, copy the remaining "clean" regions + - This preserves all updates while avoiding overwrites + +EXAMPLE EXECUTION FLOW: +----------------------- + +Given: A = 4x4 array, vA = view(A, 2:3, 2:3) +Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) + +1. INITIAL STATE: + - A and vA both exist on worker0 (main worker) + - A's data_locality = worker0, vA's data_locality = worker0 + +2. T1 SCHEDULED ON WORKER1: + - Copy A from worker0 to worker1 + - T1 executes, modifying all of A on worker1 + - Update: A's data_locality = worker1, A is now "dirty" on worker1 + +3. T2 SCHEDULED ON WORKER2: + - T2 needs vA, but vA aliases with A (which was modified by T1) + - Copy vA-region of A from worker1 to worker2 + - This is a PARTIAL copy - only the 2:3, 2:3 region + - Create vA on worker2 pointing to the appropriate region + - T2 executes, modifying vA region on worker2 + - Update: vA's data_locality = worker2 + +4. FINAL SYNCHRONIZATION: + - Some future task needs the complete A + - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 + - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A + +MEMORY SPAN COMPUTATION: +------------------------ + +The system uses memory spans to determine aliasing and compute remainders: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Remainder computation involves: +1. Computing memory spans for all overlapping aliasing objects +2. Finding the set difference: full_object_spans - updated_spans +3. Creating a "remainder aliasing" object representing the not-yet-updated regions +4. Performing move! with this remainder object to copy only needed data + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via dependency modifiers + +move_rewrap(): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +enqueue_copy_to!(): +- Schedules data movement tasks before user tasks +- Ensures data is up-to-date on the worker where a task will run + +CURRENT LIMITATIONS AND TODOS: +------------------------------- + +1. REMAINDER COMPUTATION: + - The system currently handles simple overlaps but needs sophisticated + remainder calculation for complex aliasing patterns + - Need functions to compute span set differences + +2. ORDERING DEPENDENCIES: + - Need to ensure remainder copies happen in correct order + - Must not overwrite more recent updates with stale data + +3. COMPLEX ALIASING PATTERNS: + - Multiple overlapping views of the same array + - Nested aliasing structures (views of views) + - Mixed aliasing types (diagonal + triangular regions) + +4. PERFORMANCE OPTIMIZATION: + - Minimize number of copy operations + - Batch compatible transfers + - Optimize for common access patterns +=# + +"Specifies a read-only dependency." +struct In{T} + x::T +end +"Specifies a write-only dependency." +struct Out{T} + x::T +end +"Specifies a read-write dependency." +struct InOut{T} + x::T +end +"Specifies one or more dependencies." +struct Deps{T,DT<:Tuple} + x::T + deps::DT +end +Deps(x, deps...) = Deps(x, deps) + +function unwrap_inout(arg) + readdep = false + writedep = false + if arg isa In + readdep = true + arg = arg.x + elseif arg isa Out + writedep = true + arg = arg.x + elseif arg isa InOut + readdep = true + writedep = true + arg = arg.x + elseif arg isa Deps + alldeps = Tuple[] + for dep in arg.deps + dep_mod, inner_deps = unwrap_inout(dep) + for (_, readdep, writedep) in inner_deps + push!(alldeps, (dep_mod, readdep, writedep)) + end + end + arg = arg.x + return arg, alldeps + else + readdep = true + end + return arg, Tuple[(identity, readdep, writedep)] +end + +_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) +_identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) +_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) +_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) + +struct ArgumentWrapper + arg + dep_mod + hash::UInt + + function ArgumentWrapper(arg, dep_mod) + h = hash(dep_mod) + h = _identity_hash(arg, h) + return new(arg, dep_mod, h) + end +end +Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) +Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash +Base.isequal(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash + +struct HistoryEntry + ainfo::AliasingWrapper + space::MemorySpace + write_num::Int +end + +struct DataDepsState + # The mapping of original raw argument to its Chunk + raw_arg_to_chunk::IdDict{Any,Chunk} + + # The origin memory space of each argument + # Used to track the original location of an argument, for final copy-from + arg_origin::IdDict{Any,MemorySpace} + + # The mapping of memory space to argument to remote argument copies + # Used to replace an argument with its remote copy + remote_args::Dict{MemorySpace,IdDict{Any,Chunk}} + + # The mapping of remote argument to original argument + remote_arg_to_original::IdDict{Any,Any} + + # The mapping of ainfo to argument and dep_mod + # Used to lookup which argument and dep_mod a given ainfo is generated from + # N.B. This is a mapping for remote argument copies + ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + + # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to + # Updated when a new write happens on an overlapping ainfo + # Used by remainder copies to track which portions of an argument and dep_mod were written to elsewhere, through another argument + arg_history::Dict{ArgumentWrapper,Vector{HistoryEntry}} + + # The mapping of memory space and argument to the memory space of the last direct write + # Used by remainder copies to lookup the "backstop" if any portion of the target ainfo is not updated by the remainder + arg_owner::Dict{ArgumentWrapper,MemorySpace} + + # The overlap of each argument with every other argument, based on the ainfo overlaps + # Incrementally updated as new ainfos are created + # Used for fast history updates + arg_overlaps::Dict{ArgumentWrapper,Set{ArgumentWrapper}} + + # The mapping of, for a given memory space, the backing Chunks that an ainfo references + # Used by slot generation to replace the backing Chunks during move + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + + # Cache of argument's supports_inplace_move query result + supports_inplace_cache::IdDict{Any,Bool} + + # Cache of argument and dep_mod to ainfo + # N.B. This is a mapping for remote argument copies + ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + + # The overlapping ainfos for each ainfo + # Incrementally updated as new ainfos are created + # Used for fast will_alias lookups + ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} + + # Track writers ("owners") and readers + # Updated as new writer and reader tasks are launched + # Used by task dependency tracking to calculate syncdeps and ensure correct launch ordering + ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} + ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} + + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + + arg_to_chunk = IdDict{Any,Chunk}() + arg_origin = IdDict{Any,MemorySpace}() + remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + remote_arg_to_original = IdDict{Any,Any}() + ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + arg_owner = Dict{ArgumentWrapper,MemorySpace}() + arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + + supports_inplace_cache = IdDict{Any,Bool}() + ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + + ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() + + ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() + ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() + + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + state.ainfo_arg[ainfo] = remote_arg_w + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end + +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + +# Determine which arguments could be written to, and thus need tracking +"Whether `arg` is written to by `task`." +function is_writedep(arg, deps, task::DTask) + return any(dep->dep[3], deps) +end + +# Aliasing state setup +function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + for (idx, _arg) in enumerate(spec.fargs) + arg = value(_arg) + + # Unwrap In/InOut/Out wrappers and record dependencies + arg, deps = unwrap_inout(arg) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg isa DTask ? fetch(arg; raw=true) : arg + + # Skip non-aliasing arguments + type_may_alias(typeof(arg)) || continue + + # Skip arguments not supporting in-place move + supports_inplace_move(state, arg) || continue + + # Generate a Chunk for the argument if necessary + if haskey(state.raw_arg_to_chunk, arg) + arg = state.raw_arg_to_chunk[arg] + else + if !(arg isa Chunk) + new_arg = tochunk(arg) + state.raw_arg_to_chunk[arg] = new_arg + arg = new_arg + else + state.raw_arg_to_chunk[arg] = arg + end + end + + # Track the origin space of the argument + origin_space = memory_space(arg) + state.arg_origin[arg] = origin_space + state.remote_arg_to_original[arg] = arg + + # Populate argument info for all aliasing dependencies + for (dep_mod, _, _) in deps + # Generate an ArgumentWrapper for the argument + aw = ArgumentWrapper(arg, dep_mod) + + # Populate argument info + populate_argument_info!(state, aw, origin_space) + end + end +end +function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) + # Initialize ownership and history + if !haskey(state.arg_owner, arg_w) + # N.B. This is valid (even if the backing data is up-to-date elsewhere), + # because we only use this to track the "backstop" if any portion of the + # target ainfo is not updated by the remainder (at which point, this + # is thus the correct owner). + state.arg_owner[arg_w] = origin_space + + # Initialize the overlap set + state.arg_overlaps[arg_w] = Set{ArgumentWrapper}() + end + if !haskey(state.arg_history, arg_w) + state.arg_history[arg_w] = Vector{HistoryEntry}() + end + + # Calculate the ainfo (which will populate ainfo structures and merge history) + aliasing!(state, origin_space, arg_w) +end +function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + # Initialize owner and readers + if !haskey(state.ainfos_owner, target_ainfo) + overlaps = Set{AliasingWrapper}() + push!(overlaps, target_ainfo) + for other_ainfo in keys(state.ainfos_owner) + target_ainfo == other_ainfo && continue + if will_alias(target_ainfo, other_ainfo) + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end + end + state.ainfos_overlaps[target_ainfo] = overlaps + state.ainfos_owner[target_ainfo] = nothing + state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] + end +end +function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) + history = state.arg_history[arg_w] + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) + largest_value_update!(length(history)) + origin_space = state.arg_origin[other_arg_w.arg] + for other_entry in state.arg_history[other_arg_w] + write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) + range = searchsorted(history, write_num_tuple; by=x->x.write_num) + if !isempty(range) + # Find and skip duplicates + match = false + for source_idx in range + source_entry = history[source_idx] + if source_entry.ainfo == other_entry.ainfo && + source_entry.space == other_entry.space && + source_entry.write_num == other_entry.write_num + match = true + break + end + end + match && continue + + # Insert at the first position + idx = first(range) + else + # Insert at the last position + idx = length(history) + 1 + end + insert!(history, idx, other_entry) + end +end +function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible + if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 + origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history + _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) + if last_idx > 0 + @opcounter :truncate_history_removed last_idx + deleteat!(state.arg_history[arg_w], 1:last_idx) + end + end +end + +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + +# Read/write dependency management +function get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We need to sync with both writers and readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) + _get_read_deps!(state, dest_space, ainfo, write_num, syncdeps) +end +function get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We only need to sync with writers, not readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) +end + +function _get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + other_task_write_num = state.ainfos_owner[other_ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with writer via $ainfo -> $other_ainfo" + other_task_write_num === nothing && continue + other_task, other_write_num = other_task_write_num + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with writer via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end +end +function _get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with reader via $ainfo -> $other_ainfo" + other_tasks = state.ainfos_readers[other_ainfo] + for (other_task, other_write_num) in other_tasks + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with reader via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end + end +end +function add_writer!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + state.ainfos_owner[ainfo] = task=>write_num + empty!(state.ainfos_readers[ainfo]) + + # Clear the history for this target, since this is a new write event + empty!(state.arg_history[arg_w]) + + # Add our own history + push!(state.arg_history[arg_w], HistoryEntry(ainfo, dest_space, write_num)) + + # Find overlapping arguments and update their history + for other_arg_w in state.arg_overlaps[arg_w] + other_arg_w == arg_w && continue + push!(state.arg_history[other_arg_w], HistoryEntry(ainfo, dest_space, write_num)) + end + + # Record the last place we were fully written to + state.arg_owner[arg_w] = dest_space + + # Not necessary to assert a read, but conceptually it's true + add_reader!(state, arg_w, dest_space, ainfo, task, write_num) +end +function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + push!(state.ainfos_readers[ainfo], task=>write_num) +end + +# Make a copy of each piece of data on each worker +# memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true +function generate_slot!(state::DataDepsState, dest_space, data) + if data isa DTask + data = fetch(data; raw=true) + end + # N.B. We do not perform any sync/copy with the current owner of the data, + # because all we want here is to make a copy of some version of the data, + # even if the data is not up to date. + orig_space = memory_space(data) + to_proc = first(processors(dest_space)) + from_proc = first(processors(orig_space)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping + data_chunk = tochunk(data, from_proc) + else + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + end + @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" + dest_space_args[data] = data_chunk + state.remote_arg_to_original[data_chunk] = data + + ALIASED_OBJECT_CACHE[] = nothing + + return dest_space_args[data] +end +function get_or_generate_slot!(state, dest_space, data) + @assert !(data isa ArgumentWrapper) + if !haskey(state.remote_args, dest_space) + state.remote_args[dest_space] = IdDict{Any,Any}() + end + if !haskey(state.remote_args[dest_space], data) + return generate_slot!(state, dest_space, data) + end + return state.remote_args[dest_space][data] +end +function move_rewrap(from_proc::Processor, to_proc::Processor, data) + return aliased_object!(data) do data + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data + data_converted = move(from_proc, to_proc, data) + return tochunk(data_converted, to_proc) + end + end +end +const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) +@warn "Document these public methods" maxlog=1 +# TODO: Use state to cache aliasing() results +function declare_aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + cache[ainfo] = x +end +function aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" + cache[ainfo] = x + y = x + end + return y +end +function aliased_object!(f, x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + end + return y +end +function aliased_object_unwrap!(x::Chunk) + y = unwrap(x) + ainfo = aliasing(y, identity) + return unwrap(aliased_object!(x; ainfo)) +end + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) + end +end \ No newline at end of file diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl new file mode 100644 index 000000000..04b581c17 --- /dev/null +++ b/src/datadeps/chunkview.jl @@ -0,0 +1,64 @@ +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +aliasing(x::ChunkView) = + throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +# This definition is here because it's so similar to ChunkView +function move_rewrap(from_proc::Processor, to_proc::Processor, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(parent(v)) do p + return remotecall_fetch(to_w, from_proc, to_proc, p) do from_proc, to_proc, p + return tochunk(move(from_proc, to_proc, p), to_proc) + end + end + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, inds) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +function move_rewrap(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(slice.chunk) do p_chunk + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk) do from_proc, to_proc, p_chunk + return tochunk(move(from_proc, to_proc, p_chunk), to_proc) + end + end + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, slice.slices) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/datadeps/interval_tree.jl b/src/datadeps/interval_tree.jl new file mode 100644 index 000000000..1075f5912 --- /dev/null +++ b/src/datadeps/interval_tree.jl @@ -0,0 +1,349 @@ +# Get the start address of a span +span_start(span::MemorySpan) = span.ptr.addr +span_start(span::LocalMemorySpan) = span.ptr +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +# Get the length of a span +span_len(span::MemorySpan) = span.len +span_len(span::LocalMemorySpan) = span.len +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) + +# Get the end address of a span +span_end(span::MemorySpan) = span.ptr.addr + span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +mutable struct IntervalNode{M,E} + span::M + max_end::E # Maximum end value in this subtree + left::Union{IntervalNode{M,E}, Nothing} + right::Union{IntervalNode{M,E}, Nothing} + + IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) +end + +mutable struct IntervalTree{M,E} + root::Union{IntervalNode{M,E}, Nothing} + + IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) + IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) + IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) +end + +# Construct interval tree from unsorted set of spans +function IntervalTree{M}(spans) where M + tree = IntervalTree{M}() + for span in spans + insert!(tree, span) + end + return tree +end +IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) + +function Base.show(io::IO, tree::IntervalTree) + println(io, "$(typeof(tree)) (with $(length(tree)) spans):") + for (i, span) in enumerate(tree) + println(io, " $i: [$(span_start(span)), $(span_end(span))) (len=$(span_len(span)))") + end +end + +function Base.collect(tree::IntervalTree{M}) where M + result = M[] + for span in tree + push!(result, span) + end + return result +end + +function Base.iterate(tree::IntervalTree{M}) where M + state = Vector{M}() + if tree.root === nothing + return nothing + end + return iterate(tree.root) +end +function Base.iterate(tree::IntervalTree, state) + return iterate(tree.root, state) +end +function Base.iterate(root::IntervalNode{M,E}) where {M,E} + state = Vector{IntervalNode{M,E}}() + push!(state, root) + return iterate(root, state) +end +function Base.iterate(root::IntervalNode, state) + if isempty(state) + return nothing + end + current = popfirst!(state) + if current.right !== nothing + pushfirst!(state, current.right) + end + if current.left !== nothing + pushfirst!(state, current.left) + end + return current.span, state +end + +function Base.length(tree::IntervalTree) + result = 0 + for _ in tree + result += 1 + end + return result +end + +# Update max_end value for a node based on its children +function update_max_end!(node::IntervalNode) + node.max_end = span_end(node.span) + if node.left !== nothing + node.max_end = max(node.max_end, node.left.max_end) + end + if node.right !== nothing + node.max_end = max(node.max_end, node.right.max_end) + end +end + +# Insert a span into the interval tree +function Base.insert!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = insert_node!(tree.root, span) + end + return span +end + +function insert_node!(::Nothing, span::M) where M + return IntervalNode(span) +end +function insert_node!(node::IntervalNode{M,E}, span::M) where {M,E} + if span_start(span) <= span_start(node.span) + node.left = insert_node!(node.left, span) + else + node.right = insert_node!(node.right, span) + end + + update_max_end!(node) + return node +end + +# Remove a specific span from the tree (split as needed) +function Base.delete!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = delete_node!(tree.root, span) + end + return span +end + +function delete_node!(::Nothing, span::M) where M + return nothing +end +function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} + # Check for exact match first + if span_start(node.span) == span_start(span) && span_len(node.span) == span_len(span) + # Exact match, remove the node + if node.left === nothing && node.right === nothing + return nothing + elseif node.left === nothing + return node.right + elseif node.right === nothing + return node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + end + # Check for overlap + elseif spans_overlap(node.span, span) + # Handle overlapping spans by removing current node and adding remainders + original_span = node.span + + # Remove the current node first (same logic as exact match) + if node.left === nothing && node.right === nothing + # Leaf node - remove it and create a new subtree with remainders + remaining_node = nothing + elseif node.left === nothing + remaining_node = node.right + elseif node.right === nothing + remaining_node = node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + remaining_node = node + end + + # Calculate and insert the remaining portions + original_start = span_start(original_span) + original_end = span_end(original_span) + del_start = span_start(span) + del_end = span_end(span) + + # Left portion: exists if original starts before deleted span + if original_start < del_start + left_end = min(original_end, del_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + remaining_node = insert_node!(remaining_node, left_span) + end + end + end + + # Right portion: exists if original extends beyond deleted span + if original_end > del_end + right_start = max(original_start, del_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + remaining_node = insert_node!(remaining_node, right_span) + end + end + end + + return remaining_node + elseif span_start(span) <= span_start(node.span) + node.left = delete_node!(node.left, span) + else + node.right = delete_node!(node.right, span) + end + + if node !== nothing + update_max_end!(node) + end + return node +end + +function find_min(node::IntervalNode) + while node.left !== nothing + node = node.left + end + return node +end + +# Check if two spans overlap +function spans_overlap(span1::MemorySpan, span2::MemorySpan) + return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +end +function spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) + return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +end +function spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + # N.B. The spans are assumed to be the same length and relative offset + return spans_overlap(span1.spans[1], span2.spans[1]) +end + +# Find all spans that overlap with the given query span +function find_overlapping(tree::IntervalTree{M}, query::M) where M + result = M[] + find_overlapping!(tree.root, query, result) + return result +end + +function find_overlapping!(::Nothing, query::M, result::Vector{M}) where M + return +end +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}) where {M,E} + # Check if current node overlaps with query + if spans_overlap(node.span, query) + # Get the overlapping portion of the span + overlap_start = max(span_start(node.span), span_start(query)) + overlap_end = min(span_end(node.span), span_end(query)) + overlap = M(overlap_start, overlap_end - overlap_start) + push!(result, overlap) + end + + # Recursively search left subtree if it might contain overlapping intervals + if node.left !== nothing && node.left.max_end > span_start(query) + find_overlapping!(node.left, query, result) + end + + # Recursively search right subtree if query extends beyond current node's start + if node.right !== nothing && span_end(query) > span_start(node.span) + find_overlapping!(node.right, query, result) + end +end + +# ============================================================================ +# MAIN SUBTRACTION ALGORITHM +# ============================================================================ + +""" + subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + +Subtract all spans in subtrahend_spans from the minuend_tree in-place. +The minuend_tree is modified to contain only the portions that remain after subtraction. + +Time Complexity: O(M log N + M*K) where M = |subtrahend_spans|, N = |minuend nodes|, + K = average overlaps per subtrahend span +Space Complexity: O(1) additional space (modifies tree in-place) + +If `diff` is provided, add the overlapping spans to `diff`. +""" +function subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + for sub_span in subtrahend_spans + subtract_single_span!(minuend_tree, sub_span, diff) + end +end + +""" + subtract_single_span!(tree::IntervalTree, sub_span::MemorySpan, diff=nothing) + +Subtract a single span from the interval tree. This function: +1. Finds all overlapping spans in the tree +2. Removes each overlapping span +3. Adds back the non-overlapping portions (left and/or right remnants) +4. If diff is provided, add the overlapping span to diff +""" +function subtract_single_span!(tree::IntervalTree{M}, sub_span::M, diff=nothing) where M + # Find all spans that overlap with the subtrahend + overlapping_spans = find_overlapping(tree, sub_span) + + # Process each overlapping span + for overlap_span in overlapping_spans + # Remove the overlapping span from the tree + delete!(tree, overlap_span) + + # Calculate and add back the portions that should remain + add_remaining_portions!(tree, overlap_span, sub_span) + + if diff !== nothing && !isempty(overlap_span) + push!(diff, overlap_span) + end + end +end + +""" + add_remaining_portions!(tree::IntervalTree, original::MemorySpan, subtracted::MemorySpan) + +After removing an overlapping span, add back the portions that don't overlap with the subtracted span. +There can be up to two remaining portions: left and right of the subtracted region. +""" +function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted::M) where M + original_start = span_start(original) + original_end = span_end(original) + sub_start = span_start(subtracted) + sub_end = span_end(subtracted) + + # Left portion: exists if original starts before subtracted + if original_start < sub_start + left_end = min(original_end, sub_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + insert!(tree, left_span) + end + end + end + + # Right portion: exists if original extends beyond subtracted + if original_end > sub_end + right_start = max(original_start, sub_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + insert!(tree, right_span) + end + end + end +end \ No newline at end of file diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl new file mode 100644 index 000000000..6fc85bd22 --- /dev/null +++ b/src/datadeps/queue.jl @@ -0,0 +1,500 @@ +struct DataDepsTaskQueue <: AbstractTaskQueue + # The queue above us + upper_queue::AbstractTaskQueue + # The set of tasks that have already been seen + seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} + # The data-dependency graph of all tasks + g::Union{SimpleDiGraph{Int},Nothing} + # The mapping from task to graph ID + task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol + # Which scheduler to use to assign tasks to processors + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool + + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) + seen_tasks = Pair{DTaskSpec,DTask}[] + g = SimpleDiGraph() + task_to_id = Dict{DTask,Int}() + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) + end +end + +function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.seen_tasks, spec) +end +function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.seen_tasks, specs) +end + +""" + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + +Constructs a "datadeps" (data dependencies) region and calls `f` within it. +Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or +`InOut` to indicate whether the task will read, write, or read+write that +argument, respectively. These argument dependencies will be used to specify +which tasks depend on each other based on the following rules: + +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other +- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects +- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel +- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies +- An `In` dependency synchronizes with any previous `Out` dependencies +- If unspecified, an `In` dependency is assumed + +In general, the result of executing tasks following the above rules will be +equivalent to simply executing tasks sequentially and in order of submission. +Of course, if dependencies are incorrectly specified, undefined behavior (and +unexpected results) may occur. + +Unlike other Dagger tasks, tasks executed within a datadeps region are allowed +to write to their arguments when annotated with `Out` or `InOut` +appropriately. + +At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks +to complete, rethrowing the first error, if any. The result of `f` will be +returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. +""" +function spawn_datadeps(f::Base.Callable; static::Bool=true, + traversal::Symbol=:inorder, + scheduler::Union{Symbol,Nothing}=nothing, + aliasing::Bool=true, + launch_wait::Union{Bool,Nothing}=nothing) + if !static + throw(ArgumentError("Dynamic scheduling is no longer available")) + end + wait_all(; check_errors=true) do + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool + if launch_wait + result = spawn_bulk() do + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + else + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + result = with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + return result + end +end +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) + +function distribute_tasks!(queue::DataDepsTaskQueue) + #= TODO: Improvements to be made: + # - Support for copying non-AbstractArray arguments + # - Parallelize read copies + # - Unreference unused slots + # - Reuse memory when possible + # - Account for differently-sized data + =# + + # Get the set of all processors to be scheduled on + all_procs = Processor[] + scope = get_compute_scope() + for w in procs() + append!(all_procs, get_processors(OSProc(w))) + end + filter!(proc->!isa(constrain(ExactScope(proc), scope), + InvalidScope), + all_procs) + if isempty(all_procs) + throw(Sch.SchedulingException("No processors available, try widening scope")) + end + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) + if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end + + # Round-robin assign tasks to processors + upper_queue = get_options(:task_queue) + + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + + # Start launching tasks and necessary copies + write_num = 1 + proc_idx = 1 + pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) + for (spec, task) in queue.seen_tasks[task_order] + # Populate all task dependencies + populate_task_info!(state, spec, task) + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; raw=true) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = rand(our_space_procs) + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + if task_scope == scope + # all_procs is already limited to scope + else + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end + end + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end + + f = spec.fargs[1] + f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + + # Copy raw task arguments for analysis + task_args = map(copy, spec.fargs) + + # Generate a list of ArgumentWrappers for each task argument + task_arg_ws = map(task_args) do _arg + arg = value(_arg) + arg, deps = unwrap_inout(arg) + arg = arg isa DTask ? fetch(arg; raw=true) : arg + if !type_may_alias(typeof(arg)) || !supports_inplace_move(state, arg) + return [(ArgumentWrapper(arg, identity), false, false)] + end + + # Get the Chunk for the argument + arg = state.raw_arg_to_chunk[arg] + + arg_ws = Tuple{ArgumentWrapper,Bool,Bool}[] + for (dep_mod, readdep, writedep) in deps + push!(arg_ws, (ArgumentWrapper(arg, dep_mod), readdep, writedep)) + end + return arg_ws + end + task_arg_ws = task_arg_ws::Vector{Vector{Tuple{ArgumentWrapper,Bool,Bool}}} + + # Truncate the history for each argument + for arg_ws in task_arg_ws + for (arg_w, _, _) in arg_ws + truncate_history!(state, arg_w) + end + end + + # Copy args from local to remote + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + pos = raw_position(task_args[idx]) + + # Is the data written previously or now? + if !type_may_alias(typeof(arg)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + spec.fargs[idx].value = arg + continue + end + + # Is the data writeable? + if !supports_inplace_move(state, arg) + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + spec.fargs[idx].value = arg + continue + end + + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + for (arg_w, _, _) in arg_ws + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + end + end + spec.fargs[idx].value = arg_remote + end + write_num += 1 + + # Validate that we're not accidentally performing a copy + for (idx, _arg) in enumerate(spec.fargs) + arg = value(_arg) + _, deps = unwrap_inout(value(task_args[idx])) + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{Any}() + end + syncdeps = spec.options.syncdeps + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + type_may_alias(typeof(arg)) || continue + supports_inplace_move(state, arg) || continue + for (arg_w, _, writedep) in arg_ws + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) + end + end + end + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + spec.options.scope = our_scope + spec.options.exec_scope = our_scope + enqueue!(upper_queue, spec=>task) + + # Update read/write tracking for arguments + for (idx, arg_ws) in enumerate(task_arg_ws) + arg = first(arg_ws)[1].arg + type_may_alias(typeof(arg)) || continue + for (arg_w, _, writedep) in arg_ws + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + end + end +end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl new file mode 100644 index 000000000..b420ca5d3 --- /dev/null +++ b/src/datadeps/remainders.jl @@ -0,0 +1,443 @@ +# Remainder tracking and computation functions + +""" + RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + +Represents the memory spans that remain after subtracting some regions from a base aliasing object. +This is used to perform partial data copies that only update the "remainder" regions. +""" +struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + space::S + spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + syncdeps::Set{ThunkSyncdep} +end +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) + +memory_spans(ra::RemainderAliasing) = ra.spans + +Base.hash(ra::RemainderAliasing, h::UInt) = hash(ra.spans, hash(RemainderAliasing, h)) +Base.:(==)(ra1::RemainderAliasing, ra2::RemainderAliasing) = ra1.spans == ra2.spans + +# Add will_alias support for RemainderAliasing +function will_alias(x::RemainderAliasing, y::AbstractAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::AbstractAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::RemainderAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +struct MultiRemainderAliasing <: AbstractAliasing + remainders::Vector{<:RemainderAliasing} +end +MultiRemainderAliasing() = MultiRemainderAliasing(RemainderAliasing[]) + +memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders)...) + +Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) +Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders + +#= FIXME: Integrate with main documentation +Problem statement: + +Remainder copy calculation needs to ensure that, for a given argument and +dependency modifier, and for a given target memory space, any data not yet +updated (whether through this arg or through another that aliases) is added to +the remainder, while any data that has been updated is not in the remainder. +Remainder copies may be multi-part, as data may be spread across multiple other +memory spaces. + +Ainfo is not alone sufficient to identify the combination of argument and +dependency modifier, as ainfo is specific to an allocation in a given memory +space. Thus, this combination needs to be tracked together, and separately from +memory space. However, information may span multiple memory spaces (and thus +multiple ainfos), so we should try to make queries of cross-memory space +information fast, as they will need to be performed for every task, for every +combination. + +Game Plan: + +- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once +- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies +- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered +- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps +- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from +- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address + - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us +- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). +- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. + - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. + - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. + - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. + - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. + - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). +=# + +struct FullCopy end + +""" + compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper) + +Computes what remainder regions need to be copied to `target_space` before a task can access `arg_w`. +Returns a `MultiRemainderAliasing` object representing the remainder, or `NoAliasing()` if no remainder needed. + +The algorithm starts by collecting the memory spans of `arg_w` in `target_space` - this is the "remainder". +When this remainder is empty, the algorithm will be finished. +Additionally, a dictionary is created to store the source and destination +memory spans (for each source memory space) that will be used to create the +`MultiRemainderAliasing` object - this is the "tracker". + +The algorithm walks backwards through the `arg_history` vector for `arg_w` +(which is an ordered list of all overlapping ainfos that were directy written to (potentially in a different memory space than `target_space`) +since the last time this `arg_w` was written to). If this ainfo is in `target_space`, +then it is not under consideration; it is simply subtraced from the remainder with `subtract_remainder!`, +and the algorithm goes to the next ainfo. Otherwise, the algorithm will consider this ainfo for tracking. + +For each overlapping ainfo (which lives in a different memory space than `target_space`) to be tracked, there exists a corresponding "mirror" ainfo in +`target_space`, which is the equivalent of the overlapping ainfo, but in +`target_space`. This mirror ainfo is assumed to have an identical number of memory spans as the overlapping ainfo, +and each memory span is assumed to be identical in size, but not necessarily identical in address. + +These three sets of memory spans (from the remainder, the overlapping ainfo, and the mirror ainfo) are then passed to `schedule_aliasing!`. +This call will subtract the spans of the mirror ainfo from the remainder (as the two live in the same memory space and thus can be directly compared), +and will update the remainder accordingly. +Additionaly, it will also use this subtraction to update the tracker, by adding the equivalent spans (mapped from mirror ainfo to overlapping ainfo) to the tracker as the source, +and the spans of the remainder as the destination. + +If the history is exhausted without the remainder becoming empty, then the +remaining data in `target_space` is assumed to be up-to-date (as the latest write +to `arg_w` is the furthest back we need to consider). + +Finally, the tracker is converted into a `MultiRemainderAliasing` object, +and returned. +""" +function compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper, + write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history + spaces_set = Set{MemorySpace}() + push!(spaces_set, target_space) + owner_space = state.arg_owner[arg_w] + push!(spaces_set, owner_space) + for entry in state.arg_history[arg_w] + push!(spaces_set, entry.space) + end + spaces = collect(spaces_set) + N = length(spaces) + + # Lookup all memory spans for arg_w in these spaces + target_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + target_space_ainfo = aliasing!(state, space, arg_w) + spans = memory_spans(target_space_ainfo) + push!(target_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(target_ainfos)) + + # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) + for entry in state.arg_history[arg_w] + if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart + @goto restart + end + end + + # We may only need to schedule a full copy from the origin space to the + # target space if this is the first time we've written to `arg_w` + if isempty(state.arg_history[arg_w]) + if owner_space != target_space + return FullCopy(), 0 + else + return NoAliasing(), 0 + end + end + + # Create our remainder as an interval tree over all target ainfos + remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + + # Create our tracker + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + + # Walk backwards through the history of writes to this target + # other_ainfo is the overlapping ainfo that was written to + # other_space is the memory space of the overlapping ainfo + last_idx = length(state.arg_history[arg_w]) + for idx in length(state.arg_history[arg_w]):-1:0 + if isempty(remainder) + # All done! + last_idx = idx + break + end + + if idx > 0 + other_entry = state.arg_history[arg_w][idx] + other_ainfo = other_entry.ainfo + other_space = other_entry.space + else + # If we've reached the end of the history, evaluate ourselves + other_ainfo = aliasing!(state, owner_space, arg_w) + other_space = owner_space + end + + # Lookup all memory spans for arg_w in these spaces + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) + other_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + other_space_ainfo = aliasing!(state, space, other_arg_w) + spans = memory_spans(other_space_ainfo) + push!(other_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(other_ainfos)) + other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + + if other_space == target_space + # Only subtract, this data is already up-to-date in target_space + # N.B. We don't add to syncdeps here, because we'll see this ainfo + # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract + subtract_spans!(remainder, other_many_spans) + continue + end + + # Subtract from remainder and schedule copy in tracker + other_space_idx = something(findfirst(==(other_space), spaces)) + target_space_idx = something(findfirst(==(target_space), spaces)) + tracker_other_space = get!(tracker, other_space) do + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + end + @opcounter :compute_remainder_for_arg_schedule + schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps + @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + end + end + + if isempty(tracker) + return NoAliasing(), 0 + end + + # Return scheduled copies and the index of the last ainfo we considered + mra = MultiRemainderAliasing() + for space in spaces + if haskey(tracker, space) + spans, syncdeps = tracker[space] + if !isempty(spans) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + end + end + end + return mra, last_idx +end + +### Memory Span Set Operations for Remainder Computation + +""" + schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) + +Calculates the difference between `remainder` and `other_many_spans`, subtracts +it from `remainder`, and then adds that difference to `tracker` as a scheduled +copy from `other_many_spans` to the subtraced portion of `remainder`. +""" +function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N + diff = Vector{ManyMemorySpan{N}}() + subtract_spans!(remainder, other_many_spans, diff) + + for span in diff + source_span = span.spans[source_space_idx] + dest_span = span.spans[dest_space_idx] + push!(tracker, (source_span, dest_span)) + end +end + +### Remainder copy functions + +""" + enqueue_remainder_copy_to!(state::DataDepsState, f, target_ainfo::AliasingWrapper, remainder_aliasing, dep_mod, arg, idx, + our_space::MemorySpace, our_scope, task::DTask, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object before a task runs. +""" +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) + end +end +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +""" + enqueue_remainder_copy_from!(state::DataDepsState, target_ainfo::AliasingWrapper, arg, remainder_aliasing, + origin_space::MemorySpace, origin_scope, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object back to the original space. +""" +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + dest_scope, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) + end +end +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing remainder copy-from for: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# FIXME: Document me +function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing full copy-from: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# Main copy function for RemainderAliasing +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # Get the source data for each span + copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = Vector{UInt8}[] + for (from_span, _) in dep_mod.spans + copy = Vector{UInt8}(undef, from_span.len) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copy)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + end + push!(copies, copy) + end + return copies + end + + # Copy the data into the destination object + for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + + return +end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a21..b1ff40d8f 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -30,7 +30,7 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement function unwrap(x::Chunk) - @assert root_worker_id(x.processor) == myid() + @assert x.handle.owner == myid() MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = @@ -99,10 +99,13 @@ end RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = RemotePtr(UInt(x), CPURAMMemorySpace(myid())) Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) @@ -116,13 +119,15 @@ struct MemorySpan{S} end MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = MemorySpan{S}(ptr, UInt(len)) - +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) - struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 @@ -279,7 +284,7 @@ function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), RemotePtr{Cvoid}(pointer(x)), parentindices(x), - size(x), strides(parent(x))) + size(x), strides(x)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -402,70 +407,35 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) return x_span.ptr <= y_end && y_span.ptr <= x_end end -struct ChunkView{N} - chunk::Chunk - slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} -end - -function Base.view(c::Chunk, slices...) - if c.domain isa ArrayDomain - nd, sz = ndims(c.domain), size(c.domain) - nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) - - for (i, s) in enumerate(slices) - if s isa Int - 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s isa AbstractRange - isempty(s) && continue - 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s === Colon() - continue - else - throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) - end - end - end +### More space-efficient memory spans - return ChunkView(c, slices) +struct LocalMemorySpan + ptr::UInt + len::UInt end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) - -function aliasing(x::ChunkView{N}) where N - remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} end -memory_space(x::ChunkView) = memory_space(x.chunk) -isremotehandle(x::ChunkView) = true +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) -#= -function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) - to_w = root_worker_id(to_space) - @assert to_w == myid() - to_raw = unwrap(to.chunk) - from_w = root_worker_id(from_space) - from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) - from_view = view(from_raw, from.slices...) - to_view = view(to_raw, to.slices...) - move!(dep_mod, to_space, from_space, to_view, from_view) - return +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} end -=# +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - if from_proc == to_proc - return view(unwrap(slice.chunk), slice.slices...) - else - # Need to copy the underlying data, so collapse the view - from_w = root_worker_id(from_proc) - data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices - copy(view(unwrap(chunk), slices...)) - end - return move(from_proc, to_proc, data) - end -end +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 615030400..873e47e79 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -35,3 +35,28 @@ macro dagdebug(thunk, category, msg, args...) end end) end + +# FIXME: Calculate fast-growth based on clock time, not iteration +const OPCOUNTER_CATEGORIES = Symbol[] +const OPCOUNTER_FAST_GROWTH_THRESHOLD = Ref(10_000_000) +struct OpCounter + value::Threads.Atomic{Int} +end +OpCounter() = OpCounter(Threads.Atomic{Int}(0)) +macro opcounter(category, count=1) + cat_sym = category.value + @gensym old + opcounter_sym = Symbol(:OPCOUNTER_, cat_sym) + if !isdefined(__module__, opcounter_sym) + __module__.eval(:(#=const=# $opcounter_sym = OpCounter())) + end + esc(quote + if $(QuoteNode(cat_sym)) in $OPCOUNTER_CATEGORIES + $old = Threads.atomic_add!($opcounter_sym.value, Int($count)) + if $old > 1 && (mod1($old, $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) == 1 || $count > $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) + println("Fast-growing counter: $($(QuoteNode(cat_sym))) = $($old)") + end + end + end) +end +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file From 876b82d43579af4db01444c475d78bceb4dbd9dd Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 19:49:44 -0700 Subject: [PATCH 02/24] Add type-stable spawn code paths --- src/argument.jl | 32 +- src/datadeps/aliasing.jl | 74 +++-- src/datadeps/queue.jl | 613 +++++++++++++++++++++------------------ src/queue.jl | 97 ++++--- src/submission.jl | 71 +++-- src/thunk.jl | 74 +++-- 6 files changed, 571 insertions(+), 390 deletions(-) diff --git a/src/argument.jl b/src/argument.jl index 94246a75e..849486e03 100644 --- a/src/argument.jl +++ b/src/argument.jl @@ -20,6 +20,7 @@ function pos_kw(pos::ArgPosition) @assert pos.kw != :NULL return pos.kw end + mutable struct Argument pos::ArgPosition value @@ -41,6 +42,35 @@ function Base.iterate(arg::Argument, state::Bool) return nothing end end - Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) chunktype(arg::Argument) = chunktype(value(arg)) + +mutable struct TypedArgument{T} + pos::ArgPosition + value::T +end +TypedArgument(pos::Integer, value::T) where T = TypedArgument{T}(ArgPosition(true, pos, :NULL), value) +TypedArgument(kw::Symbol, value::T) where T = TypedArgument{T}(ArgPosition(false, 0, kw), value) +Base.setproperty!(arg::TypedArgument, name::Symbol, value::T) where T = + throw(ArgumentError("Cannot set properties of TypedArgument")) +ispositional(arg::TypedArgument) = ispositional(arg.pos) +iskw(arg::TypedArgument) = iskw(arg.pos) +pos_idx(arg::TypedArgument) = pos_idx(arg.pos) +pos_kw(arg::TypedArgument) = pos_kw(arg.pos) +raw_position(arg::TypedArgument) = raw_position(arg.pos) +value(arg::TypedArgument) = arg.value +valuetype(arg::TypedArgument{T}) where T = T +Base.iterate(arg::TypedArgument) = (arg.pos, true) +function Base.iterate(arg::TypedArgument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end +Base.copy(arg::TypedArgument{T}) where T = TypedArgument{T}(ArgPosition(arg.pos), arg.value) +chunktype(arg::TypedArgument) = chunktype(value(arg)) + +Argument(arg::TypedArgument) = Argument(arg.pos, arg.value) + +const AnyArgument = Union{Argument, TypedArgument} \ No newline at end of file diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 1d2a15a34..87eba76df 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -192,6 +192,11 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + function unwrap_inout(arg) readdep = false writedep = false @@ -372,48 +377,69 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - arg = value(_arg) + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(arg) + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Skip arguments not supporting in-place move - supports_inplace_move(state, arg) || continue + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end + end # Generate a Chunk for the argument if necessary if haskey(state.raw_arg_to_chunk, arg) - arg = state.raw_arg_to_chunk[arg] + arg_chunk = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - new_arg = tochunk(arg) - state.raw_arg_to_chunk[arg] = new_arg - arg = new_arg + arg_chunk = tochunk(arg) + state.raw_arg_to_chunk[arg] = arg_chunk else state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg end end # Track the origin space of the argument - origin_space = memory_space(arg) - state.arg_origin[arg] = origin_space - state.remote_arg_to_original[arg] = arg + origin_space = memory_space(arg_chunk) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk # Populate argument info for all aliasing dependencies - for (dep_mod, _, _) in deps - # Generate an ArgumentWrapper for the argument - aw = ArgumentWrapper(arg, dep_mod) - - # Populate argument info - populate_argument_info!(state, aw, origin_space) + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) end end end diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 6fc85bd22..f8f907741 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -2,7 +2,7 @@ struct DataDepsTaskQueue <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} + seen_tasks::Union{Vector{DTaskPair},Nothing} # The data-dependency graph of all tasks g::Union{SimpleDiGraph{Int},Nothing} # The mapping from task to graph ID @@ -20,7 +20,7 @@ struct DataDepsTaskQueue <: AbstractTaskQueue traversal::Symbol=:inorder, scheduler::Symbol=:naive, aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] + seen_tasks = DTaskPair[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, @@ -28,11 +28,11 @@ struct DataDepsTaskQueue <: AbstractTaskQueue end end -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) +function enqueue!(queue::DataDepsTaskQueue, pair::DTaskPair) + push!(queue.seen_tasks, pair) end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) +function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.seen_tasks, pairs) end """ @@ -116,12 +116,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for w in procs() append!(all_procs, get_processors(OSProc(w))) end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) + filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end + scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 @@ -184,317 +183,367 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Start launching tasks and necessary copies write_num = 1 proc_idx = 1 - pressures = Dict{Processor,Int}() + #pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(value(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] + for pair in queue.seen_tasks[task_order] + spec = pair.spec + task = pair.task + write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + end + end +end +struct DataDepsTaskDependency + arg_w::ArgumentWrapper + readdep::Bool + writedep::Bool +end +DataDepsTaskDependency(arg, dep) = + DataDepsTaskDependency(ArgumentWrapper(arg, dep[1]), dep[2], dep[3]) +struct DataDepsTaskArgument + arg + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::Vector{DataDepsTaskDependency} +end +struct TypedDataDepsTaskArgument{T,N} + arg::T + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::NTuple{N,DataDepsTaskDependency} +end +map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) +@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed + @specialize spec fargs + + if typed + fargs::Tuple + else + fargs::Vector{Argument} + end + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) end - f_chunk = tochunk(value(spec.fargs[1])) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) end - f_chunk = tochunk(value(spec.fargs[1])) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) end + end - # FIXME: Copy deps are computed eagerly - deps = @something(spec.options.syncdeps, Set{Any}()) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) end + spaces_completed[space] = completed + end - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue end - else - error("Invalid scheduler: $sched") + our_proc = rand(our_space_procs) + break end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - # Find the scope for this task (and its copies) + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end + # all_procs is already limited to scope else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) + f = spec.fargs[1] + # FIXME: May not be correct to move this under uniformity + #f.value = move(default_processor(), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - # Generate a list of ArgumentWrappers for each task argument - task_arg_ws = map(task_args) do _arg - arg = value(_arg) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) || !supports_inplace_move(state, arg) - return [(ArgumentWrapper(arg, identity), false, false)] - end + # Copy raw task arguments for analysis + # N.B. Used later for checking dependencies + task_args = map_or_ntuple(idx->copy(spec.fargs[idx]), spec.fargs) - # Get the Chunk for the argument - arg = state.raw_arg_to_chunk[arg] + # Populate all task dependencies + task_arg_ws = populate_task_info!(state, task_args, spec, task) - arg_ws = Tuple{ArgumentWrapper,Bool,Bool}[] - for (dep_mod, readdep, writedep) in deps - push!(arg_ws, (ArgumentWrapper(arg, dep_mod), readdep, writedep)) - end - return arg_ws + # Truncate the history for each argument + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + truncate_history!(state, dep.arg_w) end - task_arg_ws = task_arg_ws::Vector{Vector{Tuple{ArgumentWrapper,Bool,Bool}}} + return + end - # Truncate the history for each argument - for arg_ws in task_arg_ws - for (arg_w, _, _) in arg_ws - truncate_history!(state, arg_w) - end + # Copy args from local to remote + remote_args = map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + pos = raw_position(arg_ws.pos) + + # Is the data written previously or now? + if !arg_ws.may_alias + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + return arg end - # Copy args from local to remote - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - pos = raw_position(task_args[idx]) + # Is the data writeable? + if !arg_ws.inplace_move + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + return arg + end - # Is the data written previously or now? - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" - spec.fargs[idx].value = arg - continue + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end + end + return arg_remote + end + write_num += 1 - # Is the data writeable? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end + # Validate that we're not accidentally performing a copy + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = remote_args[idx] - # Is the source of truth elsewhere? - arg_remote = get_or_generate_slot!(state, our_space, arg) - for (arg_w, _, _) in arg_ws - dep_mod = arg_w.dep_mod - remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) - if remainder isa MultiRemainderAliasing - enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) - elseif remainder isa FullCopy - enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) - else - @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - arg = value(_arg) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end + # Get the dependencies again as (dep_mod, readdep, writedep) + deps = map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + (dep.arg_w.dep_mod, dep.readdep, dep.writedep) end - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{Any}() - end - syncdeps = spec.options.syncdeps - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - for (arg_w, _, writedep) in arg_ws - ainfo = aliasing!(state, our_space, arg_w) - dep_mod = arg_w.dep_mod - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" - get_write_deps!(state, our_space, ainfo, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" - get_read_deps!(state, our_space, ainfo, write_num, syncdeps) - end - end + # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, arg_ws) in enumerate(task_arg_ws) - arg = first(arg_ws)[1].arg - type_may_alias(typeof(arg)) || continue - for (arg_w, _, writedep) in arg_ws - ainfo = aliasing!(state, our_space, arg_w) - dep_mod = arg_w.dep_mod - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" - add_writer!(state, arg_w, our_space, ainfo, task, write_num) - else - add_reader!(state, arg_w, our_space, ainfo, task, write_num) - end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{ThunkSyncdep}() + end + syncdeps = spec.options.syncdeps + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) + return end + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - # Copy args from remote to local - # N.B. We sort the keys to ensure a deterministic order for uniformity - for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) - arg = arg_w.arg - origin_space = state.arg_origin[arg] - remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) - if remainder isa MultiRemainderAliasing - origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) - enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) - elseif remainder isa FullCopy - origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) - enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) else - @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + return Argument(task_arg_ws[idx].pos, remote_args[idx]) end end + new_spec = DTaskSpec(new_fargs, spec.options) + new_spec.options.scope = our_scope + new_spec.options.exec_scope = our_scope + new_spec.options.occupancy = Dict(Any=>0) + enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + + # Update read/write tracking for arguments + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + for dep in arg_ws.deps + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + return + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + + return write_num, proc_idx end diff --git a/src/queue.jl b/src/queue.jl index c8c6007ec..37947a0ac 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,32 +1,63 @@ -mutable struct DTaskSpec - fargs::Vector{Argument} +mutable struct DTaskSpec{typed,FA<:Tuple} + _fargs::Vector{Argument} + _typed_fargs::FA options::Options end +DTaskSpec(fargs::Vector{Argument}, options::Options) = + DTaskSpec{false, Tuple{}}(fargs, (), options) +DTaskSpec(fargs::FA, options::Options) where FA = + DTaskSpec{true, FA}(Argument[], fargs, options) +is_typed(spec::DTaskSpec{typed}) where typed = typed +function Base.getproperty(spec::DTaskSpec{typed}, field::Symbol) where typed + if field === :fargs + if typed + return getfield(spec, :_typed_fargs) + else + return getfield(spec, :_fargs) + end + else + return getfield(spec, field) + end +end + +struct DTaskPair + spec::DTaskSpec + task::DTask +end +is_typed(pair::DTaskPair) = is_typed(pair.spec) +Base.iterate(pair::DTaskPair) = (pair.spec, true) +function Base.iterate(pair::DTaskPair, state::Bool) + if state + return (pair.task, false) + else + return nothing + end +end abstract type AbstractTaskQueue end function enqueue! end struct DefaultTaskQueue <: AbstractTaskQueue end -enqueue!(::DefaultTaskQueue, spec::Pair{DTaskSpec,DTask}) = - eager_launch!(spec) -enqueue!(::DefaultTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) = - eager_launch!(specs) +enqueue!(::DefaultTaskQueue, pair::DTaskPair) = + eager_launch!(pair) +enqueue!(::DefaultTaskQueue, pairs::Vector{DTaskPair}) = + eager_launch!(pairs) -enqueue!(spec::Pair{DTaskSpec,DTask}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), spec) -enqueue!(specs::Vector{Pair{DTaskSpec,DTask}}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), specs) +enqueue!(pair::DTaskPair) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pair) +enqueue!(pairs::Vector{DTaskPair}) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pairs) struct LazyTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} - LazyTaskQueue() = new(Pair{DTaskSpec,DTask}[]) + tasks::Vector{DTaskPair} + LazyTaskQueue() = new(DTaskPair[]) end -function enqueue!(queue::LazyTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) +function enqueue!(queue::LazyTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) end -function enqueue!(queue::LazyTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) +function enqueue!(queue::LazyTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) end function spawn_bulk(f::Base.Callable) queue = LazyTaskQueue() @@ -50,25 +81,25 @@ function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) push!(syncdeps, ThunkSyncdep(task)) end end -function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) +function enqueue!(queue::InOrderTaskQueue, pair::DTaskPair) if length(queue.prev_tasks) > 0 - _add_prev_deps!(queue, first(spec)) + _add_prev_deps!(queue, pair.spec) empty!(queue.prev_tasks) end - push!(queue.prev_tasks, last(spec)) - enqueue!(queue.upper_queue, spec) + push!(queue.prev_tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::InOrderTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) +function enqueue!(queue::InOrderTaskQueue, pairs::Vector{DTaskPair}) if length(queue.prev_tasks) > 0 - for (spec, task) in specs - _add_prev_deps!(queue, spec) + for pair in pairs + _add_prev_deps!(queue, pair.spec) end empty!(queue.prev_tasks) end - for (spec, task) in specs - push!(queue.prev_tasks, task) + for pair in pairs + push!(queue.prev_tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function spawn_sequential(f::Base.Callable) queue = InOrderTaskQueue(get_options(:task_queue, DefaultTaskQueue())) @@ -79,15 +110,15 @@ struct WaitAllQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue tasks::Vector{DTask} end -function enqueue!(queue::WaitAllQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec[2]) - enqueue!(queue.upper_queue, spec) +function enqueue!(queue::WaitAllQueue, pair::DTaskPair) + push!(queue.tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - for (_, task) in specs - push!(queue.tasks, task) +function enqueue!(queue::WaitAllQueue, pairs::Vector{DTaskPair}) + for pair in pairs + push!(queue.tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function wait_all(f; check_errors::Bool=false) queue = WaitAllQueue(get_options(:task_queue, DefaultTaskQueue()), DTask[]) diff --git a/src/submission.jl b/src/submission.jl index 2e7b1c836..4ff4f2294 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -268,24 +268,29 @@ function eager_process_elem_submission_to_local!(id_map, arg::Argument) arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) - spec, task = spec_pair +function eager_process_elem_submission_to_local(id_map, arg::TypedArgument{T}) where T + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + return Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) + end + return arg +end +function eager_process_args_submission_to_local!(id_map, spec::DTaskSpec{false}) for arg in spec.fargs eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) - for spec_pair in spec_pairs - eager_process_args_submission_to_local!(id_map, spec_pair) - end +function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) + return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -function DTaskMetadata(spec::DTaskSpec) - f = value(spec.fargs[1]) +DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function eager_metadata(fargs) + f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f - arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) - return_type = Base.promote_op(f, arg_types...) - return DTaskMetadata(return_type) + arg_types = ntuple(i->chunktype(value(fargs[i+1])), length(fargs)-1) + return Base.promote_op(f, arg_types...) end function eager_spawn(spec::DTaskSpec) @@ -298,48 +303,64 @@ end chunktype(t::DTask) = t.metadata.return_type -function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) +function eager_launch!(pair::DTaskPair) + spec = pair.spec + task = pair.task + # Assign a name, if specified eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - lock(Sch.EAGER_ID_MAP) do id_map - eager_process_args_submission_to_local!(id_map, spec=>task) + fargs = lock(Sch.EAGER_ID_MAP) do id_map + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - spec.fargs, spec.options, true)) + fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end -function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) - ntasks = length(specs) +# FIXME: Don't convert Tuple to Vector{Argument} +function eager_launch!(pairs::Vector{DTaskPair}) + ntasks = length(pairs) # Assign a name, if specified - for (spec, task) in specs - eager_assign_name!(spec, task) + for pair in pairs + eager_assign_name!(pair.spec, pair.task) end #=FIXME:REALLOC_N=# - uids = [task.uid for (_, task) in specs] - futures = [task.future for (_, task) in specs] + uids = [pair.task.uid for pair in pairs] + futures = [pair.task.future for pair in pairs] # Get all functions, args/kwargs, and options #=FIXME:REALLOC_N=# all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local!(id_map, specs) - [spec.fargs for (spec, _) in specs] + return map(pairs) do pair + spec = pair.spec + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end + end end - all_options = Options[spec.options for (spec, _) in specs] + all_options = Options[pair.spec.options for pair in pairs] # Submit the tasks #=FIXME:REALLOC=# thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, all_fargs, all_options, true)) for i in 1:ntasks - task = specs[i][2] + task = pairs[i].task task.thunk_ref = thunk_ids[i].ref end end diff --git a/src/thunk.jl b/src/thunk.jl index 482d66209..d1701e3ef 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -186,21 +186,19 @@ function args_kwargs_to_arguments(f, args, kwargs) end return args_kwargs end -function args_kwargs_to_arguments(f, args) - @nospecialize f args - args_kwargs = Argument[] - push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) - pos_ctr = 1 - for idx in 1:length(args) - pos, arg = args[idx]::Pair - if pos === nothing - push!(args_kwargs, Argument(pos_ctr, arg)) - pos_ctr += 1 +function args_kwargs_to_typedarguments(f, args, kwargs) + nargs = 1 + length(args) + length(kwargs) + return ntuple(nargs) do idx + if idx == 1 + return TypedArgument(ArgPosition(true, 0, :NULL), f) + elseif idx in 2:(1+length(args)) + arg = args[idx-1] + return TypedArgument(idx, arg) else - push!(args_kwargs, Argument(pos, arg)) + kw, value = kwargs[idx-length(args)-1] + return TypedArgument(kw, value) end end - return args_kwargs end """ @@ -491,7 +489,11 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) @gensym result return quote let - $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + $result = if $get_task_typed() + $typed_spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + else + $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + end if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->fetch($result; raw=true)))) end @@ -516,6 +518,9 @@ function _setindex!_return_value(A, value, idxs...) return value end +const TASK_TYPED = ScopedValue{Bool}(false) +get_task_typed() = TASK_TYPED[] + """ Dagger.spawn(f, args...; kwargs...) -> DTask @@ -526,6 +531,36 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Argument form + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function typed_spawn(f, args...; kwargs...) + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Tuple of TypedArgument form + args_kwargs = args_kwargs_to_typedarguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function _spawn(args_kwargs, task_options) # Get all scoped options and determine which propagate beyond this task scoped_options = get_options()::NamedTuple if haskey(scoped_options, :propagates) @@ -539,20 +574,9 @@ function spawn(f, args...; kwargs...) end append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) - # Merge all passed options - if length(args) >= 1 && first(args) isa Options - # N.B. Make a defensive copy in case user aliases Options struct - task_options = copy(first(args)::Options) - args = args[2:end] - else - task_options = Options() - end # N.B. Merges into task_options options_merge!(task_options, scoped_options; override=false) - # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_arguments(f, args, kwargs) - # Get task queue, and don't let it propagate task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue filter!(prop -> prop != :task_queue, propagates) @@ -568,7 +592,7 @@ function spawn(f, args...; kwargs...) task = eager_spawn(spec) # Enqueue the task into the task queue - enqueue!(task_queue, spec=>task) + enqueue!(task_queue, DTaskPair(spec, task)) return task end From f1528190e328a8bd806e1611a6a6d2d06e3d767d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 19:50:51 -0700 Subject: [PATCH 03/24] datadeps: Optimize ainfo aliasing lookups --- src/Dagger.jl | 4 +- src/datadeps/aliasing.jl | 40 ++-- src/memory-spaces.jl | 271 +++++++++++++++++------ src/{datadeps => utils}/interval_tree.jl | 96 ++++---- src/utils/memory-span.jl | 98 ++++++++ 5 files changed, 381 insertions(+), 128 deletions(-) rename src/{datadeps => utils}/interval_tree.jl (81%) create mode 100644 src/utils/memory-span.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 987963b34..102a76149 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -73,6 +73,9 @@ include("utils/fetch.jl") include("utils/chunks.jl") include("utils/logging.jl") include("submission.jl") +abstract type MemorySpace end +include("utils/memory-span.jl") +include("utils/interval_tree.jl") include("memory-spaces.jl") # Task scheduling @@ -85,7 +88,6 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps/aliasing.jl") include("datadeps/chunkview.jl") -include("datadeps/interval_tree.jl") include("datadeps/remainders.jl") include("datadeps/queue.jl") diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 87eba76df..6586409c4 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -299,6 +299,10 @@ struct DataDepsState # N.B. This is a mapping for remote argument copies ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + # The overlapping ainfos for each ainfo # Incrementally updated as new ainfos are created # Used for fast will_alias lookups @@ -328,13 +332,14 @@ struct DataDepsState supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + ainfos_lookup = AliasingLookup() ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, - supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end end @@ -463,27 +468,30 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o aliasing!(state, origin_space, arg_w) end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) - # Initialize owner and readers if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos overlaps = Set{AliasingWrapper}() push!(overlaps, target_ainfo) - for other_ainfo in keys(state.ainfos_owner) + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) target_ainfo == other_ainfo && continue - if will_alias(target_ainfo, other_ainfo) - # Mark us and them as overlapping - push!(overlaps, other_ainfo) - push!(state.ainfos_overlaps[other_ainfo], target_ainfo) - - # Add overlapping history to our own - other_remote_arg_w = state.ainfo_arg[other_ainfo] - other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] - other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) - push!(state.arg_overlaps[original_arg_w], other_arg_w) - push!(state.arg_overlaps[other_arg_w], original_arg_w) - merge_history!(state, original_arg_w, other_arg_w) - end + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) end state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers state.ainfos_owner[target_ainfo] = nothing state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index b1ff40d8f..fcc3dbf0b 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,5 +1,3 @@ -abstract type MemorySpace end - struct CPURAMMemorySpace <: MemorySpace owner::Int end @@ -92,46 +90,16 @@ end may_alias(::MemorySpace, ::MemorySpace) = true may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -struct RemotePtr{T,S<:MemorySpace} <: Ref{T} - addr::UInt - space::S -end -RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) -RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) -RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) -# FIXME: Don't hardcode CPURAMMemorySpace -RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) -Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = - RemotePtr(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = - RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr -Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) -Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) -function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) - @assert ptr1.space == ptr2.space - return ptr1.addr < ptr2.addr -end - -struct MemorySpan{S} - ptr::RemotePtr{Cvoid,S} - len::UInt -end -MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = - MemorySpan{S}(ptr, UInt(len)) -MemorySpan{S}(addr::UInt, len::Integer) where S = - MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) -Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr -Base.isempty(x::MemorySpan) = x.len == 0 abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) -struct AliasingWrapper <: AbstractAliasing +### Type-generic aliasing info wrapper + +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -140,8 +108,204 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + cursor_ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, cursor_ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -406,36 +570,3 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end - -### More space-efficient memory spans - -struct LocalMemorySpan - ptr::UInt - len::UInt -end -LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) -Base.isempty(x::LocalMemorySpan) = x.len == 0 - -# FIXME: Store the length separately, since it's shared by all spans -struct ManyMemorySpan{N} - spans::NTuple{N,LocalMemorySpan} -end -Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) - -struct ManyPair{N} <: Unsigned - pairs::NTuple{N,UInt} -end -Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair -Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) -Base.convert(::Type{ManyPair}, x::ManyPair) = x -Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) -Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) -Base.:-(x::ManyPair) = error("Can't negate a ManyPair") -Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs -Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] -Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] -Base.string(x::ManyPair) = "ManyPair($(x.pairs))" - -ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = - ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) - diff --git a/src/datadeps/interval_tree.jl b/src/utils/interval_tree.jl similarity index 81% rename from src/datadeps/interval_tree.jl rename to src/utils/interval_tree.jl index 1075f5912..e67f66b24 100644 --- a/src/datadeps/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -1,16 +1,3 @@ -# Get the start address of a span -span_start(span::MemorySpan) = span.ptr.addr -span_start(span::LocalMemorySpan) = span.ptr -span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) -# Get the length of a span -span_len(span::MemorySpan) = span.len -span_len(span::LocalMemorySpan) = span.len -span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) - -# Get the end address of a span -span_end(span::MemorySpan) = span.ptr.addr + span.len -span_end(span::LocalMemorySpan) = span.ptr + span.len -span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) mutable struct IntervalNode{M,E} span::M max_end::E # Maximum end value in this subtree @@ -20,6 +7,7 @@ mutable struct IntervalNode{M,E} IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocatorMemorySpan{T}) where T = new{LocatorMemorySpan{T},UInt64}(span, span_end(span), nothing, nothing) end mutable struct IntervalTree{M,E} @@ -28,6 +16,7 @@ mutable struct IntervalTree{M,E} IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) + IntervalTree{LocatorMemorySpan{T}}() where T = new{LocatorMemorySpan{T},UInt64}(nothing) end # Construct interval tree from unsorted set of spans @@ -94,19 +83,48 @@ end # Update max_end value for a node based on its children function update_max_end!(node::IntervalNode) - node.max_end = span_end(node.span) + max_end = span_end(node.span) if node.left !== nothing - node.max_end = max(node.max_end, node.left.max_end) + max_end = max(max_end, node.left.max_end) end if node.right !== nothing - node.max_end = max(node.max_end, node.right.max_end) + max_end = max(max_end, node.right.max_end) end + node.max_end = max_end end # Insert a span into the interval tree -function Base.insert!(tree::IntervalTree{M}, span::M) where M +function Base.insert!(tree::IntervalTree{M,E}, span::M) where {M,E} if !isempty(span) - tree.root = insert_node!(tree.root, span) + if tree.root === nothing + tree.root = IntervalNode(span) + update_max_end!(tree.root) + return span + end + #tree.root = insert_node!(tree.root, span) + to_update = Vector{IntervalNode{M,E}}() + prev_node = tree.root + cur_node = tree.root + while cur_node !== nothing + if span_start(span) <= span_start(cur_node.span) + cur_node = cur_node.left + else + cur_node = cur_node.right + end + if cur_node !== nothing + prev_node = cur_node + push!(to_update, cur_node) + end + end + if prev_node.left === nothing + prev_node.left = IntervalNode(span) + else + prev_node.right = IntervalNode(span) + end + for node_idx in eachindex(to_update) + node = to_update[node_idx] + update_max_end!(node) + end end return span end @@ -221,46 +239,42 @@ function find_min(node::IntervalNode) return node end -# Check if two spans overlap -function spans_overlap(span1::MemorySpan, span2::MemorySpan) - return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) -end -function spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) - return span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) -end -function spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N - # N.B. The spans are assumed to be the same length and relative offset - return spans_overlap(span1.spans[1], span2.spans[1]) -end - # Find all spans that overlap with the given query span -function find_overlapping(tree::IntervalTree{M}, query::M) where M +function find_overlapping(tree::IntervalTree{M}, query::M; exact::Bool=true) where M result = M[] - find_overlapping!(tree.root, query, result) + find_overlapping!(tree.root, query, result; exact) + return result +end +function find_overlapping!(tree::IntervalTree{M}, query::M, result::Vector{M}; exact::Bool=true) where M + find_overlapping!(tree.root, query, result; exact) return result end -function find_overlapping!(::Nothing, query::M, result::Vector{M}) where M +function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=true) where M return end -function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}) where {M,E} +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} # Check if current node overlaps with query if spans_overlap(node.span, query) - # Get the overlapping portion of the span - overlap_start = max(span_start(node.span), span_start(query)) - overlap_end = min(span_end(node.span), span_end(query)) - overlap = M(overlap_start, overlap_end - overlap_start) - push!(result, overlap) + if exact + # Get the overlapping portion of the span + overlap_start = max(span_start(node.span), span_start(query)) + overlap_end = min(span_end(node.span), span_end(query)) + overlap = M(overlap_start, overlap_end - overlap_start) + push!(result, overlap) + else + push!(result, node.span) + end end # Recursively search left subtree if it might contain overlapping intervals if node.left !== nothing && node.left.max_end > span_start(query) - find_overlapping!(node.left, query, result) + find_overlapping!(node.left, query, result; exact) end # Recursively search right subtree if query extends beyond current node's start if node.right !== nothing && span_end(query) > span_start(node.span) - find_overlapping!(node.right, query, result) + find_overlapping!(node.right, query, result; exact) end end diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl new file mode 100644 index 000000000..91f291cbe --- /dev/null +++ b/src/utils/memory-span.jl @@ -0,0 +1,98 @@ +### Remote pointer type + +struct RemotePtr{T,S<:MemorySpace} <: Ref{T} + addr::UInt + space::S +end +RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) +RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) +RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) +Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = + RemotePtr(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = + RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr +Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) +Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) +function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) + @assert ptr1.space == ptr2.space + return ptr1.addr < ptr2.addr +end + +### Generic memory spans + +struct MemorySpan{S} + ptr::RemotePtr{Cvoid,S} + len::UInt +end +MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = + MemorySpan{S}(ptr, UInt(len)) +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 +span_start(span::MemorySpan) = span.ptr.addr +span_len(span::MemorySpan) = span.len +span_end(span::MemorySpan) = span.ptr.addr + span.len +spans_overlap(span1::MemorySpan, span2::MemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +### More space-efficient memory spans + +struct LocalMemorySpan + ptr::UInt + len::UInt +end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 +span_start(span::LocalMemorySpan) = span.ptr +span_len(span::LocalMemorySpan) = span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} +end +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = + # N.B. The spans are assumed to be the same length and relative offset + spans_overlap(span1.spans[1], span2.spans[1]) + +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} +end +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" + +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) + +### Memory spans with ownership info + +struct LocatorMemorySpan{T} + span::LocalMemorySpan + owner::T +end +LocatorMemorySpan{T}(start::UInt64, len::UInt64) where T = # For interval tree + LocatorMemorySpan{T}(LocalMemorySpan(start, len), 0) +Base.isempty(x::LocatorMemorySpan) = span_len(x.span) == 0 +span_start(x::LocatorMemorySpan) = span_start(x.span) +span_end(x::LocatorMemorySpan) = span_end(x.span) +span_len(x::LocatorMemorySpan) = span_len(x.span) +spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = + spans_overlap(span1.span, span2.span) \ No newline at end of file From 8b53c358e5f665a36d17e783474bdc58e6f2e02f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 17:51:50 -0700 Subject: [PATCH 04/24] datadeps: Optimize remote ArgumentWrapper lookup --- src/datadeps/aliasing.jl | 84 +++++++++++++++++++++++++--------------- src/memory-spaces.jl | 10 ++++- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 6586409c4..44401c38f 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -269,6 +269,9 @@ struct DataDepsState # The mapping of remote argument to original argument remote_arg_to_original::IdDict{Any,Any} + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies @@ -323,6 +326,7 @@ struct DataDepsState arg_origin = IdDict{Any,MemorySpace}() remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() @@ -338,37 +342,11 @@ struct DataDepsState ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) end end -# N.B. arg_w must be the original argument wrapper, not a remote copy -function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) - # Grab the remote copy of the argument, and calculate the ainfo - remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) - remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) - - # Check if we already have the result cached - if haskey(state.ainfo_cache, remote_arg_w) - return state.ainfo_cache[remote_arg_w] - end - - # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) - - # Cache the result - state.ainfo_cache[remote_arg_w] = ainfo - - # Update the mapping of ainfo to argument and dep_mod - state.ainfo_arg[ainfo] = remote_arg_w - - # Populate info for the new ainfo - populate_ainfo!(state, arg_w, ainfo, target_space) - - return ainfo -end - function supports_inplace_move(state::DataDepsState, arg) return get!(state.supports_inplace_cache, arg) do return supports_inplace_move(arg) @@ -467,6 +445,41 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o # Calculate the ainfo (which will populate ainfo structures and merge history) aliasing!(state, origin_space, arg_w) end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = remote_arg_w + else + @assert state.ainfo_arg[ainfo] == remote_arg_w + end + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) if !haskey(state.ainfos_owner, target_ainfo) # Add ourselves to the lookup oracle @@ -673,11 +686,18 @@ function get_or_generate_slot!(state, dest_space, data) end function move_rewrap(from_proc::Processor, to_proc::Processor, data) return aliased_object!(data) do data - to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - data_converted = move(from_proc, to_proc, data) - return tochunk(data_converted, to_proc) - end + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + end +end +function remotecall_endpoint(f, from_proc, to_proc, orig_space, dest_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) + end + return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) end end const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index fcc3dbf0b..4124bbba6 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -382,8 +382,14 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T - aliasing(unwrap(x), T) +function aliasing(x::Chunk, T) + @assert x.handle isa DRef + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end + return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T + aliasing(unwrap(x), T) + end end aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x aliasing(unwrap(x)) From e0bc71ab008e06fb7a89844c06a3e35f0f8d93f3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 15 Oct 2025 17:52:46 -0700 Subject: [PATCH 05/24] thunk: Remove unnecessary scope allocations --- src/thunk.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index d1701e3ef..e13e299f0 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -17,8 +17,6 @@ function unset!(spec::ThunkSpec, _) spec.id = 0 spec.cache_ref = nothing spec.affinity = nothing - compute_scope = DefaultScope() - result_scope = AnyScope() spec.options = nothing end From 9dff3817d602fca2e77c8c9be320725222f43290 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 11 Nov 2025 15:30:36 -0700 Subject: [PATCH 06/24] test/datadeps: Remove aliasing=false tests --- test/datadeps.jl | 411 +++++++++++++++++++++++------------------------ 1 file changed, 203 insertions(+), 208 deletions(-) diff --git a/test/datadeps.jl b/test/datadeps.jl index cd83be95f..4fb873454 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -177,16 +177,15 @@ end @everywhere mut_V!(V) = (V .= 1;) function test_datadeps(;args_chunks::Bool, args_thunks::Bool, - args_loc::Int, - aliasing::Bool) + args_loc::Int) # Returns last value - @test Dagger.spawn_datadeps(;aliasing) do + @test Dagger.spawn_datadeps() do 42 end == 42 # Tasks are started and finished as spawn_datadeps returns ts = [] - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for i in 1:5 t = Dagger.@spawn sleep(0.1) @test !istaskstarted(t) @@ -195,7 +194,7 @@ function test_datadeps(;args_chunks::Bool, @test all(istaskdone, ts) # Rethrows any task exceptions - @test_throws Exception Dagger.spawn_datadeps(;aliasing) do + @test_throws Exception Dagger.spawn_datadeps() do Dagger.@spawn error("Test") end @@ -209,7 +208,7 @@ function test_datadeps(;args_chunks::Bool, # Task return values can be tracked ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) push!(ts, t1) push!(ts, Dagger.@spawn copyto!(Out(A), In(t1))) @@ -224,7 +223,7 @@ function test_datadeps(;args_chunks::Bool, # R->R Non-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(In(A))) end @@ -236,7 +235,7 @@ function test_datadeps(;args_chunks::Bool, # R->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end @@ -248,7 +247,7 @@ function test_datadeps(;args_chunks::Bool, # W->W Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A))) end @@ -260,7 +259,7 @@ function test_datadeps(;args_chunks::Bool, # R->R Non-Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) end @@ -272,7 +271,7 @@ function test_datadeps(;args_chunks::Bool, # R->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(In(A), In(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end @@ -284,7 +283,7 @@ function test_datadeps(;args_chunks::Bool, # W->W Self-Aliasing ts = [] logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) push!(ts, Dagger.@spawn do_nothing(Out(A), Out(A))) end @@ -293,197 +292,195 @@ function test_datadeps(;args_chunks::Bool, test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) - if aliasing - function wrap_chunk_thunk(f, args...) - if args_thunks || args_chunks - result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) - if args_thunks - return result - elseif args_chunks - return fetch(result; raw=true) - end - else - # N.B. We don't allocate remotely for raw data - return f(args...) - end - end - B = wrap_chunk_thunk(rand, 4, 4) - - # Views - B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) - B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) - B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) - B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - for (B_name, B_view) in ( - (:B_ul, B_ul), - (:B_ur, B_ur), - (:B_ll, B_ll), - (:B_lr, B_lr), - (:B_mid, B_mid)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - B_view === B_mid && continue - @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) - end - local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid - local t_ul2, t_ur2, t_ll2, t_lr2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) - t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) - t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) - t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) - t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) - t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) + function wrap_chunk_thunk(f, args...) + if args_thunks || args_chunks + result = Dagger.@spawn scope=Dagger.scope(worker=args_loc) f(args...) + if args_thunks + return result + elseif args_chunks + return fetch(result; raw=true) end + else + # N.B. We don't allocate remotely for raw data + return f(args...) end - tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = - task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) - tid_ul2, tid_ur2, tid_ll2, tid_lr2 = - task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) - tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, - tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) - - # (Unit)Upper/LowerTriangular and Diagonal - B_upper = wrap_chunk_thunk(UpperTriangular, B) - B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) - B_lower = wrap_chunk_thunk(LowerTriangular, B) - B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) - for (B_name, B_view) in ( - (:B_upper, B_upper), - (:B_unitupper, B_unitupper), - (:B_lower, B_lower), - (:B_unitlower, B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) - end - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) - - @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) - @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) - @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) - - local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag - local t_upper2, t_unitupper2, t_lower2, t_unitlower2 - logs = with_logs() do - Dagger.spawn_datadeps(;aliasing) do - t_A = Dagger.@spawn do_nothing(InOut(A)) - t_B = Dagger.@spawn do_nothing(InOut(B)) - t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) - t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) - t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) - t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) - t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) - t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) - end + end + B = wrap_chunk_thunk(rand, 4, 4) + + # Views + B_ul = wrap_chunk_thunk(view, B, 1:2, 1:2) + B_ur = wrap_chunk_thunk(view, B, 1:2, 3:4) + B_ll = wrap_chunk_thunk(view, B, 3:4, 1:2) + B_lr = wrap_chunk_thunk(view, B, 3:4, 3:4) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + for (B_name, B_view) in ( + (:B_ul, B_ul), + (:B_ur, B_ur), + (:B_ll, B_ll), + (:B_lr, B_lr), + (:B_mid, B_mid)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + B_view === B_mid && continue + @test Dagger.will_alias(Dagger.aliasing(B_mid), Dagger.aliasing(B_view)) + end + local t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid + local t_ul2, t_ur2, t_ll2, t_lr2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_ul = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr = Dagger.@spawn do_nothing(InOut(B_lr)) + t_mid = Dagger.@spawn do_nothing(InOut(B_mid)) + t_ul2 = Dagger.@spawn do_nothing(InOut(B_ul)) + t_ur2 = Dagger.@spawn do_nothing(InOut(B_ur)) + t_ll2 = Dagger.@spawn do_nothing(InOut(B_ll)) + t_lr2 = Dagger.@spawn do_nothing(InOut(B_lr)) end - tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = - task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = - task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) - tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, - tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) - - # Additional aliasing tests - views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) - - A = wrap_chunk_thunk(identity, B) - - A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) - A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) - B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) - B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) - - A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) - A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) - B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) - B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) - - A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) - B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) - - @test views_overlap(A_r1, A_r1) - @test views_overlap(B_r1, B_r1) - @test views_overlap(A_c1, A_c1) - @test views_overlap(B_c1, B_c1) - - @test views_overlap(A_r1, B_r1) - @test views_overlap(A_r2, B_r2) - @test views_overlap(A_c1, B_c1) - @test views_overlap(A_c2, B_c2) - - @test !views_overlap(A_r1, A_r2) - @test !views_overlap(B_r1, B_r2) - @test !views_overlap(A_c1, A_c2) - @test !views_overlap(B_c1, B_c2) - - @test views_overlap(A_r1, A_c1) - @test views_overlap(A_r1, B_c1) - @test views_overlap(A_r2, A_c2) - @test views_overlap(A_r2, B_c2) - - for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) - @test !views_overlap(A_r1, mid) - @test !views_overlap(B_r1, mid) - @test !views_overlap(A_c1, mid) - @test !views_overlap(B_c1, mid) - - @test views_overlap(A_r2, mid) - @test views_overlap(B_r2, mid) - @test views_overlap(A_c2, mid) - @test views_overlap(B_c2, mid) + end + tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid = + task_id.([t_A, t_B, t_ul, t_ur, t_ll, t_lr, t_mid]) + tid_ul2, tid_ur2, tid_ll2, tid_lr2 = + task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) + tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, + tid_ul2, tid_ur2, tid_ll2, tid_lr2] + test_task_dominators(logs, tid_A, []; all_tids=tids_all) + test_task_dominators(logs, tid_B, []; all_tids=tids_all) + test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) + test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) + test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) + + # (Unit)Upper/LowerTriangular and Diagonal + B_upper = wrap_chunk_thunk(UpperTriangular, B) + B_unitupper = wrap_chunk_thunk(UnitUpperTriangular, B) + B_lower = wrap_chunk_thunk(LowerTriangular, B) + B_unitlower = wrap_chunk_thunk(UnitLowerTriangular, B) + for (B_name, B_view) in ( + (:B_upper, B_upper), + (:B_unitupper, B_unitupper), + (:B_lower, B_lower), + (:B_unitlower, B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B), Dagger.aliasing(B_view)) + end + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_lower)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B_unitlower)) + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B_unitupper)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B_unitlower)) + + @test Dagger.will_alias(Dagger.aliasing(B_upper), Dagger.aliasing(B, Diagonal)) + @test Dagger.will_alias(Dagger.aliasing(B_lower), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitupper), Dagger.aliasing(B, Diagonal)) + @test !Dagger.will_alias(Dagger.aliasing(B_unitlower), Dagger.aliasing(B, Diagonal)) + + local t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag + local t_upper2, t_unitupper2, t_lower2, t_unitlower2 + logs = with_logs() do + Dagger.spawn_datadeps() do + t_A = Dagger.@spawn do_nothing(InOut(A)) + t_B = Dagger.@spawn do_nothing(InOut(B)) + t_upper = Dagger.@spawn do_nothing(InOut(B_upper)) + t_unitupper = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_lower = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitlower = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_diag = Dagger.@spawn do_nothing(Deps(B, InOut(Diagonal))) + t_unitlower2 = Dagger.@spawn do_nothing(InOut(B_unitlower)) + t_lower2 = Dagger.@spawn do_nothing(InOut(B_lower)) + t_unitupper2 = Dagger.@spawn do_nothing(InOut(B_unitupper)) + t_upper2 = Dagger.@spawn do_nothing(InOut(B_upper)) end + end + tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag = + task_id.([t_A, t_B, t_upper, t_unitupper, t_lower, t_unitlower, t_diag]) + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2 = + task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) + tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, + tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] + test_task_dominators(logs, tid_A, []; all_tids=tids_all) + test_task_dominators(logs, tid_B, []; all_tids=tids_all) + # FIXME: Proper non-dominance checks + test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) + test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + + # Additional aliasing tests + views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) + + A = wrap_chunk_thunk(identity, B) + + A_r1 = wrap_chunk_thunk(view, A, 1:1, 1:4) + A_r2 = wrap_chunk_thunk(view, A, 2:2, 1:4) + B_r1 = wrap_chunk_thunk(view, B, 1:1, 1:4) + B_r2 = wrap_chunk_thunk(view, B, 2:2, 1:4) + + A_c1 = wrap_chunk_thunk(view, A, 1:4, 1:1) + A_c2 = wrap_chunk_thunk(view, A, 1:4, 2:2) + B_c1 = wrap_chunk_thunk(view, B, 1:4, 1:1) + B_c2 = wrap_chunk_thunk(view, B, 1:4, 2:2) + + A_mid = wrap_chunk_thunk(view, A, 2:3, 2:3) + B_mid = wrap_chunk_thunk(view, B, 2:3, 2:3) + + @test views_overlap(A_r1, A_r1) + @test views_overlap(B_r1, B_r1) + @test views_overlap(A_c1, A_c1) + @test views_overlap(B_c1, B_c1) + + @test views_overlap(A_r1, B_r1) + @test views_overlap(A_r2, B_r2) + @test views_overlap(A_c1, B_c1) + @test views_overlap(A_c2, B_c2) + + @test !views_overlap(A_r1, A_r2) + @test !views_overlap(B_r1, B_r2) + @test !views_overlap(A_c1, A_c2) + @test !views_overlap(B_c1, B_c2) + + @test views_overlap(A_r1, A_c1) + @test views_overlap(A_r1, B_c1) + @test views_overlap(A_r2, A_c2) + @test views_overlap(A_r2, B_c2) + + for (name, mid) in ((:A_mid, A_mid), (:B_mid, B_mid)) + @test !views_overlap(A_r1, mid) + @test !views_overlap(B_r1, mid) + @test !views_overlap(A_c1, mid) + @test !views_overlap(B_c1, mid) + + @test views_overlap(A_r2, mid) + @test views_overlap(B_r2, mid) + @test views_overlap(A_c2, mid) + @test views_overlap(B_c2, mid) + end - @test views_overlap(A_mid, A_mid) - @test views_overlap(A_mid, B_mid) + @test views_overlap(A_mid, A_mid) + @test views_overlap(A_mid, B_mid) - # SubArray hashing - V = zeros(3) - Dagger.spawn_datadeps(;aliasing) do - Dagger.@spawn mut_V!(InOut(view(V, 1:2))) - Dagger.@spawn mut_V!(InOut(view(V, 2:3))) - end - @test fetch(V) == [1, 1, 1] + # SubArray hashing + V = zeros(3) + Dagger.spawn_datadeps() do + Dagger.@spawn mut_V!(InOut(view(V, 1:2))) + Dagger.@spawn mut_V!(InOut(view(V, 2:3))) end + @test fetch(V) == [1, 1, 1] # FIXME: Deps # Outer Scope - exec_procs = fetch.(Dagger.spawn_datadeps(;aliasing) do + exec_procs = fetch.(Dagger.spawn_datadeps() do [Dagger.@spawn Dagger.task_processor() for i in 1:10] end) unique!(exec_procs) @@ -499,7 +496,7 @@ function test_datadeps(;args_chunks::Bool, end # Inner Scope - @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do + @test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1 end @@ -528,7 +525,7 @@ function test_datadeps(;args_chunks::Bool, C = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(C) D = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(D) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do Dagger.@spawn add!(InOut(B), In(A)) Dagger.@spawn add!(InOut(C), In(A)) Dagger.@spawn add!(InOut(C), In(B)) @@ -545,7 +542,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks As = map(A->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A)), As) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do to_reduce = Vector[] push!(to_reduce, As) while !isempty(to_reduce) @@ -576,7 +573,7 @@ function test_datadeps(;args_chunks::Bool, elseif args_thunks M = map(m->(Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(m)), M) end - Dagger.spawn_datadeps(;aliasing) do + Dagger.spawn_datadeps() do for k in range(1, mt) Dagger.@spawn LAPACK.potrf!('L', InOut(M[k, k])) for _m in range(k+1, mt) @@ -596,18 +593,16 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset "$(aliasing ? "With" : "Without") Aliasing Support" for aliasing in (true, false) - @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) - args_chunks = args_mode == :Chunk - args_thunks = args_mode == :Thunk - for nw in (1, 2) - args_loc = nw == 2 ? 2 : 1 - for nt in (1, 2) - if nprocs() >= nw && Threads.nthreads() >= nt - @testset "$nw Workers, $nt Threads" begin - Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do - test_datadeps(;args_chunks, args_thunks, args_loc, aliasing) - end +@testset @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) + args_chunks = args_mode == :Chunk + args_thunks = args_mode == :Thunk + for nw in (1, 2) + args_loc = nw == 2 ? 2 : 1 + for nt in (1, 2) + if nprocs() >= nw && Threads.nthreads() >= nt + @testset "$nw Workers, $nt Threads" begin + Dagger.with_options(;scope=Dagger.scope(workers=1:nw, threads=1:nt)) do + test_datadeps(;args_chunks, args_thunks, args_loc) end end end From bab76c36084c0a851d23eac10dba53f04bff69a4 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:47:40 -0700 Subject: [PATCH 07/24] datadeps: ainfo_arg must track ainfo -> multiple arg_w --- src/datadeps/aliasing.jl | 22 +++++++++++----------- src/datadeps/remainders.jl | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 44401c38f..f3c8ef62e 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -275,7 +275,7 @@ struct DataDepsState # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies - ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to # Updated when a new write happens on an overlapping ainfo @@ -327,7 +327,7 @@ struct DataDepsState remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() - ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() @@ -470,10 +470,9 @@ function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::Argum # Update the mapping of ainfo to argument and dep_mod if !haskey(state.ainfo_arg, ainfo) - state.ainfo_arg[ainfo] = remote_arg_w - else - @assert state.ainfo_arg[ainfo] == remote_arg_w + state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) end + push!(state.ainfo_arg[ainfo], remote_arg_w) # Populate info for the new ainfo populate_ainfo!(state, arg_w, ainfo, target_space) @@ -495,12 +494,13 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, push!(state.ainfos_overlaps[other_ainfo], target_ainfo) # Add overlapping history to our own - other_remote_arg_w = state.ainfo_arg[other_ainfo] - other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] - other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) - push!(state.arg_overlaps[original_arg_w], other_arg_w) - push!(state.arg_overlaps[other_arg_w], original_arg_w) - merge_history!(state, original_arg_w, other_arg_w) + for other_remote_arg_w in state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end end state.ainfos_overlaps[target_ainfo] = overlaps diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index b420ca5d3..11e28d6dc 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -190,7 +190,7 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Lookup all memory spans for arg_w in these spaces - other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) other_ainfos = Vector{Vector{LocalMemorySpan}}() for space in spaces From 2ef6c2613063032cac8c14a6e09bc5403272cebc Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:49:06 -0700 Subject: [PATCH 08/24] datadeps: Fix broken ChunkView unwrapping --- src/datadeps/aliasing.jl | 18 +++++------------- src/datadeps/chunkview.jl | 29 +++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index f3c8ef62e..2efe1855b 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -645,9 +645,6 @@ isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. @@ -656,16 +653,11 @@ function generate_slot!(state::DataDepsState, dest_space, data) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping - data_chunk = tochunk(data, from_proc) - else - ctx = Sch.eager_context() - id = rand(Int) - @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(from_proc, to_proc, data) - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) - end + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 04b581c17..e6e1d4840 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -27,38 +27,51 @@ end Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) -aliasing(x::ChunkView) = - throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +function aliasing(x::ChunkView{N}) where N + return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(v) + end +end memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true # This definition is here because it's so similar to ChunkView -function move_rewrap(from_proc::Processor, to_proc::Processor, v::SubArray) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) p_chunk = aliased_object!(parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, p) do from_proc, to_proc, p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p return tochunk(move(from_proc, to_proc, p), to_proc) end end inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, inds) do from_proc, to_proc, p_chunk, inds + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) return tochunk(v_new, to_proc) end end -function move_rewrap(from_proc::Processor, to_proc::Processor, slice::ChunkView) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) p_chunk = aliased_object!(slice.chunk) do p_chunk - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk) do from_proc, to_proc, p_chunk + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk return tochunk(move(from_proc, to_proc, p_chunk), to_proc) end end - return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, slice.slices) do from_proc, to_proc, p_chunk, inds + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) return tochunk(v_new, to_proc) end end +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices + chunk_new = move(from_proc, to_proc, chunk) + v_new = view(chunk_new, slices...) + return tochunk(v_new, to_proc) + end +end Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file From 13e4945bbeaaa2c457b03021131883b1aebc49dd Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 15 Nov 2025 10:49:48 -0700 Subject: [PATCH 09/24] datadeps: Signature fixups and small cleanups --- src/datadeps/aliasing.jl | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 2efe1855b..783ca9862 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -513,7 +513,6 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ history = state.arg_history[arg_w] @opcounter :merge_history @opcounter :merge_history_complexity length(history) - largest_value_update!(length(history)) origin_space = state.arg_origin[other_arg_w.arg] for other_entry in state.arg_history[other_arg_w] write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) @@ -676,29 +675,25 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function move_rewrap(from_proc::Processor, to_proc::Processor, data) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) return aliased_object!(data) do data return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end -function remotecall_endpoint(f, from_proc, to_proc, orig_space, dest_space, data) +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) to_w = root_worker_id(to_proc) if to_w == myid() data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc, dest_space) + return tochunk(data_converted, to_proc) end - return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc, dest_space) + return tochunk(data_converted, to_proc) end end const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) @warn "Document these public methods" maxlog=1 # TODO: Use state to cache aliasing() results -function declare_aliased_object!(x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] - cache[ainfo] = x -end function aliased_object!(x; ainfo=aliasing(x, identity)) cache = ALIASED_OBJECT_CACHE[] if haskey(cache, ainfo) @@ -721,11 +716,6 @@ function aliased_object!(f, x; ainfo=aliasing(x, identity)) end return y end -function aliased_object_unwrap!(x::Chunk) - y = unwrap(x) - ainfo = aliasing(y, identity) - return unwrap(aliased_object!(x; ainfo)) -end struct DataDepsSchedulerState task_to_spec::Dict{DTask,DTaskSpec} From 293f33307b36d2f2a68034c2772a00f58b42c3b6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:22:21 -0700 Subject: [PATCH 10/24] datadeps: Fix aliased object detection around Chunks --- src/datadeps/aliasing.jl | 74 ++++++++++++++++++++++++++------------- src/datadeps/chunkview.jl | 53 +++++++++++++++++++++++++--- src/utils/chunks.jl | 3 ++ 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 783ca9862..82989a259 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -651,18 +651,20 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if !haskey(state.ainfo_backing_chunk, dest_space) + state.ainfo_backing_chunk[dest_space] = Dict{AbstractAliasing,Chunk}() + end + # FIXME: tochunk the cache just once per space + aliased_object_cache = AliasedObjectCache(tochunk(state.ainfo_backing_chunk[dest_space])) ctx = Sch.eager_context() id = rand(Int) @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data - ALIASED_OBJECT_CACHE[] = nothing - return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -675,8 +677,47 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - return aliased_object!(data) do data +struct AliasedObjectCache + chunk::Chunk +end +@warn "Document these public methods" maxlog=1 +function Base.haskey(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(haskey, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return haskey(cache_raw, ainfo) +end +function Base.getindex(cache::AliasedObjectCache, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(getindex, wid, cache, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + return getindex(cache_raw, ainfo) +end +function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) + wid = root_worker_id(cache.chunk) + if wid != myid() + return remotecall_fetch(setindex!, wid, cache, value, ainfo) + end + cache_raw = unwrap(cache.chunk)::Dict{AbstractAliasing,Chunk} + cache_raw[ainfo] = value + return +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) + # Unwrap so that we hit the right dispatch + wid = root_worker_id(data) + if wid != myid() + return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) + end + data_raw = unwrap(data) + return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + # For generic data + return aliased_object!(cache, data) do data return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end @@ -691,30 +732,15 @@ function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) return tochunk(data_converted, to_proc) end end -const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) -@warn "Document these public methods" maxlog=1 -# TODO: Use state to cache aliasing() results -function aliased_object!(x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] - if haskey(cache, ainfo) - y = cache[ainfo] - else - @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" - cache[ainfo] = x - y = x - end - return y -end -function aliased_object!(f, x; ainfo=aliasing(x, identity)) - cache = ALIASED_OBJECT_CACHE[] +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) if haskey(cache, ainfo) - y = cache[ainfo] + return cache[ainfo] else y = f(x) @assert y isa Chunk "Didn't get a Chunk from functor" cache[ainfo] = y + return y end - return y end struct DataDepsSchedulerState diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index e6e1d4840..60ded6151 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -38,9 +38,9 @@ memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true # This definition is here because it's so similar to ChunkView -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(parent(v)) do p + p_chunk = aliased_object!(cache, parent(v)) do p return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p return tochunk(move(from_proc, to_proc, p), to_proc) end @@ -52,9 +52,54 @@ function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::Memor return tochunk(v_new, to_proc) end end -function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(cache, parent(v)) do p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p + return tochunk(move(from_proc, to_proc, p), to_proc) + end + end + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end + end +end +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + return aliased_object!(cache, v[]) do p + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p + return tochunk(Ref(move(from_proc, to_proc, p)), to_proc) + end + end +end +#= +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x + else + @warn "Cannot move-rewrap object of type $T" + return x + end +end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(slice.chunk) do p_chunk + p_chunk = aliased_object!(cache, slice.chunk) do p_chunk return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk return tochunk(move(from_proc, to_proc, p_chunk), to_proc) end diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 400b49332..9f0c3b487 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -174,6 +174,9 @@ function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) end tochunk(x::Thunk, proc=nothing, scope=nothing; kwargs...) = x +root_worker_id(chunk::Chunk) = root_worker_id(chunk.handle) +root_worker_id(dref::DRef) = dref.owner # FIXME: Migration + function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io serialize(io, MemPool.MMWrap(data)) From 87cdbe9f54106062f44f8edbf87f636d2f16df40 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:27:03 -0700 Subject: [PATCH 11/24] datadeps: Validate ManyMemorySpan inner span lengths --- src/datadeps/remainders.jl | 7 +++++- src/utils/interval_tree.jl | 18 ++++++++++++---- src/utils/memory-span.jl | 44 +++++++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 11e28d6dc..c7826fb90 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -163,6 +163,7 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Create our remainder as an interval tree over all target ainfos + VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) # Create our tracker @@ -200,6 +201,9 @@ function compute_remainder_for_arg!(state::DataDepsState, end nspans = length(first(other_ainfos)) other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + foreach(other_many_spans) do span + verify_span(span) + end if other_space == target_space # Only subtract, this data is already up-to-date in target_space @@ -223,6 +227,7 @@ function compute_remainder_for_arg!(state::DataDepsState, get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end + VERIFY_SPAN_CURRENT_OBJECT[] = nothing if isempty(tracker) return NoAliasing(), 0 @@ -253,10 +258,10 @@ copy from `other_many_spans` to the subtraced portion of `remainder`. function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N diff = Vector{ManyMemorySpan{N}}() subtract_spans!(remainder, other_many_spans, diff) - for span in diff source_span = span.spans[source_space_idx] dest_span = span.spans[dest_space_idx] + @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end end diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl index e67f66b24..5b79d456d 100644 --- a/src/utils/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -25,6 +25,7 @@ function IntervalTree{M}(spans) where M for span in spans insert!(tree, span) end + verify_spans(tree) return tree end IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) @@ -44,6 +45,13 @@ function Base.collect(tree::IntervalTree{M}) where M return result end +# Useful for debugging when spans get misaligned +function verify_spans(tree::IntervalTree{ManyMemorySpan{N}}) where N + for span in tree + verify_span(span) + end +end + function Base.iterate(tree::IntervalTree{M}) where M state = Vector{M}() if tree.root === nothing @@ -196,6 +204,7 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} original_end = span_end(original_span) del_start = span_start(span) del_end = span_end(span) + verify_span(span) # Left portion: exists if original starts before deleted span if original_start < del_start @@ -258,10 +267,11 @@ function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; if spans_overlap(node.span, query) if exact # Get the overlapping portion of the span - overlap_start = max(span_start(node.span), span_start(query)) - overlap_end = min(span_end(node.span), span_end(query)) - overlap = M(overlap_start, overlap_end - overlap_start) - push!(result, overlap) + overlap = span_diff(node.span, query) + verify_span(overlap) + if !isempty(overlap) + push!(result, overlap) + end else push!(result, node.span) end diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl index 91f291cbe..c00d16c36 100644 --- a/src/utils/memory-span.jl +++ b/src/utils/memory-span.jl @@ -38,6 +38,18 @@ span_len(span::MemorySpan) = span.len span_end(span::MemorySpan) = span.ptr.addr + span.len spans_overlap(span1::MemorySpan, span2::MemorySpan) = span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::MemorySpan, span2::MemorySpan) + @assert span1.ptr.space == span2.ptr.space + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + start_ptr = RemotePtr(start, span1.ptr.space) + if start < stop + len = stop - start + return MemorySpan(start_ptr, len) + else + return MemorySpan(start_ptr, 0) + end +end ### More space-efficient memory spans @@ -52,6 +64,16 @@ span_len(span::LocalMemorySpan) = span.len span_end(span::LocalMemorySpan) = span.ptr + span.len spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) +function span_diff(span1::LocalMemorySpan, span2::LocalMemorySpan) + start = max(span_start(span1), span_start(span2)) + stop = min(span_end(span1), span_end(span2)) + if start < stop + len = stop - start + return LocalMemorySpan(start, len) + else + return LocalMemorySpan(start, 0) + end +end # FIXME: Store the length separately, since it's shared by all spans struct ManyMemorySpan{N} @@ -64,6 +86,17 @@ span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.s spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = # N.B. The spans are assumed to be the same length and relative offset spans_overlap(span1.spans[1], span2.spans[1]) +function span_diff(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N + verify_span(span1) + verify_span(span2) + span = ManyMemorySpan(ntuple(i -> span_diff(span1.spans[i], span2.spans[i]), N)) + verify_span(span) + return span +end +const VERIFY_SPAN_CURRENT_OBJECT = TaskLocalValue{Any}(()->nothing) +function verify_span(span::ManyMemorySpan{N}) where N + @assert allequal(span_len, span.spans) "All spans must be the same: $(map(span_len, span.spans))\nWhile processing $(typeof(VERIFY_SPAN_CURRENT_OBJECT[]))" +end struct ManyPair{N} <: Unsigned pairs::NTuple{N,UInt} @@ -78,6 +111,7 @@ Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] Base.string(x::ManyPair) = "ManyPair($(x.pairs))" +Base.show(io::IO, x::ManyPair) = print(io, string(x)) ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) @@ -95,4 +129,12 @@ span_start(x::LocatorMemorySpan) = span_start(x.span) span_end(x::LocatorMemorySpan) = span_end(x.span) span_len(x::LocatorMemorySpan) = span_len(x.span) spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = - spans_overlap(span1.span, span2.span) \ No newline at end of file + spans_overlap(span1.span, span2.span) +function span_diff(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T + span = LocatorMemorySpan(span_diff(span1.span, span2.span), 0) + verify_span(span) + return span +end +function verify_span(span::LocatorMemorySpan{T}) where T + verify_span(span.span) +end \ No newline at end of file From d4d93305835c7d22e405d8aea76eca0bb54ab151 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:28:00 -0700 Subject: [PATCH 12/24] datadeps: Optimize RemainderAliasing move! copies --- src/datadeps/remainders.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index c7826fb90..bc8797c6e 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -417,28 +417,33 @@ end # Main copy function for RemainderAliasing function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S - # Get the source data for each span + # Copy the data from the source object copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod - copies = Vector{UInt8}[] - for (from_span, _) in dep_mod.spans - copy = Vector{UInt8}(undef, from_span.len) - GC.@preserve copy begin + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + GC.@preserve copies begin + for (from_span, _) in dep_mod.spans from_ptr = Ptr{UInt8}(from_span.ptr) - to_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len end - push!(copies, copy) end + @assert offset == len+1 return copies end # Copy the data into the destination object - for (copy, (_, to_span)) in zip(copies, dep_mod.spans) - GC.@preserve copy begin - from_ptr = Ptr{UInt8}(pointer(copy)) + offset = 1 + GC.@preserve copies begin + for (_, to_span) in dep_mod.spans + from_ptr = Ptr{UInt8}(pointer(copies, offset)) to_ptr = Ptr{UInt8}(to_span.ptr) unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len end + @assert offset == length(copies)+1 end # Ensure that the data is visible From 9996206cefc67e6b700c5d03ae527bbe0e1352c2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 9 Dec 2025 16:29:34 -0700 Subject: [PATCH 13/24] datadeps: Overhaul Datadeps tests --- src/datadeps/queue.jl | 7 + src/datadeps/remainders.jl | 16 ++ test/datadeps.jl | 405 +++++++++++++++++++++++++++++-------- 3 files changed, 346 insertions(+), 82 deletions(-) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index f8f907741..8b92f3087 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -206,6 +206,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) + @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end end @@ -520,7 +524,10 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope new_spec.options.occupancy = Dict(Any=>0) + ctx = Sch.eager_context() + @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) # Update read/write tracking for arguments map_or_ntuple(task_arg_ws) do idx diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index bc8797c6e..1bd563d08 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -308,7 +308,11 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -353,7 +357,11 @@ function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySp @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -382,7 +390,11 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) @@ -409,7 +421,11 @@ function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task + ctx = Sch.eager_context() + id = rand(UInt) + @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) diff --git a/test/datadeps.jl b/test/datadeps.jl index 4fb873454..05a3091af 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -1,4 +1,5 @@ -import Dagger: ChunkView, Chunk +import Dagger: ChunkView, Chunk, AbstractAliasing, MemorySpace, ArgumentWrapper +import Dagger: aliasing, memory_space using LinearAlgebra, Graphs @testset "Memory Aliasing" begin @@ -82,7 +83,7 @@ end end function with_logs(f) - Dagger.enable_logging!(;taskdeps=true, taskargs=true) + Dagger.enable_logging!(;taskdeps=true, taskargs=true, timeline=true) try f() return Dagger.fetch_logs!() @@ -108,68 +109,296 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) end error("Task $tid not found in logs") end -function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false) - g = SimpleDiGraph() - tid_to_v = Dict{Int,Int}() +function all_tasks_in_logs(logs::Dict) + all_tids = Int[] + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + if core_log.category == :add_thunk && core_log.kind == :finish + tid = id_log.thunk_id::Int + push!(all_tids, tid) + end + end + end + return all_tids +end +mutable struct FlowEntry + kind::Symbol + tid::Int + ainfo::AbstractAliasing + to_ainfo::AbstractAliasing + from_space::MemorySpace + to_space::MemorySpace + read::Bool + write::Bool +end +struct FlowCheck + read::Bool + write::Bool + arg_w::ArgumentWrapper + orig_ainfo::AbstractAliasing + orig_space::MemorySpace + function FlowCheck(kind, arg, dep_mod=identity) + if kind == :read + read = true + write = false + elseif kind == :write + read = false + write = true + elseif kind == :readwrite + read = true + write = true + else + error("Invalid kind: $kind") + end + arg_w = maybe_rewrap_arg_w(ArgumentWrapper(arg, dep_mod)) + return new(read, write, arg_w, aliasing(arg, dep_mod), memory_space(arg)) + end +end +struct FlowGraph + g::SimpleDiGraph + tid_to_v::Dict{Int,Int} + FlowGraph() = new(SimpleDiGraph(), Dict{Int,Int}()) +end +struct FlowState + flows::Dict{ArgumentWrapper,Vector{FlowEntry}} + graph::FlowGraph + FlowState() = new(Dict{ArgumentWrapper,Vector{FlowEntry}}(), FlowGraph()) +end +function maybe_rewrap_arg_w(arg_w::ArgumentWrapper) + arg = arg_w.arg + if arg isa DTask + arg = fetch(arg; raw=true) + end + if arg isa Chunk && Dagger.root_worker_id(arg) == myid() + arg = Dagger.unwrap(arg) + end + return ArgumentWrapper(arg, arg_w.dep_mod) +end +function build_dataflow(logs::Dict; verbose::Bool=false) + state = FlowState() + orig_ainfos = Dict{AbstractAliasing,AbstractAliasing}() + ainfo_arg_w = Dict{AbstractAliasing,ArgumentWrapper}() + + function add_execute!(arg_w, orig_ainfo, ainfo, tid, space, read, write) + ainfo_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + # Skip duplicates (same arg 2+ times to same task) + dup_idx = findfirst(flow->flow.tid == tid, ainfo_flows) + if dup_idx === nothing + if !haskey(orig_ainfos, ainfo) + orig_ainfos[ainfo] = orig_ainfo + end + if !haskey(ainfo_arg_w, ainfo) + ainfo_arg_w[ainfo] = arg_w + end + verbose && println("Adding execute flow (tid $tid, space $space, read $read, write $write):\n $orig_ainfo ->\n $ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(ainfo_flows, FlowEntry(:execute, tid, ainfo, ainfo, space, space, read, write)) + else + # Union read and write fields + ainfo_flows[dup_idx].read |= read + ainfo_flows[dup_idx].write |= write + end + end + function add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + dep_mod = arg_w.dep_mod + from_ainfo = aliasing(from_arg, dep_mod) + to_ainfo = aliasing(to_arg, dep_mod) + if !haskey(orig_ainfos, from_ainfo) + orig_ainfos[from_ainfo] = from_ainfo + end + if !haskey(ainfo_arg_w, from_ainfo) + ainfo_arg_w[from_ainfo] = arg_w + end + if !haskey(ainfo_arg_w, to_ainfo) + ainfo_arg_w[to_ainfo] = arg_w + end + orig_ainfo = orig_ainfos[from_ainfo] + orig_ainfos[to_ainfo] = orig_ainfo + arg_flows = get!(Vector{FlowEntry}, state.flows, arg_w) + verbose && println("Adding copy flow (tid $tid, from_space $from_space, to_space $to_space):\n $orig_ainfo ->\n $to_ainfo") + verbose && println(" $(arg_w.dep_mod), $(arg_w.arg)") + push!(arg_flows, FlowEntry(:copy, tid, from_ainfo, to_ainfo, from_space, to_space, true, true)) + end + + # Populate graph from syncdeps seen = Set{Int}() - to_visit = copy(all_tids) + to_visit = all_tasks_in_logs(logs) while !isempty(to_visit) this_tid = popfirst!(to_visit) this_tid in seen && continue push!(seen, this_tid) - if !(this_tid in keys(tid_to_v)) - add_vertex!(g); tid_to_v[this_tid] = nv(g) + if !(this_tid in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[this_tid] = nv(state.graph.g) end # Add syncdeps deps = taskdeps_for_task(logs, this_tid) for dep in deps - if !(dep in keys(tid_to_v)) - add_vertex!(g); tid_to_v[dep] = nv(g) + if !(dep in keys(state.graph.tid_to_v)) + add_vertex!(state.graph.g); state.graph.tid_to_v[dep] = nv(state.graph.g) end - add_edge!(g, tid_to_v[this_tid], tid_to_v[dep]) + add_edge!(state.graph.g, state.graph.tid_to_v[this_tid], state.graph.tid_to_v[dep]) push!(to_visit, dep) end end - state = dijkstra_shortest_paths(g, tid_to_v[tid]) - any_failed = false - @test !has_edge(g, tid_to_v[tid], tid_to_v[tid]) - any_failed |= has_edge(g, tid_to_v[tid], tid_to_v[tid]) - for dom in doms - @test state.pathcounts[tid_to_v[dom]] > 0 - if state.pathcounts[tid_to_v[dom]] == 0 - println("Expected dominance for $dom of $tid") - any_failed = true - end - end - if nondom_check - for nondom in all_tids - nondom == tid && continue - nondom in doms && continue - @test state.pathcounts[tid_to_v[nondom]] == 0 - if state.pathcounts[tid_to_v[nondom]] > 0 - println("Expected non-dominance for $nondom of $tid") - any_failed = true + + # Populate flows and graphs from datadeps logs + for w in keys(logs) + _logs = logs[w] + for idx in 1:length(_logs[:core]) + core_log = _logs[:core][idx] + id_log = _logs[:id][idx] + tl_log = _logs[:timeline][idx] + if core_log.category == :datadeps_execute && core_log.kind == :finish + tid = id_log.thunk_id + for (remote_arg, depset) in zip(tl_log.args, tl_log.deps) + for dep in depset.deps + arg_w = maybe_rewrap_arg_w(dep.arg_w) + orig_ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + remote_ainfo = aliasing(remote_arg, arg_w.dep_mod) + space = memory_space(remote_arg) + add_execute!(arg_w, orig_ainfo, remote_ainfo, tid, space, dep.readdep, dep.writedep) + end + end + elseif (core_log.category == :datadeps_copy || core_log.category == :datadeps_copy_skip) && core_log.kind == :finish + tid = tl_log.thunk_id + from_space = tl_log.from_space + to_space = tl_log.to_space + from_arg = tl_log.from_arg + to_arg = tl_log.to_arg + arg_w = maybe_rewrap_arg_w(tl_log.arg_w) + add_copy!(arg_w, from_arg, to_arg, tid, from_space, to_space) + end + end + end + + return state +end +function test_dataflow(state::FlowState, checks...; verbose::Bool=true) + # Check that each ainfo starts and ends in the same space + for arg_w in keys(state.flows) + ainfo = aliasing(arg_w.arg, arg_w.dep_mod) + arg_flows = state.flows[arg_w] + orig_space = memory_space(arg_w.arg) #arg_flows[1].from_space + #=if ainfo != arg_flows[1].ainfo + verbose && println("Ainfo key $(ainfo) is not the same as the first flow's ainfo $(ainfo_flows[1].ainfo)") + return false + end=# + final_space = arg_flows[end].to_space + # FIXME: will_alias doesn't check across spaces + any_writes = any(flows->Dagger.will_alias(flows[1], ainfo) && any(flow->flow.write, flows[2]), state.flows) + if orig_space != final_space + if verbose + println("Arg ($(arg_w.dep_mod), $(arg_w.arg)) starts in $(orig_space) but ends in $(final_space)") + for flow in arg_flows + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space)") + end + end + return false + end + end + + # Check each flow against the previous flow, ensuring that the previous flow is a dominator of the current flow + # FIXME: Validate non-dominance when unnecessary? + for arg_w in keys(state.flows) + arg_flows = state.flows[arg_w] + for (idx, flow) in enumerate(arg_flows) + if idx > 1 + prev_flow = arg_flows[idx-1] + if !prev_flow.write && !flow.write + # R->R don't depend on each other + continue + end + if !prev_flow.write && flow.write && prev_flow.kind == :execute && flow.kind == :copy && prev_flow.ainfo != flow.to_ainfo + # Copy only writes to a different ainfo, so don't depend on each other + continue + end + if flow.tid == 0 + # Ignore copy skip flows + continue + end + v = state.graph.tid_to_v[flow.tid] + prev_v = state.graph.tid_to_v[prev_flow.tid] + path_state = dijkstra_shortest_paths(state.graph.g, v; allpaths=true) + if path_state.pathcounts[prev_v] == 0 + if verbose + println("Flow $(idx-1) (tid $(prev_flow.tid), $(prev_flow.kind), R:$(prev_flow.read), W:$(prev_flow.write)) is not a dominator of flow $(idx) (tid $(flow.tid), $(flow.kind), R:$(flow.read), W:$(flow.write))") + @show length(state.flows[arg_w]) + for flow in state.flows[arg_w] + println(" $(flow.kind) $(flow.tid) $(flow.from_space) -> $(flow.to_space) (R:$(flow.read), W:$(flow.write))") + end + for flow in state.flows[arg_w] + println(" May write to: $(flow.to_ainfo)") + end + e_vs = collect(edges(state.graph.g)) + e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), state.graph.tid_to_v))[1], + only(filter(tv->tv[2]==dst(e), state.graph.tid_to_v))[1]), + e_vs) + sort!(e_tids) + for e in e_tids + s_tid, d_tid = src(e), dst(e) + println("Edge: $s_tid -(up)> $d_tid") + end + end + return false + end end end end - # For debugging purposes - if any_failed - println("Failure detected!") - println("Root: $tid") - println("Exp. doms: $doms") - println("All: $all_tids") - e_vs = collect(edges(g)) - e_tids = map(e->Edge(only(filter(tv->tv[2]==src(e), tid_to_v))[1], - only(filter(tv->tv[2]==dst(e), tid_to_v))[1]), - e_vs) - sort!(e_tids) - for e in e_tids - s_tid, d_tid = src(e), dst(e) - println("Edge: $s_tid -(up)> $d_tid") + # Walk through each check, ensuring that the current state of the flow matches the check + arg_locations = Dict{ArgumentWrapper,MemorySpace}() + flow_idxs = Dict{ArgumentWrapper,Int}(arg_w=>1 for arg_w in keys(state.flows)) + for (idx, check) in enumerate(checks) + # Record the original location of the ainfo + if !haskey(arg_locations, check.arg_w) + arg_locations[check.arg_w] = check.orig_space + end + + # Try to advance a flow + if !haskey(flow_idxs, check.arg_w) + if verbose + @warn "Didn't encounter argument ($(check.arg_w.dep_mod), $(check.arg_w.arg))" + println("Seen arguments:") + for arg_w in keys(state.flows) + println(" ($(arg_w.dep_mod), $(arg_w.arg))") + end + return false + end + end + flow_idx = flow_idxs[check.arg_w] + while true + if flow_idx > length(state.flows[check.arg_w]) + verbose && println("Exhausted all tasks while trying to find $(check.arg_w)") + return false + end + flow = state.flows[check.arg_w][flow_idx] + if flow.kind == :execute + # The current flow state must match the check + if flow.read == check.read && flow.write == check.write + # Match, move on to next check + flow_idx += 1 + break + else + verbose && println("Expected ($(check.read), $(check.write)), got ($(flow.read), $(flow.write))") + return false + end + elseif flow.kind == :copy + # We need to advance our ainfo location + # FIXME: Assert proper data progression (requires more complex tracking of other arguments) + #@assert flow.from_space == arg_locations[check.arg_w] + arg_locations[check.arg_w] = flow.to_space + flow_idx += 1 + end end + + flow_idxs[check.arg_w] = flow_idx end + + return true end @everywhere do_nothing(Xs...) = nothing @@ -205,8 +434,11 @@ function test_datadeps(;args_chunks::Bool, A = Dagger.@spawn scope=Dagger.scope(worker=args_loc) copy(A) end + @warn "Negative-test the test_dataflow helper" + # Task return values can be tracked ts = [] + local t1 logs = with_logs() do Dagger.spawn_datadeps() do t1 = Dagger.@spawn fill(42, 1) @@ -216,9 +448,12 @@ function test_datadeps(;args_chunks::Bool, end tid_1, tid_2 = task_id.(ts) @test fetch(A)[1] == 42.0 - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + # FIXME: We don't record the task as a syncdep, but instead internally `fetch` the chunk - test_task_dominators(logs, tid_2, [#=tid_1=#]; all_tids=[tid_1, tid_2]) + # We don't see the :readwrite because we don't see the use of t1 + #@test test_dataflow(state, FlowCheck(:readwrite, t1)) + @test test_dataflow(state, FlowCheck(:read, t1), FlowCheck(:write, A)) # R->R Non-Aliasing ts = [] @@ -229,8 +464,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false) + state = build_dataflow(logs) + test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Aliasing ts = [] @@ -241,8 +476,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) # W->W Aliasing ts = [] @@ -253,8 +488,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) # R->R Non-Self-Aliasing ts = [] @@ -265,8 +500,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:read, A)) # R->W Self-Aliasing ts = [] @@ -277,8 +512,11 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) + if !test_dataflow(state, FlowCheck(:read, A), FlowCheck(:write, A)) + exit(1) + end # W->W Self-Aliasing ts = [] @@ -289,8 +527,8 @@ function test_datadeps(;args_chunks::Bool, end end tid_1, tid_2 = task_id.(ts) - test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2]) - test_task_dominators(logs, tid_2, [tid_1]; all_tids=[tid_1, tid_2]) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:write, A), FlowCheck(:write, A)) function wrap_chunk_thunk(f, args...) if args_thunks || args_chunks @@ -346,17 +584,16 @@ function test_datadeps(;args_chunks::Bool, task_id.([t_ul2, t_ur2, t_ll2, t_lr2]) tids_all = [tid_A, tid_B, tid_ul, tid_ur, tid_ll, tid_lr, tid_mid, tid_ul2, tid_ur2, tid_ll2, tid_lr2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - test_task_dominators(logs, tid_ul, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ur, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_ll, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_lr, [tid_B]; all_tids=tids_all) - test_task_dominators(logs, tid_mid, [tid_B, tid_ul, tid_ur, tid_ll, tid_lr]; all_tids=tids_all) - test_task_dominators(logs, tid_ul2, [tid_B, tid_mid, tid_ul]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ur2, [tid_B, tid_mid, tid_ur]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_ll2, [tid_B, tid_mid, tid_ll]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lr2, [tid_B, tid_mid, tid_lr]; all_tids=tids_all, nondom_check=false) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ul)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ur)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_ll)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lr)) + for arg in [B_ul, B_ur, B_ll, B_lr] + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, arg), FlowCheck(:readwrite, B_mid), FlowCheck(:readwrite, arg)) + end # (Unit)Upper/LowerTriangular and Diagonal B_upper = wrap_chunk_thunk(UpperTriangular, B) @@ -403,18 +640,22 @@ function test_datadeps(;args_chunks::Bool, task_id.([t_upper2, t_unitupper2, t_lower2, t_unitlower2]) tids_all = [tid_A, tid_B, tid_upper, tid_unitupper, tid_lower, tid_unitlower, tid_diag, tid_upper2, tid_unitupper2, tid_lower2, tid_unitlower2] - test_task_dominators(logs, tid_A, []; all_tids=tids_all) - test_task_dominators(logs, tid_B, []; all_tids=tids_all) - # FIXME: Proper non-dominance checks - test_task_dominators(logs, tid_upper, [tid_B]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower, [tid_B, tid_upper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower, [tid_B, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_diag, [tid_B, tid_upper, tid_lower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitlower2, [tid_B, tid_lower, tid_unitlower]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_lower2, [tid_B, tid_lower, tid_unitlower, tid_diag, tid_unitlower2]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_unitupper2, [tid_B, tid_upper, tid_unitupper]; all_tids=tids_all, nondom_check=false) - test_task_dominators(logs, tid_upper2, [tid_B, tid_upper, tid_unitupper, tid_diag, tid_unitupper2]; all_tids=tids_all, nondom_check=false) + state = build_dataflow(logs) + @test test_dataflow(state, FlowCheck(:readwrite, A)) + @test test_dataflow(state, FlowCheck(:readwrite, B)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_lower), + FlowCheck(:readwrite, B, Diagonal)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_lower), FlowCheck(:readwrite, B_unitlower), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitlower), FlowCheck(:readwrite, B_lower)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B_unitupper)) + @test test_dataflow(state, FlowCheck(:readwrite, B), FlowCheck(:readwrite, B_upper), FlowCheck(:readwrite, B_unitupper), + FlowCheck(:readwrite, B, Diagonal), FlowCheck(:readwrite, B_unitupper), FlowCheck(:readwrite, B_upper)) # Additional aliasing tests views_overlap(x, y) = Dagger.will_alias(Dagger.aliasing(x), Dagger.aliasing(y)) @@ -593,7 +834,7 @@ function test_datadeps(;args_chunks::Bool, @test isapprox(M_dense, expected) end -@testset @testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) +@testset "$args_mode Data" for args_mode in (:Raw, :Chunk, :Thunk) args_chunks = args_mode == :Chunk args_thunks = args_mode == :Thunk for nw in (1, 2) From f745dbe6440879e1f94963245d2f2a365f3cdafb Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 10 Dec 2025 14:19:05 -0700 Subject: [PATCH 14/24] datadeps: Validate further that RemainderAliasing is not empty --- src/datadeps/remainders.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 1bd563d08..6be0160ee 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -221,15 +221,15 @@ function compute_remainder_for_arg!(state::DataDepsState, (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) end @opcounter :compute_remainder_for_arg_schedule - schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) - if compute_syncdeps + has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps && has_overlap @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end VERIFY_SPAN_CURRENT_OBJECT[] = nothing - if isempty(tracker) + if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) return NoAliasing(), 0 end @@ -243,6 +243,7 @@ function compute_remainder_for_arg!(state::DataDepsState, end end end + @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" return mra, last_idx end @@ -264,6 +265,7 @@ function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_ @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end + return !isempty(diff) end ### Remainder copy functions From f2381a8bb5371db734d3a8057376cb6042c3dc83 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 12 Dec 2025 12:01:55 -0700 Subject: [PATCH 15/24] datadeps: Fix aliasing for degenerate views --- src/memory-spaces.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 4124bbba6..fcce572c4 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -451,10 +451,16 @@ end function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) S = CPURAMMemorySpace - return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + p = parent(x) + NA = ndims(p) + raw_inds = parentindices(x) + inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) + sz = ntuple(i->length(inds[i]), NA) + return StridedAliasing{T,NA,S}(RemotePtr{Cvoid}(pointer(p)), RemotePtr{Cvoid}(pointer(x)), - parentindices(x), - size(x), strides(x)) + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) From 321f08ac1728b777595ddefcff6c5dd213aefd2e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 23 Sep 2025 22:01:54 +0000 Subject: [PATCH 16/24] datadeps: Fix GPU execution --- ext/CUDAExt.jl | 9 +++++++ ext/IntelExt.jl | 9 +++++++ ext/MetalExt.jl | 9 +++++++ ext/OpenCLExt.jl | 9 +++++++ ext/ROCExt.jl | 9 +++++++ src/datadeps/remainders.jl | 54 ++++++++++++++++++++++++++++++-------- src/gpu.jl | 5 +++- src/memory-spaces.jl | 23 +++++++++------- 8 files changed, 106 insertions(+), 21 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 6b8c61f9a..9f9b8df4d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -48,6 +48,13 @@ function Dagger.memory_space(x::CuArray) device_uuid = CUDA.uuid(dev) return CUDAVRAMMemorySpace(myid(), device_id, device_uuid) end +function Dagger.aliasing(x::CuArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + cuptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(cuptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)]) Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)]) @@ -75,6 +82,8 @@ function with_context!(space::CUDAVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CuArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CUDAVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_stream = stream() diff --git a/ext/IntelExt.jl b/ext/IntelExt.jl index 74253007d..08d54ee81 100644 --- a/ext/IntelExt.jl +++ b/ext/IntelExt.jl @@ -46,6 +46,13 @@ function Dagger.memory_space(x::oneArray) return IntelVRAMMemorySpace(myid(), device_id) end _device_id(dev::ZeDevice) = findfirst(other_dev->other_dev === dev, collect(oneAPI.devices())) +function Dagger.aliasing(x::oneArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::oneArrayDeviceProc) = Set([IntelVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::IntelVRAMMemorySpace) = Set([oneArrayDeviceProc(space.owner, space.device_id)]) @@ -68,6 +75,8 @@ function with_context!(space::IntelVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::oneArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::IntelVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_drv = driver() old_dev = device() diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 50cfc8905..21cea360a 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -43,6 +43,13 @@ function Dagger.memory_space(x::MtlArray) return MetalVRAMMemorySpace(myid(), device_id) end _device_id(dev::MtlDevice) = findfirst(other_dev->other_dev === dev, Metal.devices()) +function Dagger.aliasing(x::MtlArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::MtlArrayDeviceProc) = Set([MetalVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::MetalVRAMMemorySpace) = Set([MtlArrayDeviceProc(space.owner, space.device_id)]) @@ -66,6 +73,8 @@ end function with_context!(space::MetalVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() end +Dagger.with_context!(proc::MtlArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::MetalVRAMMemorySpace) = with_context!(space) function with_context(f, x) with_context!(x) return f() diff --git a/ext/OpenCLExt.jl b/ext/OpenCLExt.jl index fbf73de72..f8eac930c 100644 --- a/ext/OpenCLExt.jl +++ b/ext/OpenCLExt.jl @@ -44,6 +44,13 @@ function Dagger.memory_space(x::CLArray) idx = findfirst(==(queue), QUEUES) return CLMemorySpace(myid(), idx) end +function Dagger.aliasing(x::CLArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)]) Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)]) @@ -71,6 +78,8 @@ function with_context!(space::CLMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device) end +Dagger.with_context!(proc::CLArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::CLMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = cl.context() old_queue = cl.queue() diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 288c4744f..773c2bb95 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -39,6 +39,13 @@ end Dagger.root_worker_id(space::ROCVRAMMemorySpace) = space.owner Dagger.memory_space(x::ROCArray) = ROCVRAMMemorySpace(myid(), AMDGPU.device(x).device_id) +function Dagger.aliasing(x::ROCArray{T}) where T + space = Dagger.memory_space(x) + S = typeof(space) + gpu_ptr = pointer(x) + rptr = Dagger.RemotePtr{Cvoid}(UInt64(gpu_ptr), space) + return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x))) +end Dagger.memory_spaces(proc::ROCArrayDeviceProc) = Set([ROCVRAMMemorySpace(proc.owner, proc.device_id)]) Dagger.processors(space::ROCVRAMMemorySpace) = Set([ROCArrayDeviceProc(space.owner, space.device_id)]) @@ -67,6 +74,8 @@ function with_context!(space::ROCVRAMMemorySpace) @assert Dagger.root_worker_id(space) == myid() with_context!(space.device_id) end +Dagger.with_context!(proc::ROCArrayDeviceProc) = with_context!(proc) +Dagger.with_context!(space::ROCVRAMMemorySpace) = with_context!(space) function with_context(f, x) old_ctx = context() old_device = AMDGPU.device() diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 6be0160ee..a4b819b2a 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -143,6 +143,7 @@ function compute_remainder_for_arg!(state::DataDepsState, push!(target_ainfos, LocalMemorySpan.(spans)) end nspans = length(first(target_ainfos)) + @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) for entry in state.arg_history[arg_w] @@ -435,33 +436,42 @@ end # Main copy function for RemainderAliasing function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # TODO: Support direct copy between GPU memory spaces + + @assert sizeof(eltype(chunktype(from))) == sizeof(eltype(chunktype(to))) "Source and destination chunks have different element sizes: $(sizeof(eltype(chunktype(from)))) != $(sizeof(eltype(chunktype(to))))" + # Copy the data from the source object - copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) copies = Vector{UInt8}(undef, len) - offset = 1 + from_raw = unwrap(from) + offset = UInt64(1) + with_context!(from_space) GC.@preserve copies begin for (from_span, _) in dep_mod.spans - from_ptr = Ptr{UInt8}(from_span.ptr) - to_ptr = Ptr{UInt8}(pointer(copies, offset)) - unsafe_copyto!(to_ptr, from_ptr, from_span.len) + elsize = sizeof(eltype(from_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(from_span.len / elsize) + read_remainder!(copies, offset_n, from_raw, from_span.ptr, n) offset += from_span.len end end - @assert offset == len+1 + @assert offset == len+UInt64(1) return copies end # Copy the data into the destination object - offset = 1 + offset = UInt64(1) + to_raw = unwrap(to) GC.@preserve copies begin for (_, to_span) in dep_mod.spans - from_ptr = Ptr{UInt8}(pointer(copies, offset)) - to_ptr = Ptr{UInt8}(to_span.ptr) - unsafe_copyto!(to_ptr, from_ptr, to_span.len) + elsize = sizeof(eltype(to_raw)) + offset_n = UInt64((offset-1) / elsize) + UInt64(1) + n = UInt64(to_span.len / elsize) + write_remainder!(copies, offset_n, to_raw, to_span.ptr, n) offset += to_span.len end - @assert offset == length(copies)+1 + @assert offset == length(copies)+UInt64(1) end # Ensure that the data is visible @@ -469,3 +479,25 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: return end + +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(from)) + from_offset = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) + from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} + copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) + copyto!(copies_typed, 1, from_vec, Int(from_offset), Int(n)) +end +function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::SubArray, from_ptr::UInt64, n::UInt64) + read_remainder!(copies, copies_offset, parent(from), from_ptr, n) +end + +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, n::UInt64) + elsize = sizeof(eltype(to)) + to_offset = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) + to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} + copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) + copyto!(to_vec, Int(to_offset), copies_typed, 1, Int(n)) +end +function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::SubArray, to_ptr::UInt64, n::UInt64) + write_remainder!(copies, copies_offset, parent(to), to_ptr, n) +end diff --git a/src/gpu.jl b/src/gpu.jl index 06d749543..fa93f8076 100644 --- a/src/gpu.jl +++ b/src/gpu.jl @@ -100,4 +100,7 @@ function gpu_synchronize(kind::Symbol) gpu_synchronize(Val(kind)) end end -gpu_synchronize(::Val{:CPU}) = nothing \ No newline at end of file +gpu_synchronize(::Val{:CPU}) = nothing + +with_context!(proc::Processor) = nothing +with_context!(space::MemorySpace) = nothing \ No newline at end of file diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index fcce572c4..bd980a81d 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,6 +1,7 @@ struct CPURAMMemorySpace <: MemorySpace owner::Int end +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner memory_space(x) = CPURAMMemorySpace(myid()) @@ -87,7 +88,8 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = true +may_alias(::MemorySpace, ::MemorySpace) = false +may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner abstract type AbstractAliasing end @@ -448,19 +450,22 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} +function aliasing(x::SubArray{T,N}) where {T,N} if isbitstype(T) - S = CPURAMMemorySpace p = parent(x) + space = memory_space(p) + S = typeof(space) + parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) + ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) NA = ndims(p) raw_inds = parentindices(x) inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) sz = ntuple(i->length(inds[i]), NA) - return StridedAliasing{T,NA,S}(RemotePtr{Cvoid}(pointer(p)), - RemotePtr{Cvoid}(pointer(x)), - inds, - sz, - strides(p)) + return StridedAliasing{T,NA,S}(parent_ptr, + ptr, + inds, + sz, + strides(p)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -577,7 +582,7 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space + @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end From c671d24c7c7b3fac5baceb1fb3e3635dfe3a0942 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 14 Dec 2025 13:08:32 -0500 Subject: [PATCH 17/24] Sch: Skip set_failed! store when result already set --- src/sch/util.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index 3f9d7b2f6..d3b7a4804 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -238,7 +238,7 @@ function set_failed!(state, origin, thunk=origin) @dagdebug thunk :finish "Setting as failed" filter!(x -> x !== thunk, state.ready) # N.B. If origin === thunk, we assume that the caller has already set the error - if origin !== thunk + if origin !== thunk && !has_result(state, thunk) origin_ex = load_result(state, origin) if origin_ex isa RemoteException origin_ex = origin_ex.captured From 0c19fa0e108947d9c0cb7b82593607e5d77a5543 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sun, 14 Dec 2025 13:09:22 -0500 Subject: [PATCH 18/24] scopes: Disallow constructing empty UnionScope --- src/scopes.jl | 3 +++ test/scopes.jl | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/scopes.jl b/src/scopes.jl index ba291bc2b..79190c292 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -40,6 +40,9 @@ struct UnionScope <: AbstractScope push!(scope_set, scope) end end + if isempty(scope_set) + throw(ArgumentError("Cannot construct UnionScope with no inner scopes")) + end return new((collect(scope_set)...,)) end end diff --git a/test/scopes.jl b/test/scopes.jl index fa5bf1135..55d15b349 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -123,8 +123,8 @@ us_es1_multi_ch = Dagger.tochunk(nothing, OSProc(), UnionScope(es1, es1)) @test fetch(Dagger.@spawn exact_scope_test(us_es1_multi_ch)) == es1.processor - # No inner scopes - @test UnionScope() isa UnionScope + # No inner scopes (disallowed) + @test_throws ArgumentError UnionScope() # Same inner scope @test fetch(Dagger.@spawn exact_scope_test(us_es1_ch, us_es1_ch)) == es1.processor @@ -165,7 +165,7 @@ @test Dagger.scope(:any) isa AnyScope @test Dagger.scope(:default) == DefaultScope() @test_throws ArgumentError Dagger.scope(:blah) - @test Dagger.scope(()) == UnionScope() + @test_throws ArgumentError Dagger.scope(()) @test Dagger.scope(worker=wid1) == Dagger.scope(workers=[wid1]) From 6a1bf1603b5d5c2eca96d96cd0dd70337899c72c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 13:52:07 -0500 Subject: [PATCH 19/24] datadeps: Consolidate aliasing rewrap code --- src/datadeps/aliasing.jl | 97 ++++++++++++++++++++++++++++++--------- src/datadeps/chunkview.jl | 66 +------------------------- 2 files changed, 77 insertions(+), 86 deletions(-) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 82989a259..ef83006b9 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -25,19 +25,19 @@ KEY CONCEPTS: 1. ALIASING ANALYSIS: - Every mutable argument is analyzed for its memory access pattern - Memory spans are computed to determine which bytes in memory are accessed - - Objects that access overlapping memory spans are considered "aliasing" + - Arguments that access overlapping memory spans are considered "aliasing" - Examples: An array A and view(A, 2:3, 2:3) alias each other 2. DATA LOCALITY TRACKING: - The system tracks where the "source of truth" for each piece of data lives - As tasks execute and modify data, the source of truth may move between workers - - Each aliasing region can have its own independent source of truth location + - Each argument can have its own independent source of truth location 3. ALIASED OBJECT MANAGEMENT: - When copying arguments between workers, the system tracks "aliased objects" - This ensures that if both an array and its view need to be copied to a worker, only one copy of the underlying array is made, with the view pointing to it - - The aliased_object!() functions manage this sharing + - The aliased_object!() and move_rewrap() functions manage this sharing THE DISTRIBUTED ALIASING PROBLEM: --------------------------------- @@ -63,11 +63,9 @@ MULTITHREADED BEHAVIOR (WORKS): - Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) DISTRIBUTED BEHAVIOR (THE PROBLEM): -- Tasks may be scheduled on different workers - Each argument must be copied to the destination worker -- Without special handling, we would copy A to worker1 and vA to worker2 -- This creates two separate arrays, breaking the aliasing relationship -- Updates to the view on worker2 don't affect the array on worker1 +- Without special handling, we would copy A and vA independently to another worker +- This creates two separate arrays, breaking the aliasing relationship between A and vA THE SOLUTION - PARTIAL DATA MOVEMENT: ------------------------------------- @@ -706,6 +704,32 @@ function Base.setindex!(cache::AliasedObjectCache, value::Chunk, ainfo::Abstract cache_raw[ainfo] = value return end +function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) + if haskey(cache, ainfo) + return cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + return y + end +end +function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end + return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc) + end +end +function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) + return aliased_object!(cache, x) do x + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) + end +end function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) # Unwrap so that we hit the right dispatch wid = root_worker_id(data) @@ -721,27 +745,58 @@ function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::P return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) end end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) +end +# FIXME: Do this programmatically via recursive dispatch +for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) + @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = $(wrapper)(p_new) + return tochunk(v_new, to_proc) + end end end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) - if haskey(cache, ainfo) - return cache[ainfo] +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) + to_w = root_worker_id(to_proc) + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, v[]) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk + p_new = move(from_proc, to_proc, p_chunk) + v_new = Ref(p_new) + return tochunk(v_new, to_proc) + end +end +#= +function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T + if isstructtype(T) + # Check all object fields (recursive) + for field in fieldnames(T) + value = getfield(x, field) + new_value = aliased_object!(cache, value) do value + return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) + end + setfield!(x, field, new_value) + end + return x else - y = f(x) - @assert y isa Chunk "Didn't get a Chunk from functor" - cache[ainfo] = y - return y + @warn "Cannot move-rewrap object of type $T" + return x end end +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x +move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x +=# struct DataDepsSchedulerState task_to_spec::Dict{DTask,DTaskSpec} diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 60ded6151..42f32cca9 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -37,73 +37,9 @@ end memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true -# This definition is here because it's so similar to ChunkView -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) - to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(move(from_proc, to_proc, p), to_proc) - end - end - inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) - end -end -# FIXME: Do this programmatically via recursive dispatch -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, parent(v)) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(move(from_proc, to_proc, p), to_proc) - end - end - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - p_new = move(from_proc, to_proc, p_chunk) - v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) - end - end -end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - to_w = root_worker_id(to_proc) - return aliased_object!(cache, v[]) do p - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p) do from_proc, to_proc, from_space, to_space, p - return tochunk(Ref(move(from_proc, to_proc, p)), to_proc) - end - end -end -#= -function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T - if isstructtype(T) - # Check all object fields (recursive) - for field in fieldnames(T) - value = getfield(x, field) - new_value = aliased_object!(cache, value) do value - return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) - end - setfield!(x, field, new_value) - end - return x - else - @warn "Cannot move-rewrap object of type $T" - return x - end -end -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x -=# function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) to_w = root_worker_id(to_proc) - p_chunk = aliased_object!(cache, slice.chunk) do p_chunk - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - return tochunk(move(from_proc, to_proc, p_chunk), to_proc) - end - end + p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, slice.chunk) return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds p_new = move(from_proc, to_proc, p_chunk) v_new = view(p_new, inds...) From e59fdd7fb10ac4fa086f4c150c59430666b95a34 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 13:52:35 -0500 Subject: [PATCH 20/24] HaloArray: Add aliasing methods --- src/memory-spaces.jl | 7 +++++-- src/utils/haloarray.jl | 20 +++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index bd980a81d..d39e665cc 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -322,8 +322,11 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - all_spans = MemorySpan{CPURAMMemorySpace}[] - for sub_a in ca.sub_ainfos + if length(ca.sub_ainfos) == 0 + return MemorySpan{CPURAMMemorySpace}[] + end + all_spans = memory_spans(ca.sub_ainfos[1]) + for sub_a in ca.sub_ainfos[2:end] append!(all_spans, memory_spans(sub_a)) end return all_spans diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index 1fadbeeb6..c5990099e 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -99,4 +99,22 @@ Adapt.adapt_structure(to, H::Dagger.HaloArray) = HaloArray(Adapt.adapt(to, H.center), Adapt.adapt.(Ref(to), H.edges), Adapt.adapt.(Ref(to), H.corners), - H.halo_width) \ No newline at end of file + H.halo_width) + +function aliasing(A::HaloArray) + return CombinedAliasing([aliasing(A.center), aliasing(A.edges), aliasing(A.corners)]) +end +memory_space(A::HaloArray) = memory_space(A.center) +function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, A::HaloArray) + center_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.center) + edge_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.edges[i]), length(A.edges)) + corner_chunks = ntuple(i->rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, A.corners[i]), length(A.corners)) + halo_width = A.halo_width + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width) do from_proc, to_proc, from_space, to_space, center_chunk, edge_chunks, corner_chunks, halo_width + center_new = move(from_proc, to_proc, center_chunk) + edges_new = ntuple(i->move(from_proc, to_proc, edge_chunks[i]), length(edge_chunks)) + corners_new = ntuple(i->move(from_proc, to_proc, corner_chunks[i]), length(corner_chunks)) + return tochunk(HaloArray(center_new, edges_new, corners_new, halo_width), to_proc) + end +end \ No newline at end of file From 36b25dc4a60672230c0e145791cbe48a074e7663 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 15 Dec 2025 14:26:07 -0500 Subject: [PATCH 21/24] CI: Extend CUDA job time --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e9177b9a9..ad94a7b6d 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -91,7 +91,7 @@ steps: codecov: true - label: Julia 1.11 (CUDA) - timeout_in_minutes: 20 + timeout_in_minutes: 30 <<: *gputest plugins: - JuliaCI/julia#v1: From 563caef66e403d30152b0871ca8e8e166aaf50e6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Dec 2025 11:12:12 -0500 Subject: [PATCH 22/24] datadeps: Make IntervalTree find_overlapping non-recursive --- src/utils/interval_tree.jl | 214 +++++++++++++++++++++++++------------ 1 file changed, 147 insertions(+), 67 deletions(-) diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl index 5b79d456d..dda355756 100644 --- a/src/utils/interval_tree.jl +++ b/src/utils/interval_tree.jl @@ -140,15 +140,36 @@ end function insert_node!(::Nothing, span::M) where M return IntervalNode(span) end -function insert_node!(node::IntervalNode{M,E}, span::M) where {M,E} - if span_start(span) <= span_start(node.span) - node.left = insert_node!(node.left, span) - else - node.right = insert_node!(node.right, span) +function insert_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Use a queue to track the path for updating max_end after insertion + path = Vector{IntervalNode{M,E}}() + current = root + + # Traverse to find the insertion point + while current !== nothing + push!(path, current) + if span_start(span) <= span_start(current.span) + if current.left === nothing + current.left = IntervalNode(span) + break + end + current = current.left + else + if current.right === nothing + current.right = IntervalNode(span) + break + end + current = current.right + end end - update_max_end!(node) - return node + # Update max_end for all ancestors (process in reverse order) + while !isempty(path) + node = pop!(path) + update_max_end!(node) + end + + return root end # Remove a specific span from the tree (split as needed) @@ -162,44 +183,78 @@ end function delete_node!(::Nothing, span::M) where M return nothing end -function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} - # Check for exact match first - if span_start(node.span) == span_start(span) && span_len(node.span) == span_len(span) - # Exact match, remove the node - if node.left === nothing && node.right === nothing - return nothing - elseif node.left === nothing - return node.right - elseif node.right === nothing - return node.left +function delete_node!(root::IntervalNode{M,E}, span::M) where {M,E} + # Track the path to the target node: (node, direction_to_child) + path = Vector{Tuple{IntervalNode{M,E}, Symbol}}() + current = root + target = nothing + target_type = :none # :exact or :overlap + + # Phase 1: Search for target node + while current !== nothing + is_exact = span_start(current.span) == span_start(span) && span_len(current.span) == span_len(span) + is_overlap = !is_exact && spans_overlap(current.span, span) + + if is_exact + target = current + target_type = :exact + break + elseif is_overlap + target = current + target_type = :overlap + break + elseif span_start(span) <= span_start(current.span) + push!(path, (current, :left)) + current = current.left else - # Node has two children - replace with inorder successor - successor = find_min(node.right) - node.span = successor.span - node.right = delete_node!(node.right, successor.span) + push!(path, (current, :right)) + current = current.right end - # Check for overlap - elseif spans_overlap(node.span, span) - # Handle overlapping spans by removing current node and adding remainders - original_span = node.span - - # Remove the current node first (same logic as exact match) - if node.left === nothing && node.right === nothing - # Leaf node - remove it and create a new subtree with remainders - remaining_node = nothing - elseif node.left === nothing - remaining_node = node.right - elseif node.right === nothing - remaining_node = node.left + end + + if target === nothing + return root + end + + # Phase 2: Compute replacement for target node + original_span = target.span + succ_path = Vector{IntervalNode{M,E}}() # Path to successor (for max_end updates) + local replacement::Union{IntervalNode{M,E}, Nothing} + + if target.left === nothing && target.right === nothing + # Leaf node + replacement = nothing + elseif target.left === nothing + # Only right child + replacement = target.right + elseif target.right === nothing + # Only left child + replacement = target.left + else + # Two children - find and remove inorder successor + successor = find_min(target.right) + + if target.right === successor + # Successor is direct right child + target.right = successor.right else - # Node has two children - replace with inorder successor - successor = find_min(node.right) - node.span = successor.span - node.right = delete_node!(node.right, successor.span) - remaining_node = node + # Track path to successor for max_end updates + succ_parent = target.right + push!(succ_path, succ_parent) + while succ_parent.left !== successor + succ_parent = succ_parent.left + push!(succ_path, succ_parent) + end + # Remove successor by replacing with its right child + succ_parent.left = successor.right end - # Calculate and insert the remaining portions + target.span = successor.span + replacement = target + end + + # Phase 3: Handle overlap case - add remaining portions + if target_type == :overlap original_start = span_start(original_span) original_end = span_end(original_span) del_start = span_start(span) @@ -212,7 +267,7 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} if left_end > original_start left_span = M(original_start, left_end - original_start) if !isempty(left_span) - remaining_node = insert_node!(remaining_node, left_span) + replacement = insert_node!(replacement, left_span) end end end @@ -223,22 +278,39 @@ function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} if original_end > right_start right_span = M(right_start, original_end - right_start) if !isempty(right_span) - remaining_node = insert_node!(remaining_node, right_span) + replacement = insert_node!(replacement, right_span) end end end + end - return remaining_node - elseif span_start(span) <= span_start(node.span) - node.left = delete_node!(node.left, span) + # Phase 4: Update parent's child pointer + if isempty(path) + root = replacement else - node.right = delete_node!(node.right, span) + parent, dir = path[end] + if dir == :left + parent.left = replacement + else + parent.right = replacement + end end - if node !== nothing - update_max_end!(node) + # Phase 5: Update max_end in correct order (bottom-up) + # First: successor path (if any) + for i in length(succ_path):-1:1 + update_max_end!(succ_path[i]) end - return node + # Second: target node (if it wasn't removed) + if replacement === target + update_max_end!(target) + end + # Third: main path (ancestors of target) + for i in length(path):-1:1 + update_max_end!(path[i][1]) + end + + return root end function find_min(node::IntervalNode) @@ -263,28 +335,36 @@ function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=t return end function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} - # Check if current node overlaps with query - if spans_overlap(node.span, query) - if exact - # Get the overlapping portion of the span - overlap = span_diff(node.span, query) - verify_span(overlap) - if !isempty(overlap) - push!(result, overlap) + # Use a queue for breadth-first traversal + queue = Vector{IntervalNode{M,E}}() + push!(queue, node) + + while !isempty(queue) + current = popfirst!(queue) + + # Check if current node overlaps with query + if spans_overlap(current.span, query) + if exact + # Get the overlapping portion of the span + overlap = span_diff(current.span, query) + verify_span(overlap) + if !isempty(overlap) + push!(result, overlap) + end + else + push!(result, current.span) end - else - push!(result, node.span) end - end - # Recursively search left subtree if it might contain overlapping intervals - if node.left !== nothing && node.left.max_end > span_start(query) - find_overlapping!(node.left, query, result; exact) - end + # Enqueue left subtree if it might contain overlapping intervals + if current.left !== nothing && current.left.max_end > span_start(query) + push!(queue, current.left) + end - # Recursively search right subtree if query extends beyond current node's start - if node.right !== nothing && span_end(query) > span_start(node.span) - find_overlapping!(node.right, query, result; exact) + # Enqueue right subtree if query extends beyond current node's start + if current.right !== nothing && span_end(query) > span_start(current.span) + push!(queue, current.right) + end end end From 6f9f98eee2c5bf42fc21343b5f22a0c9a5359dc8 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Dec 2025 12:37:51 -0500 Subject: [PATCH 23/24] datadeps: Add TID to dagdebug statements --- src/datadeps/queue.jl | 17 +++++++++-------- src/datadeps/remainders.jl | 8 ++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 8b92f3087..c7b5e2bc1 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -405,9 +405,10 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end f = spec.fargs[1] + tid = task.uid # FIXME: May not be correct to move this under uniformity #f.value = move(default_processor(), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis # N.B. Used later for checking dependencies @@ -434,13 +435,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Is the data written previously or now? if !arg_ws.may_alias - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" return arg end # Is the data writeable? if !arg_ws.inplace_move - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" return arg end @@ -457,7 +458,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end end return arg_remote @@ -501,16 +502,16 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" get_write_deps!(state, our_space, ainfo, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end return end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" # Launch user's task new_fargs = map_or_ntuple(task_arg_ws) do idx @@ -540,7 +541,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" add_writer!(state, arg_w, our_space, ainfo, task, write_num) else add_reader!(state, arg_w, our_space, ainfo, task, write_num) diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index a4b819b2a..67fdd2588 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -293,7 +293,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac # overwritten by more recent partial updates source_space = remainder_aliasing.space - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -308,7 +308,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac empty!(remainder_aliasing.syncdeps) # We can't bring these to move! get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task ctx = Sch.eager_context() @@ -377,7 +377,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: source_space = state.arg_owner[arg_w] target_ainfo = aliasing!(state, dest_space, arg_w) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -390,7 +390,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task ctx = Sch.eager_context() From 357a2d69b10cbe75daf9477e7f2f05c4ff991cbb Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Dec 2025 17:26:06 -0500 Subject: [PATCH 24/24] fixup! scopes: Disallow constructing empty UnionScope --- test/task-affinity.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/task-affinity.jl b/test/task-affinity.jl index f1e26295a..ce898b476 100644 --- a/test/task-affinity.jl +++ b/test/task-affinity.jl @@ -135,7 +135,7 @@ @testset "Chunk function, scope, compute_scope and result_scope" begin @everywhere g(x, y) = x * 2 + y * 3 - n = cld(numscopes, 3) + n = fld(numscopes, 3) shuffle!(availscopes) scope_a = availscopes[1:n]