File size: 1,839 Bytes
3f3da6d
 
c7b6585
3f3da6d
3a3fa0c
 
3f3da6d
c7b6585
3f3da6d
5ecfb97
 
38629bc
5ecfb97
 
 
3f3da6d
 
 
 
5ecfb97
7ea38a3
5ecfb97
 
295ab92
5ecfb97
 
 
 
3a3fa0c
5ecfb97
 
 
 
295ab92
 
5ecfb97
 
 
 
3f3da6d
5ecfb97
3f3da6d
5ecfb97
 
 
3f3da6d
5ecfb97
 
3f3da6d
5ecfb97
3f3da6d
f447626
0044283
3f3da6d
 
 
e2167bf
3f3da6d
 
57c2b2d
6d0987e
f447626
0aa401d
d136b79
3f221c4
3f3da6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import numpy as np
from PIL import Image

from tensorflow.keras import models 
from tensorflow.keras.preprocessing.image import load_img
import tensorflow as tf
from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference

models_name = [
    "VGG16",
    "DenseNet121",
    "DenseNet"
]

# open categories.txt in read mode
categories = open("categories.txt", "r")
labels = categories.readline().split(";")

# create a radio
radio = gr.inputs.Radio(models_name, default="DenseNet121", type="value")

def predict_image(image, model_name):
    # model create by keras 
    if model_name == "DenseNet":
        image = np.array(image) / 255
        image = np.expand_dims(image, axis=0)

        model = model = models.load_model("./models/" + model_name + "/model.h5")
        pred = model.predict(image)

        pred = dict((labels[i], "%.2f" % pred[0][i]) for i in range(len(labels)))
        
    # model create by HugsVision
    else:
        image = Image.fromarray(np.uint8(image)).convert('RGB')
        classifier = TorchVisionClassifierInference(
            model_path = "./models/" + model_name
        )

        pred = classifier.predict_image(img=image, return_str=False)

        for key in pred.keys():
            pred[key] = pred[key]/100
    

    print(pred)
    return pred

image = gr.inputs.Image(shape=(300, 300), label="Upload Your Image Here")
label = gr.outputs.Label(num_top_classes=len(labels))

samples = [["samples/" + p + ".jpg"] for p in labels]
        
interface = gr.Interface(
    fn=predict_image, 
    inputs=[image, radio], 
    outputs=label, 
    capture_session=True, 
    allow_flagging=False,
    title="🦈 Shark image classifier",
    description="Made with HugsVision & ❤️",
    examples=samples,
    theme=None
) 
interface.launch()