Spaces:
Runtime error
Runtime error
File size: 5,811 Bytes
89564ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from torch import nn
import timm
from hybridnets.model import BiFPN, Regressor, Classifier, BiFPNDecoder
from utils.utils import Anchors
from hybridnets.model import SegmentationHead
from encoders import get_encoder
class HybridNetsBackbone(nn.Module):
def __init__(self, num_classes=80, compound_coef=0, seg_classes=1, backbone_name=None, **kwargs):
super(HybridNetsBackbone, self).__init__()
self.compound_coef = compound_coef
self.seg_classes = seg_classes
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6]
self.anchor_scale = [1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,1.25,]
self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
conv_channel_coef = {
# the channels of P3/P4/P5.
0: [40, 112, 320],
1: [40, 112, 320],
2: [48, 120, 352],
3: [48, 136, 384],
4: [56, 160, 448],
5: [64, 176, 512],
6: [72, 200, 576],
7: [72, 200, 576],
8: [80, 224, 640],
}
num_anchors = len(self.aspect_ratios) * self.num_scales
self.bifpn = nn.Sequential(
*[BiFPN(self.fpn_num_filters[self.compound_coef],
conv_channel_coef[compound_coef],
True if _ == 0 else False,
attention=True if compound_coef < 6 else False,
use_p8=compound_coef > 7)
for _ in range(self.fpn_cell_repeats[compound_coef])])
self.num_classes = num_classes
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
'''Modified by Dat Vu'''
# self.decoder = DecoderModule()
self.bifpndecoder = BiFPNDecoder(pyramid_channels=self.fpn_num_filters[self.compound_coef])
self.segmentation_head = SegmentationHead(
in_channels=64,
out_channels=self.seg_classes+1 if self.seg_classes > 1 else self.seg_classes,
activation='softmax2d' if self.seg_classes > 1 else 'sigmoid',
kernel_size=1,
upsampling=4,
)
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_classes=num_classes,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
**kwargs)
if backbone_name:
# Use timm to create another backbone that you prefer
# https://github.com/rwightman/pytorch-image-models
self.encoder = timm.create_model(backbone_name, pretrained=True, features_only=True, out_indices=(2,3,4)) # P3,P4,P5
else:
# EfficientNet_Pytorch
self.encoder = get_encoder(
'efficientnet-b' + str(self.backbone_compound_coef[compound_coef]),
in_channels=3,
depth=5,
weights='imagenet',
)
self.initialize_decoder(self.bifpndecoder)
self.initialize_head(self.segmentation_head)
self.initialize_decoder(self.bifpn)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def forward(self, inputs):
max_size = inputs.shape[-1]
# p1, p2, p3, p4, p5 = self.backbone_net(inputs)
p2, p3, p4, p5 = self.encoder(inputs)[-4:] # self.backbone_net(inputs)
features = (p3, p4, p5)
features = self.bifpn(features)
p3,p4,p5,p6,p7 = features
outputs = self.bifpndecoder((p2,p3,p4,p5,p6,p7))
segmentation = self.segmentation_head(outputs)
regression = self.regressor(features)
classification = self.classifier(features)
anchors = self.anchors(inputs, inputs.dtype)
return features, regression, classification, anchors, segmentation
def initialize_decoder(self, module):
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def initialize_head(self, module):
for m in module.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
|