@@ -416,6 +416,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
416416 m_op_ctxs.push_back (ctx);
417417 }
418418
419+ std::unordered_map<std::shared_ptr<GNode>, std::unordered_map<size_t , size_t >> input_id_map;
419420 // Register input tensors
420421 for (const auto & m_node : m_order_nodes)
421422 {
@@ -430,6 +431,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
430431 set_input (input_id, m_node->get_inputs ().at (in_edge->get_dst_input ()));
431432 graph->add_edge (
432433 in_edge->get_src (), in_edge->get_src_output (), shared_from_this (), input_id);
434+ input_id_map[m_node][in_edge->get_dst_input ()] = input_id;
433435 }
434436 }
435437 // Add control-edges as inputs of fused node
@@ -461,6 +463,29 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr<Graph> graph)
461463 has_output = true ;
462464 set_output (get_output_size (),
463465 m_node->get_outputs ().at (out_edge->get_src_output ()));
466+
467+ // get inplace annotation
468+ auto op = std::dynamic_pointer_cast<Op>(m_node->get_op_ptr ());
469+ auto op_annotations = op->get_op_annotations ();
470+ if (op_annotations)
471+ {
472+ auto oi_pairs = op_annotations->get_in_place_oi_pairs ();
473+ for (auto oi_pair : oi_pairs)
474+ {
475+ auto iter = input_id_map.find (m_node);
476+ if (iter != input_id_map.end () && iter->second .count (oi_pair.input ) > 0 )
477+ {
478+ auto fused_op =
479+ std::dynamic_pointer_cast<Op>(shared_from_this ()->get_op_ptr ());
480+ AddInplace (fused_op,
481+ get_output_size () - 1 ,
482+ iter->second [oi_pair.input ],
483+ oi_pair.destructive ,
484+ oi_pair.force_inplace );
485+ // NNFUSION_LOG(INFO) << "========================: node=" << m_node->get_op_type() << ", oi: <" << oi_pair.output << ", " << oi_pair.input << ">";
486+ }
487+ }
488+ }
464489 }
465490 graph->add_edge (shared_from_this (),
466491 get_output_size () - 1 ,
0 commit comments