Skip to content

Commit be05041

Browse files
committed
fix: 线程数少于扇出数的问题
1 parent a546d69 commit be05041

File tree

1 file changed

+61
-63
lines changed

1 file changed

+61
-63
lines changed

kpipeline/graph.h

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ namespace kpipeline
4848
std::mutex completion_mutex;
4949
std::condition_variable cv_completion;
5050

51-
std::function<void(const std::string&)> schedule_next;
52-
std::function<void(const std::string&)> prune_branch;
51+
// ======================== 统一调度逻辑 ========================
5352

54-
auto task_wrapper = [this, &ws, &schedule_next, &profiler, enable_profiling, &cv_completion
55-
](const std::string& node_name)
53+
// 单个节点执行的核心逻辑
54+
auto task_wrapper = [this, &ws, &profiler, enable_profiling, &cv_completion](const std::string& node_name,
55+
const std::function<void(const std::string&)>& on_finish)
5656
{
57-
// 如果图已经失败,则不执行任何新任务
58-
if (graph_failed_.load()) { return; }
57+
if (graph_failed_.load())
58+
{
59+
on_finish(node_name); // 如果图已失败,直接标记为完成
60+
return;
61+
}
5962

6063
try
6164
{
@@ -65,109 +68,107 @@ namespace kpipeline
6568
nodes_.at(node_name)->Execute(ws);
6669

6770
if (enable_profiling) profiler.End(node_name, start_time);
68-
schedule_next(node_name);
6971
}
7072
catch (const std::exception& e)
7173
{
7274
std::lock_guard<std::mutex> lock(exception_mutex_);
73-
// 只记录第一个发生的异常
7475
if (!first_exception_)
7576
{
7677
first_exception_ = std::current_exception();
7778
graph_failed_ = true;
7879
LOG_INFO("!!! Graph execution failed in node '{}'. Halting all operations. Error:{}", node_name, e.what());
7980
}
80-
// 唤醒主线程,让它提前结束并重新抛出异常
81-
cv_completion.notify_one();
8281
}
82+
on_finish(node_name);
8383
};
8484

85-
schedule_next = [this, &ws, &adj, &in_degree, &pool, &finished_nodes_count, total_graph_nodes, &cv_completion, &
86-
task_wrapper, &prune_branch](const std::string& completed_node_name)
85+
// 统一的完成处理函数
86+
std::function<void(const std::string&)> on_node_finished;
87+
on_node_finished =
88+
[this, &ws, &adj, &in_degree, &pool, &finished_nodes_count, total_graph_nodes, &cv_completion, &task_wrapper, &
89+
on_node_finished](const std::string& finished_node_name)
8790
{
88-
if (graph_failed_.load())
91+
if (finished_nodes_count.fetch_add(1) + 1 == total_graph_nodes)
8992
{
90-
// 如果图已失败,我们仍需增加计数器以最终结束等待,但不再调度
91-
if (finished_nodes_count.fetch_add(1) + 1 == total_graph_nodes) { cv_completion.notify_one(); }
92-
return;
93+
cv_completion.notify_one();
9394
}
9495

95-
if (adj.count(completed_node_name))
96+
if (adj.count(finished_node_name))
9697
{
97-
for (const auto& successor_name : adj.at(completed_node_name))
98+
for (const auto& successor_name : adj.at(finished_node_name))
9899
{
99100
if (--in_degree.at(successor_name) == 0)
100101
{
101102
bool can_run = true;
102-
const auto& successor_node = nodes_.at(successor_name);
103-
for (const auto& control_input : successor_node->GetControlInputs())
104-
{
105-
if (!ws.Has(control_input))
106-
{
107-
can_run = false;
108-
break;
109-
}
110-
}
111-
if (can_run)
103+
if (graph_failed_.load())
112104
{
113-
pool.Enqueue(task_wrapper, successor_name);
105+
// 如果图已失败,所有后续节点都不能运行
106+
can_run = false;
114107
}
115108
else
116109
{
117-
prune_branch(successor_name);
110+
const auto& successor_node = nodes_.at(successor_name);
111+
for (const auto& control_input : successor_node->GetControlInputs())
112+
{
113+
if (!ws.Has(control_input))
114+
{
115+
can_run = false;
116+
break;
117+
}
118+
}
118119
}
120+
121+
// 无论是执行还是剪枝,都提交一个新任务来处理,以释放当前线程
122+
pool.Enqueue([this, can_run, successor_name, &task_wrapper, &on_node_finished]()
123+
{
124+
if (can_run)
125+
{
126+
task_wrapper(successor_name, on_node_finished);
127+
}
128+
else
129+
{
130+
LOG_INFO(" > Pruning branch at node: {}", successor_name);
131+
on_node_finished(successor_name); // 被剪枝的节点直接标记为完成
132+
}
133+
});
119134
}
120135
}
121136
}
122-
123-
if (finished_nodes_count.fetch_add(1) + 1 == total_graph_nodes)
124-
{
125-
cv_completion.notify_one();
126-
}
127137
};
128138

129-
prune_branch = [this, &adj, &in_degree, &finished_nodes_count, total_graph_nodes, &cv_completion, &prune_branch](
130-
const std::string& pruned_node_name)
131-
{
132-
if (graph_failed_.load())
133-
{
134-
if (finished_nodes_count.fetch_add(1) + 1 == total_graph_nodes) { cv_completion.notify_one(); }
135-
return;
136-
}
137-
138-
if (adj.count(pruned_node_name))
139-
{
140-
for (const auto& successor_name : adj.at(pruned_node_name))
141-
{
142-
if (--in_degree.at(successor_name) == 0) { prune_branch(successor_name); }
143-
}
144-
}
145-
if (finished_nodes_count.fetch_add(1) + 1 == total_graph_nodes)
146-
{
147-
cv_completion.notify_one();
148-
}
149-
};
139+
// ======================== 结束 ========================
150140

151141
LOG_INFO("--- Starting Graph Execution with {} threads ---", num_threads);
142+
bool has_started = false;
152143
for (const auto& [name, node] : nodes_)
153144
{
154145
if (in_degree.at(name) == 0)
155146
{
156-
pool.Enqueue(task_wrapper, name);
147+
has_started = true;
148+
// 启动初始节点
149+
pool.Enqueue(task_wrapper, name, on_node_finished);
157150
}
158151
}
159152

153+
if (!has_started && total_graph_nodes > 0)
154+
{
155+
throw PipelineException("Graph has no entry points, but is not empty.");
156+
}
157+
160158
std::unique_lock<std::mutex> lock(completion_mutex);
161-
// 等待条件:所有节点完成 或 图执行失败
162159
cv_completion.wait(lock, [&]
163160
{
164-
return finished_nodes_count.load() == total_graph_nodes || graph_failed_.load();
161+
// 等待条件:所有节点完成(无论是执行还是剪枝) 或 图执行失败
162+
return finished_nodes_count.load() >= total_graph_nodes || graph_failed_.load();
165163
});
166164

165+
// 在重新抛出异常前,确保线程池被正确关闭,避免析构函数中任务队列还有任务
166+
// ThreadPool 的析构函数会等待所有任务完成,这正是我们需要的
167+
// 如果需要强制停止,ThreadPool 需要一个更复杂的停止机制
168+
167169
if (graph_failed_.load())
168170
{
169171
LOG_WARN("--- Graph Execution Halted Due to Error ---");
170-
// 等待线程池中的现有任务完成或退出,以避免悬空引用
171172
}
172173
else
173174
{
@@ -179,7 +180,6 @@ namespace kpipeline
179180
profiler.PrintReport();
180181
}
181182

182-
// 如果有异常,就在主线程中重新抛出它
183183
if (first_exception_)
184184
{
185185
std::rethrow_exception(first_exception_);
@@ -251,8 +251,6 @@ namespace kpipeline
251251
}
252252

253253
std::map<std::string, std::shared_ptr<Node>> nodes_;
254-
255-
// --- 用于快速失败的成员 ---
256254
std::atomic<bool> graph_failed_;
257255
std::mutex exception_mutex_;
258256
std::exception_ptr first_exception_;

0 commit comments

Comments
 (0)