USR-DA / main.py
NMHung's picture
Update main.py
f8adcfd
import glob
import io
import os
import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import wget
from torchvision.transforms import Compose, ToTensor
from model import decoder, encoder
WEIGHT_PATH = './weights/best_weight.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(object):
def __init__(self) -> None:
self.model_Enc = encoder.Encoder_RRDB(num_feat=64).to(device=DEVICE)
self.model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=64).to(device=DEVICE)
self.preprocess = Compose([ToTensor()])
self.load_model()
def load_model(self, weight_path=WEIGHT_PATH):
if not os.path.isfile("./weights/best_weight.pth"):
response = wget.download("https://raw.githubusercontent.com/hungnguyen2611/super-resolution/master/weights/best_weight.pth", "./weights/best_weight.pth")
weight = torch.load(weight_path, map_location=torch.device(DEVICE))
print("[LOADING] Loading encoder...")
self.model_Enc.load_state_dict(weight['model_Enc'])
print("[LOADING] Loading decoder...")
self.model_Dec_SR.load_state_dict(weight['model_Dec_SR'])
print("[LOADING] Loading done!")
self.model_Enc.eval()
self.model_Dec_SR.eval()
def predict(self, img):
with torch.no_grad():
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.preprocess(img)
img = img.unsqueeze(0)
img = img.to(DEVICE)
feat = self.model_Enc(img)
out = self.model_Dec_SR(feat)
min_max = (0, 1)
out = out.detach()[0].float().cpu()
out = out.squeeze().float().cpu().clamp_(*min_max)
out = (out - min_max[0]) / (min_max[1] - min_max[0])
out = out.numpy()
out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
out = (out*255.0).round()
out = out.astype(np.uint8)
return out
model = Model()
def predict(img):
global model
img.save("test/1.png", "PNG")
image = cv2.imread("test/1.png", cv2.IMREAD_COLOR)
out = model.predict(img=image)
cv2.imwrite(f'images_uploaded/1.png', out)
return f"images_uploaded/1.png"
if __name__ == '__main__':
title = "Super-Resolution Demo USR-DA Unofficial πŸš€πŸš€πŸ”₯"
description = '''
<br>
**This Demo expects low-quality and low-resolution images, better visual on real-world images**
</br>
'''
article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Unsupervised_Real-World_Super-Resolution_A_Domain_Adaptation_Perspective_ICCV_2021_paper.pdf' target='_blank'>Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective</a> | <a href='https://github.com/hungnguyen2611/super-resolution.git' target='_blank'>Github Repo</a></p>"
examples= glob.glob("testsets/*.png")
gr.Interface(
predict,
gr.inputs.Image(type="pil", label="Input").style(height=260),
gr.inputs.Image(type="pil", label="Ouput").style(height=240),
title=title,
description=description,
article=article,
examples=examples,
).launch(enable_queue=True)