diff --git a/src/callbacks.jl b/src/callbacks.jl index a84d11e4c..804f9fcf1 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -188,15 +188,14 @@ end end integrator.sol.stats.ncondition += 1 - ivec = integrator.vector_event_last_time + ivec = enumerate(integrator.callback_cache.prev_simultaneous_events) prev_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) if integrator.event_last_time == counter && - minimum(ODE_DEFAULT_NORM( - ArrayInterface.allowed_getindex(previous_condition, - ivec), integrator.t)) <= - 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) + minimum(minimum(idx -> ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, idx), integrator.t), + (idx for (idx, triggered) ∈ ivec if triggered), init=typemax(typeof(integrator.t)))) <= + 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) # If there was a previous event, utilize the derivative at the start to # chose the previous sign. If the derivative is positive at tprev, then @@ -215,7 +214,11 @@ end abst = integrator.tprev + integrator.dt * callback.repeat_nudge tmp_condition = get_condition(integrator, callback, abst) @. prev_sign = sign(previous_condition) - prev_sign[ivec] = sign(tmp_condition[ivec]) + for (idx, triggered) ∈ ivec + if triggered + prev_sign[idx] = sign(tmp_condition[idx]) + end + end else @. prev_sign = sign(previous_condition) end @@ -263,7 +266,6 @@ end interp_index = callback.interp_points end end - event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx end @@ -466,6 +468,10 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun callback, counter) if event_occurred + (; simultaneous_events, prev_simultaneous_events) = integrator.callback_cache + prev_simultaneous_events .= simultaneous_events + simultaneous_events .= false + if callback.condition === nothing new_t = zero(typeof(integrator.t)) min_event_idx = findfirst(isequal(1), event_idx) @@ -492,14 +498,13 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun Θ = top_t else if integrator.event_last_time == counter && - integrator.vector_event_last_time == idx && + prev_simultaneous_events[idx] && abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) && prev_sign_index == 1 # Determined that there is an event by derivative # But floating point error may make the end point negative - bottom_t += integrator.dt * callback.repeat_nudge sign_top = sign(zero_func(top_t)) sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) && @@ -515,8 +520,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun end end if integrator.tdir * Θ < integrator.tdir * min_t + simultaneous_events .= false + end + if integrator.tdir * Θ <= integrator.tdir * min_t min_event_idx = idx min_t = Θ + simultaneous_events[idx] = true end end end @@ -532,9 +541,19 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) new_t = ts[interp_index] - integrator.tprev min_event_idx = findfirst(isequal(1), event_idx) + for (i, idx) ∈ enumerate(event_idx) + if idx == 1 + simultaneous_events[i] = true + end + end else # If no solve and no interpolants, just use endpoint new_t = integrator.dt + for (i, idx) ∈ enumerate(event_idx) + if idx == 1 + simultaneous_events[i] = true + end + end min_event_idx = findfirst(isequal(1), event_idx) end end @@ -546,13 +565,12 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun if event_occurred && min_event_idx < 0 error("Callback handling failed. Please file an issue with code to reproduce.") end - - new_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int + # We still pass around the min_event_idx for now because some stuff in OrdinaryDiffEqCore expects it to be an Int + new_t, prev_sign, event_occurred::Bool, min_event_idx::Int end function apply_callback!(integrator, - callback::Union{ContinuousCallback, VectorContinuousCallback}, + callback::ContinuousCallback, cb_time, prev_sign, event_idx) if isadaptive(integrator) set_proposed_dt!(integrator, @@ -610,6 +628,58 @@ function apply_callback!(integrator, false, saved_in_cb end +function apply_callback!(integrator, + callback::VectorContinuousCallback, + cb_time, prev_sign, min_event_idx) + if isadaptive(integrator) + set_proposed_dt!(integrator, + integrator.tdir * max(nextfloat(integrator.opts.dtmin), + integrator.tdir * callback.dtrelax * integrator.dt)) + end + + change_t_via_interpolation!( + integrator, integrator.tprev + cb_time, Val{:false}, callback.initializealg) + + # handle saveat + _, savedexactly = savevalues!(integrator) + saved_in_cb = true + + @inbounds if callback.save_positions[1] + # if already saved then skip saving + savedexactly || savevalues!(integrator, true) + end + + u_modified = false + for (i, triggered) ∈ enumerate(integrator.callback_cache.simultaneous_events) + if triggered + if prev_sign[i] < 0 callback.affect! !== nothing + callback.affect!(integrator, i) + u_modified = true + elseif prev_sign[i] > 0 && callback.affect_neg! !== nothing + callback.affect_neg!(integrator, i) + u_modified = true + end + end + end + integrator.u_modified = u_modified + if u_modified + reeval_internals_due_to_modification!( + integrator, callback_initializealg = callback.initializealg) + + @inbounds if callback.save_positions[2] + savevalues!(integrator, true) + if !isdefined(integrator.opts, :save_discretes) || integrator.opts.save_discretes + for i ∈ integrator.callback_cache.simultaneous_events + SciMLBase.save_discretes!(integrator, callback, i) + end + end + saved_in_cb = true + end + return true, saved_in_cb + end + false, saved_in_cb +end + #Base Case: Just one @inline function apply_discrete_callback!(integrator, callback::DiscreteCallback) saved_in_cb = false @@ -698,22 +768,28 @@ mutable struct CallbackCache{conditionType, signType} previous_condition::conditionType next_sign::signType prev_sign::signType + simultaneous_events::Vector{Bool} + prev_simultaneous_events::Vector{Bool} end function CallbackCache(u, max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) previous_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) prev_sign = similar(u, signType, max_len) - CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + simultaneous_events = zeros(Bool, max_len) + prev_simultaneous_events = zeros(Bool, max_len) + CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end function CallbackCache(max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) previous_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) prev_sign = zeros(signType, max_len) - CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + simultaneous_events = zeros(Bool, max_len) + prev_simultaneous_events = zeros(Bool, max_len) + CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign, simultaneous_events, prev_simultaneous_events) end