chenzhicun
初始化web demo.
cda9597
import gradio as gr
from PIL import Image
import torch
from torchvision import transforms
from models.models_x import *
import torchvision_x_functional as TF_x
import torchvision.transforms.functional as TF
from torchvision import transforms
import cv2
from timm.models.hub import download_cached_file
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
trans = transforms.ToTensor()
LUT0 = Generator3DLUT_identity()
LUT1 = Generator3DLUT_zero()
LUT2 = Generator3DLUT_zero()
classifier = Classifier()
trilinear_ = Tritri()
if cuda:
LUT0 = LUT0.cuda()
LUT1 = LUT1.cuda()
LUT2 = LUT2.cuda()
classifier = classifier.cuda()
# Load pretrained models
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/LUTs.pth',
check_hash=False, progress=True)
LUTs = torch.load(cache, map_location=torch.device('cpu'))
LUT0.load_state_dict(LUTs["0"])
LUT1.load_state_dict(LUTs["1"])
LUT2.load_state_dict(LUTs["2"])
LUT0.eval()
LUT1.eval()
LUT2.eval()
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/classifier.pth',
check_hash=False, progress=True)
classifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
classifier.eval()
XLUT0 = Generator3DLUT_identity()
XLUT1 = Generator3DLUT_zero()
XLUT2 = Generator3DLUT_zero()
Xclassifier = Classifier()
Xtrilinear_ = Tritri()
if cuda:
XLUT0 = XLUT0.cuda()
XLUT1 = XLUT1.cuda()
XLUT2 = XLUT2.cuda()
Xclassifier = Xclassifier.cuda()
# Load pretrained models
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/LUTs.pth',
check_hash=False, progress=True)
XLUTs = torch.load(cache, map_location=torch.device('cpu'))
XLUT0.load_state_dict(XLUTs["0"])
XLUT1.load_state_dict(XLUTs["1"])
XLUT2.load_state_dict(XLUTs["2"])
XLUT0.eval()
XLUT1.eval()
XLUT2.eval()
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/classifier.pth',
check_hash=False, progress=True)
Xclassifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu')))
Xclassifier.eval()
def generate_LUT(img):
pred = classifier(img).squeeze()
LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
return LUT
def generate_XLUT(img):
pred = Xclassifier(img).squeeze()
XLUT = pred[0] * XLUT0.LUT + pred[1] * XLUT1.LUT + pred[2] * XLUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
return XLUT
def inference(ori_image, models_n):
with torch.no_grad():
if models_n == 'sRGB':
# img = Image.open(ori_image)
# img = TF.to_tensor(img).type(Tensor)
img = trans(ori_image)
img = img.unsqueeze(0)
LUT = generate_LUT(img)
result = trilinear_(LUT, img)
result = result.permute(0, 3, 1, 2)
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
elif models_n == 'XYZ':
img = trans(ori_image)
img = img.unsqueeze(0)
XLUT = generate_XLUT(img)
result = Xtrilinear_(XLUT, img)
result = result.permute(0, 3, 1, 2)
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
return im
inputs = [gr.inputs.Image(type='pil', label='待增强图片'),
gr.inputs.Radio(choices=['sRGB', 'XYZ'], type="value", default="sRGB", label="图片色彩空间")]
outputs = [gr.outputs.Image(type='pil', label='增强后图片')]
title = '基于LUT的图像增强演示'
gr.Interface(inference, inputs, outputs, title=title, allow_flagging= 'never',
examples=[['./examples/example.jpg', 'sRGB']]).launch(enable_queue=True)