File size: 2,066 Bytes
07e1105
 
 
 
 
 
 
 
 
7072d7b
07e1105
 
2153e8d
 
 
 
07e1105
 
 
 
 
 
 
63ccb58
07e1105
a32d7be
07e1105
 
 
 
63ccb58
 
 
 
 
 
 
 
 
07e1105
 
 
a32d7be
 
 
 
 
 
 
 
63ccb58
 
 
 
 
 
 
a32d7be
07e1105
 
2fcfba4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torchvision

import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr
import os

def load_image(img_path):
    if isinstance(img_path, str):
        d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    else:
        d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
    d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
    d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
    d_img = np.array(d_img).astype('float32') / 255
    d_img = np.transpose(d_img, (2, 0, 1))
    
    return d_img

# import time
def predict(image):
    global model
    trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])

    """Run a single prediction on the model"""
    img = load_image(image)
    # t = time.time()
    if is_gpu:
        img_tensor = trans(img).unsqueeze(0).cuda()
        iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
        # print('GPU ', time.time() - t)
    else:
        img_tensor = trans(img).unsqueeze(0)
        iq = model(img_tensor).detach().numpy().tolist()[0]
        # print('CPU Time: ', time.time() - t)

    return "The image quality of the image is: {}".format(round(iq, 4))

# os.system("wget https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")

parser = argparse.ArgumentParser()
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
config = parser.parse_args()

is_gpu = torch.cuda.is_available()
if is_gpu:
    model = MoNet.MoNet(config, is_gpu=is_gpu).cuda()
    model.load_state_dict(torch.load('best_model.pkl'))
else:
    model = MoNet.MoNet(config, is_gpu=is_gpu)
    model.load_state_dict(torch.load('best_model.pkl', map_location="cpu"))
model.eval()

interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch()