File size: 3,680 Bytes
5911ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc4c6d5
5911ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24120e7
 
ee18541
24120e7
 
 
 
 
 
 
 
 
 
83d0585
 
6ec9c73
 
 
 
 
 
 
 
 
 
 
 
17ff198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import numpy as np
import tensorflow as tf

def softmax(vector):
    e = np.exp(vector)
    return e / e.sum()

def image_to_output (input_img):
    gr_img = []
    gr_img.append(input_img)
    img2 = tf.image.resize(tf.cast(gr_img, tf.float32)/255. , [224, 224])

    # print(img2)

    x_test = np.asarray(img2)

    prediction = model2.predict(x_test,batch_size=1).flatten()
    prediction = softmax(prediction)

    confidences = {labels[i]: float(prediction[i]) for i in range(102)}
    # confidences = {labels[i]:float(top[i]) for i in range(num_predictions)}

    return confidences

# Download the model checkpoint
import os
import requests
pretrained_repo = 'pretrained_model'
model_repo_link = 'https://huggingface.co/qmjnh/flowerClassification_2/resolve/main/'
for item in [
        'variables.data-00000-of-00001',
        'variables.index',
        'keras_metadata.pb',
        'saved_model.pb',
    ]:
    params = requests.get(model_repo_link+item)
    if item.startswith('variables'):
        output_file = os.path.join(pretrained_repo, 'variables', item)
    else:
        output_file = os.path.join(pretrained_repo, item)
    if not os.path.exists(os.path.dirname(output_file)):
        os.makedirs(os.path.dirname(output_file))
    with open(output_file, 'wb') as f: 
        print(f'Downloading from {model_repo_link+item} to {output_file}')
        f.write(params.content)
    

# Load the model
model2=tf.keras.models.load_model(pretrained_repo)

# Read the labels
with open('flower_names.txt') as f:
    labels = f.readlines()

# Run gradio
from gradio.components import Image as gradio_image
from gradio.components import Label as gradio_label
UI=gr.Interface(fn=image_to_output, 
             inputs=gradio_image(shape=(224,224)),
             outputs=gradio_label(num_top_classes=5),
             interpretation="default"
             )
             
             
description = "This model was trained to recognize 102 types of flowers. For the model to work with high accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)"

UI=gr.Interface(fn=image_to_output, 
             inputs=gradio_image(shape=(224,224)),
             outputs=gradio_label(num_top_classes=5),
             interpretation="default",
             description=description,
             title="Flower Classifier",
             article="*built by qmjnh*"
             )

description = "This model was trained to recognize 102 types of flowers. For the model to work with high-accuracy, refer to the trained flowers [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt)"
article1="This is an AI model trained to predict the name of the flower in the input picture. To try out the model, simply drop/upload a picture into the '*input box*' and press 'Submit'. The predictions will show up in the '*output box*'\n. Since the model was only trained to classify 102 types of flowers (flowers list can be found [here](https://huggingface.co/spaces/qmjnh/flowerClassification_2/blob/main/flower_names.txt) ), the prediction might be incorrect, but chances are if you try googling the names predicted by the model, the resulting flower will be very familiar to that in your picture. "
article2="\n *built by qmjnh*"
UI=gr.Interface(fn=image_to_output, 
             inputs=gradio_image(shape=(224,224)),
             outputs=gradio_label(num_top_classes=5),
             interpretation="none",
             description=description,
             title="Flower Classifier",
             article= article1 + article2
             )



UI.launch(share=True)