@@ -48,14 +48,17 @@ namespace kpipeline
4848 std::mutex completion_mutex;
4949 std::condition_variable cv_completion;
5050
51- std::function<void (const std::string&)> schedule_next;
52- std::function<void (const std::string&)> prune_branch;
51+ // ======================== 统一调度逻辑 ========================
5352
54- auto task_wrapper = [this , &ws, &schedule_next, &profiler, enable_profiling, &cv_completion
55- ](const std::string& node_name)
53+ // 单个节点执行的核心逻辑
54+ auto task_wrapper = [this , &ws, &profiler, enable_profiling, &cv_completion](const std::string& node_name,
55+ const std::function<void (const std::string&)>& on_finish)
5656 {
57- // 如果图已经失败,则不执行任何新任务
58- if (graph_failed_.load ()) { return ; }
57+ if (graph_failed_.load ())
58+ {
59+ on_finish (node_name); // 如果图已失败,直接标记为完成
60+ return ;
61+ }
5962
6063 try
6164 {
@@ -65,109 +68,107 @@ namespace kpipeline
6568 nodes_.at (node_name)->Execute (ws);
6669
6770 if (enable_profiling) profiler.End (node_name, start_time);
68- schedule_next (node_name);
6971 }
7072 catch (const std::exception& e)
7173 {
7274 std::lock_guard<std::mutex> lock (exception_mutex_);
73- // 只记录第一个发生的异常
7475 if (!first_exception_)
7576 {
7677 first_exception_ = std::current_exception ();
7778 graph_failed_ = true ;
7879 LOG_INFO (" !!! Graph execution failed in node '{}'. Halting all operations. Error:{}" , node_name, e.what ());
7980 }
80- // 唤醒主线程,让它提前结束并重新抛出异常
81- cv_completion.notify_one ();
8281 }
82+ on_finish (node_name);
8383 };
8484
85- schedule_next = [this , &ws, &adj, &in_degree, &pool, &finished_nodes_count, total_graph_nodes, &cv_completion, &
86- task_wrapper, &prune_branch](const std::string& completed_node_name)
85+ // 统一的完成处理函数
86+ std::function<void (const std::string&)> on_node_finished;
87+ on_node_finished =
88+ [this , &ws, &adj, &in_degree, &pool, &finished_nodes_count, total_graph_nodes, &cv_completion, &task_wrapper, &
89+ on_node_finished](const std::string& finished_node_name)
8790 {
88- if (graph_failed_. load () )
91+ if (finished_nodes_count. fetch_add ( 1 ) + 1 == total_graph_nodes )
8992 {
90- // 如果图已失败,我们仍需增加计数器以最终结束等待,但不再调度
91- if (finished_nodes_count.fetch_add (1 ) + 1 == total_graph_nodes) { cv_completion.notify_one (); }
92- return ;
93+ cv_completion.notify_one ();
9394 }
9495
95- if (adj.count (completed_node_name ))
96+ if (adj.count (finished_node_name ))
9697 {
97- for (const auto & successor_name : adj.at (completed_node_name ))
98+ for (const auto & successor_name : adj.at (finished_node_name ))
9899 {
99100 if (--in_degree.at (successor_name) == 0 )
100101 {
101102 bool can_run = true ;
102- const auto & successor_node = nodes_.at (successor_name);
103- for (const auto & control_input : successor_node->GetControlInputs ())
104- {
105- if (!ws.Has (control_input))
106- {
107- can_run = false ;
108- break ;
109- }
110- }
111- if (can_run)
103+ if (graph_failed_.load ())
112104 {
113- pool.Enqueue (task_wrapper, successor_name);
105+ // 如果图已失败,所有后续节点都不能运行
106+ can_run = false ;
114107 }
115108 else
116109 {
117- prune_branch (successor_name);
110+ const auto & successor_node = nodes_.at (successor_name);
111+ for (const auto & control_input : successor_node->GetControlInputs ())
112+ {
113+ if (!ws.Has (control_input))
114+ {
115+ can_run = false ;
116+ break ;
117+ }
118+ }
118119 }
120+
121+ // 无论是执行还是剪枝,都提交一个新任务来处理,以释放当前线程
122+ pool.Enqueue ([this , can_run, successor_name, &task_wrapper, &on_node_finished]()
123+ {
124+ if (can_run)
125+ {
126+ task_wrapper (successor_name, on_node_finished);
127+ }
128+ else
129+ {
130+ LOG_INFO (" > Pruning branch at node: {}" , successor_name);
131+ on_node_finished (successor_name); // 被剪枝的节点直接标记为完成
132+ }
133+ });
119134 }
120135 }
121136 }
122-
123- if (finished_nodes_count.fetch_add (1 ) + 1 == total_graph_nodes)
124- {
125- cv_completion.notify_one ();
126- }
127137 };
128138
129- prune_branch = [this , &adj, &in_degree, &finished_nodes_count, total_graph_nodes, &cv_completion, &prune_branch](
130- const std::string& pruned_node_name)
131- {
132- if (graph_failed_.load ())
133- {
134- if (finished_nodes_count.fetch_add (1 ) + 1 == total_graph_nodes) { cv_completion.notify_one (); }
135- return ;
136- }
137-
138- if (adj.count (pruned_node_name))
139- {
140- for (const auto & successor_name : adj.at (pruned_node_name))
141- {
142- if (--in_degree.at (successor_name) == 0 ) { prune_branch (successor_name); }
143- }
144- }
145- if (finished_nodes_count.fetch_add (1 ) + 1 == total_graph_nodes)
146- {
147- cv_completion.notify_one ();
148- }
149- };
139+ // ======================== 结束 ========================
150140
151141 LOG_INFO (" --- Starting Graph Execution with {} threads ---" , num_threads);
142+ bool has_started = false ;
152143 for (const auto & [name, node] : nodes_)
153144 {
154145 if (in_degree.at (name) == 0 )
155146 {
156- pool.Enqueue (task_wrapper, name);
147+ has_started = true ;
148+ // 启动初始节点
149+ pool.Enqueue (task_wrapper, name, on_node_finished);
157150 }
158151 }
159152
153+ if (!has_started && total_graph_nodes > 0 )
154+ {
155+ throw PipelineException (" Graph has no entry points, but is not empty." );
156+ }
157+
160158 std::unique_lock<std::mutex> lock (completion_mutex);
161- // 等待条件:所有节点完成 或 图执行失败
162159 cv_completion.wait (lock, [&]
163160 {
164- return finished_nodes_count.load () == total_graph_nodes || graph_failed_.load ();
161+ // 等待条件:所有节点完成(无论是执行还是剪枝) 或 图执行失败
162+ return finished_nodes_count.load () >= total_graph_nodes || graph_failed_.load ();
165163 });
166164
165+ // 在重新抛出异常前,确保线程池被正确关闭,避免析构函数中任务队列还有任务
166+ // ThreadPool 的析构函数会等待所有任务完成,这正是我们需要的
167+ // 如果需要强制停止,ThreadPool 需要一个更复杂的停止机制
168+
167169 if (graph_failed_.load ())
168170 {
169171 LOG_WARN (" --- Graph Execution Halted Due to Error ---" );
170- // 等待线程池中的现有任务完成或退出,以避免悬空引用
171172 }
172173 else
173174 {
@@ -179,7 +180,6 @@ namespace kpipeline
179180 profiler.PrintReport ();
180181 }
181182
182- // 如果有异常,就在主线程中重新抛出它
183183 if (first_exception_)
184184 {
185185 std::rethrow_exception (first_exception_);
@@ -251,8 +251,6 @@ namespace kpipeline
251251 }
252252
253253 std::map<std::string, std::shared_ptr<Node>> nodes_;
254-
255- // --- 用于快速失败的成员 ---
256254 std::atomic<bool > graph_failed_;
257255 std::mutex exception_mutex_;
258256 std::exception_ptr first_exception_;
0 commit comments