import gradio as gr from PIL import Image import torch from torchvision import transforms from models.models_x import * import torchvision_x_functional as TF_x import torchvision.transforms.functional as TF from torchvision import transforms import cv2 from timm.models.hub import download_cached_file cuda = True if torch.cuda.is_available() else False Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor trans = transforms.ToTensor() LUT0 = Generator3DLUT_identity() LUT1 = Generator3DLUT_zero() LUT2 = Generator3DLUT_zero() classifier = Classifier() trilinear_ = Tritri() if cuda: LUT0 = LUT0.cuda() LUT1 = LUT1.cuda() LUT2 = LUT2.cuda() classifier = classifier.cuda() # Load pretrained models cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/LUTs.pth', check_hash=False, progress=True) LUTs = torch.load(cache, map_location=torch.device('cpu')) LUT0.load_state_dict(LUTs["0"]) LUT1.load_state_dict(LUTs["1"]) LUT2.load_state_dict(LUTs["2"]) LUT0.eval() LUT1.eval() LUT2.eval() cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/sRGB/classifier.pth', check_hash=False, progress=True) classifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu'))) classifier.eval() XLUT0 = Generator3DLUT_identity() XLUT1 = Generator3DLUT_zero() XLUT2 = Generator3DLUT_zero() Xclassifier = Classifier() Xtrilinear_ = Tritri() if cuda: XLUT0 = XLUT0.cuda() XLUT1 = XLUT1.cuda() XLUT2 = XLUT2.cuda() Xclassifier = Xclassifier.cuda() # Load pretrained models cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/LUTs.pth', check_hash=False, progress=True) XLUTs = torch.load(cache, map_location=torch.device('cpu')) XLUT0.load_state_dict(XLUTs["0"]) XLUT1.load_state_dict(XLUTs["1"]) XLUT2.load_state_dict(XLUTs["2"]) XLUT0.eval() XLUT1.eval() XLUT2.eval() cache = download_cached_file('https://czc-checkpoint.oss-cn-hangzhou.aliyuncs.com/bing/XYZ/classifier.pth', check_hash=False, progress=True) Xclassifier.load_state_dict(torch.load(cache, map_location=torch.device('cpu'))) Xclassifier.eval() def generate_LUT(img): pred = classifier(img).squeeze() LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT return LUT def generate_XLUT(img): pred = Xclassifier(img).squeeze() XLUT = pred[0] * XLUT0.LUT + pred[1] * XLUT1.LUT + pred[2] * XLUT2.LUT # + pred[3] * LUT3.LUT + pred[4] * LUT4.LUT return XLUT def inference(ori_image, models_n): with torch.no_grad(): if models_n == 'sRGB': # img = Image.open(ori_image) # img = TF.to_tensor(img).type(Tensor) img = trans(ori_image) img = img.unsqueeze(0) LUT = generate_LUT(img) result = trilinear_(LUT, img) result = result.permute(0, 3, 1, 2) ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) elif models_n == 'XYZ': img = trans(ori_image) img = img.unsqueeze(0) XLUT = generate_XLUT(img) result = Xtrilinear_(XLUT, img) result = result.permute(0, 3, 1, 2) ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) return im inputs = [gr.inputs.Image(type='pil', label='待增强图片'), gr.inputs.Radio(choices=['sRGB', 'XYZ'], type="value", default="sRGB", label="图片色彩空间")] outputs = [gr.outputs.Image(type='pil', label='增强后图片')] title = '基于LUT的图像增强演示' gr.Interface(inference, inputs, outputs, title=title, allow_flagging= 'never', examples=[['./examples/example.jpg', 'sRGB']]).launch(enable_queue=True)