PeterYoung777's picture
Create app.py
e2c3f43
raw
history blame
No virus
1.66 kB
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import efficientnetv2_m as create_model
def predict(img):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_size = {"s": [300, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "s"
data_transform = transforms.Compose(
[transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
json_path = './class_indices.json'
json_file = open(json_path, "r")
class_indict = json.load(json_file)
model = create_model(num_classes=5).to(device)
model_weight_path = "./weights/model-20.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
return print_res
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=5),
theme="default").launch()