Spaces:
Sleeping
Sleeping
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import cv2 | |
import tensorflow as tf | |
import gradio as gr | |
import io | |
def load_model(model_path): | |
model = tf.keras.models.load_model(model_path) | |
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy']) | |
return model | |
def get_model_summary(model): | |
stream = io.StringIO() | |
model.summary(print_fn=lambda x: stream.write(x + "\n")) | |
summary_str = stream.getvalue() | |
stream.close() | |
return summary_str | |
def get_input_shape(model): | |
input_shape = model.input_shape[1:] # Skip the batch dimension | |
return input_shape | |
def preprocess_image(image, input_shape): | |
img = np.array(image) | |
num_channels = input_shape[-1] | |
if num_channels == 1: # Model expects grayscale | |
if len(img.shape) == 2: # Image is already grayscale | |
img = np.expand_dims(img, axis=-1) | |
elif img.shape[2] == 3: # Convert RGB to grayscale | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
img = np.expand_dims(img, axis=-1) | |
elif num_channels == 3: # Model expects RGB | |
if len(img.shape) == 2: # Convert grayscale to RGB | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
elif img.shape[2] == 1: # Convert single channel to RGB | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
img_resized = cv2.resize(img, (input_shape[0], input_shape[1])) | |
img_normalized = img_resized / 255.0 | |
img_batch = np.expand_dims(img_normalized, axis=0) | |
return img_batch | |
def diagnose_image(image, model, input_shape): | |
img_batch = preprocess_image(image, input_shape) | |
prediction = model.predict(img_batch) | |
glaucoma_probability = prediction[0][0] | |
result_text = f"Probability of glaucoma: {glaucoma_probability:.2%}" | |
img_display = np.array(image) | |
if img_display.shape[2] == 1: # Convert to RGB for display | |
img_display = cv2.cvtColor(img_display.squeeze(), cv2.COLOR_GRAY2RGB) | |
image_pil = Image.fromarray(img_display) | |
draw = ImageDraw.Draw(image_pil) | |
font = ImageFont.load_default() | |
text = f"{glaucoma_probability:.2%}" | |
text_bbox = draw.textbbox((0, 0), text, font=font) | |
text_size = (text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]) | |
rect_width = 200 | |
rect_height = 100 | |
rect_x = (image_pil.width - rect_width) // 2 | |
rect_y = (image_pil.height - rect_height) // 2 | |
draw.rectangle([rect_x, rect_y, rect_x + rect_width, rect_y + rect_height], outline="red", width=3) | |
text_x = rect_x + (rect_width - text_size[0]) // 2 | |
text_y = rect_y + (rect_height - text_size[1]) // 2 | |
draw.text((text_x, text_y), text, fill="red", font=font) | |
return image_pil, result_text | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Glaucoma Detection App") | |
gr.Markdown("Upload an fundus eye image to detect the probability of glaucoma.") | |
with gr.Row(): | |
model_file = gr.File(label="Upload Model (.h5 or .keras)") | |
load_model_btn = gr.Button("Load Model") | |
model_info = gr.Markdown() | |
image = gr.Image(type="pil", label="Upload Image") | |
submit_btn = gr.Button("Diagnose") | |
result = gr.Textbox(label="Diagnosis Result") | |
def load_and_display_model_info(file): | |
model = load_model(file.name) | |
model_summary = get_model_summary(model) | |
input_shape = get_input_shape(model) | |
return model, model_summary, input_shape | |
model = gr.State(None) | |
input_shape = gr.State(None) | |
def diagnose_and_display(image, model, input_shape): | |
return diagnose_image(image, model, input_shape) | |
load_model_btn.click(fn=load_and_display_model_info, inputs=model_file, outputs=[model, model_info, input_shape]) | |
submit_btn.click(fn=diagnose_and_display, inputs=[image, model, input_shape], outputs=[image, result]) | |
gr.Markdown("### Glaucoma Analyzer V.1.0.0 by Thariq Arian") | |
demo.launch() | |
if __name__ == "__main__": | |
main() |