pipeline_server / detectors /retinanet /retinanet_fusion.py
zy984764389's picture
Upload folder using huggingface_hub
4ba6fde verified
import torch
import torch.nn as nn
from detectors.retinanet.fpn import FPN101, FPN50
from torch.autograd import Variable
class RetinaNet(nn.Module):
def __init__(self, num_classes=20,num_anchors=5):
super(RetinaNet, self).__init__()
self.num_anchors = num_anchors
self.fpn = FPN50()
self.num_classes = num_classes
self.p4_layer = self._fusion_head(128)
self.p3_layer = self._fusion_head(128)
self.loc_head = self._make_head(self.num_anchors*4)
self.cls_head = self._make_head(self.num_anchors*self.num_classes)
self.up_sample_layer = self.up_sample()
self.down_sample_layer = self.down_sample()
def forward(self, x):
fms = self.fpn(x)
fms = fms[:2]
#fms = [fms]#one layer
p4 = self.p4_layer(fms[1])
#p4 = nn.functional.interpolate(p4,scale_factor=2, mode='bicubic', align_corners=True)
p3 = self.p3_layer(fms[0])
p3 = self.down_sample_layer(p3)
#print (p3.shape,p4.shape)
fusion = torch.cat([p3,p4],1)
loc_pred = self.loc_head(fusion).permute(0,2,3,1).contiguous().view(x.size(0),-1,4) # [N, 9*4,H,W] -> [N,H,W, 9*4] -> [N,H*W*9, 4]
cls_pred = self.cls_head(fusion).permute(0,2,3,1).contiguous().view(x.size(0),-1,self.num_classes) # [N,9*20,H,W] -> [N,H,W,9*20] -> [N,H*W*9,20]
return loc_pred,cls_pred
def up_sample(self):
up = nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1)
return up
def down_sample(self):
layer = nn.Conv2d(128, 128, 3, stride=2, padding=1)
return layer
def _make_head(self, out_planes):
layers = []
for _ in range(4):
layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
layers.append(nn.ReLU(True))
layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))
return nn.Sequential(*layers)
def _fusion_head(self,out_planes):
layers = []
for _ in range(2):
layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
layers.append(nn.ReLU(True))
layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))
return nn.Sequential(*layers)
def freeze_bn(self):
'''Freeze BatchNorm layers.'''
for layer in self.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.eval()
def test():
img = cv2.imread(path_img, 1)
net = RetinaNet()
loc_preds, cls_preds = net(Variable(torch.randn(2,3,224,224)))
print(loc_preds.size())
print(cls_preds.size())
import numpy as np
from PIL import Image
def img_transform(img_in, transform):
"""
将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W
:param img_roi: np.array
:return:
"""
img = img_in.copy()
img = Image.fromarray(np.uint8(img))
img = transform(img)
img = img.unsqueeze(0) # C*H*W --> B*C*H*W
return img
if __name__ == '__main__':
import cv2
import torchvision.transforms as transforms
net = RetinaNet(2,3)
model_dir = '/home/robert/Models/Retinanet_inference_example/checkpoint/drone_collection_KNN_15m/'
print (net._modules.keys(),net.fpn._modules)
#net.load_state_dict(torch.load(model_dir+'final_model_dict.pkl'))
net = torch.load(model_dir+'final_model.pkl')
torch.save(net.state_dict(), model_dir+'extra_dict.pkl')
print (net.state_dict().keys())