From aa6112a77a79454ec85e35b310f3e5432108ca0d Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sat, 3 Aug 2024 09:23:45 +0200 Subject: [PATCH 1/2] fix: activation for last layer of UNet2D --- unet/blocks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unet/blocks.py b/unet/blocks.py index ec7aaa2..77387bb 100644 --- a/unet/blocks.py +++ b/unet/blocks.py @@ -123,9 +123,11 @@ def __init__(self, in_channels, middle_channels, out_channels, softmax=False): nn.BatchNorm2d(middle_channels), nn.ReLU(inplace=True), nn.Conv2d(middle_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1) ] + if softmax: + layers.append(nn.Softmax(dim=1)) + self.first = nn.Sequential(*layers) def forward(self, x): @@ -249,4 +251,4 @@ def __init__(self, in_channels, middle_channels, out_channels, softmax=False): self.first = nn.Sequential(*layers) def forward(self, x): - return self.first(x) \ No newline at end of file + return self.first(x) From 5036b76ac18bf40623d70186853a9134026f90bb Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sat, 3 Aug 2024 12:10:39 +0200 Subject: [PATCH 2/2] fix: input channels for Last2D conv block --- unet/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unet/unet.py b/unet/unet.py index 5f55d7f..b7bcc0a 100755 --- a/unet/unet.py +++ b/unet/unet.py @@ -21,7 +21,7 @@ def __init__(self, in_channels, out_channels, conv_depths=(64, 128, 256, 512, 10 decoder_layers = [] decoder_layers.extend([Decoder2D(2 * conv_depths[i + 1], 2 * conv_depths[i], 2 * conv_depths[i], conv_depths[i]) for i in reversed(range(len(conv_depths)-2))]) - decoder_layers.append(Last2D(conv_depths[1], conv_depths[0], out_channels)) + decoder_layers.append(Last2D(2*conv_depths[0], conv_depths[0], out_channels)) # encoder, center and decoder layers self.encoder_layers = nn.Sequential(*encoder_layers)