Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from models.models_x import * | |
import torchvision_x_functional as TF_x | |
import torchvision.transforms.functional as TF | |
from torchvision import transforms | |
import cv2 | |
from timm.models.hub import download_cached_file | |
cuda = True if torch.cuda.is_available() else False | |
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | |
trans = transforms.ToTensor() | |
LUT0 = Generator3DLUT_identity() | |
LUT1 = Generator3DLUT_zero() | |
LUT2 = Generator3DLUT_zero() | |
classifier = Classifier() | |
trilinear_ = Tritri() | |
if cuda: | |
LUT0 = LUT0.cuda() | |
LUT1 = LUT1.cuda() | |
LUT2 = LUT2.cuda() | |
classifier = classifier.cuda() | |
# Load pretrained models | |
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/LUTs.pth', | |
check_hash=False, progress=True) | |
LUTs = torch.load(cache, map_location=torch.device('cpu')) | |
LUT0.load_state_dict(LUTs["0"]) | |
LUT1.load_state_dict(LUTs["1"]) | |
LUT2.load_state_dict(LUTs["2"]) | |
LUT0.eval() | |
LUT1.eval() | |
LUT2.eval() | |
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/classifier.pth', | |
check_hash=False, progress=True) | |
classifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu'))) | |
classifier.eval() | |
XLUT0 = Generator3DLUT_identity() | |
XLUT1 = Generator3DLUT_zero() | |
XLUT2 = Generator3DLUT_zero() | |
Xclassifier = Classifier() | |
Xtrilinear_ = Tritri() | |
if cuda: | |
XLUT0 = XLUT0.cuda() | |
XLUT1 = XLUT1.cuda() | |
XLUT2 = XLUT2.cuda() | |
Xclassifier = Xclassifier.cuda() | |
# Load pretrained models | |
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/LUTs.pth', | |
check_hash=False, progress=True) | |
XLUTs = torch.load(cache, map_location=torch.device('cpu')) | |
XLUT0.load_state_dict(XLUTs["0"]) | |
XLUT1.load_state_dict(XLUTs["1"]) | |
XLUT2.load_state_dict(XLUTs["2"]) | |
XLUT0.eval() | |
XLUT1.eval() | |
XLUT2.eval() | |
cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/classifier.pth', | |
check_hash=False, progress=True) | |
Xclassifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu'))) | |
Xclassifier.eval() | |
def generate_LUT(img): | |
pred = classifier(img).squeeze() | |
LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT | |
return LUT | |
def generate_XLUT(img): | |
pred = Xclassifier(img).squeeze() | |
XLUT = pred[0] * XLUT0.LUT + pred[1] * XLUT1.LUT + pred[2] * XLUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT | |
return XLUT | |
def inference(ori_image, models_n): | |
with torch.no_grad(): | |
if models_n == 'sRGB': | |
# img = Image.open(ori_image) | |
# img = TF.to_tensor(img).type(Tensor) | |
img = trans(ori_image) | |
img = img.unsqueeze(0) | |
LUT = generate_LUT(img) | |
result = trilinear_(LUT, img) | |
result = result.permute(0, 3, 1, 2) | |
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | |
im = Image.fromarray(ndarr) | |
elif models_n == 'XYZ': | |
img = trans(ori_image) | |
img = img.unsqueeze(0) | |
XLUT = generate_XLUT(img) | |
result = Xtrilinear_(XLUT, img) | |
result = result.permute(0, 3, 1, 2) | |
ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | |
im = Image.fromarray(ndarr) | |
return im | |
inputs = [gr.inputs.Image(type='pil', label='待增强图片'), | |
gr.inputs.Radio(choices=['sRGB', 'XYZ'], type="value", default="sRGB", label="图片色彩空间")] | |
outputs = [gr.outputs.Image(type='pil', label='增强后图片')] | |
title = '基于LUT的图像增强演示' | |
gr.Interface(inference, inputs, outputs, title=title, allow_flagging= 'never', | |
examples=[['./examples/example.jpg', 'sRGB']]).launch(enable_queue=True) |