ImageEnhancement / models /models_x.py
chenzhicun
初始化web demo.
ec08fea
from doctest import OutputChecker
from turtle import forward
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch
import numpy as np
import math
from models.trilinear_test import bing_lut_trilinearInterplt,Tritri
from re import I
import time
from PIL import Image
###########################################
# use this module for pytorch 1.x,together with trilinear_cpp
###########################################
def weights_init_normal_classifier(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class resnet18_224(nn.Module):
def __init__(self, out_dim=5, aug_test=False):
super(resnet18_224, self).__init__()
self.aug_test = aug_test
net = models.resnet18(pretrained=True)
# self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
# self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
self.upsample = nn.Upsample(size=(224,224),mode='bilinear')
net.fc = nn.Linear(512, out_dim)
self.model = net
def forward(self, x):
x = self.upsample(x)
if self.aug_test:
# x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0)
x = torch.cat((x, torch.flip(x, [3])), 0)
f = self.model(x)
return f
##############################
# Discriminator
##############################
def discriminator_block(in_filters, out_filters, normalization=False):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
layers.append(nn.LeakyReLU(0.2))
if normalization:
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
#layers.append(nn.BatchNorm2d(out_filters))
return layers
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Upsample(size=(256,256),mode='bilinear'),
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
*discriminator_block(128, 128),
#*discriminator_block(128, 128),
nn.Conv2d(128, 1, 8, padding=0)
)
def forward(self, img_input):
return self.model(img_input)
class Classifier(nn.Module):
def __init__(self, in_channels=3):
super(Classifier, self).__init__()
self.model = nn.Sequential(
# nn.Downsample(size=(256,256),mode='bilinear'),
nn.Upsample(size=(256,256),mode='bilinear'), #original
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
*discriminator_block(16, 32, normalization=True),
*discriminator_block(32, 64, normalization=True),
*discriminator_block(64, 128, normalization=True),
*discriminator_block(128, 128),
#*discriminator_block(128, 128, normalization=True),
nn.Dropout(p=0.5),
nn.Conv2d(128, 3, 8, padding=0),
)
def forward(self, img_input):
return self.model(img_input)
class Classifier_unpaired(nn.Module):
def __init__(self, in_channels=3):
super(Classifier_unpaired, self).__init__()
self.model = nn.Sequential(
nn.Upsample(size=(256,256),mode='bilinear'),
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
*discriminator_block(128, 128),
#*discriminator_block(128, 128),
nn.Conv2d(128, 3, 8, padding=0),
)
def forward(self, img_input):
return self.model(img_input)
class Generator3DLUT_identity(nn.Module):
def __init__(self, dim=33):
super(Generator3DLUT_identity, self).__init__()
if dim == 33:
file = open("IdentityLUT33.txt", 'r')
elif dim == 64:
file = open("IdentityLUT64.txt", 'r')
lines = file.readlines()
buffer = np.zeros((3,dim,dim,dim), dtype=np.float32)
for i in range(0,dim):
for j in range(0,dim):
for k in range(0,dim):
n = i * dim*dim + j * dim + k
x = lines[n].split()
buffer[0,i,j,k] = float(x[0])
buffer[1,i,j,k] = float(x[1])
buffer[2,i,j,k] = float(x[2])
self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
self.TrilinearInterpolation = Tritri()
# self.trilinearItp = bing_lut_trilinearInterplt()
def forward(self, x):
_, output = self.TrilinearInterpolation(self.LUT, x)
# output = self.trilinearItp(self.LUT,x)
#self.LUT, output = self.TrilinearInterpolation(self.LUT, x)
return output
class Generator3DLUT_zero(nn.Module):
def __init__(self, dim=33):
super(Generator3DLUT_zero, self).__init__()
self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
self.LUT = nn.Parameter(torch.tensor(self.LUT))
self.TrilinearInterpolation = Tritri()
# self.trilinearItp = bing_lut_trilinearInterplt()
def forward(self, x):
_, output = self.TrilinearInterpolation(self.LUT, x)
# output = self.trilinearItp(self.LUT,x)
return output
class LUT_all(nn.Module):
def __init__(self,
path_LUT="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/LUTs_399.pth",
path_classifier="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/classifier_399.pth") -> None:
super(LUT_all,self).__init__()
self.classifier=Classifier()
self.classifier.load_state_dict(torch.load(path_classifier))
self.LUT0 = Generator3DLUT_identity()
self.LUT1 = Generator3DLUT_zero()
self.LUT2 = Generator3DLUT_zero()
LUTs = torch.load(path_LUT)
self.LUT0.load_state_dict(LUTs["0"])
self.LUT1.load_state_dict(LUTs["1"])
self.LUT2.load_state_dict(LUTs["2"])
# self.trilinear_ = TrilinearInterpolation()
# self.trilinear_ = bing_lut_trilinearInterplt()
self.trilinear_=Tritri()
def forward(self,img):
pred = self.classifier(img).squeeze()
# #numpy squeeze方法去掉矩阵中维度为1的维度,返回np.ndarray
# LUT = pred[0] * self.LUT0.LUT
LUT = pred[0] * self.LUT0.LUT + pred[1] * self.LUT1.LUT + pred[2] * self.LUT2.LUT
output = self.trilinear_(LUT, img)
# _,output = self.trilinear_(LUT, img)
return output
# return LUT
# class TrilinearInterpolationFunction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, lut, x):
# x = x.contiguous()
# output = x.new(x.size())
# dim = lut.size()[-1]
# shift = dim ** 3
# binsize = 1.000001 / (dim-1)
# W = x.size(2)
# H = x.size(3)
# batch = x.size(0)
# #trilinear这个包是作者自己实现的
# assert 1 == trilinear.forward(lut,
# x,
# output,
# dim,
# shift,
# binsize,
# W,
# H,
# batch)
# int_package = torch.IntTensor([dim, shift, W, H, batch])
# float_package = torch.FloatTensor([binsize])
# variables = [lut, x, int_package, float_package]
# ctx.save_for_backward(*variables)
# return lut, output
# @staticmethod
# def backward(ctx, lut_grad, x_grad):
# lut, x, int_package, float_package = ctx.saved_variables
# dim, shift, W, H, batch = int_package
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
# binsize = float(float_package[0])
# assert 1 == trilinear.backward(x,
# x_grad,
# lut_grad,
# dim,
# shift,
# binsize,
# W,
# H,
# batch)
# return lut_grad, x_grad
# class TrilinearInterpolation(torch.nn.Module):
# def __init__(self):
# super(TrilinearInterpolation, self).__init__()
# def forward(self, lut, x):
# return TrilinearInterpolationFunction.apply(lut, x)
class TV_3D(nn.Module):
def __init__(self, dim=33):
super(TV_3D,self).__init__()
self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
self.weight_r[:,:,:,(0,dim-2)] *= 2.0
self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
self.weight_g[:,:,(0,dim-2),:] *= 2.0
self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
self.weight_b[:,(0,dim-2),:,:] *= 2.0
self.relu = torch.nn.ReLU()
def forward(self, LUT):
dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
return tv, mn
##new by bing##
if __name__=='__main__':
def img_process_256(img):
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
img=img.resize((256,256))
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
img = trans(img)
img = torch.unsqueeze(img,0) # 填充一维
print("img",img.size())
# # 将其由HWC格式改成NCHW格式,N=1
# img=np.array(img)
return img
def img_process_4k(img):
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256])
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
img = trans(img)
img = torch.unsqueeze(img,0) # 填充一维
print("img",img.size())
# # 将其由HWC格式改成NCHW格式,N=1
# img=np.array(img)
return img
img_ori=Image.open("/home/elle/bing/proj/code/download-4k-img/picture/%s" % ("X4_Animal2_BIC_g_03.png"))
img=img_process_256(img_ori)
img_4k=img_process_4k(img_ori)
model=LUT_all()
out=model(img_4k)
print(out)