|
import gradio as gr |
|
from transformers import AutoModelForImageClassification, AutoProcessor |
|
from PIL import Image |
|
import io |
|
import fitz |
|
import os |
|
|
|
|
|
model_name = "AsmaaElnagger/Diabetic_RetinoPathy_detection" |
|
model = AutoModelForImageClassification.from_pretrained(model_name) |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
|
|
def pdf_to_images_pymupdf(pdf_data): |
|
try: |
|
pdf_document = fitz.open(stream=pdf_data, filetype="pdf") |
|
images = [] |
|
for page_num in range(pdf_document.page_count): |
|
page = pdf_document.load_page(page_num) |
|
pix = page.get_pixmap() |
|
img_data = pix.tobytes("jpeg") |
|
images.append(img_data) |
|
return images |
|
except Exception as e: |
|
print(f"Error converting PDF: {e}") |
|
return None |
|
|
|
|
|
def classify_file(file_path): |
|
try: |
|
file_ext = os.path.splitext(file_path)[-1].lower() |
|
|
|
if file_ext in ['.jpg', '.jpeg', '.png', '.gif']: |
|
|
|
image = Image.open(file_path).convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
result = model.config.id2label[predicted_class_idx] |
|
return result |
|
|
|
elif file_ext == '.pdf': |
|
|
|
with open(file_path, "rb") as f: |
|
pdf_data = f.read() |
|
images = pdf_to_images_pymupdf(pdf_data) |
|
|
|
if images: |
|
image = Image.open(io.BytesIO(images[0])).convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
result = model.config.id2label[predicted_class_idx] |
|
return result |
|
else: |
|
return "PDF conversion failed." |
|
|
|
else: |
|
return "Unsupported file type." |
|
|
|
except Exception as e: |
|
return f"An error occurred: {e}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_file, |
|
inputs=gr.File(label="Upload PDF or Image"), |
|
outputs="text", |
|
title="Diabetic Retinopathy Detection", |
|
description="Upload a fundus scan (image or PDF) to detect diabetic retinopathy." |
|
) |
|
|
|
|
|
demo.launch() |
|
|