File size: 2,019 Bytes
e2c3f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488377e
e2c3f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0c8de1
15f2b4c
e2c3f43
 
 
27b0208
63b55ce
e2c3f43
 
63b55ce
4bbf391
63b55ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import efficientnetv2_m as create_model



def predict(img):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    img_size = {"s": [300, 384],  # train_size, val_size
                "m": [384, 480],
                "l": [384, 480]}
    num_model = "m"

    data_transform = transforms.Compose(
        [transforms.Resize(img_size[num_model][1]),
         transforms.CenterCrop(img_size[num_model][1]),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
    json_path = './class_indices.json'
    json_file = open(json_path, "r")
    class_indict = json.load(json_file)
    model = create_model(num_classes=5).to(device)
    model_weight_path = "./weights/model-20.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
    print_res = "class: {}    \n prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    return print_res

import gradio as gr
examples = ['d.jpg', 'rose.jpg', 'rose2.jpg', 'images.jpg']
inter = gr.Interface(fn=predict, 
             inputs=gr.inputs.Image(type="pil"),
             outputs=gr.outputs.Label(num_top_classes=5),
             title = 'Five types of flower Detection',
description= 'This program can be used to detect five types of flowers: "daisy", "dandelion", "roses", "sunflowers", "tulips", and the program will give the classification results along with a confidence score.', theme = 'huggingface')
inter.launch(inline=False,debug=True)