|
import math |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torchvision.models import resnet50, ResNet50_Weights |
|
|
|
import myutils |
|
|
|
|
|
class ResBlock(nn.Module): |
|
"""A simple residual block component.""" |
|
def __init__(self, indim, outdim=None, stride=1): |
|
super(ResBlock, self).__init__() |
|
outdim = outdim or indim |
|
self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1, stride=stride) |
|
self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) |
|
self.downsample = nn.Conv2d(indim, outdim, kernel_size=1, stride=stride) if indim != outdim or stride != 1 else None |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = F.relu(self.conv1(x)) |
|
out = self.conv2(out) |
|
if self.downsample: |
|
identity = self.downsample(identity) |
|
out += identity |
|
return F.relu(out) |
|
|
|
|
|
class EncoderM(nn.Module): |
|
def __init__(self, load_imagenet_params): |
|
super(EncoderM, self).__init__() |
|
self.conv1_m = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.conv1_o = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
|
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None |
|
resnet = resnet50(weights=weights) |
|
self.conv1 = resnet.conv1 |
|
self.bn1 = resnet.bn1 |
|
self.relu = resnet.relu |
|
self.maxpool = resnet.maxpool |
|
|
|
self.res2 = resnet.layer1 |
|
self.res3 = resnet.layer2 |
|
self.res4 = resnet.layer3 |
|
|
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def forward(self, in_f, in_m, in_o): |
|
f = (in_f - self.mean) / self.std |
|
|
|
x = self.conv1(f) + self.conv1_m(in_m) + self.conv1_o(in_o) |
|
x = self.bn1(x) |
|
r1 = self.relu(x) |
|
x = self.maxpool(r1) |
|
r2 = self.res2(x) |
|
r3 = self.res3(r2) |
|
r4 = self.res4(r3) |
|
|
|
return r4, r1 |
|
|
|
|
|
class EncoderQ(nn.Module): |
|
def __init__(self, load_imagenet_params): |
|
super(EncoderQ, self).__init__() |
|
weights = ResNet50_Weights.IMAGENET1K_V1 if load_imagenet_params else None |
|
resnet = resnet50(weights=weights) |
|
self.conv1 = resnet.conv1 |
|
self.bn1 = resnet.bn1 |
|
self.relu = resnet.relu |
|
self.maxpool = resnet.maxpool |
|
|
|
self.res2 = resnet.layer1 |
|
self.res3 = resnet.layer2 |
|
self.res4 = resnet.layer3 |
|
|
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def forward(self, in_f): |
|
f = (in_f - self.mean) / self.std |
|
|
|
x = self.conv1(f) |
|
x = self.bn1(x) |
|
r1 = self.relu(x) |
|
x = self.maxpool(r1) |
|
r2 = self.res2(x) |
|
r3 = self.res3(r2) |
|
r4 = self.res4(r3) |
|
|
|
return r4, r3, r2, r1 |
|
|
|
|
|
class KeyValue(nn.Module): |
|
|
|
def __init__(self, indim, keydim, valdim): |
|
super(KeyValue, self).__init__() |
|
self.keydim = keydim |
|
self.valdim = valdim |
|
self.Key = nn.Conv2d(indim, keydim, kernel_size=(3, 3), padding=(1, 1), stride=1) |
|
self.Value = nn.Conv2d(indim, valdim, kernel_size=(3, 3), padding=(1, 1), stride=1) |
|
|
|
def forward(self, x): |
|
key = self.Key(x) |
|
key = key.view(*key.shape[:2], -1) |
|
|
|
val = self.Value(x) |
|
val = val.view(*val.shape[:2], -1) |
|
return key, val |
|
|
|
|
|
class Refine(nn.Module): |
|
def __init__(self, inplanes, planes): |
|
super(Refine, self).__init__() |
|
self.convFS = nn.Conv2d(inplanes, planes, kernel_size=(3, 3), padding=(1, 1), stride=1) |
|
self.ResFS = ResBlock(planes, planes) |
|
self.ResMM = ResBlock(planes, planes) |
|
self.scale_factor = 2 |
|
|
|
def forward(self, f, pm): |
|
s = self.ResFS(self.convFS(f)) |
|
m = s + F.interpolate(pm, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) |
|
m = self.ResMM(m) |
|
|
|
return m |
|
|
|
|
|
class Matcher(nn.Module): |
|
def __init__(self, thres_valid=1e-3, update_bank=False): |
|
super(Matcher, self).__init__() |
|
self.thres_valid = thres_valid |
|
self.update_bank = update_bank |
|
|
|
def forward(self, feature_bank, q_in, q_out): |
|
|
|
mem_out_list = [] |
|
|
|
for i in range(0, feature_bank.obj_n): |
|
d_key, bank_n = feature_bank.keys[i].size() |
|
|
|
try: |
|
p = torch.matmul(feature_bank.keys[i].transpose(0, 1), q_in) / math.sqrt(d_key) |
|
p = F.softmax(p, dim=1) |
|
mem = torch.matmul(feature_bank.values[i], p) |
|
except RuntimeError as e: |
|
device = feature_bank.keys[i].device |
|
key_cpu = feature_bank.keys[i].cpu() |
|
value_cpu = feature_bank.values[i].cpu() |
|
q_in_cpu = q_in.cpu() |
|
|
|
p = torch.matmul(key_cpu.transpose(0, 1), q_in_cpu) / math.sqrt(d_key) |
|
p = F.softmax(p, dim=1) |
|
mem = torch.matmul(value_cpu, p).to(device) |
|
p = p.to(device) |
|
print('\tLine 158. GPU out of memory, use CPU', f'p size: {p.shape}') |
|
|
|
mem_out_list.append(torch.cat([mem, q_out], dim=1)) |
|
|
|
if self.update_bank: |
|
try: |
|
ones = torch.ones_like(p) |
|
zeros = torch.zeros_like(p) |
|
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0] |
|
except RuntimeError as e: |
|
device = p.device |
|
p = p.cpu() |
|
ones = torch.ones_like(p) |
|
zeros = torch.zeros_like(p) |
|
bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device) |
|
print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}') |
|
|
|
feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1) |
|
|
|
mem_out_tensor = torch.stack(mem_out_list, dim=0).transpose(0, 1) |
|
|
|
return mem_out_tensor |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, device): |
|
super(Decoder, self).__init__() |
|
|
|
self.device = device |
|
mdim_global = 256 |
|
mdim_local = 32 |
|
local_size = 7 |
|
|
|
|
|
self.convFM = nn.Conv2d(1024, mdim_global, kernel_size=3, padding=1, stride=1) |
|
self.ResMM = ResBlock(mdim_global, mdim_global) |
|
self.RF3 = Refine(512, mdim_global) |
|
self.RF2 = Refine(256, mdim_global) |
|
self.pred2 = nn.Conv2d(mdim_global, 2, kernel_size=3, padding=1, stride=1) |
|
|
|
|
|
self.local_avg = nn.AvgPool2d(local_size, stride=1, padding=local_size // 2) |
|
self.local_max = nn.MaxPool2d(local_size, stride=1, padding=local_size // 2) |
|
self.local_convFM = nn.Conv2d(128, mdim_local, kernel_size=3, padding=1, stride=1) |
|
self.local_ResMM = ResBlock(mdim_local, mdim_local) |
|
self.local_pred2 = nn.Conv2d(mdim_local, 2, kernel_size=3, padding=1, stride=1) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
|
def forward(self, patch_match, r3, r2, r1=None, feature_shape=None): |
|
p = self.ResMM(self.convFM(patch_match)) |
|
p = self.RF3(r3, p) |
|
p = self.RF2(r2, p) |
|
p = self.pred2(F.relu(p)) |
|
|
|
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False) |
|
|
|
bs, obj_n, h, w = feature_shape |
|
rough_seg = F.softmax(p, dim=1)[:, 1] |
|
rough_seg = rough_seg.view(bs, obj_n, h, w) |
|
rough_seg = F.softmax(rough_seg, dim=1) |
|
|
|
|
|
uncertainty = myutils.calc_uncertainty(rough_seg) |
|
uncertainty = uncertainty.expand(-1, obj_n, -1, -1).reshape(bs * obj_n, 1, h, w) |
|
|
|
rough_seg = rough_seg.view(bs * obj_n, 1, h, w) |
|
r1_weighted = r1 * rough_seg |
|
r1_local = self.local_avg(r1_weighted) |
|
r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8) |
|
r1_conf = self.local_max(rough_seg) |
|
|
|
local_match = torch.cat([r1, r1_local], dim=1) |
|
q = self.local_ResMM(self.local_convFM(local_match)) |
|
q = r1_conf * self.local_pred2(F.relu(q)) |
|
|
|
p = p + uncertainty * q |
|
p = F.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False) |
|
p = F.softmax(p, dim=1)[:, 1] |
|
|
|
return p |
|
|
|
|
|
class AFB_URR(nn.Module): |
|
def __init__(self, device, update_bank, load_imagenet_params=False): |
|
super(AFB_URR, self).__init__() |
|
|
|
self.device = device |
|
self.encoder_m = EncoderM(load_imagenet_params) |
|
self.encoder_q = EncoderQ(load_imagenet_params) |
|
|
|
self.keyval_r4 = KeyValue(1024, keydim=128, valdim=512) |
|
|
|
self.global_matcher = Matcher(update_bank=update_bank) |
|
self.decoder = Decoder(device) |
|
|
|
def memorize(self, frame, mask): |
|
|
|
_, K, H, W = mask.shape |
|
|
|
(frame, mask), pad = myutils.pad_divide_by([frame, mask], 16, (frame.size()[2], frame.size()[3])) |
|
|
|
frame = frame.expand(K, -1, -1, -1) |
|
mask = mask[0].unsqueeze(1).float() |
|
mask_ones = torch.ones_like(mask) |
|
mask_inv = (mask_ones - mask).clamp(0, 1) |
|
|
|
r4, r1 = self.encoder_m(frame, mask, mask_inv) |
|
|
|
k4, v4 = self.keyval_r4(r4) |
|
k4_list = [k4[i] for i in range(K)] |
|
v4_list = [v4[i] for i in range(K)] |
|
|
|
return k4_list, v4_list |
|
|
|
def segment(self, frame, fb_global): |
|
|
|
obj_n = fb_global.obj_n |
|
|
|
if not self.training: |
|
[frame], pad = myutils.pad_divide_by([frame], 16, (frame.size()[2], frame.size()[3])) |
|
|
|
r4, r3, r2, r1 = self.encoder_q(frame) |
|
bs, _, global_match_h, global_match_w = r4.shape |
|
_, _, local_match_h, local_match_w = r1.shape |
|
|
|
k4, v4 = self.keyval_r4(r4) |
|
res_global = self.global_matcher(fb_global, k4, v4) |
|
res_global = res_global.reshape(bs * obj_n, v4.shape[1] * 2, global_match_h, global_match_w) |
|
|
|
r3_size = r3.shape |
|
r2_size = r2.shape |
|
r3 = r3.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r3_size[1:]) |
|
r2 = r2.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r2_size[1:]) |
|
|
|
r1_size = r1.shape |
|
r1 = r1.unsqueeze(1).expand(-1, obj_n, -1, -1, -1).reshape(bs * obj_n, *r1_size[1:]) |
|
feature_size = (bs, obj_n, r1_size[2], r1_size[3]) |
|
score = self.decoder(res_global, r3, r2, r1, feature_size) |
|
|
|
|
|
score = score.view(bs, obj_n, *frame.shape[-2:]) |
|
|
|
if self.training: |
|
uncertainty = myutils.calc_uncertainty(F.softmax(score, dim=1)) |
|
uncertainty = uncertainty.view(bs, -1).norm(p=2, dim=1) / math.sqrt(frame.shape[-2] * frame.shape[-1]) |
|
uncertainty = uncertainty.mean() |
|
else: |
|
uncertainty = None |
|
|
|
score = torch.clamp(score, 1e-7, 1 - 1e-7) |
|
score = torch.log((score / (1 - score))) |
|
|
|
if not self.training: |
|
if pad[2] + pad[3] > 0: |
|
score = score[:, :, pad[2]:-pad[3], :] |
|
if pad[0] + pad[1] > 0: |
|
score = score[:, :, :, pad[0]:-pad[1]] |
|
|
|
return score, uncertainty |
|
|
|
def forward(self, x): |
|
pass |
|
|