diff --git a/ganrectf/models.py b/ganrectf/models.py index eed0a92..4127b3d 100644 --- a/ganrectf/models.py +++ b/ganrectf/models.py @@ -156,7 +156,7 @@ def make_generator_3d(shape_input, conv_num, conv_size, dropout, output_num): for fc in fc_stack: x = fc(x) - x = tf.reshape(x, shape=[-1, img_w, img_w, 1]) + x = tf.reshape(x, shape=[-1, shape_input[1], shape_input[1], 1]) # Convolutions for conv in conv_stack: x = conv(x)