diff --git a/unet/blocks.py b/unet/blocks.py index ec7aaa2..77387bb 100644 --- a/unet/blocks.py +++ b/unet/blocks.py @@ -123,9 +123,11 @@ def __init__(self, in_channels, middle_channels, out_channels, softmax=False): nn.BatchNorm2d(middle_channels), nn.ReLU(inplace=True), nn.Conv2d(middle_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1) ] + if softmax: + layers.append(nn.Softmax(dim=1)) + self.first = nn.Sequential(*layers) def forward(self, x): @@ -249,4 +251,4 @@ def __init__(self, in_channels, middle_channels, out_channels, softmax=False): self.first = nn.Sequential(*layers) def forward(self, x): - return self.first(x) \ No newline at end of file + return self.first(x) diff --git a/unet/unet.py b/unet/unet.py index 5f55d7f..b7bcc0a 100755 --- a/unet/unet.py +++ b/unet/unet.py @@ -21,7 +21,7 @@ def __init__(self, in_channels, out_channels, conv_depths=(64, 128, 256, 512, 10 decoder_layers = [] decoder_layers.extend([Decoder2D(2 * conv_depths[i + 1], 2 * conv_depths[i], 2 * conv_depths[i], conv_depths[i]) for i in reversed(range(len(conv_depths)-2))]) - decoder_layers.append(Last2D(conv_depths[1], conv_depths[0], out_channels)) + decoder_layers.append(Last2D(2*conv_depths[0], conv_depths[0], out_channels)) # encoder, center and decoder layers self.encoder_layers = nn.Sequential(*encoder_layers)