File size: 1,781 Bytes
3f3da6d
 
c7b6585
3f3da6d
3a3fa0c
 
3f3da6d
c7b6585
3f3da6d
5ecfb97
 
38629bc
5ecfb97
 
 
3f3da6d
 
 
 
5ecfb97
 
 
 
 
3f3da6d
5ecfb97
 
 
 
 
 
 
3f3da6d
5ecfb97
 
 
 
3a3fa0c
5ecfb97
 
 
 
 
 
 
 
 
3f3da6d
5ecfb97
3f3da6d
5ecfb97
 
 
3f3da6d
5ecfb97
 
3f3da6d
5ecfb97
3f3da6d
 
 
 
e2167bf
3f3da6d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import numpy as np
from PIL import Image

from tensorflow.keras import models 
from tensorflow.keras.preprocessing.image import load_img
import tensorflow as tf
from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference

models_name = [
    "VGG16",
    "DenseNet121",
    "DenseNet"
]

# open categories.txt in read mode
categories = open("categories.txt", "r")
labels = categories.readline().split(";")

# create a radio
radio = gr.inputs.Radio(models_name, default="DenseNet", type="value")

def predict_image(image, model_name):


    print("======================")
    print(type(image))
    print(type(model_name))
    print("==========")
    print(image)
    print(model_name)
    print("======================")

    if model_name == "DenseNet":
        image = np.array(image) / 255
        image = np.expand_dims(image, axis=0)

        model = model = models.load_model("./models/" + model_name + "/model.h5")
        pred = model.predict(image)

        pred = dict((labels[i], "%.2f" % pred[0][i]) for i in range(len(labels)))
    else:
        
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        classifier = TorchVisionClassifierInference(
            model_path = "./models/" + model_name
        )

        pred = classifier.predict_image(img=image, return_str=False)

        for key in pred.keys():
            pred[key] = pred[key]/100
    

    print(pred)
    return pred

image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
label = gr.outputs.Label(num_top_classes=len(labels))
        
interface = gr.Interface(
    fn=predict_image, 
    inputs=[image, radio], 
    outputs=label, 
    capture_session=True, 
    allow_flagging=False, 
)
interface.launch()