File size: 2,066 Bytes
07e1105 7072d7b 07e1105 2153e8d 07e1105 63ccb58 07e1105 a32d7be 07e1105 63ccb58 07e1105 a32d7be 63ccb58 a32d7be 07e1105 2fcfba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import torchvision
import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr
import os
def load_image(img_path):
if isinstance(img_path, str):
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
else:
d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
d_img = np.array(d_img).astype('float32') / 255
d_img = np.transpose(d_img, (2, 0, 1))
return d_img
# import time
def predict(image):
global model
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
"""Run a single prediction on the model"""
img = load_image(image)
# t = time.time()
if is_gpu:
img_tensor = trans(img).unsqueeze(0).cuda()
iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
# print('GPU ', time.time() - t)
else:
img_tensor = trans(img).unsqueeze(0)
iq = model(img_tensor).detach().numpy().tolist()[0]
# print('CPU Time: ', time.time() - t)
return "The image quality of the image is: {}".format(round(iq, 4))
# os.system("wget https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
parser = argparse.ArgumentParser()
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
config = parser.parse_args()
is_gpu = torch.cuda.is_available()
if is_gpu:
model = MoNet.MoNet(config, is_gpu=is_gpu).cuda()
model.load_state_dict(torch.load('best_model.pkl'))
else:
model = MoNet.MoNet(config, is_gpu=is_gpu)
model.load_state_dict(torch.load('best_model.pkl', map_location="cpu"))
model.eval()
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch() |