Spaces:
Runtime error
Runtime error
File size: 3,537 Bytes
9080570 33e2863 9080570 a8463d2 dae1a1c 9080570 0f6fd48 9080570 0f6fd48 9080570 0f6fd48 b63ae53 9080570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import os
import yaml
import torch
import argparse
import numpy as np
import gradio as gr
from PIL import Image
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider
## local code
from models import seemore
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def load_img (filename, norm=True,):
img = np.array(Image.open(filename).convert("RGB"))
h, w = img.shape[:2]
if w > 1920 or h > 1080:
new_h, new_w = h // 4, w // 4
img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC))
if norm:
img = img / 255.
img = img.astype(np.float32)
return img
def process_img (image):
img = np.array(image)
img = img / 255.
img = img.astype(np.float32)
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0. , 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
#return Image.fromarray(restored_img) #
return (image, Image.fromarray(restored_img))
def load_network(net, load_path, strict=True, param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
load_net = load_net[param_key]
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
CONFIG = "configs/eval_seemore_t_x4.yml"
hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./")
MODEL_NAME = "SeemoRe_T_X4.pth"
# parse config file
with open(os.path.join(CONFIG), "r") as f:
config = yaml.safe_load(f)
cfg = dict2namespace(config)
device = torch.device("cpu")
model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim,
img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size,
recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
model = model.to(device)
print ("IMAGE MODEL CKPT:", MODEL_NAME)
load_network(model, MODEL_NAME, strict=True, param_key='params')
css = """
/* .image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}*/
footer {visibility: hidden !important;}
"""
demo = gr.Interface(
fn=process_img,
inputs=[gr.Image(type="pil", label="Изображение"),],
outputs=[gr.Image(type="pil", label="Расширеное изображение", min_width=500),],
css=css,
)
if __name__ == "__main__":
demo.launch() |