incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from torch.autograd import Variable
import torchvision.utils as vutils
class CNN(nn.Module):
def __init__(self,layer,matrixSize=32):
super(CNN,self).__init__()
# 256x64x64
if(layer == 'r31'):
self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(128,64,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(64,matrixSize,3,1,1))
elif(layer == 'r41'):
# 512x32x32
self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(256,128,3,1,1),
nn.ReLU(inplace=True),
nn.Conv2d(128,matrixSize,3,1,1))
self.fc = nn.Linear(32*32,32*32)
def forward(self,x,masks,style=False):
color_code_number = 9
xb,xc,xh,xw = x.size()
x = x.view(xc,-1)
feature_sub_mean = x.clone()
for i in range(color_code_number):
mask = masks[i].clone().squeeze(0)
mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST)
mask = torch.FloatTensor(mask)
mask = mask.long()
if(torch.sum(mask) >= 10):
mask = mask.view(-1)
# dilation here
"""
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
mask = mask.cpu().numpy()
mask = cv2.dilate(mask.astype(np.float32), kernel)
mask = torch.from_numpy(mask)
mask = mask.squeeze()
"""
fgmask = (mask>0).nonzero().squeeze(1)
fgmask = fgmask.cuda()
selectFeature = torch.index_select(x,1,fgmask) # 32x96
# subtract mean
f_mean = torch.mean(selectFeature,1)
f_mean = f_mean.unsqueeze(1).expand_as(selectFeature)
selectFeature = selectFeature - f_mean
feature_sub_mean.index_copy_(1,fgmask,selectFeature)
feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw))
# 32x16x16
b,c,h,w = feature.size()
transMatrices = {}
feature = feature.view(c,-1)
for i in range(color_code_number):
mask = masks[i].clone().squeeze(0)
mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST)
mask = torch.FloatTensor(mask)
mask = mask.long()
if(torch.sum(mask) >= 10):
mask = mask.view(-1)
fgmask = Variable((mask==1).nonzero().squeeze(1))
fgmask = fgmask.cuda()
selectFeature = torch.index_select(feature,1,fgmask) # 32x96
tc,tN = selectFeature.size()
covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN)
transmatrix = self.fc(covMatrix.view(-1))
transMatrices[i] = transmatrix
return transMatrices,feature_sub_mean
class MulLayer(nn.Module):
def __init__(self,layer,matrixSize=32):
super(MulLayer,self).__init__()
self.snet = CNN(layer)
self.cnet = CNN(layer)
self.matrixSize = matrixSize
if(layer == 'r41'):
self.compress = nn.Conv2d(512,matrixSize,1,1,0)
self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
elif(layer == 'r31'):
self.compress = nn.Conv2d(256,matrixSize,1,1,0)
self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
def forward(self,cF,sF,cmasks,smasks):
sb,sc,sh,sw = sF.size()
sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True)
cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False)
compress_content = self.compress(cF_sub_mean.view(cF.size()))
cb,cc,ch,cw = compress_content.size()
compress_content = compress_content.view(cc,-1)
transfeature = compress_content.clone()
color_code_number = 9
finalSMean = Variable(torch.zeros(cF.size()).cuda(0))
finalSMean = finalSMean.view(sc,-1)
for i in range(color_code_number):
cmask = cmasks[i].clone().squeeze(0)
smask = smasks[i].clone().squeeze(0)
cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST)
cmask = torch.FloatTensor(cmask)
cmask = cmask.long()
smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST)
smask = torch.FloatTensor(smask)
smask = smask.long()
if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10
and (i in sMatrices) and (i in cMatrices)):
cmask = cmask.view(-1)
fgcmask = Variable((cmask==1).nonzero().squeeze(1))
fgcmask = fgcmask.cuda()
smask = smask.view(-1)
fgsmask = Variable((smask==1).nonzero().squeeze(1))
fgsmask = fgsmask.cuda()
sFF = sF.view(sc,-1)
sFF_select = torch.index_select(sFF,1,fgsmask)
sMean = torch.mean(sFF_select,dim=1,keepdim=True)
sMean = sMean.view(1,sc,1,1)
sMean = sMean.expand_as(cF)
sMatrix = sMatrices[i]
cMatrix = cMatrices[i]
sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
transmatrix = torch.mm(sMatrix,cMatrix) # (C*C)
compress_content_select = torch.index_select(compress_content,1,fgcmask)
transfeatureFG = torch.mm(transmatrix,compress_content_select)
transfeature.index_copy_(1,fgcmask,transfeatureFG)
sMean = sMean.contiguous()
sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask)
finalSMean.index_copy_(1,fgcmask,sMean_select)
out = self.unzip(transfeature.view(cb,cc,ch,cw))
return out + finalSMean.view(out.size())