Update app.py
Browse files
app.py
CHANGED
@@ -22,17 +22,7 @@ def load_image(img_path):
|
|
22 |
return d_img
|
23 |
|
24 |
def predict(image):
|
25 |
-
|
26 |
-
# model related
|
27 |
-
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
|
28 |
-
help='The backbone for MoNet.')
|
29 |
-
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
30 |
-
config = parser.parse_args()
|
31 |
-
|
32 |
-
model = MoNet.MoNet(config).cuda()
|
33 |
-
model.load_state_dict(torch.load('best_model.pkl'))
|
34 |
-
model.eval()
|
35 |
-
|
36 |
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
37 |
|
38 |
"""Run a single prediction on the model"""
|
@@ -42,7 +32,17 @@ def predict(image):
|
|
42 |
|
43 |
return "The image quality of the image is: {}".format(round(iq, 4))
|
44 |
|
45 |
-
os.system("wget
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|
48 |
interface.launch()
|
|
|
22 |
return d_img
|
23 |
|
24 |
def predict(image):
|
25 |
+
global model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
27 |
|
28 |
"""Run a single prediction on the model"""
|
|
|
32 |
|
33 |
return "The image quality of the image is: {}".format(round(iq, 4))
|
34 |
|
35 |
+
# os.system("wget https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
|
36 |
+
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
# model related
|
39 |
+
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.')
|
40 |
+
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
41 |
+
config = parser.parse_args()
|
42 |
+
|
43 |
+
model = MoNet.MoNet(config).cuda()
|
44 |
+
model.load_state_dict(torch.load('best_model.pkl'))
|
45 |
+
model.eval()
|
46 |
|
47 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|
48 |
interface.launch()
|