Spaces:
Runtime error
Runtime error
from doctest import OutputChecker | |
from turtle import forward | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from torch.autograd import Variable | |
import torch | |
import numpy as np | |
import math | |
from models.trilinear_test import bing_lut_trilinearInterplt,Tritri | |
from re import I | |
import time | |
from PIL import Image | |
########################################### | |
# use this module for pytorch 1.x,together with trilinear_cpp | |
########################################### | |
def weights_init_normal_classifier(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
torch.nn.init.xavier_normal_(m.weight.data) | |
elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1: | |
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | |
torch.nn.init.constant_(m.bias.data, 0.0) | |
class resnet18_224(nn.Module): | |
def __init__(self, out_dim=5, aug_test=False): | |
super(resnet18_224, self).__init__() | |
self.aug_test = aug_test | |
net = models.resnet18(pretrained=True) | |
# self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() | |
# self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() | |
self.upsample = nn.Upsample(size=(224,224),mode='bilinear') | |
net.fc = nn.Linear(512, out_dim) | |
self.model = net | |
def forward(self, x): | |
x = self.upsample(x) | |
if self.aug_test: | |
# x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0) | |
x = torch.cat((x, torch.flip(x, [3])), 0) | |
f = self.model(x) | |
return f | |
############################## | |
# Discriminator | |
############################## | |
def discriminator_block(in_filters, out_filters, normalization=False): | |
"""Returns downsampling layers of each discriminator block""" | |
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)] | |
layers.append(nn.LeakyReLU(0.2)) | |
if normalization: | |
layers.append(nn.InstanceNorm2d(out_filters, affine=True)) | |
#layers.append(nn.BatchNorm2d(out_filters)) | |
return layers | |
class Discriminator(nn.Module): | |
def __init__(self, in_channels=3): | |
super(Discriminator, self).__init__() | |
self.model = nn.Sequential( | |
nn.Upsample(size=(256,256),mode='bilinear'), | |
nn.Conv2d(3, 16, 3, stride=2, padding=1), | |
nn.LeakyReLU(0.2), | |
nn.InstanceNorm2d(16, affine=True), | |
*discriminator_block(16, 32), | |
*discriminator_block(32, 64), | |
*discriminator_block(64, 128), | |
*discriminator_block(128, 128), | |
#*discriminator_block(128, 128), | |
nn.Conv2d(128, 1, 8, padding=0) | |
) | |
def forward(self, img_input): | |
return self.model(img_input) | |
class Classifier(nn.Module): | |
def __init__(self, in_channels=3): | |
super(Classifier, self).__init__() | |
self.model = nn.Sequential( | |
# nn.Downsample(size=(256,256),mode='bilinear'), | |
nn.Upsample(size=(256,256),mode='bilinear'), #original | |
nn.Conv2d(3, 16, 3, stride=2, padding=1), | |
nn.LeakyReLU(0.2), | |
nn.InstanceNorm2d(16, affine=True), | |
*discriminator_block(16, 32, normalization=True), | |
*discriminator_block(32, 64, normalization=True), | |
*discriminator_block(64, 128, normalization=True), | |
*discriminator_block(128, 128), | |
#*discriminator_block(128, 128, normalization=True), | |
nn.Dropout(p=0.5), | |
nn.Conv2d(128, 3, 8, padding=0), | |
) | |
def forward(self, img_input): | |
return self.model(img_input) | |
class Classifier_unpaired(nn.Module): | |
def __init__(self, in_channels=3): | |
super(Classifier_unpaired, self).__init__() | |
self.model = nn.Sequential( | |
nn.Upsample(size=(256,256),mode='bilinear'), | |
nn.Conv2d(3, 16, 3, stride=2, padding=1), | |
nn.LeakyReLU(0.2), | |
nn.InstanceNorm2d(16, affine=True), | |
*discriminator_block(16, 32), | |
*discriminator_block(32, 64), | |
*discriminator_block(64, 128), | |
*discriminator_block(128, 128), | |
#*discriminator_block(128, 128), | |
nn.Conv2d(128, 3, 8, padding=0), | |
) | |
def forward(self, img_input): | |
return self.model(img_input) | |
class Generator3DLUT_identity(nn.Module): | |
def __init__(self, dim=33): | |
super(Generator3DLUT_identity, self).__init__() | |
if dim == 33: | |
file = open("IdentityLUT33.txt", 'r') | |
elif dim == 64: | |
file = open("IdentityLUT64.txt", 'r') | |
lines = file.readlines() | |
buffer = np.zeros((3,dim,dim,dim), dtype=np.float32) | |
for i in range(0,dim): | |
for j in range(0,dim): | |
for k in range(0,dim): | |
n = i * dim*dim + j * dim + k | |
x = lines[n].split() | |
buffer[0,i,j,k] = float(x[0]) | |
buffer[1,i,j,k] = float(x[1]) | |
buffer[2,i,j,k] = float(x[2]) | |
self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True)) | |
self.TrilinearInterpolation = Tritri() | |
# self.trilinearItp = bing_lut_trilinearInterplt() | |
def forward(self, x): | |
_, output = self.TrilinearInterpolation(self.LUT, x) | |
# output = self.trilinearItp(self.LUT,x) | |
#self.LUT, output = self.TrilinearInterpolation(self.LUT, x) | |
return output | |
class Generator3DLUT_zero(nn.Module): | |
def __init__(self, dim=33): | |
super(Generator3DLUT_zero, self).__init__() | |
self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float) | |
self.LUT = nn.Parameter(torch.tensor(self.LUT)) | |
self.TrilinearInterpolation = Tritri() | |
# self.trilinearItp = bing_lut_trilinearInterplt() | |
def forward(self, x): | |
_, output = self.TrilinearInterpolation(self.LUT, x) | |
# output = self.trilinearItp(self.LUT,x) | |
return output | |
class LUT_all(nn.Module): | |
def __init__(self, | |
path_LUT="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/LUTs_399.pth", | |
path_classifier="saved_models/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB/classifier_399.pth") -> None: | |
super(LUT_all,self).__init__() | |
self.classifier=Classifier() | |
self.classifier.load_state_dict(torch.load(path_classifier)) | |
self.LUT0 = Generator3DLUT_identity() | |
self.LUT1 = Generator3DLUT_zero() | |
self.LUT2 = Generator3DLUT_zero() | |
LUTs = torch.load(path_LUT) | |
self.LUT0.load_state_dict(LUTs["0"]) | |
self.LUT1.load_state_dict(LUTs["1"]) | |
self.LUT2.load_state_dict(LUTs["2"]) | |
# self.trilinear_ = TrilinearInterpolation() | |
# self.trilinear_ = bing_lut_trilinearInterplt() | |
self.trilinear_=Tritri() | |
def forward(self,img): | |
pred = self.classifier(img).squeeze() | |
# #numpy squeeze方法去掉矩阵中维度为1的维度,返回np.ndarray | |
# LUT = pred[0] * self.LUT0.LUT | |
LUT = pred[0] * self.LUT0.LUT + pred[1] * self.LUT1.LUT + pred[2] * self.LUT2.LUT | |
output = self.trilinear_(LUT, img) | |
# _,output = self.trilinear_(LUT, img) | |
return output | |
# return LUT | |
# class TrilinearInterpolationFunction(torch.autograd.Function): | |
# @staticmethod | |
# def forward(ctx, lut, x): | |
# x = x.contiguous() | |
# output = x.new(x.size()) | |
# dim = lut.size()[-1] | |
# shift = dim ** 3 | |
# binsize = 1.000001 / (dim-1) | |
# W = x.size(2) | |
# H = x.size(3) | |
# batch = x.size(0) | |
# #trilinear这个包是作者自己实现的 | |
# assert 1 == trilinear.forward(lut, | |
# x, | |
# output, | |
# dim, | |
# shift, | |
# binsize, | |
# W, | |
# H, | |
# batch) | |
# int_package = torch.IntTensor([dim, shift, W, H, batch]) | |
# float_package = torch.FloatTensor([binsize]) | |
# variables = [lut, x, int_package, float_package] | |
# ctx.save_for_backward(*variables) | |
# return lut, output | |
# @staticmethod | |
# def backward(ctx, lut_grad, x_grad): | |
# lut, x, int_package, float_package = ctx.saved_variables | |
# dim, shift, W, H, batch = int_package | |
# dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) | |
# binsize = float(float_package[0]) | |
# assert 1 == trilinear.backward(x, | |
# x_grad, | |
# lut_grad, | |
# dim, | |
# shift, | |
# binsize, | |
# W, | |
# H, | |
# batch) | |
# return lut_grad, x_grad | |
# class TrilinearInterpolation(torch.nn.Module): | |
# def __init__(self): | |
# super(TrilinearInterpolation, self).__init__() | |
# def forward(self, lut, x): | |
# return TrilinearInterpolationFunction.apply(lut, x) | |
class TV_3D(nn.Module): | |
def __init__(self, dim=33): | |
super(TV_3D,self).__init__() | |
self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float) | |
self.weight_r[:,:,:,(0,dim-2)] *= 2.0 | |
self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float) | |
self.weight_g[:,:,(0,dim-2),:] *= 2.0 | |
self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float) | |
self.weight_b[:,(0,dim-2),:,:] *= 2.0 | |
self.relu = torch.nn.ReLU() | |
def forward(self, LUT): | |
dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:] | |
dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:] | |
dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:] | |
tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b)) | |
mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) | |
return tv, mn | |
##new by bing## | |
if __name__=='__main__': | |
def img_process_256(img): | |
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256]) | |
img=img.resize((256,256)) | |
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) | |
img = trans(img) | |
img = torch.unsqueeze(img,0) # 填充一维 | |
print("img",img.size()) | |
# # 将其由HWC格式改成NCHW格式,N=1 | |
# img=np.array(img) | |
return img | |
def img_process_4k(img): | |
# 将PIL类型的图片文件(mode=RGB size=3840x2160,三通道)转换为tensor,tensor维度是[N,C,H,W](即[1,3,256,256]) | |
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) | |
img = trans(img) | |
img = torch.unsqueeze(img,0) # 填充一维 | |
print("img",img.size()) | |
# # 将其由HWC格式改成NCHW格式,N=1 | |
# img=np.array(img) | |
return img | |
img_ori=Image.open("/home/elle/bing/proj/code/download-4k-img/picture/%s" % ("X4_Animal2_BIC_g_03.png")) | |
img=img_process_256(img_ori) | |
img_4k=img_process_4k(img_ori) | |
model=LUT_all() | |
out=model(img_4k) | |
print(out) | |