Fracture_AI / app.py
Samanta Das
Update app.py
5bd43cd verified
import gradio as gr
from PIL import Image, ImageDraw
import numpy as np
import cv2 # Add OpenCV to handle image processing
import sys
import os
import time
import threading
import logging
# Set up logging configuration
logging.basicConfig(level=logging.INFO)
# Import models
from yolo import FractureDetector
from llama import generate_response_based_on_yolo
# Define model path and output folder relative to the app's directory
MODEL_PATH = 'yolov8n_custom_exported.pt' # Update this to the relative path of your model on Hugging Face Spaces
OUTPUT_FOLDER = 'output_images' # Use a relative path to avoid permission issues
# Ensure the output folder exists
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Initialize the fracture detector with both parameters
detector = FractureDetector(MODEL_PATH, OUTPUT_FOLDER)
def delete_file_after_delay(file_path, delay):
"""Delete the specified file after a given delay."""
def delete_file():
time.sleep(delay)
try:
os.remove(file_path)
logging.info(f"Temporary file {file_path} has been deleted.")
except Exception as e:
logging.error(f"Error deleting temporary file: {e}")
thread = threading.Thread(target=delete_file)
thread.start()
def mark_fracture_area(image, detections):
"""Draw bounding boxes on the image based on detected fractures."""
draw = ImageDraw.Draw(image)
for detection in detections:
x1, y1, x2, y2 = map(int, detection['coordinates'])
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
return image
def analyze_image(input_image):
"""Analyze the uploaded image for fractures and generate a report."""
logging.info("Starting analysis on uploaded image.")
# Save the uploaded image to a temporary location
temp_image_path = os.path.join(OUTPUT_FOLDER, 'temp_uploaded_image.jpg')
input_image.save(temp_image_path)
try:
logging.info("Performing fracture detection.")
# Convert PIL image to NumPy array for OpenCV
img_array = np.array(input_image) # Convert PIL image to NumPy array
if len(img_array.shape) == 2: # Grayscale image
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR)
elif img_array.shape[2] == 4: # RGBA image
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
else:
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
# Perform fracture detection
detections = detector.detect_fractures(img_array, conf_threshold=0.25)
# Mark the fracture areas on the image
marked_image = mark_fracture_area(input_image.copy(), detections)
# Generate analysis report
analysis_report = generate_response_based_on_yolo(detections)
# Schedule deletion of the temporary file after 2 minutes
delete_file_after_delay(temp_image_path, 120)
logging.info("Analysis completed successfully.")
return marked_image, analysis_report
except Exception as e:
logging.error(f"An error occurred during analysis: {str(e)}")
return input_image, f"An error occurred during analysis: {str(e)}"
# Define the Gradio interface
iface = gr.Interface(
fn=analyze_image,
inputs=gr.Image(type="pil", label="Upload X-ray Image"),
outputs=[
gr.Image(type="pil", label="Output Image with Marked Fractures"),
gr.Textbox(label="Fracture Analysis Report")
],
title="Fracture Detection and Analysis System",
description="Upload an X-ray image to detect fractures, view marked areas, and receive a detailed analysis report."
)
# Launch the Gradio interface
if __name__ == "__main__":
iface.launch()