File size: 3,272 Bytes
e5b70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8adcfd
e5b70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8adcfd
e5b70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import glob
import io
import os

import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import wget
from torchvision.transforms import Compose, ToTensor

from model import decoder, encoder

WEIGHT_PATH = './weights/best_weight.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Model(object):
    def __init__(self) -> None:
        self.model_Enc = encoder.Encoder_RRDB(num_feat=64).to(device=DEVICE)
        self.model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=64).to(device=DEVICE)
        self.preprocess = Compose([ToTensor()])
        self.load_model()

    def load_model(self, weight_path=WEIGHT_PATH):
        if not os.path.isfile("./weights/best_weight.pth"):
            response = wget.download("https://raw.githubusercontent.com/hungnguyen2611/super-resolution/master/weights/best_weight.pth", "./weights/best_weight.pth")
        weight = torch.load(weight_path, map_location=torch.device(DEVICE))
        print("[LOADING] Loading encoder...")
        self.model_Enc.load_state_dict(weight['model_Enc'])
        print("[LOADING] Loading decoder...")
        self.model_Dec_SR.load_state_dict(weight['model_Dec_SR'])
        print("[LOADING] Loading done!")
        self.model_Enc.eval()
        self.model_Dec_SR.eval()

    def predict(self, img):
        with torch.no_grad():
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = self.preprocess(img)
            img = img.unsqueeze(0)
            img = img.to(DEVICE)

            feat = self.model_Enc(img)
            out = self.model_Dec_SR(feat)
            min_max = (0, 1)
            out = out.detach()[0].float().cpu()

            out = out.squeeze().float().cpu().clamp_(*min_max)
            out = (out - min_max[0]) / (min_max[1] - min_max[0])
            out = out.numpy()
            out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))

            out = (out*255.0).round()
            out = out.astype(np.uint8)
            return out

model = Model()

def predict(img):
    global model
    img.save("test/1.png", "PNG")
    image = cv2.imread("test/1.png", cv2.IMREAD_COLOR)
    out = model.predict(img=image)

    cv2.imwrite(f'images_uploaded/1.png', out)
    return f"images_uploaded/1.png"




if __name__ == '__main__':
    title = "Super-Resolution Demo USR-DA Unofficial πŸš€πŸš€πŸ”₯"
    description = ''' 
<br>
**This Demo expects low-quality and low-resolution images, better visual on real-world images**
</br>
'''
    article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Unsupervised_Real-World_Super-Resolution_A_Domain_Adaptation_Perspective_ICCV_2021_paper.pdf' target='_blank'>Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective</a> | <a href='https://github.com/hungnguyen2611/super-resolution.git' target='_blank'>Github Repo</a></p>"
    examples= glob.glob("testsets/*.png")
    gr.Interface(
        predict, 
        gr.inputs.Image(type="pil", label="Input").style(height=260),
        gr.inputs.Image(type="pil", label="Ouput").style(height=240),
        title=title,
        description=description,
        article=article,
        examples=examples,
        ).launch(enable_queue=True)