stylematte / test.py
shir0ha's picture
Duplicate from befozg/stylematte
cff9442
#modified from Github repo: https://github.com/JizhiziLi/P3M
#added inference code for other networks
import torch
import cv2
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
from skimage.transform import resize
from torchvision import transforms,models
import os
from models import *
import torch.nn.functional as F
import torch
import torch.nn as nn
import math
from torch.autograd import Variable
import torch.nn.functional as fnn
import glob
import tqdm
from torch.autograd import Variable
from typing import Type, Any, Callable, Union, List, Optional
import logging
import time
from omegaconf import OmegaConf
config = OmegaConf.load("base.yaml")
device = "cpu"
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class TFI(nn.Module):
expansion = 1
def __init__(self, planes,stride=1):
super(TFI, self).__init__()
middle_planes = int(planes/2)
self.transform = conv1x1(planes, middle_planes)
self.conv1 = conv3x3(middle_planes*3, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
def forward(self, input_s_guidance, input_m_decoder, input_m_encoder):
input_s_guidance_transform = self.transform(input_s_guidance)
input_m_decoder_transform = self.transform(input_m_decoder)
input_m_encoder_transform = self.transform(input_m_encoder)
x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
return out
class SBFI(nn.Module):
def __init__(self, planes,stride=1):
super(SBFI, self).__init__()
self.stride = stride
self.transform1 = conv1x1(planes, int(planes/2))
self.transform2 = conv1x1(64, int(planes/2))
self.maxpool = nn.MaxPool2d(2, stride=stride)
self.conv1 = conv3x3(planes, planes, 1)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, input_m_decoder,e0):
input_m_decoder_transform = self.transform1(input_m_decoder)
e0_maxpool = self.maxpool(e0)
e0_transform = self.transform2(e0_maxpool)
x = torch.cat((input_m_decoder_transform,e0_transform),1)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = out+input_m_decoder
return out
class DBFI(nn.Module):
def __init__(self, planes,stride=1):
super(DBFI, self).__init__()
self.stride = stride
self.transform1 = conv1x1(planes, int(planes/2))
self.transform2 = conv1x1(512, int(planes/2))
self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
self.conv1 = conv3x3(planes, planes, 1)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, 3, 1)
self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear')
def forward(self, input_s_decoder,e4):
input_s_decoder_transform = self.transform1(input_s_decoder)
e4_transform = self.transform2(e4)
e4_upsample = self.upsample(e4_transform)
x = torch.cat((input_s_decoder_transform,e4_upsample),1)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = out+input_s_decoder
out_side = self.conv2(out)
out_side = self.upsample2(out_side)
return out, out_side
class P3mNet(nn.Module):
def __init__(self):
super().__init__()
self.resnet = resnet34_mp()
############################
### Encoder part - RESNETMP
############################
self.encoder0 = nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
)
self.mp0 = self.resnet.maxpool1
self.encoder1 = nn.Sequential(
self.resnet.layer1)
self.mp1 = self.resnet.maxpool2
self.encoder2 = self.resnet.layer2
self.mp2 = self.resnet.maxpool3
self.encoder3 = self.resnet.layer3
self.mp3 = self.resnet.maxpool4
self.encoder4 = self.resnet.layer4
self.mp4 = self.resnet.maxpool5
self.tfi_3 = TFI(256)
self.tfi_2 = TFI(128)
self.tfi_1 = TFI(64)
self.tfi_0 = TFI(64)
self.sbfi_2 = SBFI(128, 8)
self.sbfi_1 = SBFI(64, 4)
self.sbfi_0 = SBFI(64, 2)
self.dbfi_2 = DBFI(128, 4)
self.dbfi_1 = DBFI(64, 8)
self.dbfi_0 = DBFI(64, 16)
##########################
### Decoder part - GLOBAL
##########################
self.decoder4_g = nn.Sequential(
nn.Conv2d(512,512,3,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,512,3,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear') )
self.decoder3_g = nn.Sequential(
nn.Conv2d(256,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear') )
self.decoder2_g = nn.Sequential(
nn.Conv2d(128,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear'))
self.decoder1_g = nn.Sequential(
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear'))
self.decoder0_g = nn.Sequential(
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,3,3,padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'))
##########################
### Decoder part - LOCAL
##########################
self.decoder4_l = nn.Sequential(
nn.Conv2d(512,512,3,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,512,3,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
self.decoder3_l = nn.Sequential(
nn.Conv2d(256,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,256,3,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True))
self.decoder2_l = nn.Sequential(
nn.Conv2d(128,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,128,3,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.decoder1_l = nn.Sequential(
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.decoder0_l = nn.Sequential(
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64,64,3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.decoder_final_l = nn.Conv2d(64,1,3,padding=1)
def forward(self, input):
##########################
### Encoder part - RESNET
##########################
e0 = self.encoder0(input)
e0p, id0 = self.mp0(e0)
e1p, id1 = self.mp1(e0p)
e1 = self.encoder1(e1p)
e2p, id2 = self.mp2(e1)
e2 = self.encoder2(e2p)
e3p, id3 = self.mp3(e2)
e3 = self.encoder3(e3p)
e4p, id4 = self.mp4(e3)
e4 = self.encoder4(e4p)
###########################
### Decoder part - Global
###########################
d4_g = self.decoder4_g(e4)
d3_g = self.decoder3_g(d4_g)
d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4)
d2_g = self.decoder2_g(d2_g)
d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4)
d1_g = self.decoder1_g(d1_g)
d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4)
d0_g = self.decoder0_g(d0_g)
global_sigmoid = d0_g
###########################
### Decoder part - Local
###########################
d4_l = self.decoder4_l(e4)
d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2)
d3_l = self.tfi_3(d4_g, d4_l, e3)
d3_l = self.decoder3_l(d3_l)
d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2)
d2_l = self.tfi_2(d3_g, d3_l, e2)
d2_l = self.sbfi_2(d2_l, e0)
d2_l = self.decoder2_l(d2_l)
d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2)
d1_l = self.tfi_1(d2_g, d2_l, e1)
d1_l = self.sbfi_1(d1_l, e0)
d1_l = self.decoder1_l(d1_l)
d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2)
d0_l = self.tfi_0(d1_g, d1_l, e0p)
d0_l = self.sbfi_0(d0_l, e0)
d0_l = self.decoder0_l(d0_l)
d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2)
d0_l = self.decoder_final_l(d0_l)
local_sigmoid = F.sigmoid(d0_l)
##########################
### Fusion net - G/L
##########################
fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.attention(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
#pdb.set_trace()
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, 1000)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes,stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes,groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu(x1)
x1, idx1 = self.maxpool1(x1)
x2, idx2 = self.maxpool2(x1)
x2 = self.layer1(x2)
x3, idx3 = self.maxpool3(x2)
x3 = self.layer2(x3)
x4, idx4 = self.maxpool4(x3)
x4 = self.layer3(x4)
x5, idx5 = self.maxpool5(x4)
x5 = self.layer4(x5)
x_cls = self.avgpool(x5)
x_cls = torch.flatten(x_cls, 1)
x_cls = self.fc(x_cls)
return x_cls
def forward(self, x):
return self._forward_impl(x)
def resnet34_mp(**kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
checkpoint = torch.load("checkpoints/r34mp_pretrained_imagenet.pth.tar")
model.load_state_dict(checkpoint)
return model
##############################
### Training loses for P3M-NET
##############################
def get_crossentropy_loss(gt,pre):
gt_copy = gt.clone()
gt_copy[gt_copy==0] = 0
gt_copy[gt_copy==255] = 2
gt_copy[gt_copy>2] = 1
gt_copy = gt_copy.long()
gt_copy = gt_copy[:,0,:,:]
criterion = nn.CrossEntropyLoss()
entropy_loss = criterion(pre, gt_copy)
return entropy_loss
def get_alpha_loss(predict, alpha, trimap):
weighted = torch.zeros(trimap.shape).to(device)
weighted[trimap == 128] = 1.
alpha_f = alpha / 255.
alpha_f = alpha_f.to(device)
diff = predict - alpha_f
diff = diff * weighted
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.)
return alpha_loss_weighted
def get_alpha_loss_whole_img(predict, alpha):
weighted = torch.ones(alpha.shape).to(device)
alpha_f = alpha / 255.
alpha_f = alpha_f.to(device)
diff = predict - alpha_f
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
alpha_loss = alpha_loss.sum()/(weighted.sum())
return alpha_loss
## Laplacian loss is refer to
## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
if size % 2 != 1:
raise ValueError("kernel size must be uneven")
grid = np.float32(np.mgrid[0:size,0:size].T)
gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
kernel = np.sum(gaussian(grid), axis=2)
kernel /= np.sum(kernel)
kernel = np.tile(kernel, (n_channels, 1, 1))
kernel = torch.FloatTensor(kernel[:, None, :, :]).to(device)
return Variable(kernel, requires_grad=False)
def conv_gauss(img, kernel):
""" convolve img with a gaussian kernel that has been built with build_gauss_kernel """
n_channels, _, kw, kh = kernel.shape
img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return fnn.conv2d(img, kernel, groups=n_channels)
def laplacian_pyramid(img, kernel, max_levels=5):
current = img
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
diff = current - filtered
pyr.append(diff)
current = fnn.avg_pool2d(filtered, 2)
pyr.append(current)
return pyr
def get_laplacian_loss(predict, alpha, trimap):
weighted = torch.zeros(trimap.shape).to(device)
weighted[trimap == 128] = 1.
alpha_f = alpha / 255.
alpha_f = alpha_f.to(device)
alpha_f = alpha_f.clone()*weighted
predict = predict.clone()*weighted
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
return laplacian_loss_weighted
def get_laplacian_loss_whole_img(predict, alpha):
alpha_f = alpha / 255.
alpha_f = alpha_f.to(device)
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
return laplacian_loss
def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
weighted = torch.ones(alpha.shape).to(device)
predict_3 = torch.cat((predict, predict, predict), 1)
comp = predict_3 * fg + (1. - predict_3) * bg
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
comp_loss = comp_loss.sum()/(weighted.sum())
return comp_loss
##############################
### Test loss for matting
##############################
def calculate_sad_mse_mad(predict_old,alpha,trimap):
predict = np.copy(predict_old)
pixel = float((trimap == 128).sum())
predict[trimap == 255] = 1.
predict[trimap == 0 ] = 0.
sad_diff = np.sum(np.abs(predict - alpha))/1000
if pixel==0:
pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
mse_diff = np.sum((predict - alpha) ** 2)/pixel
mad_diff = np.sum(np.abs(predict - alpha))/pixel
return sad_diff, mse_diff, mad_diff
def calculate_sad_mse_mad_whole_img(predict, alpha):
pixel = predict.shape[0]*predict.shape[1]
sad_diff = np.sum(np.abs(predict - alpha))/1000
mse_diff = np.sum((predict - alpha) ** 2)/pixel
mad_diff = np.sum(np.abs(predict - alpha))/pixel
return sad_diff, mse_diff, mad_diff
def calculate_sad_fgbg(predict, alpha, trimap):
sad_diff = np.abs(predict-alpha)
weight_fg = np.zeros(predict.shape)
weight_bg = np.zeros(predict.shape)
weight_trimap = np.zeros(predict.shape)
weight_fg[trimap==255] = 1.
weight_bg[trimap==0 ] = 1.
weight_trimap[trimap==128 ] = 1.
sad_fg = np.sum(sad_diff*weight_fg)/1000
sad_bg = np.sum(sad_diff*weight_bg)/1000
sad_trimap = np.sum(sad_diff*weight_trimap)/1000
return sad_fg, sad_bg
def compute_gradient_whole_image(pd, gt):
from scipy.ndimage import gaussian_filter
pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
pd_mag = np.sqrt(pd_x**2 + pd_y**2)
gt_mag = np.sqrt(gt_x**2 + gt_y**2)
error_map = np.square(pd_mag - gt_mag)
loss = np.sum(error_map) / 10
return loss
def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
from scipy.ndimage import morphology
from skimage.measure import label, regionprops
h, w = pd.shape
thresh_steps = np.arange(0, 1.1, step)
l_map = -1 * np.ones((h, w), dtype=np.float32)
lambda_map = np.ones((h, w), dtype=np.float32)
for i in range(1, thresh_steps.size):
pd_th = pd >= thresh_steps[i]
gt_th = gt >= thresh_steps[i]
label_image = label(pd_th & gt_th, connectivity=1)
cc = regionprops(label_image)
size_vec = np.array([c.area for c in cc])
if len(size_vec) == 0:
continue
max_id = np.argmax(size_vec)
coords = cc[max_id].coords
omega = np.zeros((h, w), dtype=np.float32)
omega[coords[:, 0], coords[:, 1]] = 1
flag = (l_map == -1) & (omega == 0)
l_map[flag == 1] = thresh_steps[i-1]
dist_maps = morphology.distance_transform_edt(omega==0)
dist_maps = dist_maps / dist_maps.max()
l_map[l_map == -1] = 1
d_pd = pd - l_map
d_gt = gt - l_map
phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
return loss
def gen_trimap_from_segmap_e2e(segmap):
trimap = np.argmax(segmap, axis=1)[0]
trimap = trimap.astype(np.int64)
trimap[trimap==1]=128
trimap[trimap==2]=255
return trimap.astype(np.uint8)
def get_masked_local_from_global(global_sigmoid, local_sigmoid):
values, index = torch.max(global_sigmoid,1)
index = index[:,None,:,:].float()
### index <===> [0, 1, 2]
### bg_mask <===> [1, 0, 0]
bg_mask = index.clone()
bg_mask[bg_mask==2]=1
bg_mask = 1- bg_mask
### trimap_mask <===> [0, 1, 0]
trimap_mask = index.clone()
trimap_mask[trimap_mask==2]=0
### fg_mask <===> [0, 0, 1]
fg_mask = index.clone()
fg_mask[fg_mask==1]=0
fg_mask[fg_mask==2]=1
fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask
return fusion_sigmoid
def get_masked_local_from_global_test(global_result, local_result):
weighted_global = np.ones(global_result.shape)
weighted_global[global_result==255] = 0
weighted_global[global_result==0] = 0
fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global
return fusion_result
def inference_once( model, scale_img, scale_trimap=None):
pred_list = []
tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).to(device)
input_t = tensor_img
input_t = input_t/255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(input_t)
input_t = input_t.unsqueeze(0).float()
# pred_global, pred_local, pred_fusion = model(input_t)[:3]
pred_fusion = model(input_t)[:3]
pred_global = pred_fusion
pred_local = pred_fusion
pred_global = pred_global.data.cpu().numpy()
pred_global = gen_trimap_from_segmap_e2e(pred_global)
pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
return pred_global, pred_local, pred_fusion
# def inference_img( test_choice,model, img):
# h, w, c = img.shape
# new_h = min(config['datasets'].MAX_SIZE_H, h - (h % 32))
# new_w = min(config['datasets'].MAX_SIZE_W, w - (w % 32))
# if test_choice=='HYBRID':
# global_ratio = 1/2
# local_ratio = 1
# resize_h = int(h*global_ratio)
# resize_w = int(w*global_ratio)
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
# scale_img = resize(img,(new_h,new_w))*255.0
# pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once( model, scale_img)
# pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
# resize_h = int(h*local_ratio)
# resize_w = int(w*local_ratio)
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
# scale_img = resize(img,(new_h,new_w))*255.0
# pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once( model, scale_img)
# pred_retouching_2 = resize(pred_retouching_2,(h,w))
# pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
# return pred_fusion
# else:
# resize_h = int(h/2)
# resize_w = int(w/2)
# new_h = min(config['datasets'].MAX_SIZE_H, resize_h - (resize_h % 32))
# new_w = min(config['datasets'].MAX_SIZE_W, resize_w - (resize_w % 32))
# scale_img = resize(img,(new_h,new_w))*255.0
# pred_global, pred_local, pred_fusion = inference_once( model, scale_img)
# pred_local = resize(pred_local,(h,w))
# pred_global = resize(pred_global,(h,w))*255.0
# pred_fusion = resize(pred_fusion,(h,w))
# return pred_fusion
def inference_img(model, img):
h,w,_ = img.shape
# print(img.shape)
if h%8!=0 or w%8!=0:
img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
# print(img.shape)
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
input_t = tensor_img
input_t = input_t/255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(input_t)
input_t = input_t.unsqueeze(0).float()
with torch.no_grad():
out=model(input_t)
# print("out",out.shape)
result = out[0][:,-h:,-w:].cpu().numpy()
# print(result.shape)
return result[0]
def test_am2k(model):
############################
# Some initial setting for paths
############################
ORIGINAL_PATH = config['datasets']['am2k']['validation_original']
MASK_PATH = config['datasets']['am2k']['validation_mask']
TRIMAP_PATH = config['datasets']['am2k']['validation_trimap']
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
############################
# Start testing
############################
sad_diffs = 0.
mse_diffs = 0.
mad_diffs = 0.
grad_diffs = 0.
conn_diffs = 0.
sad_trimap_diffs = 0.
mse_trimap_diffs = 0.
mad_trimap_diffs = 0.
sad_fg_diffs = 0.
sad_bg_diffs = 0.
total_number = len(img_paths)
log("===============================")
log(f'====> Start Testing\n\t--Dataset: AM2k\n\t-\n\t--Number: {total_number}')
for img_path in tqdm.tqdm(img_paths):
img_name=(img_path.split("/")[-1])[:-4]
alpha_path = MASK_PATH+img_name+'.png'
trimap_path = TRIMAP_PATH+img_name+'.png'
pil_img = Image.open(img_path)
img = np.array(pil_img)
trimap = np.array(Image.open(trimap_path))
alpha = np.array(Image.open(alpha_path))/255.
img = img[:,:,:3] if img.ndim>2 else img
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
with torch.no_grad():
# torch.cuda.empty_cache()
predict = inference_img( model, img)
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
grad_diff = compute_gradient_whole_image(predict, alpha)
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
sad_diffs += sad_diff
mse_diffs += mse_diff
mad_diffs += mad_diff
mse_trimap_diffs += mse_trimap_diff
sad_trimap_diffs += sad_trimap_diff
mad_trimap_diffs += mad_trimap_diff
sad_fg_diffs += sad_fg_diff
sad_bg_diffs += sad_bg_diff
conn_diffs += conn_diff
grad_diffs += grad_diff
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
log("===============================")
log(f"Testing numbers: {total_number}")
log("SAD: {}".format(sad_diffs / total_number))
log("MSE: {}".format(mse_diffs / total_number))
log("MAD: {}".format(mad_diffs / total_number))
log("GRAD: {}".format(grad_diffs / total_number))
log("CONN: {}".format(conn_diffs / total_number))
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
log("SAD FG: {}".format(sad_fg_diffs / total_number))
log("SAD BG: {}".format(sad_bg_diffs / total_number))
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
def test_p3m10k(model,dataset_choice, max_image=-1):
############################
# Some initial setting for paths
############################
if dataset_choice == 'P3M_500_P':
val_option = 'VAL500P'
else:
val_option = 'VAL500NP'
ORIGINAL_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['ORIGINAL_PATH']
MASK_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['MASK_PATH']
TRIMAP_PATH = config['datasets']['p3m10k']+"/validation/"+config['datasets']['p3m10k_test'][val_option]['TRIMAP_PATH']
############################
# Start testing
############################
sad_diffs = 0.
mse_diffs = 0.
mad_diffs = 0.
sad_trimap_diffs = 0.
mse_trimap_diffs = 0.
mad_trimap_diffs = 0.
sad_fg_diffs = 0.
sad_bg_diffs = 0.
conn_diffs = 0.
grad_diffs = 0.
model.eval()
img_paths = glob.glob(ORIGINAL_PATH+"/*.jpg")
if (max_image>1):
img_paths = img_paths[:max_image]
total_number = len(img_paths)
log("===============================")
log(f'====> Start Testing\n\t----Test: {dataset_choice}\n\t--Number: {total_number}')
for img_path in tqdm.tqdm(img_paths):
img_name=(img_path.split("/")[-1])[:-4]
alpha_path = MASK_PATH+img_name+'.png'
trimap_path = TRIMAP_PATH+img_name+'.png'
pil_img = Image.open(img_path)
img = np.array(pil_img)
trimap = np.array(Image.open(trimap_path))
alpha = np.array(Image.open(alpha_path))/255.
img = img[:,:,:3] if img.ndim>2 else img
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
with torch.no_grad():
# torch.cuda.empty_cache()
start = time.time()
predict = inference_img( model, img) #HYBRID show less accuracy
# tensorimg=transforms.ToTensor()(pil_img)
# input_img=transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensorimg)
# predict = model(input_img.unsqueeze(0).to(device))[0][0].detach().cpu().numpy()
# if predict.shape!=(pil_img.height,pil_img.width):
# print("resize for ",img_path)
# predict = resize(predict,(pil_img.height,pil_img.width))
sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
grad_diff = compute_gradient_whole_image(predict, alpha)
log(f"[{img_paths.index(img_path)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\n-----------")
sad_diffs += sad_diff
mse_diffs += mse_diff
mad_diffs += mad_diff
mse_trimap_diffs += mse_trimap_diff
sad_trimap_diffs += sad_trimap_diff
mad_trimap_diffs += mad_trimap_diff
sad_fg_diffs += sad_fg_diff
sad_bg_diffs += sad_bg_diff
conn_diffs += conn_diff
grad_diffs += grad_diff
Image.fromarray(np.uint8(predict*255)).save(f"test/{img_name}.png")
log("===============================")
log(f"Testing numbers: {total_number}")
log("SAD: {}".format(sad_diffs / total_number))
log("MSE: {}".format(mse_diffs / total_number))
log("MAD: {}".format(mad_diffs / total_number))
log("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
log("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
log("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
log("SAD FG: {}".format(sad_fg_diffs / total_number))
log("SAD BG: {}".format(sad_bg_diffs / total_number))
log("CONN: {}".format(conn_diffs / total_number))
log("GRAD: {}".format(grad_diffs / total_number))
return sad_diffs/total_number,mse_diffs/total_number,grad_diffs/total_number
def log(str):
print(str)
logging.info(str)
if __name__ == '__main__':
print('*********************************')
config = OmegaConf.load("base.yaml")
config=OmegaConf.merge(config,OmegaConf.from_cli())
print(config)
model = MaskForm()
model = model.to(device)
checkpoint = f"{config.checkpoint_dir}/{config.checkpoint}"
state_dict = torch.load(checkpoint, map_location=f'{device}')
print("loaded",checkpoint)
model.load_state_dict(state_dict)
model.eval()
logging.basicConfig(filename=f'report/{config.checkpoint.replace("/","--")}.report', encoding='utf-8',filemode='w', level=logging.INFO)
# ckpt = torch.load("checkpoints/p3mnet_pretrained_on_p3m10k.pth")
# model.load_state_dict(ckpt['state_dict'], strict=True)
# model = model.cuda()
if config.dataset_to_use =="AM2K":
test_am2k(model)
else:
for dataset_choice in ['P3M_500_P','P3M_500_NP']:
test_p3m10k(model,dataset_choice)