PeterYoung777's picture
Update app.py
488377e
raw
history blame contribute delete
No virus
2.02 kB
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import efficientnetv2_m as create_model
def predict(img):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_size = {"s": [300, 384], # train_size, val_size
"m": [384, 480],
"l": [384, 480]}
num_model = "m"
data_transform = transforms.Compose(
[transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
json_path = './class_indices.json'
json_file = open(json_path, "r")
class_indict = json.load(json_file)
model = create_model(num_classes=5).to(device)
model_weight_path = "./weights/model-20.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} \n prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
return print_res
import gradio as gr
examples = ['d.jpg', 'rose.jpg', 'rose2.jpg', 'images.jpg']
inter = gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=5),
title = 'Five types of flower Detection',
description= 'This program can be used to detect five types of flowers: "daisy", "dandelion", "roses", "sunflowers", "tulips", and the program will give the classification results along with a confidence score.', theme = 'huggingface')
inter.launch(inline=False,debug=True)