yomna-ashraf's picture
Update app.py
e299395 verified
from flask import Flask, request, jsonify
from transformers import pipeline
from PIL import Image
import io
import fitz # PyMuPDF
import os
from werkzeug.utils import secure_filename
app = Flask(__name__)
# Load model and processor using pipeline
model_name = "AsmaaElnagger/Diabetic_RetinoPathy_detection"
classifier = pipeline("image-classification", model=model_name)
# PDF to image conversion
def pdf_to_images_pymupdf(pdf_data):
try:
pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
images = []
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
pix = page.get_pixmap()
img_data = pix.tobytes("jpeg")
images.append(img_data)
return images
except Exception as e:
print(f"Error converting PDF: {e}")
return None
# File classification function (modified for API)
def classify_file(file_path):
try:
file_ext = os.path.splitext(file_path)[-1].lower()
if file_ext in ['.jpg', '.jpeg', '.png', '.gif']:
# Handle image upload
image = Image.open(file_path).convert("RGB")
result = classifier(image)[0] # Get the top prediction
return {
"prediction": result["label"],
"confidence": result["score"] * 100,
}
elif file_ext == '.pdf':
# Handle PDF upload
with open(file_path, "rb") as f:
pdf_data = f.read()
images = pdf_to_images_pymupdf(pdf_data)
if images:
image = Image.open(io.BytesIO(images[0])).convert("RGB")
result = classifier(image)[0] # Get the top prediction
return {
"prediction": result["label"],
"confidence": result["score"] * 100,
}
else:
return {"error": "PDF conversion failed."}
else:
return {"error": "Unsupported file type."}
except Exception as e:
return {"error": f"An error occurred: {e}"}
# API endpoint for file classification
@app.route('/classify', methods=['POST'])
def classify():
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No file selected"}), 400
filename = secure_filename(file.filename)
filepath = os.path.join('/tmp', filename) # Save to a temporary location
file.save(filepath)
result = classify_file(filepath)
os.remove(filepath) # remove temp file
return jsonify(result), 200 # Return JSON response
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)