|
import gradio as gr |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import sqlite3 |
|
import base64 |
|
from io import BytesIO |
|
import tempfile |
|
import pandas as pd |
|
|
|
|
|
model = YOLO("best.pt") |
|
|
|
def predict_image(input_image, name, patient_id): |
|
if input_image is None: |
|
return None, "Please Input The Image" |
|
|
|
|
|
image_np = np.array(input_image) |
|
|
|
|
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
results = model(image_np) |
|
|
|
|
|
image_with_boxes = image_np.copy() |
|
raw_predictions = [] |
|
|
|
if results[0].boxes: |
|
|
|
highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) |
|
|
|
|
|
class_index = highest_confidence_result.cls.item() |
|
if class_index == 0: |
|
label = "Immature" |
|
color = (0, 255, 255) |
|
elif class_index == 1: |
|
label = "Mature" |
|
color = (255, 0, 0) |
|
else: |
|
label = "Normal" |
|
color = (0, 255, 0) |
|
|
|
confidence = highest_confidence_result.conf.item() |
|
xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) |
|
|
|
|
|
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), color, 2) |
|
|
|
|
|
font_scale = 1.0 |
|
thickness = 2 |
|
|
|
|
|
(text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) |
|
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) |
|
|
|
|
|
cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) |
|
|
|
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") |
|
|
|
raw_predictions_str = "\n".join(raw_predictions) |
|
|
|
|
|
pil_image_with_boxes = Image.fromarray(image_with_boxes) |
|
|
|
|
|
pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, patient_id, label) |
|
|
|
return pil_image_with_boxes, raw_predictions_str |
|
|
|
|
|
def add_watermark(image): |
|
try: |
|
logo = Image.open('image-logo.png').convert("RGBA") |
|
image = image.convert("RGBA") |
|
|
|
|
|
basewidth = 100 |
|
wpercent = (basewidth / float(logo.size[0])) |
|
hsize = int((float(wpercent) * logo.size[1])) |
|
logo = logo.resize((basewidth, hsize), Image.LANCZOS) |
|
|
|
|
|
position = (image.width - logo.width - 10, image.height - logo.height - 10) |
|
|
|
|
|
transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) |
|
transparent.paste(image, (0, 0)) |
|
transparent.paste(logo, position, mask=logo) |
|
|
|
return transparent.convert("RGB") |
|
except Exception as e: |
|
print(f"Error adding watermark: {e}") |
|
return image |
|
|
|
|
|
def add_text_and_watermark(image, name, patient_id, label): |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
font_size = 48 |
|
try: |
|
font = ImageFont.truetype("font.ttf", size=font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
print("Error: cannot open resource, using default font.") |
|
|
|
text = f"Name: {name}, ID: {patient_id}, Result: {label}" |
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), text, font=font) |
|
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] |
|
text_x = 20 |
|
text_y = 40 |
|
padding = 10 |
|
|
|
|
|
draw.rectangle( |
|
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], |
|
fill="black" |
|
) |
|
|
|
|
|
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) |
|
|
|
|
|
image_with_watermark = add_watermark(image) |
|
|
|
return image_with_watermark |
|
|
|
|
|
def init_db(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute('''CREATE TABLE IF NOT EXISTS results |
|
(id INTEGER PRIMARY KEY, name TEXT, patient_id TEXT, input_image BLOB, predicted_image BLOB, result TEXT)''') |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
def submit_result(name, patient_id, input_image, predicted_image, result): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
|
|
input_image_np = np.array(input_image) |
|
_, input_buffer = cv2.imencode('.png', cv2.cvtColor(input_image_np, cv2.COLOR_RGB2BGR)) |
|
input_image_bytes = input_buffer.tobytes() |
|
|
|
predicted_image_np = np.array(predicted_image) |
|
predicted_image_rgb = cv2.cvtColor(predicted_image_np, cv2.COLOR_RGB2BGR) |
|
_, predicted_buffer = cv2.imencode('.png', predicted_image_rgb) |
|
predicted_image_bytes = predicted_buffer.tobytes() |
|
|
|
c.execute("INSERT INTO results (name, patient_id, input_image, predicted_image, result) VALUES (?, ?, ?, ?, ?)", |
|
(name, patient_id, input_image_bytes, predicted_image_bytes, result)) |
|
conn.commit() |
|
conn.close() |
|
return "Result submitted to database." |
|
|
|
|
|
def view_database(): |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT name, patient_id FROM results") |
|
rows = c.fetchall() |
|
conn.close() |
|
|
|
|
|
df = pd.DataFrame(rows, columns=["Name", "Patient ID"]) |
|
|
|
return df |
|
|
|
|
|
def download_file(choice): |
|
if choice == "Database (.db)": |
|
|
|
return 'results.db' |
|
elif choice == "Database (.html)": |
|
df = view_database() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as temp_file: |
|
df.to_html(temp_file.name) |
|
return temp_file.name |
|
else: |
|
conn = sqlite3.connect('results.db') |
|
c = conn.cursor() |
|
c.execute("SELECT predicted_image FROM results ORDER BY id DESC LIMIT 1") |
|
row = c.fetchone() |
|
conn.close() |
|
if row: |
|
image_bytes = row[0] |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
|
temp_file.write(image_bytes) |
|
temp_file.flush() |
|
return temp_file.name |
|
else: |
|
raise FileNotFoundError("No images found in the database.") |
|
|
|
|
|
init_db() |
|
|
|
|
|
def interface(name, patient_id, input_image): |
|
if input_image is None: |
|
return "Please upload an image." |
|
|
|
output_image, raw_result = predict_image(input_image, name, patient_id) |
|
submit_status = submit_result(name, patient_id, input_image, output_image, raw_result) |
|
return output_image, submit_status |
|
|
|
inputs = [ |
|
gr.Textbox(label="Name"), |
|
gr.Textbox(label="Patient ID"), |
|
gr.Image(type="pil", label="Input Image") |
|
] |
|
|
|
outputs = [ |
|
gr.Image(label="Output Image"), |
|
gr.Textbox(label="Status") |
|
] |
|
|
|
|
|
download_inputs = gr.Radio(["Database (.db)", "Database (.html)", "Image (.png)"], label="Download Type") |
|
download_output = gr.File(label="Download File") |
|
|
|
app = gr.Interface( |
|
fn=interface, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="AI Cataract Detector", |
|
description="Upload an image, enter the patient's name and ID, and receive a prediction." |
|
) |
|
|
|
download_app = gr.Interface( |
|
fn=download_file, |
|
inputs=download_inputs, |
|
outputs=download_output, |
|
title="Download Results" |
|
) |
|
|
|
|
|
gr.TabbedInterface([app, download_app], ["Prediction", "Download"]).launch() |