Blazer007 commited on
Commit
7dbf08d
1 Parent(s): 49c126a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import gradio as gr
4
+ from huggingface_hub import from_pretrained_keras
5
+
6
+ teacher_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_teacher_model")
7
+
8
+ student_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_student_model")
9
+
10
+ class_names = [
11
+ "Airplane",
12
+ "Automobile",
13
+ "Bird",
14
+ "Cat",
15
+ "Deer",
16
+ "Dog",
17
+ "Frog",
18
+ "Horse",
19
+ "Ship",
20
+ "Truck",
21
+ ]
22
+
23
+ examples = [
24
+ ['./aeroplane.png'],
25
+ ['./horse.png'],
26
+ ['./ship.png'],
27
+ ['./truck.png']
28
+ ]
29
+
30
+ IMG_SIZE = 72
31
+
32
+ def teacher_model_output(image_tensor):
33
+ predictions = teacher_model.predict(np.expand_dims((image_tensor), axis=0))
34
+ predictions = np.squeeze(predictions)
35
+ predictions = np.argmax(predictions)
36
+ predicted_label = class_names[predictions.item()]
37
+ return str(predicted_label)
38
+
39
+ def student_model_output(image_tensor):
40
+ predictions = student_model.predict(np.expand_dims((image_tensor), axis=0))
41
+ predictions = np.squeeze(predictions)
42
+ predictions = np.argmax(predictions)
43
+ predicted_label = class_names[predictions.item()]
44
+ return str(predicted_label)
45
+
46
+ def infer(input_image):
47
+ image_tensor = tf.convert_to_tensor(input_image)
48
+ image_tensor.set_shape([None, None, 3])
49
+ image_tensor = tf.image.resize(image_tensor, (IMG_SIZE, IMG_SIZE))
50
+ return teacher_model_output(image_tensor), student_model_output(image_tensor)
51
+
52
+ input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE))
53
+ output = [gr.outputs.Label(label = "Teacher Model Output"), gr.outputs.Label(label = "Student Model Output")]
54
+
55
+ title = "Image Classification using Consistency training with supervision"
56
+ description = "Upload an image or select from examples to classify it.<br>The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck.<br><p><b>Teacher Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_teacher_model</b> <br><b> Student Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_student_model </b><br><b>Keras Example - https://keras.io/examples/vision/consistency_training/</b></p>"
57
+
58
+
59
+ article = "<div style='text-align: center;'><a href='https://twitter.com/_Blazer_007' target='_blank'>Space by Vivek Rai</a><br><a href='https://twitter.com/RisingSayak' target='_blank'>Keras example by Sayak Paul</a></div>"
60
+
61
+ gr_interface = gr.Interface(
62
+ infer,
63
+ input,
64
+ output,
65
+ examples=examples,
66
+ allow_flagging=False,
67
+ analytics_enabled=False,
68
+ title=title,
69
+ description=description,
70
+ article=article).launch(enable_queue=True, debug=True)