Skip to content

Commit 1a9bc65

Browse files
committed
Merge Conv3d IR
2 parents e53d642 + cfae096 commit 1a9bc65

File tree

5 files changed

+166
-66
lines changed

5 files changed

+166
-66
lines changed

artifacts/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# OSDI'20 Artifacts Evaluation
22

3-
OSDI'20 Artifact Evaluation of paper #292, titled "[Rammer: Enabling Holistic Deep Learning Compiler Optimizations with rTasks](https://www.usenix.org/conference/osdi20/presentation/ma)".
3+
- OSDI'20 Artifact Evaluation of paper #292, titled "[Rammer: Enabling Holistic Deep Learning Compiler Optimizations with rTasks](https://www.usenix.org/conference/osdi20/presentation/ma)".
4+
Please refer to the [osdi20_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi20_artifact/artifacts)**
45

5-
**Please refer to the [osdi20_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi20_artifact/artifacts)**
6+
7+
- OSDI'22 Artifact Evaluation of paper #158, titled "[Roller: Fast and Efficient Tensor Compilation for Deep Learning](https://www.usenix.org/conference/osdi22/presentation/zhu)".
8+
Please refer to the [osdi22_artifact branch](https://github.com/microsoft/nnfusion/tree/osdi22_artifact/artifacts)**

src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp

Lines changed: 129 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,55 +30,139 @@ REGISTER_OP(Convolution)
3030
})
3131
*/
3232
.translate_v2([](std::shared_ptr<graph::GNode> curr) -> std::string {
33-
auto ir_template =
34-
R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )";
35-
3633
auto _op = static_pointer_cast<nnfusion::op::Convolution>(curr->get_op_ptr());
3734
NNFUSION_CHECK_NOT_NULLPTR(_op) << "Node type is not " << curr->get_op_ptr()->get_op_type();
38-
const auto& dilation_h = _op->get_window_dilation_strides()[0];
39-
const auto& dilation_w = _op->get_window_dilation_strides()[1];
40-
const auto& stride_h = _op->get_window_movement_strides()[0];
41-
const auto& stride_w = _op->get_window_movement_strides()[1];
42-
const auto& is_nchw = _op->get_data_format() == "NCHW";
43-
const auto& padding_below = _op->get_padding_below();
44-
const auto& padding_above = _op->get_padding_above();
45-
const auto& padding_h = _op->get_padding_below()[0];
46-
const auto& padding_w = _op->get_padding_below()[1];
47-
const auto& kernel_size_h =
48-
is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0];
49-
const auto& kernel_size_w =
50-
is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1];
51-
const auto& in_shape = curr->get_input_shape(0);
52-
const auto& out_shape = curr->get_output_shape(0);
53-
const std::string data_format = is_nchw ? "nchw" : "nhwc";
54-
NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet.";
55-
NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet.";
56-
NNFUSION_CHECK(padding_below == padding_above)
57-
<< "Asymetric padding is not supported by now.";
58-
nnfusion::op::OpConfig::any config;
59-
std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h);
60-
std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w);
61-
std::string shape_template =
62-
is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]";
63-
config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]";
64-
config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]";
65-
config["height"] = is_nchw ? out_shape[2] : out_shape[1];
66-
config["width"] = is_nchw ? out_shape[3] : out_shape[2];
67-
config["pad_0"] = to_string(padding_h);
68-
config["pad_1"] = to_string(padding_w);
69-
config["input0_layout"] = op::create_code_from_template(shape_template, config);
7035

71-
std::string pad_cond;
72-
if (padding_h || padding_w)
36+
if (_op->get_data_format() == "NCHW" || _op->get_data_format() == "NHWC") // Conv2D
7337
{
74-
config["in_height"] = is_nchw ? in_shape[2] : in_shape[1];
75-
config["in_width"] = is_nchw ? in_shape[3] : in_shape[2];
76-
auto pad_template = ".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO +
77-
" >= 0, " + WO +
78-
" < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))";
79-
pad_cond = op::create_code_from_template(pad_template, config);
38+
auto ir_template =
39+
R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )";
40+
41+
const auto& dilation_h = _op->get_window_dilation_strides()[0];
42+
const auto& dilation_w = _op->get_window_dilation_strides()[1];
43+
const auto& stride_h = _op->get_window_movement_strides()[0];
44+
const auto& stride_w = _op->get_window_movement_strides()[1];
45+
const auto& is_nchw = _op->get_data_format() == "NCHW";
46+
const auto& padding_below = _op->get_padding_below();
47+
const auto& padding_above = _op->get_padding_above();
48+
const auto& padding_h = _op->get_padding_below()[0];
49+
const auto& padding_w = _op->get_padding_below()[1];
50+
const auto& kernel_size_h =
51+
is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0];
52+
const auto& kernel_size_w =
53+
is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1];
54+
const auto& in_shape = curr->get_input_shape(0);
55+
const auto& out_shape = curr->get_output_shape(0);
56+
const std::string data_format = is_nchw ? "nchw" : "nhwc";
57+
if (dilation_h != 1 || dilation_w != 1)
58+
{
59+
NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet.";
60+
return "";
61+
}
62+
if (padding_below != padding_above)
63+
{
64+
NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now.";
65+
return "";
66+
}
67+
// NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet.";
68+
// NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet.";
69+
// NNFUSION_CHECK(padding_below == padding_above)
70+
// << "Asymetric padding is not supported by now.";
71+
nnfusion::op::OpConfig::any config;
72+
std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h);
73+
std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w);
74+
std::string shape_template =
75+
is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]";
76+
config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]";
77+
config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]";
78+
config["height"] = is_nchw ? out_shape[2] : out_shape[1];
79+
config["width"] = is_nchw ? out_shape[3] : out_shape[2];
80+
config["pad_0"] = to_string(padding_h);
81+
config["pad_1"] = to_string(padding_w);
82+
config["input0_layout"] = op::create_code_from_template(shape_template, config);
83+
84+
std::string pad_cond;
85+
if (padding_h || padding_w)
86+
{
87+
config["in_height"] = is_nchw ? in_shape[2] : in_shape[1];
88+
config["in_width"] = is_nchw ? in_shape[3] : in_shape[2];
89+
auto pad_template =
90+
".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + " >= 0, " + WO +
91+
" < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))";
92+
pad_cond = op::create_code_from_template(pad_template, config);
93+
}
94+
config["pad_cond"] = pad_cond;
95+
96+
return op::create_code_from_template(ir_template, config);
97+
}
98+
else if (_op->get_data_format() == "NCDHW") // Conv3D
99+
{
100+
auto ir_template =
101+
R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where DO in @depth@, HO in @height@, WO in @width@; )";
102+
103+
const auto& dilation_d = _op->get_window_dilation_strides()[0];
104+
const auto& dilation_h = _op->get_window_dilation_strides()[1];
105+
const auto& dilation_w = _op->get_window_dilation_strides()[2];
106+
const auto& stride_d = _op->get_window_movement_strides()[0];
107+
const auto& stride_h = _op->get_window_movement_strides()[1];
108+
const auto& stride_w = _op->get_window_movement_strides()[2];
109+
const auto& padding_below = _op->get_padding_below();
110+
const auto& padding_above = _op->get_padding_above();
111+
const auto& padding_d = _op->get_padding_below()[0];
112+
const auto& padding_h = _op->get_padding_below()[1];
113+
const auto& padding_w = _op->get_padding_below()[2];
114+
const auto& kernel_size_d = curr->get_input_shape(1)[2];
115+
const auto& kernel_size_h = curr->get_input_shape(1)[3];
116+
const auto& kernel_size_w = curr->get_input_shape(1)[4];
117+
const auto& in_shape = curr->get_input_shape(0);
118+
const auto& out_shape = curr->get_output_shape(0);
119+
const std::string data_format = "NCDHW";
120+
if (dilation_d != 1 || dilation_h != 1 || dilation_w != 1)
121+
{
122+
NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet.";
123+
return "";
124+
}
125+
if (padding_below != padding_above)
126+
{
127+
NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now.";
128+
return "";
129+
}
130+
// NNFUSION_CHECK(dilation_d == 1) << "Not support other dilation yet.";
131+
// NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet.";
132+
// NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet.";
133+
// NNFUSION_CHECK(padding_below == padding_above)
134+
// << "Asymetric padding is not supported by now.";
135+
nnfusion::op::OpConfig::any config;
136+
std::string DO = "-@pad_0@ + KD + DO * " + to_string(stride_d);
137+
std::string HO = "-@pad_1@ + KH + HO * " + to_string(stride_h);
138+
std::string WO = "-@pad_2@ + KW + WO * " + to_string(stride_w);
139+
std::string shape_template = "[N, C, " + DO + ", " + HO + ", " + WO + "]";
140+
config["input1_layout"] = "[F, C, KD, KH, KW]";
141+
config["output0_layout"] = "[N, F, DO, HO, WO]";
142+
config["depth"] = out_shape[2];
143+
config["height"] = out_shape[3];
144+
config["width"] = out_shape[4];
145+
config["pad_0"] = to_string(padding_d);
146+
config["pad_1"] = to_string(padding_h);
147+
config["pad_2"] = to_string(padding_w);
148+
config["input0_layout"] = op::create_code_from_template(shape_template, config);
149+
150+
std::string pad_cond;
151+
if (padding_d || padding_h || padding_w)
152+
{
153+
config["in_depth"] = in_shape[2];
154+
config["in_height"] = in_shape[3];
155+
config["in_width"] = in_shape[4];
156+
auto pad_template =
157+
".when([" + DO + " >= 0, " + DO + " < @in_depth@, " + HO + " >= 0, " + HO +
158+
" < @in_height@, " + WO + " >= 0, " + WO +
159+
" < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))";
160+
pad_cond = op::create_code_from_template(pad_template, config);
161+
}
162+
config["pad_cond"] = pad_cond;
163+
164+
return op::create_code_from_template(ir_template, config);
80165
}
81-
config["pad_cond"] = pad_cond;
82166

83-
return op::create_code_from_template(ir_template, config);
167+
return "";
84168
});

src/nnfusion/core/operators/util/validation_util.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,26 +176,33 @@ std::tuple<nnfusion::element::Type, nnfusion::PartialShape>
176176
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides
177177
<< "), and filter dilation (" << filter_dilation << ") do not match.";
178178

179-
OP_VALIDATION(op, data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC")
180-
<< "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC.";
179+
OP_VALIDATION(op,
180+
data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC" ||
181+
data_format == "NCDHW")
182+
<< "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC, Conv3D: NCDHW.";
181183

182184
nnfusion::Dimension batch_size =
183185
(data_batch_shape.rank().is_static() ? data_batch_shape[0]
184186
: nnfusion::Dimension::dynamic());
185187
nnfusion::Dimension data_channel_count =
186188
(data_batch_shape.rank().is_static()
187-
? (data_format == "NCW" || data_format == "NCHW") ? data_batch_shape[1]
188-
: data_batch_shape[3]
189+
? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
190+
? data_batch_shape[1]
191+
: data_batch_shape[3]
189192
: nnfusion::Dimension::dynamic());
190193
nnfusion::PartialShape data_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank));
191194

192195
nnfusion::Dimension filter_output_channel_count =
193196
(filters_shape.rank().is_static()
194-
? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[0] : filters_shape[3]
197+
? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
198+
? filters_shape[0]
199+
: filters_shape[3]
195200
: nnfusion::Dimension::dynamic());
196201
nnfusion::Dimension filter_input_channel_count =
197202
(filters_shape.rank().is_static()
198-
? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[1] : filters_shape[2]
203+
? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
204+
? filters_shape[1]
205+
: filters_shape[2]
199206
: nnfusion::Dimension::dynamic());
200207
nnfusion::PartialShape filter_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank));
201208

@@ -207,16 +214,18 @@ std::tuple<nnfusion::element::Type, nnfusion::PartialShape>
207214
{
208215
if (data_batch_shape.rank().is_static())
209216
{
210-
data_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW")
211-
? data_batch_shape[i + 2]
212-
: data_batch_shape[i + 1];
217+
data_spatial_shape[i] =
218+
(data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
219+
? data_batch_shape[i + 2]
220+
: data_batch_shape[i + 1];
213221
}
214222

215223
if (filters_shape.rank().is_static())
216224
{
217-
filter_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW")
218-
? filters_shape[i + 2]
219-
: filters_shape[i];
225+
filter_spatial_shape[i] =
226+
(data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
227+
? filters_shape[i + 2]
228+
: filters_shape[i];
220229
}
221230
}
222231

@@ -253,7 +262,7 @@ std::tuple<nnfusion::element::Type, nnfusion::PartialShape>
253262

254263
nnfusion::PartialShape batch_output_shape(nnfusion::PartialShape::dynamic(spatial_rank + 2));
255264

256-
if (data_format == "NCW" || data_format == "NCHW")
265+
if (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW")
257266
{
258267
batch_output_shape[0] = batch_size;
259268
batch_output_shape[1] = filter_output_channel_count;

src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ cmake_minimum_required(VERSION 3.5)
13711371
13721372
SET(SRC "nnfusion_rt.cu" CACHE STRING "codegen source file")
13731373
SET(TARGET_NAME "nnfusion_naive_rt" CACHE STRING "codegen target name")
1374-
SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" CACHE STRING "target architecture")
1374+
SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86" CACHE STRING "target architecture")
13751375
13761376
if(NOT CMAKE_BUILD_TYPE)
13771377
set(CMAKE_BUILD_TYPE Release)

src/nnfusion/frontend/onnx_import/op/conv.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ namespace nnfusion
7878
{
7979
conv_data_format = "NCHW";
8080
}
81-
// else if (data_shape.size() == 5)
82-
// {
83-
// conv_data_format = "NCDHW";
84-
// }
81+
else if (data_shape.size() == 5)
82+
{
83+
conv_data_format = "NCDHW";
84+
}
8585
else
8686
{
8787
NNFUSION_CHECK_FAIL() << "Convolution with dimensions of "
@@ -168,7 +168,7 @@ namespace nnfusion
168168
strides, dilations, padding_below, padding_above, conv_data_format);
169169
conv_node = m_graph->add_node_and_edge(conv_op, {data, filters});
170170
}
171-
else
171+
else if (conv_data_format == "NCHW")
172172
{
173173
// split data and filters for group conv
174174
std::size_t n_data_channels{data_shape.at(1)};
@@ -264,6 +264,10 @@ namespace nnfusion
264264
convolution_nodes);
265265
}
266266
}
267+
else
268+
{
269+
NNFUSION_CHECK_FAIL() << "Not support this Convolution yet.";
270+
}
267271

268272
// add bias
269273
if (input_indexes.size() == 3)

0 commit comments

Comments
 (0)