AI-RESEARCHER-2024 commited on
Commit
5916110
·
verified ·
1 Parent(s): 5d764ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -0
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing import image as image_processor
5
+ import numpy as np
6
+ from tensorflow.keras.applications.vgg16 import preprocess_input
7
+ from tensorflow.keras.models import load_model
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ from ultralytics import YOLO
10
+ import cv2
11
+ from huggingface_hub import from_pretrained_keras
12
+
13
+ class Config:
14
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
15
+ MODELS_DIR = os.path.join(ASSETS_DIR, 'models')
16
+ FONT_DIR = os.path.join(ASSETS_DIR, 'arial.ttf')
17
+ MODELS = {
18
+ "Calculus and Caries Classification": "classification.h5",
19
+ "Caries Detection": "detection.pt",
20
+ "Dental X-Ray Segmentation": "dental_xray_seg.h5"
21
+ }
22
+ EXAMPLES = {
23
+ "Calculus and Caries Classification": os.path.join(ASSETS_DIR, 'classification'),
24
+ "Caries Detection": os.path.join(ASSETS_DIR, 'detection'),
25
+ "Dental X-Ray Segmentation": os.path.join(ASSETS_DIR, 'segmentation')
26
+ }
27
+
28
+ class ModelManager:
29
+ @staticmethod
30
+ def load_model(model_name: str):
31
+ model_path = os.path.join(Config.MODELS_DIR, Config.MODELS[model_name])
32
+ if model_name == "Dental X-Ray Segmentation":
33
+ try:
34
+ return from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
35
+ except:
36
+ return tf.keras.models.load_model(model_path)
37
+ elif model_name == "Caries Detection":
38
+ return YOLO(model_path)
39
+ else:
40
+ return load_model(model_path)
41
+
42
+
43
+ class ImageProcessor:
44
+
45
+ def process_image(self, image: Image.Image, model_name: str):
46
+ if model_name == "Calculus and Caries Classification":
47
+ return self.classify_image(image, model_name)
48
+ elif model_name == "Caries Detection":
49
+ return self.detect_caries(image)
50
+ elif model_name == "Dental X-Ray Segmentation":
51
+ return self.segment_dental_xray(image)
52
+
53
+ def classify_image(self, image: Image.Image, model_name: str):
54
+ model = ModelManager.load_model(model_name)
55
+ img = image.resize((224, 224))
56
+ x = image_processor.img_to_array(img)
57
+ x = np.expand_dims(x, axis=0)
58
+ img_data = preprocess_input(x)
59
+ result = model.predict(img_data)
60
+ if result[0][0] > result[0][1]:
61
+ prediction = 'Calculus'
62
+ else:
63
+ prediction = 'Caries'
64
+
65
+ # Draw the classification result on the image
66
+ draw = ImageDraw.Draw(image)
67
+ font = ImageFont.truetype(Config.FONT_DIR, 20)
68
+ text = f"Classified as: {prediction}"
69
+ text_width, text_height = draw.textsize(text, font=font)
70
+ draw.rectangle([(0, 0), (text_width, text_height)], fill="black")
71
+ draw.text((0, 0), text, fill="white", font=font)
72
+
73
+ return image
74
+
75
+ def detect_caries(self, image: Image.Image):
76
+ model = ModelManager.load_model("Caries Detection")
77
+ results = model.predict(image)
78
+ result = results[0]
79
+ draw = ImageDraw.Draw(image)
80
+ font = ImageFont.truetype(Config.FONT_DIR, 20)
81
+
82
+ for box in result.boxes:
83
+ x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()]
84
+ class_id = box.cls[0].item()
85
+ prob = round(box.conf[0].item(), 2)
86
+ label = f"{result.names[class_id]}: {prob}"
87
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
88
+ text_width, text_height = draw.textsize(label, font=font)
89
+ draw.rectangle([(x1, y1 - text_height), (x1 + text_width, y1)], fill="red")
90
+ draw.text((x1, y1 - text_height), label, fill="white", font=font)
91
+
92
+ return image
93
+
94
+ def segment_dental_xray(self, image: Image.Image):
95
+ model = ModelManager.load_model("Dental X-Ray Segmentation")
96
+ img = np.asarray(image)
97
+ img_cv = self.convert_one_channel(img)
98
+ img_cv = cv2.resize(img_cv, (512, 512), interpolation=cv2.INTER_LANCZOS4)
99
+ img_cv = np.float32(img_cv / 255)
100
+ img_cv = np.reshape(img_cv, (1, 512, 512, 1))
101
+ prediction = model.predict(img_cv)
102
+ predicted = prediction[0]
103
+ predicted = cv2.resize(predicted, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
104
+ mask = np.uint8(predicted * 255)
105
+ _, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY + cv2.THRESH_OTSU)
106
+ kernel = np.ones((5, 5), dtype=np.float32)
107
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
108
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
109
+ cnts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
110
+
111
+ # Make a writable copy of the image
112
+ img_writable = self.convert_rgb(img).copy()
113
+ output = cv2.drawContours(img_writable, cnts, -1, (255, 0, 0), 3)
114
+ return Image.fromarray(output)
115
+
116
+ def convert_one_channel(self, img):
117
+ if len(img.shape) > 2:
118
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
119
+ return img
120
+
121
+ def convert_rgb(self, img):
122
+ if len(img.shape) == 2:
123
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
124
+ return img
125
+
126
+
127
+ class GradioInterface:
128
+ def __init__(self):
129
+ self.image_processor = ImageProcessor()
130
+ self.preloaded_examples = self.preload_examples()
131
+
132
+ def preload_examples(self):
133
+ preloaded = {}
134
+ for model_name, example_dir in Config.EXAMPLES.items():
135
+ examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)]
136
+ preloaded[model_name] = examples
137
+ return preloaded
138
+
139
+ def create_interface(self):
140
+ app_styles = """
141
+ <style>
142
+ /* Global Styles */
143
+ body, #root {
144
+ font-family: Helvetica, Arial, sans-serif;
145
+ background-color: #1a1a1a;
146
+ color: #fafafa;
147
+ }
148
+ /* Header Styles */
149
+ .app-header {
150
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
151
+ padding: 24px;
152
+ border-radius: 8px;
153
+ margin-bottom: 24px;
154
+ text-align: center;
155
+ }
156
+ .app-title {
157
+ font-size: 48px;
158
+ margin: 0;
159
+ color: #fafafa;
160
+ }
161
+ .app-subtitle {
162
+ font-size: 24px;
163
+ margin: 8px 0 16px;
164
+ color: #fafafa;
165
+ }
166
+ .app-description {
167
+ font-size: 16px;
168
+ line-height: 1.6;
169
+ opacity: 0.8;
170
+ margin-bottom: 24px;
171
+ }
172
+ /* Button Styles */
173
+ .publication-links {
174
+ display: flex;
175
+ justify-content: center;
176
+ flex-wrap: wrap;
177
+ gap: 8px;
178
+ margin-bottom: 16px;
179
+ }
180
+ .publication-link {
181
+ display: inline-flex;
182
+ align-items: center;
183
+ padding: 8px 16px;
184
+ background-color: #333;
185
+ color: #fff !important;
186
+ text-decoration: none !important;
187
+ border-radius: 20px;
188
+ font-size: 14px;
189
+ transition: background-color 0.3s;
190
+ }
191
+ .publication-link:hover {
192
+ background-color: #555;
193
+ }
194
+ .publication-link i {
195
+ margin-right: 8px;
196
+ }
197
+ /* Content Styles */
198
+ .content-container {
199
+ background-color: #2a2a2a;
200
+ border-radius: 8px;
201
+ padding: 24px;
202
+ margin-bottom: 24px;
203
+ }
204
+ /* Image Styles */
205
+ .image-preview img {
206
+ max-width: 512px;
207
+ max-height: 512px;
208
+ margin: 0 auto;
209
+ border-radius: 4px;
210
+ display: block;
211
+ object-fit: contain;
212
+ }
213
+ /* Control Styles */
214
+ .control-panel {
215
+ background-color: #333;
216
+ padding: 16px;
217
+ border-radius: 8px;
218
+ margin-top: 16px;
219
+ }
220
+ /* Gradio Component Overrides */
221
+ .gr-button {
222
+ background-color: #4a4a4a;
223
+ color: #fff;
224
+ border: none;
225
+ border-radius: 4px;
226
+ padding: 8px 16px;
227
+ cursor: pointer;
228
+ transition: background-color 0.3s;
229
+ }
230
+ .gr-button:hover {
231
+ background-color: #5a5a5a;
232
+ }
233
+ .gr-input, .gr-dropdown {
234
+ background-color: #3a3a3a;
235
+ color: #fff;
236
+ border: 1px solid #4a4a4a;
237
+ border-radius: 4px;
238
+ padding: 8px;
239
+ }
240
+ .gr-form {
241
+ background-color: transparent;
242
+ }
243
+ .gr-panel {
244
+ border: none;
245
+ background-color: transparent;
246
+ }
247
+ /* Override any conflicting styles from Bulma */
248
+ .button.is-normal.is-rounded.is-dark {
249
+ color: #fff !important;
250
+ text-decoration: none !important;
251
+ }
252
+ </style>
253
+ """
254
+
255
+ header_html = f"""
256
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css">
257
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
258
+ {app_styles}
259
+ <div class="app-header">
260
+ <h1 class="app-title">AI in Dentistry</h1>
261
+ <h2 class="app-subtitle"> Advancing Imaging and Clinical Transcription</h2>
262
+ <p class="app-description">
263
+ This application demonstrates the use of AI in dentistry for tasks such as classification, detection, and segmentation.
264
+ </p>
265
+ </div>
266
+ """
267
+
268
+ def process_image(image, model_name):
269
+ result = self.image_processor.process_image(image, model_name)
270
+ return result
271
+
272
+ def update_examples(model_name):
273
+ examples = self.preloaded_examples[model_name]
274
+ return gr.Dataset(samples=[[example] for example in examples])
275
+
276
+ with gr.Blocks() as demo:
277
+ gr.HTML(header_html)
278
+ with gr.Row(elem_classes="content-container"):
279
+ with gr.Column():
280
+ input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
281
+ with gr.Row(elem_classes="control-panel"):
282
+ model_name = gr.Dropdown(
283
+ label="Model",
284
+ choices=list(Config.MODELS.keys()),
285
+ value="Calculus and Caries Classification",
286
+ )
287
+ examples = gr.Examples(
288
+ inputs=input_image,
289
+ examples=self.preloaded_examples["Calculus and Caries Classification"],
290
+ )
291
+ with gr.Column():
292
+ result = gr.Image(label="Result", elem_classes="image-preview")
293
+ run_button = gr.Button("Run", elem_classes="gr-button")
294
+
295
+ model_name.change(
296
+ fn=update_examples,
297
+ inputs=model_name,
298
+ outputs=examples.dataset,
299
+ )
300
+
301
+ run_button.click(
302
+ fn=process_image,
303
+ inputs=[input_image, model_name],
304
+ outputs=result,
305
+ )
306
+
307
+ return demo
308
+
309
+ def main():
310
+ interface = GradioInterface()
311
+ demo = interface.create_interface()
312
+ demo.launch(share=False)
313
+
314
+ if __name__ == "__main__":
315
+ main()