|
| 1 | +"""Feature Pyramid Network (FPN) on top of ResNets and task-specific heads on top of it. |
| 2 | +
|
| 3 | +See: |
| 4 | +- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection |
| 5 | +- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance |
| 6 | + and Semantic Segmentation |
| 7 | +
|
| 8 | +""" |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | + |
| 13 | +from torchvision.models import resnet50 |
| 14 | + |
| 15 | + |
| 16 | +class FPN(nn.Module): |
| 17 | + """Feature Pyramid Network (FPN): top-down architecture with lateral connections. |
| 18 | + Can be used as feature extractor for object detection or segmentation. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, num_filters=256, pretrained=True): |
| 22 | + """Creates an `FPN` instance for feature extraction. |
| 23 | +
|
| 24 | + Args: |
| 25 | + num_filters: the number of filters in each output pyramid level |
| 26 | + pretrained: use ImageNet pre-trained backbone feature extractor |
| 27 | + """ |
| 28 | + |
| 29 | + super().__init__() |
| 30 | + |
| 31 | + self.resnet = resnet50(pretrained=pretrained) |
| 32 | + |
| 33 | + self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool) |
| 34 | + self.enc1 = self.resnet.layer1 # 256 |
| 35 | + self.enc2 = self.resnet.layer2 # 512 |
| 36 | + self.enc3 = self.resnet.layer3 # 1024 |
| 37 | + self.enc4 = self.resnet.layer4 # 2048 |
| 38 | + |
| 39 | + self.lateral4 = nn.Conv2d(2048, num_filters, kernel_size=1, bias=False) |
| 40 | + self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False) |
| 41 | + self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False) |
| 42 | + self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False) |
| 43 | + |
| 44 | + self.smooth4 = Conv3x3(num_filters, num_filters) |
| 45 | + self.smooth3 = Conv3x3(num_filters, num_filters) |
| 46 | + self.smooth2 = Conv3x3(num_filters, num_filters) |
| 47 | + self.smooth1 = Conv3x3(num_filters, num_filters) |
| 48 | + |
| 49 | + def forward(self, x): |
| 50 | + # Bottom-up pathway, from ResNet |
| 51 | + |
| 52 | + enc0 = self.enc0(x) |
| 53 | + enc1 = self.enc1(enc0) # 256 |
| 54 | + enc2 = self.enc2(enc1) # 512 |
| 55 | + enc3 = self.enc3(enc2) # 1024 |
| 56 | + enc4 = self.enc4(enc3) # 2048 |
| 57 | + |
| 58 | + # Lateral connections |
| 59 | + |
| 60 | + lateral4 = self.lateral4(enc4) |
| 61 | + lateral3 = self.lateral3(enc3) |
| 62 | + lateral2 = self.lateral2(enc2) |
| 63 | + lateral1 = self.lateral1(enc1) |
| 64 | + |
| 65 | + # Top-down pathway |
| 66 | + |
| 67 | + map4 = lateral4 |
| 68 | + map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest") |
| 69 | + map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest") |
| 70 | + map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest") |
| 71 | + |
| 72 | + # Reduce aliasing effect of upsampling |
| 73 | + |
| 74 | + map4 = self.smooth4(map4) |
| 75 | + map3 = self.smooth3(map3) |
| 76 | + map2 = self.smooth2(map2) |
| 77 | + map1 = self.smooth1(map1) |
| 78 | + |
| 79 | + return map1, map2, map3, map4 |
| 80 | + |
| 81 | + |
| 82 | +class FPNSegmentation(nn.Module): |
| 83 | + """Semantic segmentation model on top of a Feature Pyramid Network (FPN). |
| 84 | + """ |
| 85 | + |
| 86 | + def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True): |
| 87 | + """Creates an `FPNSegmentation` instance for feature extraction. |
| 88 | +
|
| 89 | + Args: |
| 90 | + num_classes: number of classes to predict |
| 91 | + num_filters: the number of filters in each segmentation head pyramid level |
| 92 | + num_filters_fpn: the number of filters in each FPN output pyramid level |
| 93 | + pretrained: use ImageNet pre-trained backbone feature extractor |
| 94 | + """ |
| 95 | + |
| 96 | + super().__init__() |
| 97 | + |
| 98 | + # Feature Pyramid Network (FPN) with four feature maps of resolutions |
| 99 | + # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. |
| 100 | + |
| 101 | + self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained) |
| 102 | + |
| 103 | + # The segmentation heads on top of the FPN |
| 104 | + |
| 105 | + self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters)) |
| 106 | + self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters)) |
| 107 | + self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters)) |
| 108 | + self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters)) |
| 109 | + |
| 110 | + self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1) |
| 111 | + |
| 112 | + def forward(self, x): |
| 113 | + map1, map2, map3, map4 = self.fpn(x) |
| 114 | + |
| 115 | + map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") |
| 116 | + map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") |
| 117 | + map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") |
| 118 | + map1 = self.head1(map1) |
| 119 | + |
| 120 | + final = self.final(torch.cat([map4, map3, map2, map1], dim=1)) |
| 121 | + |
| 122 | + return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False) |
| 123 | + |
| 124 | + |
| 125 | +class Conv1x1(nn.Module): |
| 126 | + def __init__(self, num_in, num_out): |
| 127 | + super().__init__() |
| 128 | + self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False) |
| 129 | + |
| 130 | + def forward(self, x): |
| 131 | + return self.block(x) |
| 132 | + |
| 133 | + |
| 134 | +class Conv3x3(nn.Module): |
| 135 | + def __init__(self, num_in, num_out): |
| 136 | + super().__init__() |
| 137 | + self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False) |
| 138 | + |
| 139 | + def forward(self, x): |
| 140 | + return self.block(x) |
0 commit comments