Skip to content

Commit 7faaf84

Browse files
committed
Fix inpalce annotation for fused kernel
1 parent 1a9bc65 commit 7faaf84

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/nnfusion/core/graph/gnode.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
416416
m_op_ctxs.push_back(ctx);
417417
}
418418

419+
std::unordered_map<std::shared_ptr<GNode>, std::unordered_map<size_t, size_t>> input_id_map;
419420
// Register input tensors
420421
for (const auto& m_node : m_order_nodes)
421422
{
@@ -430,6 +431,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
430431
set_input(input_id, m_node->get_inputs().at(in_edge->get_dst_input()));
431432
graph->add_edge(
432433
in_edge->get_src(), in_edge->get_src_output(), shared_from_this(), input_id);
434+
input_id_map[m_node][in_edge->get_dst_input()] = input_id;
433435
}
434436
}
435437
// Add control-edges as inputs of fused node
@@ -461,6 +463,29 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
461463
has_output = true;
462464
set_output(get_output_size(),
463465
m_node->get_outputs().at(out_edge->get_src_output()));
466+
467+
// get inplace annotation
468+
auto op = std::dynamic_pointer_cast<Op>(m_node->get_op_ptr());
469+
auto op_annotations = op->get_op_annotations();
470+
if (op_annotations)
471+
{
472+
auto oi_pairs = op_annotations->get_in_place_oi_pairs();
473+
for (auto oi_pair : oi_pairs)
474+
{
475+
auto iter = input_id_map.find(m_node);
476+
if (iter != input_id_map.end() && iter->second.count(oi_pair.input) > 0)
477+
{
478+
auto fused_op =
479+
std::dynamic_pointer_cast<Op>(shared_from_this()->get_op_ptr());
480+
AddInplace(fused_op,
481+
get_output_size() - 1,
482+
iter->second[oi_pair.input],
483+
oi_pair.destructive,
484+
oi_pair.force_inplace);
485+
//NNFUSION_LOG(INFO) << "========================: node=" << m_node->get_op_type() << ", oi: <" << oi_pair.output << ", " << oi_pair.input << ">";
486+
}
487+
}
488+
}
464489
}
465490
graph->add_edge(shared_from_this(),
466491
get_output_size() - 1,

0 commit comments

Comments
 (0)