drkareemkamal commited on
Commit
5333123
1 Parent(s): a3fbc07

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from timeit import default_timer as timer
4
+ from typing import Tuple , Dict
5
+ import tensorflow as tf
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ import os
10
+
11
+ # 1.Import and class names setup
12
+ class_names = ['CNV','DME','DRUSEN','NORMAL']
13
+
14
+
15
+ # 2. Model annd transforms prepration
16
+ # model = tf.keras.models.load_model(
17
+ # 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=True
18
+ # )
19
+ model = tf.keras.models.load_model(
20
+ 'oct_classification_final_model_lg.keras', custom_objects=None, compile=True, safe_mode=False
21
+ )
22
+
23
+
24
+ # Load save weights
25
+
26
+ # 3.prediction function (predict())
27
+
28
+ def load_and_prep_imgg(img : Image.Image, img_shape=224, scale=True):
29
+ # if not isinstance(filename, str):
30
+ # raise ValueError("The filename must be a string representing the file path.")
31
+ # img = tf.io.read_file(filename)
32
+ # img = tf.io.decode_image(img, channels=3)
33
+ # img = tf.image.resize(img, size=[img_shape, img_shape])
34
+ # if scale:
35
+ # return img / 255
36
+ # else:
37
+ # return img
38
+ img = img.resize((img_shape, img_shape))
39
+ img = np.array(img)
40
+ if img.shape[-1] == 1: # If the image is grayscale
41
+ img = np.stack([img] * 3, axis=-1)
42
+ img = tf.convert_to_tensor(img, dtype=tf.float32)
43
+ if scale:
44
+ return img / 255.0
45
+ else:
46
+ return img
47
+
48
+ def predict(img) -> Tuple[Dict,float] :
49
+
50
+ start_time = timer()
51
+
52
+ image = load_and_prep_imgg(img)
53
+ #image = Image.open(image)
54
+
55
+ pred_img = model.predict(tf.expand_dims(image, axis=0))
56
+ pred_class = class_names[pred_img.argmax()]
57
+ print(f"Predicted macular diseases is: {pred_class} with probability: {pred_img.max():.2f}")
58
+
59
+
60
+
61
+
62
+ end_time = timer()
63
+ pred_time = round(end_time - start_time , 4)
64
+
65
+ return pred_class , pred_time
66
+
67
+ ### 4. Gradio app - our Gradio interface + launch command
68
+
69
+ title = 'Macular Disease Classification'
70
+ description = 'Feature Extraction VGG model to classify Macular Diseases by OCT'
71
+ article = 'Created with TensorFlow Model Deployment'
72
+ # Create example list
73
+
74
+ example_list = [['examples/'+ example] for example in os.listdir('examples')]
75
+ example_list
76
+
77
+ # create a gradio demo
78
+ demo = gr.Interface(fn=predict ,
79
+ inputs=gr.Image(type='pil'),
80
+ outputs=[gr.Label(num_top_classes = 3 , label= 'prediction'),
81
+ gr.Number(label= 'Prediction time (s)')],
82
+ examples = example_list,
83
+ title = title,
84
+ description = description,
85
+ article= article)
86
+
87
+ # Launch the demo
88
+ demo.launch(debug= False)