@@ -30,55 +30,139 @@ REGISTER_OP(Convolution)
3030 })
3131 */
3232 .translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
33- auto ir_template =
34- R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )" ;
35-
3633 auto _op = static_pointer_cast<nnfusion::op::Convolution>(curr->get_op_ptr ());
3734 NNFUSION_CHECK_NOT_NULLPTR (_op) << " Node type is not " << curr->get_op_ptr ()->get_op_type ();
38- const auto & dilation_h = _op->get_window_dilation_strides ()[0 ];
39- const auto & dilation_w = _op->get_window_dilation_strides ()[1 ];
40- const auto & stride_h = _op->get_window_movement_strides ()[0 ];
41- const auto & stride_w = _op->get_window_movement_strides ()[1 ];
42- const auto & is_nchw = _op->get_data_format () == " NCHW" ;
43- const auto & padding_below = _op->get_padding_below ();
44- const auto & padding_above = _op->get_padding_above ();
45- const auto & padding_h = _op->get_padding_below ()[0 ];
46- const auto & padding_w = _op->get_padding_below ()[1 ];
47- const auto & kernel_size_h =
48- is_nchw ? curr->get_input_shape (1 )[2 ] : curr->get_input_shape (1 )[0 ];
49- const auto & kernel_size_w =
50- is_nchw ? curr->get_input_shape (1 )[3 ] : curr->get_input_shape (1 )[1 ];
51- const auto & in_shape = curr->get_input_shape (0 );
52- const auto & out_shape = curr->get_output_shape (0 );
53- const std::string data_format = is_nchw ? " nchw" : " nhwc" ;
54- NNFUSION_CHECK (dilation_h == 1 ) << " Not support other dilation yet." ;
55- NNFUSION_CHECK (dilation_w == 1 ) << " Not support other dilation yet." ;
56- NNFUSION_CHECK (padding_below == padding_above)
57- << " Asymetric padding is not supported by now." ;
58- nnfusion::op::OpConfig::any config;
59- std::string HO = " -@pad_0@ + KH + HO * " + to_string (stride_h);
60- std::string WO = " -@pad_1@ + KW + WO * " + to_string (stride_w);
61- std::string shape_template =
62- is_nchw ? " [N, C, " + HO + " , " + WO + " ]" : " [N, " + HO + " , " + WO + " , C]" ;
63- config[" input1_layout" ] = is_nchw ? " [F, C, KH, KW]" : " [KH, KW, C, F]" ;
64- config[" output0_layout" ] = is_nchw ? " [N, F, HO, WO]" : " [N, HO, WO, F]" ;
65- config[" height" ] = is_nchw ? out_shape[2 ] : out_shape[1 ];
66- config[" width" ] = is_nchw ? out_shape[3 ] : out_shape[2 ];
67- config[" pad_0" ] = to_string (padding_h);
68- config[" pad_1" ] = to_string (padding_w);
69- config[" input0_layout" ] = op::create_code_from_template (shape_template, config);
7035
71- std::string pad_cond;
72- if (padding_h || padding_w)
36+ if (_op->get_data_format () == " NCHW" || _op->get_data_format () == " NHWC" ) // Conv2D
7337 {
74- config[" in_height" ] = is_nchw ? in_shape[2 ] : in_shape[1 ];
75- config[" in_width" ] = is_nchw ? in_shape[3 ] : in_shape[2 ];
76- auto pad_template = " .when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO +
77- " >= 0, " + WO +
78- " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))" ;
79- pad_cond = op::create_code_from_template (pad_template, config);
38+ auto ir_template =
39+ R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )" ;
40+
41+ const auto & dilation_h = _op->get_window_dilation_strides ()[0 ];
42+ const auto & dilation_w = _op->get_window_dilation_strides ()[1 ];
43+ const auto & stride_h = _op->get_window_movement_strides ()[0 ];
44+ const auto & stride_w = _op->get_window_movement_strides ()[1 ];
45+ const auto & is_nchw = _op->get_data_format () == " NCHW" ;
46+ const auto & padding_below = _op->get_padding_below ();
47+ const auto & padding_above = _op->get_padding_above ();
48+ const auto & padding_h = _op->get_padding_below ()[0 ];
49+ const auto & padding_w = _op->get_padding_below ()[1 ];
50+ const auto & kernel_size_h =
51+ is_nchw ? curr->get_input_shape (1 )[2 ] : curr->get_input_shape (1 )[0 ];
52+ const auto & kernel_size_w =
53+ is_nchw ? curr->get_input_shape (1 )[3 ] : curr->get_input_shape (1 )[1 ];
54+ const auto & in_shape = curr->get_input_shape (0 );
55+ const auto & out_shape = curr->get_output_shape (0 );
56+ const std::string data_format = is_nchw ? " nchw" : " nhwc" ;
57+ if (dilation_h != 1 || dilation_w != 1 )
58+ {
59+ NNFUSION_LOG (NNFUSION_WARNING) << " Not support other dilation yet." ;
60+ return " " ;
61+ }
62+ if (padding_below != padding_above)
63+ {
64+ NNFUSION_LOG (NNFUSION_WARNING) << " Asymetric padding is not supported by now." ;
65+ return " " ;
66+ }
67+ // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet.";
68+ // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet.";
69+ // NNFUSION_CHECK(padding_below == padding_above)
70+ // << "Asymetric padding is not supported by now.";
71+ nnfusion::op::OpConfig::any config;
72+ std::string HO = " -@pad_0@ + KH + HO * " + to_string (stride_h);
73+ std::string WO = " -@pad_1@ + KW + WO * " + to_string (stride_w);
74+ std::string shape_template =
75+ is_nchw ? " [N, C, " + HO + " , " + WO + " ]" : " [N, " + HO + " , " + WO + " , C]" ;
76+ config[" input1_layout" ] = is_nchw ? " [F, C, KH, KW]" : " [KH, KW, C, F]" ;
77+ config[" output0_layout" ] = is_nchw ? " [N, F, HO, WO]" : " [N, HO, WO, F]" ;
78+ config[" height" ] = is_nchw ? out_shape[2 ] : out_shape[1 ];
79+ config[" width" ] = is_nchw ? out_shape[3 ] : out_shape[2 ];
80+ config[" pad_0" ] = to_string (padding_h);
81+ config[" pad_1" ] = to_string (padding_w);
82+ config[" input0_layout" ] = op::create_code_from_template (shape_template, config);
83+
84+ std::string pad_cond;
85+ if (padding_h || padding_w)
86+ {
87+ config[" in_height" ] = is_nchw ? in_shape[2 ] : in_shape[1 ];
88+ config[" in_width" ] = is_nchw ? in_shape[3 ] : in_shape[2 ];
89+ auto pad_template =
90+ " .when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + " >= 0, " + WO +
91+ " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))" ;
92+ pad_cond = op::create_code_from_template (pad_template, config);
93+ }
94+ config[" pad_cond" ] = pad_cond;
95+
96+ return op::create_code_from_template (ir_template, config);
97+ }
98+ else if (_op->get_data_format () == " NCDHW" ) // Conv3D
99+ {
100+ auto ir_template =
101+ R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where DO in @depth@, HO in @height@, WO in @width@; )" ;
102+
103+ const auto & dilation_d = _op->get_window_dilation_strides ()[0 ];
104+ const auto & dilation_h = _op->get_window_dilation_strides ()[1 ];
105+ const auto & dilation_w = _op->get_window_dilation_strides ()[2 ];
106+ const auto & stride_d = _op->get_window_movement_strides ()[0 ];
107+ const auto & stride_h = _op->get_window_movement_strides ()[1 ];
108+ const auto & stride_w = _op->get_window_movement_strides ()[2 ];
109+ const auto & padding_below = _op->get_padding_below ();
110+ const auto & padding_above = _op->get_padding_above ();
111+ const auto & padding_d = _op->get_padding_below ()[0 ];
112+ const auto & padding_h = _op->get_padding_below ()[1 ];
113+ const auto & padding_w = _op->get_padding_below ()[2 ];
114+ const auto & kernel_size_d = curr->get_input_shape (1 )[2 ];
115+ const auto & kernel_size_h = curr->get_input_shape (1 )[3 ];
116+ const auto & kernel_size_w = curr->get_input_shape (1 )[4 ];
117+ const auto & in_shape = curr->get_input_shape (0 );
118+ const auto & out_shape = curr->get_output_shape (0 );
119+ const std::string data_format = " NCDHW" ;
120+ if (dilation_d != 1 || dilation_h != 1 || dilation_w != 1 )
121+ {
122+ NNFUSION_LOG (NNFUSION_WARNING) << " Not support other dilation yet." ;
123+ return " " ;
124+ }
125+ if (padding_below != padding_above)
126+ {
127+ NNFUSION_LOG (NNFUSION_WARNING) << " Asymetric padding is not supported by now." ;
128+ return " " ;
129+ }
130+ // NNFUSION_CHECK(dilation_d == 1) << "Not support other dilation yet.";
131+ // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet.";
132+ // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet.";
133+ // NNFUSION_CHECK(padding_below == padding_above)
134+ // << "Asymetric padding is not supported by now.";
135+ nnfusion::op::OpConfig::any config;
136+ std::string DO = " -@pad_0@ + KD + DO * " + to_string (stride_d);
137+ std::string HO = " -@pad_1@ + KH + HO * " + to_string (stride_h);
138+ std::string WO = " -@pad_2@ + KW + WO * " + to_string (stride_w);
139+ std::string shape_template = " [N, C, " + DO + " , " + HO + " , " + WO + " ]" ;
140+ config[" input1_layout" ] = " [F, C, KD, KH, KW]" ;
141+ config[" output0_layout" ] = " [N, F, DO, HO, WO]" ;
142+ config[" depth" ] = out_shape[2 ];
143+ config[" height" ] = out_shape[3 ];
144+ config[" width" ] = out_shape[4 ];
145+ config[" pad_0" ] = to_string (padding_d);
146+ config[" pad_1" ] = to_string (padding_h);
147+ config[" pad_2" ] = to_string (padding_w);
148+ config[" input0_layout" ] = op::create_code_from_template (shape_template, config);
149+
150+ std::string pad_cond;
151+ if (padding_d || padding_h || padding_w)
152+ {
153+ config[" in_depth" ] = in_shape[2 ];
154+ config[" in_height" ] = in_shape[3 ];
155+ config[" in_width" ] = in_shape[4 ];
156+ auto pad_template =
157+ " .when([" + DO + " >= 0, " + DO + " < @in_depth@, " + HO + " >= 0, " + HO +
158+ " < @in_height@, " + WO + " >= 0, " + WO +
159+ " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))" ;
160+ pad_cond = op::create_code_from_template (pad_template, config);
161+ }
162+ config[" pad_cond" ] = pad_cond;
163+
164+ return op::create_code_from_template (ir_template, config);
80165 }
81- config[" pad_cond" ] = pad_cond;
82166
83- return op::create_code_from_template (ir_template, config) ;
167+ return " " ;
84168 });
0 commit comments