Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| from torch.autograd import Variable | |
| import torchvision.utils as vutils | |
| class CNN(nn.Module): | |
| def __init__(self,layer,matrixSize=32): | |
| super(CNN,self).__init__() | |
| # 256x64x64 | |
| if(layer == 'r31'): | |
| self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128,64,3,1,1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64,matrixSize,3,1,1)) | |
| elif(layer == 'r41'): | |
| # 512x32x32 | |
| self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256,128,3,1,1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128,matrixSize,3,1,1)) | |
| self.fc = nn.Linear(32*32,32*32) | |
| def forward(self,x,masks,style=False): | |
| color_code_number = 9 | |
| xb,xc,xh,xw = x.size() | |
| x = x.view(xc,-1) | |
| feature_sub_mean = x.clone() | |
| for i in range(color_code_number): | |
| mask = masks[i].clone().squeeze(0) | |
| mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST) | |
| mask = torch.FloatTensor(mask) | |
| mask = mask.long() | |
| if(torch.sum(mask) >= 10): | |
| mask = mask.view(-1) | |
| # dilation here | |
| """ | |
| kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5)) | |
| mask = mask.cpu().numpy() | |
| mask = cv2.dilate(mask.astype(np.float32), kernel) | |
| mask = torch.from_numpy(mask) | |
| mask = mask.squeeze() | |
| """ | |
| fgmask = (mask>0).nonzero().squeeze(1) | |
| fgmask = fgmask.cuda() | |
| selectFeature = torch.index_select(x,1,fgmask) # 32x96 | |
| # subtract mean | |
| f_mean = torch.mean(selectFeature,1) | |
| f_mean = f_mean.unsqueeze(1).expand_as(selectFeature) | |
| selectFeature = selectFeature - f_mean | |
| feature_sub_mean.index_copy_(1,fgmask,selectFeature) | |
| feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw)) | |
| # 32x16x16 | |
| b,c,h,w = feature.size() | |
| transMatrices = {} | |
| feature = feature.view(c,-1) | |
| for i in range(color_code_number): | |
| mask = masks[i].clone().squeeze(0) | |
| mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST) | |
| mask = torch.FloatTensor(mask) | |
| mask = mask.long() | |
| if(torch.sum(mask) >= 10): | |
| mask = mask.view(-1) | |
| fgmask = Variable((mask==1).nonzero().squeeze(1)) | |
| fgmask = fgmask.cuda() | |
| selectFeature = torch.index_select(feature,1,fgmask) # 32x96 | |
| tc,tN = selectFeature.size() | |
| covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN) | |
| transmatrix = self.fc(covMatrix.view(-1)) | |
| transMatrices[i] = transmatrix | |
| return transMatrices,feature_sub_mean | |
| class MulLayer(nn.Module): | |
| def __init__(self,layer,matrixSize=32): | |
| super(MulLayer,self).__init__() | |
| self.snet = CNN(layer) | |
| self.cnet = CNN(layer) | |
| self.matrixSize = matrixSize | |
| if(layer == 'r41'): | |
| self.compress = nn.Conv2d(512,matrixSize,1,1,0) | |
| self.unzip = nn.Conv2d(matrixSize,512,1,1,0) | |
| elif(layer == 'r31'): | |
| self.compress = nn.Conv2d(256,matrixSize,1,1,0) | |
| self.unzip = nn.Conv2d(matrixSize,256,1,1,0) | |
| def forward(self,cF,sF,cmasks,smasks): | |
| sb,sc,sh,sw = sF.size() | |
| sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True) | |
| cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False) | |
| compress_content = self.compress(cF_sub_mean.view(cF.size())) | |
| cb,cc,ch,cw = compress_content.size() | |
| compress_content = compress_content.view(cc,-1) | |
| transfeature = compress_content.clone() | |
| color_code_number = 9 | |
| finalSMean = Variable(torch.zeros(cF.size()).cuda(0)) | |
| finalSMean = finalSMean.view(sc,-1) | |
| for i in range(color_code_number): | |
| cmask = cmasks[i].clone().squeeze(0) | |
| smask = smasks[i].clone().squeeze(0) | |
| cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST) | |
| cmask = torch.FloatTensor(cmask) | |
| cmask = cmask.long() | |
| smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST) | |
| smask = torch.FloatTensor(smask) | |
| smask = smask.long() | |
| if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10 | |
| and (i in sMatrices) and (i in cMatrices)): | |
| cmask = cmask.view(-1) | |
| fgcmask = Variable((cmask==1).nonzero().squeeze(1)) | |
| fgcmask = fgcmask.cuda() | |
| smask = smask.view(-1) | |
| fgsmask = Variable((smask==1).nonzero().squeeze(1)) | |
| fgsmask = fgsmask.cuda() | |
| sFF = sF.view(sc,-1) | |
| sFF_select = torch.index_select(sFF,1,fgsmask) | |
| sMean = torch.mean(sFF_select,dim=1,keepdim=True) | |
| sMean = sMean.view(1,sc,1,1) | |
| sMean = sMean.expand_as(cF) | |
| sMatrix = sMatrices[i] | |
| cMatrix = cMatrices[i] | |
| sMatrix = sMatrix.view(self.matrixSize,self.matrixSize) | |
| cMatrix = cMatrix.view(self.matrixSize,self.matrixSize) | |
| transmatrix = torch.mm(sMatrix,cMatrix) # (C*C) | |
| compress_content_select = torch.index_select(compress_content,1,fgcmask) | |
| transfeatureFG = torch.mm(transmatrix,compress_content_select) | |
| transfeature.index_copy_(1,fgcmask,transfeatureFG) | |
| sMean = sMean.contiguous() | |
| sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask) | |
| finalSMean.index_copy_(1,fgcmask,sMean_select) | |
| out = self.unzip(transfeature.view(cb,cc,ch,cw)) | |
| return out + finalSMean.view(out.size()) | |