From e9e8576977f2351381691455f0519aad429dfae5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 28 Nov 2025 09:37:05 -0700 Subject: [PATCH] datadeps: Add Shard support --- src/datadeps.jl | 30 ++++++++++++++++++++++++++---- src/memory-spaces.jl | 2 ++ src/utils/chunks.jl | 12 +++++++----- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/datadeps.jl b/src/datadeps.jl index d20bda647..ae60b91f5 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -263,7 +263,7 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) +function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask, proc::Processor) # Populate task dependencies dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() @@ -278,6 +278,11 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) # Skip non-aliasing arguments type_may_alias(typeof(arg)) || continue + # Unwrap Shards + if arg isa Shard + arg = shard_unwrap(arg, proc) + end + # Add all aliasing dependencies for (dep_mod, readdep, writedep) in deps if state.aliasing @@ -592,9 +597,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) scheduler = queue.scheduler if scheduler == :naive @@ -737,6 +739,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) + # Populate all task dependencies + populate_task_info!(state, spec, task, our_proc) + # Find the scope for this task (and its copies) if task_scope == scope # Optimize for the common case, cache the proc=>scope mapping @@ -776,6 +781,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) continue end + # Unwrap Shards + if arg isa Shard + arg = shard_unwrap(arg, our_proc) + end + # Is the source of truth elsewhere? arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do generate_slot!(state, our_space, arg) @@ -851,6 +861,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue) arg = arg isa DTask ? fetch(arg; raw=true) : arg type_may_alias(typeof(arg)) || continue supports_inplace_move(state, arg) || continue + + # Unwrap Shards + if arg isa Shard + arg = shard_unwrap(arg, our_proc) + end + if queue.aliasing for (dep_mod, _, writedep) in deps ainfo = aliasing(astate, arg, dep_mod) @@ -884,6 +900,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue) arg, deps = unwrap_inout(arg) arg = arg isa DTask ? fetch(arg; raw=true) : arg type_may_alias(typeof(arg)) || continue + + # Unwrap Shards + if arg isa Shard + arg = shard_unwrap(arg, our_proc) + end + if queue.aliasing for (dep_mod, _, writedep) in deps ainfo = aliasing(astate, arg, dep_mod) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a21..2753385f7 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -75,6 +75,7 @@ end type_may_alias(::Type{String}) = false type_may_alias(::Type{Symbol}) = false type_may_alias(::Type{<:Type}) = false +type_may_alias(::Type{Shard}) = true type_may_alias(::Type{C}) where C<:Chunk{T} where T = type_may_alias(T) function type_may_alias(::Type{T}) where T if isbitstype(T) @@ -213,6 +214,7 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() +aliasing(::Shard) = throw(ArgumentError("Cannot resolve aliasing for Shard")) aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 400b49332..33a322f7e 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -110,23 +110,25 @@ macro shard(exs...) end end -function move(from_proc::Processor, to_proc::Processor, shard::Shard) +function shard_unwrap(shard::Shard, proc::Processor) # Match either this proc or some ancestor # N.B. This behavior may bypass the piece's scope restriction - proc = to_proc if haskey(shard.chunks, proc) - return move(from_proc, to_proc, shard.chunks[proc]) + return shard.chunks[proc] end parent = Dagger.get_parent(proc) while parent != proc proc = parent parent = Dagger.get_parent(proc) if haskey(shard.chunks, proc) - return move(from_proc, to_proc, shard.chunks[proc]) + return shard.chunks[proc] end end - throw(KeyError(to_proc)) + throw(KeyError(proc)) +end +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + return move(from_proc, to_proc, shard_unwrap(shard, to_proc)) end Base.iterate(s::Shard) = iterate(values(s.chunks)) Base.iterate(s::Shard, state) = iterate(values(s.chunks), state)