Skip to content

Commit 784d63b

Browse files
committed
Implements Feature Pyramid Network (FPN), closes #60
1 parent 6a671a0 commit 784d63b

File tree

6 files changed

+150
-15
lines changed

6 files changed

+150
-15
lines changed

robosat/fpn.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Feature Pyramid Network (FPN) on top of ResNets and task-specific heads on top of it.
2+
3+
See:
4+
- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection
5+
- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance
6+
and Semantic Segmentation
7+
8+
"""
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from torchvision.models import resnet50
14+
15+
16+
class FPN(nn.Module):
17+
"""Feature Pyramid Network (FPN): top-down architecture with lateral connections.
18+
Can be used as feature extractor for object detection or segmentation.
19+
"""
20+
21+
def __init__(self, num_filters=256, pretrained=True):
22+
"""Creates an `FPN` instance for feature extraction.
23+
24+
Args:
25+
num_filters: the number of filters in each output pyramid level
26+
pretrained: use ImageNet pre-trained backbone feature extractor
27+
"""
28+
29+
super().__init__()
30+
31+
self.resnet = resnet50(pretrained=pretrained)
32+
33+
self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
34+
self.enc1 = self.resnet.layer1 # 256
35+
self.enc2 = self.resnet.layer2 # 512
36+
self.enc3 = self.resnet.layer3 # 1024
37+
self.enc4 = self.resnet.layer4 # 2048
38+
39+
self.lateral4 = nn.Conv2d(2048, num_filters, kernel_size=1, bias=False)
40+
self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
41+
self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False)
42+
self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False)
43+
44+
self.smooth4 = Conv3x3(num_filters, num_filters)
45+
self.smooth3 = Conv3x3(num_filters, num_filters)
46+
self.smooth2 = Conv3x3(num_filters, num_filters)
47+
self.smooth1 = Conv3x3(num_filters, num_filters)
48+
49+
def forward(self, x):
50+
# Bottom-up pathway, from ResNet
51+
52+
enc0 = self.enc0(x)
53+
enc1 = self.enc1(enc0) # 256
54+
enc2 = self.enc2(enc1) # 512
55+
enc3 = self.enc3(enc2) # 1024
56+
enc4 = self.enc4(enc3) # 2048
57+
58+
# Lateral connections
59+
60+
lateral4 = self.lateral4(enc4)
61+
lateral3 = self.lateral3(enc3)
62+
lateral2 = self.lateral2(enc2)
63+
lateral1 = self.lateral1(enc1)
64+
65+
# Top-down pathway
66+
67+
map4 = lateral4
68+
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
69+
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
70+
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
71+
72+
# Reduce aliasing effect of upsampling
73+
74+
map4 = self.smooth4(map4)
75+
map3 = self.smooth3(map3)
76+
map2 = self.smooth2(map2)
77+
map1 = self.smooth1(map1)
78+
79+
return map1, map2, map3, map4
80+
81+
82+
class FPNSegmentation(nn.Module):
83+
"""Semantic segmentation model on top of a Feature Pyramid Network (FPN).
84+
"""
85+
86+
def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True):
87+
"""Creates an `FPNSegmentation` instance for feature extraction.
88+
89+
Args:
90+
num_classes: number of classes to predict
91+
num_filters: the number of filters in each segmentation head pyramid level
92+
num_filters_fpn: the number of filters in each FPN output pyramid level
93+
pretrained: use ImageNet pre-trained backbone feature extractor
94+
"""
95+
96+
super().__init__()
97+
98+
# Feature Pyramid Network (FPN) with four feature maps of resolutions
99+
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
100+
101+
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
102+
103+
# The segmentation heads on top of the FPN
104+
105+
self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
106+
self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
107+
self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
108+
self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
109+
110+
self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1)
111+
112+
def forward(self, x):
113+
map1, map2, map3, map4 = self.fpn(x)
114+
115+
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
116+
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
117+
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
118+
map1 = self.head1(map1)
119+
120+
final = self.final(torch.cat([map4, map3, map2, map1], dim=1))
121+
122+
return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False)
123+
124+
125+
class Conv1x1(nn.Module):
126+
def __init__(self, num_in, num_out):
127+
super().__init__()
128+
self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
129+
130+
def forward(self, x):
131+
return self.block(x)
132+
133+
134+
class Conv3x3(nn.Module):
135+
def __init__(self, num_in, num_out):
136+
super().__init__()
137+
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)
138+
139+
def forward(self, x):
140+
return self.block(x)

robosat/tools/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.autograd
66

77
from robosat.config import load_config
8-
from robosat.unet import UNet
8+
from robosat.fpn import FPNSegmentation
99

1010

1111
def add_parser(subparser):
@@ -25,7 +25,7 @@ def main(args):
2525
dataset = load_config(args.dataset)
2626

2727
num_classes = len(dataset["common"]["classes"])
28-
net = UNet(num_classes)
28+
net = FPNSegmentation(num_classes)
2929

3030
chkpt = torch.load(args.checkpoint, map_location="cpu")
3131
net.load_state_dict(chkpt)

robosat/tools/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from PIL import Image
1515

1616
from robosat.datasets import BufferedSlippyMapDirectory
17-
from robosat.unet import UNet
17+
from robosat.fpn import FPNSegmenation
1818
from robosat.config import load_config
1919
from robosat.colors import continuous_palette_for_color
2020
from robosat.transforms import ConvertImageMode, ImageToTensor
@@ -59,7 +59,7 @@ def map_location(storage, _):
5959
# https://github.com/pytorch/pytorch/issues/7178
6060
chkpt = torch.load(args.checkpoint, map_location=map_location)
6161

62-
net = UNet(num_classes).to(device)
62+
net = FPNSegmenation(num_classes).to(device)
6363
net = nn.DataParallel(net)
6464

6565
if cuda:

robosat/tools/serve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from flask import Flask, send_file, render_template, abort
1717

1818
from robosat.tiles import fetch_image
19-
from robosat.unet import UNet
19+
from robosat.fpn import FPNSegmenation
2020
from robosat.config import load_config
2121
from robosat.colors import make_palette
2222
from robosat.transforms import ConvertImageMode, ImageToTensor
@@ -180,7 +180,7 @@ def map_location(storage, _):
180180

181181
num_classes = len(self.dataset["common"]["classes"])
182182

183-
net = UNet(num_classes).to(self.device)
183+
net = FPNSegmenation(num_classes).to(self.device)
184184
net = nn.DataParallel(net)
185185

186186
if self.cuda:

robosat/tools/train.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from robosat.datasets import SlippyMapTilesConcatenation
2727
from robosat.metrics import MeanIoU
2828
from robosat.losses import CrossEntropyLoss2d
29-
from robosat.unet import UNet
29+
from robosat.fpn import FPNSegmentation
3030
from robosat.utils import plot
3131
from robosat.config import load_config
3232

@@ -51,24 +51,18 @@ def main(args):
5151
if model["common"]["cuda"] and not torch.cuda.is_available():
5252
sys.exit("Error: CUDA requested but not available")
5353

54-
# if args.batch_size < 2:
55-
# sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')
56-
5754
os.makedirs(model["common"]["checkpoint"], exist_ok=True)
5855

5956
num_classes = len(dataset["common"]["classes"])
60-
net = UNet(num_classes).to(device)
57+
net = FPNSegmentation(num_classes).to(device)
6158

6259
if model["common"]["cuda"]:
6360
torch.backends.cudnn.benchmark = True
6461
net = DataParallel(net)
6562

6663
optimizer = Adam(net.parameters(), lr=model["opt"]["lr"], weight_decay=model["opt"]["decay"])
67-
68-
weight = torch.Tensor(dataset["weights"]["values"])
69-
7064
criterion = CrossEntropyLoss2d(weight=weight).to(device)
71-
# criterion = FocalLoss2d(weight=weight).to(device)
65+
weight = torch.Tensor(dataset["weights"]["values"])
7266

7367
train_loader, val_loader = get_dataset_loaders(model, dataset)
7468

robosat/unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
8484
8585
Args:
8686
num_classes: number of classes to predict.
87+
num_filters: the number of filters for the decoder block
8788
pretrained: use ImageNet pre-trained backbone feature extractor
8889
"""
8990

0 commit comments

Comments
 (0)