File size: 1,893 Bytes
5911ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import numpy as np
import tensorflow as tf

def softmax(vector):
    e = np.exp(vector)
    return e / e.sum()

def image_to_output (input_img):
    gr_img = []
    gr_img.append(input_img)
    img2 = tf.image.resize(tf.cast(gr_img, tf.float32)/255. , [224, 224])

    # print(img2)

    x_test = np.asarray(img2)

    prediction = model2.predict(x_test,batch_size=1).flatten()
    prediction = softmax(prediction)

    confidences = {labels[i]: float(prediction[i]) for i in range(102)}
    # confidences = {labels[i]:float(top[i]) for i in range(num_predictions)}

    return confidences

# Download the model checkpoint
import os
import requests
pretrained_repo = 'pretrained_model'
model_repo_link = 'https://huggingface.co/qmjnh/flowerClassification_2/resolve/main'
for item in [
        'variables.data-00000-of-00001',
        'variables.index',
        'keras_metadata.pb',
        'saved_model.pb',
    ]:
    params = requests.get(model_repo_link+item)
    if item.startswith('variables'):
        output_file = os.path.join(pretrained_repo, 'variables', item)
    else:
        output_file = os.path.join(pretrained_repo, item)
    if not os.path.exists(os.path.dirname(output_file)):
        os.makedirs(os.path.dirname(output_file))
    with open(output_file, 'wb') as f: 
        print(f'Downloading from {model_repo_link+item} to {output_file}')
        f.write(params.content)
    

# Load the model
model2=tf.keras.models.load_model(pretrained_repo)

# Read the labels
with open('flower_names.txt') as f:
    labels = f.readlines()

# Run gradio
from gradio.components import Image as gradio_image
from gradio.components import Label as gradio_label
UI=gr.Interface(fn=image_to_output, 
             inputs=gradio_image(shape=(224,224)),
             outputs=gradio_label(num_top_classes=5),
             interpretation="default"
             )
UI.launch()