File size: 1,609 Bytes
39ab112
 
 
 
6e51710
39ab112
 
 
6e51710
39ab112
 
 
 
 
 
 
 
 
 
 
9cf1cd2
39ab112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7444ba5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import numpy as np
import gradio
from huggingface_hub import hf_hub_download
from basicsr.archs.srresnet_arch import MSRResNet

class SimpleRealUpscaler:
    def __init__(self):
        self.device = torch.device('cpu')
        model = MSRResNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=6, upscale=4)
        path = hf_hub_download("xiongjie/realtime-SRGAN-for-anime", filename="SRGAN_x4plus_anime.pth")
        loadnet = torch.load(path)
        if 'params_ema' in loadnet:
            keyname = 'params_ema'
        else:
            keyname = 'params'
        model.load_state_dict(loadnet[keyname], strict=True)
        model.eval()
        self.model = model.to(self.device)

    @torch.no_grad()
    def upscale(self, np_image_rgb):
        image_rgb_tensor = torch.tensor(np_image_rgb[:,:,::-1].astype(np.float32)).to(self.device)
        image_rgb_tensor /= 255
        image_rgb_tensor = image_rgb_tensor.permute(2, 0, 1)
        image_rgb_tensor = image_rgb_tensor.unsqueeze(0)
        output_img = self.model(image_rgb_tensor)
        output_img = output_img.data.squeeze().float().clamp_(0, 1)
        output_img = output_img.permute((1, 2, 0))
        output = (output_img * 255.0).round().cpu().numpy().astype(np.uint8)
        return output[:, :, ::-1]


upscaler = SimpleRealUpscaler()
def upscale(np_image_rgb):
    return upscaler.upscale(np_image_rgb)

css = ".output_image {height: 100% !important; width: 100% !important;}"
inputs = gradio.inputs.Image()
outputs = gradio.outputs.Image()
gradio.Interface(fn=upscale, inputs=inputs, outputs=outputs, css=css).launch()