import torch import torchvision import cv2 import numpy as np from models import monet as MoNet import argparse from utils.dataset.process import ToTensor, Normalize import gradio as gr import os def load_image(img_path): if isinstance(img_path, str): d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) else: d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR) d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC) d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) d_img = np.array(d_img).astype('float32') / 255 d_img = np.transpose(d_img, (2, 0, 1)) return d_img def predict(image): parser = argparse.ArgumentParser() # model related parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.') parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.') config = parser.parse_args() model = MoNet.MoNet(config).cuda() model.load_state_dict(torch.load('best_model.pkl')) model.eval() trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) """Run a single prediction on the model""" img = load_image(image) img_tensor = trans(img).unsqueeze(0).cuda() iq = model(img_tensor).cpu().detach().numpy().tolist()[0] return "The image quality of the image is: {}".format(round(iq, 4)) os.system("wget -O best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl") interface = gr.Interface(fn=predict, inputs="image", outputs="text") interface.launch()