qmjnh commited on
Commit
5911ce1
1 Parent(s): 737a245

Create new file

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+
5
+ def softmax(vector):
6
+ e = np.exp(vector)
7
+ return e / e.sum()
8
+
9
+ def image_to_output (input_img):
10
+ gr_img = []
11
+ gr_img.append(input_img)
12
+ img2 = tf.image.resize(tf.cast(gr_img, tf.float32)/255. , [224, 224])
13
+
14
+ # print(img2)
15
+
16
+ x_test = np.asarray(img2)
17
+
18
+ prediction = model2.predict(x_test,batch_size=1).flatten()
19
+ prediction = softmax(prediction)
20
+
21
+ confidences = {labels[i]: float(prediction[i]) for i in range(102)}
22
+ # confidences = {labels[i]:float(top[i]) for i in range(num_predictions)}
23
+
24
+ return confidences
25
+
26
+ # Download the model checkpoint
27
+ import os
28
+ import requests
29
+ pretrained_repo = 'pretrained_model'
30
+ model_repo_link = 'https://huggingface.co/qmjnh/flowerClassification_2/resolve/main'
31
+ for item in [
32
+ 'variables.data-00000-of-00001',
33
+ 'variables.index',
34
+ 'keras_metadata.pb',
35
+ 'saved_model.pb',
36
+ ]:
37
+ params = requests.get(model_repo_link+item)
38
+ if item.startswith('variables'):
39
+ output_file = os.path.join(pretrained_repo, 'variables', item)
40
+ else:
41
+ output_file = os.path.join(pretrained_repo, item)
42
+ if not os.path.exists(os.path.dirname(output_file)):
43
+ os.makedirs(os.path.dirname(output_file))
44
+ with open(output_file, 'wb') as f:
45
+ print(f'Downloading from {model_repo_link+item} to {output_file}')
46
+ f.write(params.content)
47
+
48
+
49
+ # Load the model
50
+ model2=tf.keras.models.load_model(pretrained_repo)
51
+
52
+ # Read the labels
53
+ with open('flower_names.txt') as f:
54
+ labels = f.readlines()
55
+
56
+ # Run gradio
57
+ from gradio.components import Image as gradio_image
58
+ from gradio.components import Label as gradio_label
59
+ UI=gr.Interface(fn=image_to_output,
60
+ inputs=gradio_image(shape=(224,224)),
61
+ outputs=gradio_label(num_top_classes=5),
62
+ interpretation="default"
63
+ )
64
+ UI.launch()