Spaces:
Running
on
T4
Running
on
T4
File size: 7,258 Bytes
d299fb7 d4b77ac d299fb7 6101b7d d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from SiamMask.models.siammask_sharp import SiamMask
from SiamMask.models.features import MultiStageFeature
from SiamMask.models.rpn import RPN, DepthCorr
from SiamMask.models.mask import Mask
import torch
import torch.nn as nn
import torch.nn.functional as F
from SiamMask.utils.load_helper import load_pretrain
from SiamMask.experiments.siammask_sharp.resnet import resnet50
class ResDownS(nn.Module):
def __init__(self, inplane, outplane):
super(ResDownS, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(inplane, outplane, kernel_size=1, bias=False),
nn.BatchNorm2d(outplane))
def forward(self, x):
x = self.downsample(x)
if x.size(3) < 20:
l = 4
r = -4
x = x[:, :, l:r, l:r]
return x
class ResDown(MultiStageFeature):
def __init__(self, pretrain=False):
super(ResDown, self).__init__()
self.features = resnet50(layer3=True, layer4=False)
if pretrain:
load_pretrain(self.features, 'resnet.model')
self.downsample = ResDownS(1024, 256)
self.layers = [self.downsample, self.features.layer2, self.features.layer3]
self.train_nums = [1, 3]
self.change_point = [0, 0.5]
self.unfix(0.0)
def param_groups(self, start_lr, feature_mult=1):
lr = start_lr * feature_mult
def _params(module, mult=1):
params = list(filter(lambda x:x.requires_grad, module.parameters()))
if len(params):
return [{'params': params, 'lr': lr * mult}]
else:
return []
groups = []
groups += _params(self.downsample)
groups += _params(self.features, 0.1)
return groups
def forward(self, x):
output = self.features(x)
p3 = self.downsample(output[-1])
return p3
def forward_all(self, x):
output = self.features(x)
p3 = self.downsample(output[-1])
return output, p3
class UP(RPN):
def __init__(self, anchor_num=5, feature_in=256, feature_out=256):
super(UP, self).__init__()
self.anchor_num = anchor_num
self.feature_in = feature_in
self.feature_out = feature_out
self.cls_output = 2 * self.anchor_num
self.loc_output = 4 * self.anchor_num
self.cls = DepthCorr(feature_in, feature_out, self.cls_output)
self.loc = DepthCorr(feature_in, feature_out, self.loc_output)
def forward(self, z_f, x_f):
cls = self.cls(z_f, x_f)
loc = self.loc(z_f, x_f)
return cls, loc
class MaskCorr(Mask):
def __init__(self, oSz=63):
super(MaskCorr, self).__init__()
self.oSz = oSz
self.mask = DepthCorr(256, 256, self.oSz**2)
def forward(self, z, x):
return self.mask(z, x)
class Refine(nn.Module):
def __init__(self):
super(Refine, self).__init__()
self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(),
nn.Conv2d(16, 4, 3, padding=1),nn.ReLU())
self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 16, 3, padding=1), nn.ReLU())
self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(),
nn.Conv2d(128, 32, 3, padding=1), nn.ReLU())
self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),
nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(),
nn.Conv2d(4, 4, 3, padding=1), nn.ReLU())
self.deconv = nn.ConvTranspose2d(256, 32, 15, 15)
self.post0 = nn.Conv2d(32, 16, 3, padding=1)
self.post1 = nn.Conv2d(16, 4, 3, padding=1)
self.post2 = nn.Conv2d(4, 1, 3, padding=1)
for modules in [self.v0, self.v1, self.v2, self.h2, self.h1, self.h0, self.deconv, self.post0, self.post1, self.post2,]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, f, corr_feature, pos=None, test=False):
if test:
p0 = torch.nn.functional.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]
p1 = torch.nn.functional.pad(f[1], [8, 8, 8, 8])[:, :, 2 * pos[0]:2 * pos[0] + 31, 2 * pos[1]:2 * pos[1] + 31]
p2 = torch.nn.functional.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0] + 15, pos[1]:pos[1] + 15]
else:
p0 = F.unfold(f[0], (61, 61), padding=0, stride=4).permute(0, 2, 1).contiguous().view(-1, 64, 61, 61)
if not (pos is None): p0 = torch.index_select(p0, 0, pos)
p1 = F.unfold(f[1], (31, 31), padding=0, stride=2).permute(0, 2, 1).contiguous().view(-1, 256, 31, 31)
if not (pos is None): p1 = torch.index_select(p1, 0, pos)
p2 = F.unfold(f[2], (15, 15), padding=0, stride=1).permute(0, 2, 1).contiguous().view(-1, 512, 15, 15)
if not (pos is None): p2 = torch.index_select(p2, 0, pos)
if not(pos is None):
p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1)
else:
p3 = corr_feature.permute(0, 2, 3, 1).contiguous().view(-1, 256, 1, 1)
out = self.deconv(p3)
out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31)))
out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61)))
out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127)))
out = out.view(-1, 127*127)
return out
def param_groups(self, start_lr, feature_mult=1):
params = filter(lambda x:x.requires_grad, self.parameters())
params = [{'params': params, 'lr': start_lr * feature_mult}]
return params
class Custom(SiamMask):
def __init__(self, pretrain=False, **kwargs):
super(Custom, self).__init__(**kwargs)
self.features = ResDown(pretrain=pretrain)
self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)
self.mask_model = MaskCorr()
self.refine_model = Refine()
def refine(self, f, pos=None):
return self.refine_model(f, pos)
def template(self, template):
self.zf = self.features(template)
def track(self, search):
search = self.features(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, search)
return rpn_pred_cls, rpn_pred_loc
def track_mask(self, search):
self.feature, self.search = self.features.forward_all(search)
rpn_pred_cls, rpn_pred_loc = self.rpn(self.zf, self.search)
self.corr_feature = self.mask_model.mask.forward_corr(self.zf, self.search)
pred_mask = self.mask_model.mask.head(self.corr_feature)
return rpn_pred_cls, rpn_pred_loc, pred_mask
def track_refine(self, pos):
pred_mask = self.refine_model(self.feature, self.corr_feature, pos=pos, test=True)
return pred_mask
|