|
import os |
|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing import image as image_processor |
|
import numpy as np |
|
from tensorflow.keras.applications.vgg16 import preprocess_input |
|
from tensorflow.keras.models import load_model |
|
from PIL import Image, ImageDraw, ImageFont |
|
from ultralytics import YOLO |
|
import cv2 |
|
from huggingface_hub import from_pretrained_keras |
|
|
|
class Config: |
|
ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') |
|
MODELS_DIR = os.path.join(ASSETS_DIR, 'models') |
|
FONT_DIR = os.path.join(ASSETS_DIR, 'arial.ttf') |
|
MODELS = { |
|
"Calculus and Caries Classification": "classification.h5", |
|
"Caries Detection": "detection.pt", |
|
"Dental X-Ray Segmentation": "dental_xray_seg.h5" |
|
} |
|
EXAMPLES = { |
|
"Calculus and Caries Classification": os.path.join(ASSETS_DIR, 'classification'), |
|
"Caries Detection": os.path.join(ASSETS_DIR, 'detection'), |
|
"Dental X-Ray Segmentation": os.path.join(ASSETS_DIR, 'segmentation') |
|
} |
|
|
|
class ModelManager: |
|
@staticmethod |
|
def load_model(model_name: str): |
|
model_path = os.path.join(Config.MODELS_DIR, Config.MODELS[model_name]) |
|
if model_name == "Dental X-Ray Segmentation": |
|
try: |
|
return from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net") |
|
except: |
|
return tf.keras.models.load_model(model_path) |
|
elif model_name == "Caries Detection": |
|
return YOLO(model_path) |
|
else: |
|
return load_model(model_path) |
|
|
|
|
|
class ImageProcessor: |
|
|
|
def process_image(self, image: Image.Image, model_name: str): |
|
if model_name == "Calculus and Caries Classification": |
|
return self.classify_image(image, model_name) |
|
elif model_name == "Caries Detection": |
|
return self.detect_caries(image) |
|
elif model_name == "Dental X-Ray Segmentation": |
|
return self.segment_dental_xray(image) |
|
|
|
def classify_image(self, image: Image.Image, model_name: str): |
|
model = ModelManager.load_model(model_name) |
|
img = image.resize((224, 224)) |
|
x = image_processor.img_to_array(img) |
|
x = np.expand_dims(x, axis=0) |
|
img_data = preprocess_input(x) |
|
result = model.predict(img_data) |
|
if result[0][0] > result[0][1]: |
|
prediction = 'Calculus' |
|
else: |
|
prediction = 'Caries' |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.truetype(Config.FONT_DIR, 20) |
|
text = f"Classified as: {prediction}" |
|
text_width, text_height = draw.textsize(text, font=font) |
|
draw.rectangle([(0, 0), (text_width, text_height)], fill="black") |
|
draw.text((0, 0), text, fill="white", font=font) |
|
|
|
return image |
|
|
|
def detect_caries(self, image: Image.Image): |
|
model = ModelManager.load_model("Caries Detection") |
|
results = model.predict(image) |
|
result = results[0] |
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.truetype(Config.FONT_DIR, 20) |
|
|
|
for box in result.boxes: |
|
x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()] |
|
class_id = box.cls[0].item() |
|
prob = round(box.conf[0].item(), 2) |
|
label = f"{result.names[class_id]}: {prob}" |
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=2) |
|
text_width, text_height = draw.textsize(label, font=font) |
|
draw.rectangle([(x1, y1 - text_height), (x1 + text_width, y1)], fill="red") |
|
draw.text((x1, y1 - text_height), label, fill="white", font=font) |
|
|
|
return image |
|
|
|
def segment_dental_xray(self, image: Image.Image): |
|
model = ModelManager.load_model("Dental X-Ray Segmentation") |
|
img = np.asarray(image) |
|
img_cv = self.convert_one_channel(img) |
|
img_cv = cv2.resize(img_cv, (512, 512), interpolation=cv2.INTER_LANCZOS4) |
|
img_cv = np.float32(img_cv / 255) |
|
img_cv = np.reshape(img_cv, (1, 512, 512, 1)) |
|
prediction = model.predict(img_cv) |
|
predicted = prediction[0] |
|
predicted = cv2.resize(predicted, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LANCZOS4) |
|
mask = np.uint8(predicted * 255) |
|
_, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
|
kernel = np.ones((5, 5), dtype=np.float32) |
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1) |
|
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) |
|
cnts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
img_writable = self.convert_rgb(img).copy() |
|
output = cv2.drawContours(img_writable, cnts, -1, (255, 0, 0), 3) |
|
return Image.fromarray(output) |
|
|
|
def convert_one_channel(self, img): |
|
if len(img.shape) > 2: |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
return img |
|
|
|
def convert_rgb(self, img): |
|
if len(img.shape) == 2: |
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
return img |
|
|
|
|
|
class GradioInterface: |
|
def __init__(self): |
|
self.image_processor = ImageProcessor() |
|
self.preloaded_examples = self.preload_examples() |
|
|
|
def preload_examples(self): |
|
preloaded = {} |
|
for model_name, example_dir in Config.EXAMPLES.items(): |
|
examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)] |
|
preloaded[model_name] = examples |
|
return preloaded |
|
|
|
def create_interface(self): |
|
app_styles = """ |
|
<style> |
|
/* Global Styles */ |
|
body, #root { |
|
font-family: Helvetica, Arial, sans-serif; |
|
background-color: #1a1a1a; |
|
color: #fafafa; |
|
} |
|
/* Header Styles */ |
|
.app-header { |
|
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); |
|
padding: 24px; |
|
border-radius: 8px; |
|
margin-bottom: 24px; |
|
text-align: center; |
|
} |
|
.app-title { |
|
font-size: 48px; |
|
margin: 0; |
|
color: #fafafa; |
|
} |
|
.app-subtitle { |
|
font-size: 24px; |
|
margin: 8px 0 16px; |
|
color: #fafafa; |
|
} |
|
.app-description { |
|
font-size: 16px; |
|
line-height: 1.6; |
|
opacity: 0.8; |
|
margin-bottom: 24px; |
|
} |
|
/* Button Styles */ |
|
.publication-links { |
|
display: flex; |
|
justify-content: center; |
|
flex-wrap: wrap; |
|
gap: 8px; |
|
margin-bottom: 16px; |
|
} |
|
.publication-link { |
|
display: inline-flex; |
|
align-items: center; |
|
padding: 8px 16px; |
|
background-color: #333; |
|
color: #fff !important; |
|
text-decoration: none !important; |
|
border-radius: 20px; |
|
font-size: 14px; |
|
transition: background-color 0.3s; |
|
} |
|
.publication-link:hover { |
|
background-color: #555; |
|
} |
|
.publication-link i { |
|
margin-right: 8px; |
|
} |
|
/* Content Styles */ |
|
.content-container { |
|
background-color: #2a2a2a; |
|
border-radius: 8px; |
|
padding: 24px; |
|
margin-bottom: 24px; |
|
} |
|
/* Image Styles */ |
|
.image-preview img { |
|
max-width: 512px; |
|
max-height: 512px; |
|
margin: 0 auto; |
|
border-radius: 4px; |
|
display: block; |
|
object-fit: contain; |
|
} |
|
/* Control Styles */ |
|
.control-panel { |
|
background-color: #333; |
|
padding: 16px; |
|
border-radius: 8px; |
|
margin-top: 16px; |
|
} |
|
/* Gradio Component Overrides */ |
|
.gr-button { |
|
background-color: #4a4a4a; |
|
color: #fff; |
|
border: none; |
|
border-radius: 4px; |
|
padding: 8px 16px; |
|
cursor: pointer; |
|
transition: background-color 0.3s; |
|
} |
|
.gr-button:hover { |
|
background-color: #5a5a5a; |
|
} |
|
.gr-input, .gr-dropdown { |
|
background-color: #3a3a3a; |
|
color: #fff; |
|
border: 1px solid #4a4a4a; |
|
border-radius: 4px; |
|
padding: 8px; |
|
} |
|
.gr-form { |
|
background-color: transparent; |
|
} |
|
.gr-panel { |
|
border: none; |
|
background-color: transparent; |
|
} |
|
/* Override any conflicting styles from Bulma */ |
|
.button.is-normal.is-rounded.is-dark { |
|
color: #fff !important; |
|
text-decoration: none !important; |
|
} |
|
</style> |
|
""" |
|
|
|
header_html = f""" |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> |
|
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> |
|
{app_styles} |
|
<div class="app-header"> |
|
<h1 class="app-title">AI in Dentistry</h1> |
|
<h2 class="app-subtitle"> Advancing Imaging and Clinical Transcription</h2> |
|
<p class="app-description"> |
|
This application demonstrates the use of AI in dentistry for tasks such as classification, detection, and segmentation. |
|
</p> |
|
</div> |
|
""" |
|
|
|
def process_image(image, model_name): |
|
result = self.image_processor.process_image(image, model_name) |
|
return result |
|
|
|
def update_examples(model_name): |
|
examples = self.preloaded_examples[model_name] |
|
return gr.Dataset(samples=[[example] for example in examples]) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(header_html) |
|
with gr.Row(elem_classes="content-container"): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") |
|
with gr.Row(elem_classes="control-panel"): |
|
model_name = gr.Dropdown( |
|
label="Model", |
|
choices=list(Config.MODELS.keys()), |
|
value="Calculus and Caries Classification", |
|
) |
|
examples = gr.Examples( |
|
inputs=input_image, |
|
examples=self.preloaded_examples["Calculus and Caries Classification"], |
|
) |
|
with gr.Column(): |
|
result = gr.Image(label="Result", elem_classes="image-preview") |
|
run_button = gr.Button("Run", elem_classes="gr-button") |
|
|
|
model_name.change( |
|
fn=update_examples, |
|
inputs=model_name, |
|
outputs=examples.dataset, |
|
) |
|
|
|
run_button.click( |
|
fn=process_image, |
|
inputs=[input_image, model_name], |
|
outputs=result, |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
interface = GradioInterface() |
|
demo = interface.create_interface() |
|
demo.launch(share=False) |
|
|
|
if __name__ == "__main__": |
|
main() |