From 95608168e6d5bfa9b4b1926975381f5e3353108f Mon Sep 17 00:00:00 2001 From: NimaTorbati Date: Tue, 4 Nov 2025 19:42:28 +0100 Subject: [PATCH] test --- ACS/model.py | 391 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 391 insertions(+) create mode 100644 ACS/model.py diff --git a/ACS/model.py b/ACS/model.py new file mode 100644 index 0000000..20540b1 --- /dev/null +++ b/ACS/model.py @@ -0,0 +1,391 @@ +from transformers import SegformerModel, SegformerConfig +import segmentation_models_pytorch as smp +from ACS.unet_decoder.decoder import UnetDecoder +import matplotlib.pyplot as plt +from segmentation_models_pytorch.decoders.unetplusplus.decoder import UnetPlusPlusDecoder + +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm + +class TimmEncoderFixed(nn.Module): + def __init__( + self, + name, + pretrained=True, + in_channels=3, + depth=5, + output_stride=32, + drop_rate=0.5, + drop_path_rate=0.0, + ): + super().__init__() + kwargs = dict( + in_chans=in_channels, + features_only=True, + pretrained=pretrained, + out_indices=tuple(range(depth)), + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + ) + + self.model = timm.create_model(name, **kwargs) + + self._in_channels = in_channels + self._out_channels = [ + in_channels, + ] + self.model.feature_info.channels() + self._depth = depth + self._output_stride = output_stride + + def forward(self, x): + features = self.model(x) + # features = [ + # x, + # ] + features + return features + + @property + def out_channels(self): + return self._out_channels + + @property + def output_stride(self): + return min(self._output_stride, 2**self._depth) + + +def get_timm_encoder( + name, + in_channels=3, + depth=5, + weights=False, + output_stride=32, + drop_rate=0.5, + drop_path_rate=0.25, +): + encoder = TimmEncoderFixed( + name, weights, in_channels, depth, output_stride, drop_rate, drop_path_rate + ) + return encoder + + + +class FusionBlock0(nn.Module): + """ + simple concatenate + out = concatenate([s,u]) + """ + def __init__(self): + super().__init__() + def forward(self, U, S): + # Resize S if needed to match U + if S.shape[2:] != U.shape[2:]: + S = F.interpolate(S, size=U.shape[2:], mode="bilinear", align_corners=False) + fused = torch.cat([U, S], dim=1) # Concatenate along channel dimension + return fused + + + +class ChannelAttention(nn.Module): + def __init__(self, in_channels, reduction=16): + super(ChannelAttention, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(in_channels // reduction, in_channels, bias=False) + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + b, c, _, _ = x.size() + avg_pool = F.adaptive_avg_pool2d(x, 1).view(b, c) + max_pool = F.adaptive_max_pool2d(x, 1).view(b, c) + + avg_out = self.mlp(avg_pool) + max_out = self.mlp(max_pool) + + scale = self.sigmoid(avg_out + max_out).view(b, c, 1, 1) + return x * scale + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x_cat = torch.cat([avg_out, max_out], dim=1) + scale = self.sigmoid(self.conv(x_cat)) + return x * scale + +# +class FusionBlock1(nn.Module): + """ + CBAM : + out = CBAM(x) + """ + def __init__(self, channels, reduction=16, spatial_kernel=7): + super(FusionBlock1, self).__init__() + self.channel_att = ChannelAttention(channels, reduction) + self.spatial_att = SpatialAttention(spatial_kernel) + + def forward(self, U, S = None): + if S is None: + x = U + else: + x = torch.cat([U, S], dim=1) + att = self.channel_att(x) + att = self.spatial_att(att) + + return att # residual addition + + + + + + +class DualEncoderUNet(nn.Module): + def __init__( + self, + unet_encoder_name="resnet34", + unet_encoder_weights=None, + segformer_variant="nvidia/segformer-b2-finetuned-ade-512-512", + classes=1, + decoder_channels=(256, 128, 64, 32,16), + simple_fusion=0, + regression=False, + in_channels=3, + freeze_segformer = False, + freeze_unet=False, + input_size=1024, + decoder_type="unet", + IgnoreBottleNeck = False, + cof_seg = 1, + cof_unet = 1, + model_depth=5, + instance_segmentation = False, + ): + super().__init__() + self.classes = classes + self.cof_seg = cof_seg + self.cof_unet = cof_unet + self.freeze_segformer = freeze_segformer + self.freeze_unet = freeze_unet + self.model_depth = model_depth + self.instance_segmentation = instance_segmentation + + ## unet encoder + + if 'convnext' in unet_encoder_name.lower(): + self.unet_encoder = get_timm_encoder( + name="convnextv2_tiny.fcmae_ft_in22k_in1k", + in_channels=in_channels, + depth=model_depth, + weights=True, + output_stride=32, + drop_rate=0.5, + drop_path_rate=0.25, + ) + else: + self.unet_encoder = smp.encoders.get_encoder( + unet_encoder_name, + in_channels=in_channels, + depth=model_depth, + weights=unet_encoder_weights, + dropout=0.5, + ) + u_out_channels = self.unet_encoder.out_channels[1:] # [64, 64, 128, 256, 512] + self.IgnoreBottleNeck = IgnoreBottleNeck + seg_cfg = SegformerConfig.from_pretrained(segformer_variant) + # seg_cfg.attention_probs_dropout_prob = 0.5 + seg_cfg.output_hidden_states = True + + ## segformer encoder + self.segformer = SegformerModel.from_pretrained(segformer_variant, config=seg_cfg) + self.register_buffer('segformer_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('segformer_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + s_expected = list(seg_cfg.hidden_sizes[-(model_depth):]) + # Fusion blocks for first 4 skips (index 0 to 3) + self.fusions = nn.ModuleList() + # mid_size = int(input_size/2) + for i in range(model_depth): # i = 0,1,2,3 for skips 0..3 + u_ch = u_out_channels[i] + s_ch = s_expected[i] + if simple_fusion == 0: + self.fusions.append(FusionBlock0()) + if simple_fusion == 1: + self.fusions.append(FusionBlock1(channels = u_ch + s_ch)) + + + if simple_fusion == 0: + encoder_channels_for_decoder = s_expected #+ [u_out_channels[model_depth - 1]] + for i in range(model_depth): + encoder_channels_for_decoder[i] = s_expected[i] + u_out_channels[i] + elif simple_fusion == 1: + encoder_channels_for_decoder = s_expected #+ [u_out_channels[model_depth - 1]] + for i in range(model_depth): + encoder_channels_for_decoder[i] = s_expected[i] + u_out_channels[i] + # Decoder expects 5 skips: 4 fused + 1 bottleneck (last unet encoder output) + + # ----- choose decoder type ----- + if decoder_type.lower() == "unet": + DecoderClass = UnetDecoder + elif decoder_type.lower() in ("unet++", "unetplusplus"): + DecoderClass = UnetPlusPlusDecoder + else: + raise ValueError(f"Unknown decoder_type '{decoder_type}'. Use 'unet' or 'unet++'.") + + if self.IgnoreBottleNeck: + encoder_channels_for_decoder = encoder_channels_for_decoder # keep first 4 (skip0..skip3) + decoder_channels = decoder_channels + print('skip') + n_blocks = model_depth + else: + n_blocks = model_depth + + self.dropout = nn.Dropout2d(p=0.3) + + + + + if self.instance_segmentation: + # Regression hv head + self.segmentation_head1 = smp.base.SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=5, # instance channels + activation=None, + kernel_size=1, + ) + self.decoder1 = DecoderClass( + encoder_channels=[in_channels] + encoder_channels_for_decoder, + decoder_channels=decoder_channels[0:model_depth], + n_blocks=n_blocks, + use_batchnorm=False, + IgnoreBottleNeck=self.IgnoreBottleNeck + ) + + self.decoder2 = DecoderClass( + encoder_channels=[in_channels] + encoder_channels_for_decoder, + decoder_channels=decoder_channels[0:model_depth], + n_blocks=n_blocks, + use_batchnorm=False, + IgnoreBottleNeck=self.IgnoreBottleNeck + ) + + self.segmentation_head2 = smp.base.SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=None, + kernel_size=1, + ) + + # self.segmentation_head1 = nn.Sequential( + # nn.Conv2d(decoder_channels[0:model_depth][-1], 1, kernel_size=3, padding=1), + # nn.ReLU() + # ) + # # Binary segmentation head + # self.segmentation_head2 = nn.Conv2d( + # decoder_channels[0:model_depth][-1], + # 5, + # kernel_size=3, + # padding=1 + # ) + # No activation here — leave logits for the loss function + + def _filter_and_sort_unet_feats(self, u_feats, input_h): + filtered = [f for f in u_feats if f.shape[2] < input_h] + filtered = sorted(filtered, key=lambda t: t.shape[2], reverse=True) + return filtered + + def _sort_segf_feats(self, s_feats): + return sorted(s_feats, key=lambda t: t.shape[2], reverse=True) + + def forward(self, x, debug_print_shapes=False): + cof_seg = self.cof_seg + cof_unet = self.cof_unet + + B, C, H_in, W_in = x.shape + + + ### segformer forward pass + # Normalize first 3 channels for segformer input + # Handle possible extra channels (like depth or others) + x_for_segformer = cof_seg*x[:, :3, :, :] + x_for_segformer = (x_for_segformer - self.segformer_mean) / self.segformer_std + s_all = self.segformer(pixel_values=cof_seg*x).hidden_states + s_feats = s_all[-(self.model_depth):] + s_feats = sorted(s_feats, key=lambda t: t.shape[2], reverse=True) + + + + # resnet forward pass + u_feats_all = self.unet_encoder(cof_unet*x) + u_feats = self._filter_and_sort_unet_feats(u_feats_all, H_in) + + + + + skips = [] + skips.append(x) + # Fuse SegFormer with first 4 U-Net skips + for i in range(self.model_depth): + U = u_feats[i] + S = s_feats[i] + if (S.shape[2] != U.shape[2]) or (S.shape[3] != U.shape[3]): + S = F.interpolate(S, size=(U.shape[2], U.shape[3]), mode="bilinear", align_corners=False) + if debug_print_shapes: + print(f"Fusing skip {i} shapes: U{tuple(U.shape)} S(resized) {tuple(S.shape)}") + U = self.dropout(U) + S = self.dropout(S) + fused = self.fusions[i](U, S) + + skips.append(fused) + + # Add bottleneck skip (last U-Net encoder output) without fusion + # bottleneck = cof_unet*u_feats_all[-1] + # if self.IgnoreBottleNeck: + # bottleneck = torch.empty([2,0,bottleneck.shape[2],bottleneck.shape[3]], device=bottleneck.device) + # skips.append(bottleneck) + + if debug_print_shapes: + print("== Skip tensors provided to decoder ==") + for i, s in enumerate(skips, start=1): + print(f"skip {i}: shape {tuple(s.shape)}") + + dec_out2 = self.decoder2(skips) + # print(f"Decoder output shape before segmentation_head: {dec_out.shape}") + + out_binary = self.segmentation_head2(dec_out2) + + if self.instance_segmentation: + dec_out1 = self.decoder1(skips) + out_hv = self.segmentation_head1(dec_out1) + out = torch.cat([out_hv, out_binary], dim=1) + return out + else: + return out_binary + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = DualEncoderUNet( + unet_encoder_weights="imagenet", + unet_encoder_name='convnext', + segformer_variant="nvidia/segformer-b2-finetuned-ade-512-512", + model_depth=4, + cof_seg=0, + simple_fusion = 1, + IgnoreBottleNeck=True, + ).to('cpu') + + print(model) # print network structure + + x = torch.randn(1, 3, 1024, 1024).to('cpu') + + with torch.no_grad(): + out = model(x, debug_print_shapes=True)