Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import cv2 # Add OpenCV to handle image processing | |
import sys | |
import os | |
import time | |
import threading | |
import logging | |
# Set up logging configuration | |
logging.basicConfig(level=logging.INFO) | |
# Import models | |
from yolo import FractureDetector | |
from llama import generate_response_based_on_yolo | |
# Define model path and output folder relative to the app's directory | |
MODEL_PATH = 'yolov8n_custom_exported.pt' # Update this to the relative path of your model on Hugging Face Spaces | |
OUTPUT_FOLDER = 'output_images' # Use a relative path to avoid permission issues | |
# Ensure the output folder exists | |
os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
# Initialize the fracture detector with both parameters | |
detector = FractureDetector(MODEL_PATH, OUTPUT_FOLDER) | |
def delete_file_after_delay(file_path, delay): | |
"""Delete the specified file after a given delay.""" | |
def delete_file(): | |
time.sleep(delay) | |
try: | |
os.remove(file_path) | |
logging.info(f"Temporary file {file_path} has been deleted.") | |
except Exception as e: | |
logging.error(f"Error deleting temporary file: {e}") | |
thread = threading.Thread(target=delete_file) | |
thread.start() | |
def mark_fracture_area(image, detections): | |
"""Draw bounding boxes on the image based on detected fractures.""" | |
draw = ImageDraw.Draw(image) | |
for detection in detections: | |
x1, y1, x2, y2 = map(int, detection['coordinates']) | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
return image | |
def analyze_image(input_image): | |
"""Analyze the uploaded image for fractures and generate a report.""" | |
logging.info("Starting analysis on uploaded image.") | |
# Save the uploaded image to a temporary location | |
temp_image_path = os.path.join(OUTPUT_FOLDER, 'temp_uploaded_image.jpg') | |
input_image.save(temp_image_path) | |
try: | |
logging.info("Performing fracture detection.") | |
# Convert PIL image to NumPy array for OpenCV | |
img_array = np.array(input_image) # Convert PIL image to NumPy array | |
if len(img_array.shape) == 2: # Grayscale image | |
img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR) | |
elif img_array.shape[2] == 4: # RGBA image | |
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR) | |
else: | |
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) | |
# Perform fracture detection | |
detections = detector.detect_fractures(img_array, conf_threshold=0.25) | |
# Mark the fracture areas on the image | |
marked_image = mark_fracture_area(input_image.copy(), detections) | |
# Generate analysis report | |
analysis_report = generate_response_based_on_yolo(detections) | |
# Schedule deletion of the temporary file after 2 minutes | |
delete_file_after_delay(temp_image_path, 120) | |
logging.info("Analysis completed successfully.") | |
return marked_image, analysis_report | |
except Exception as e: | |
logging.error(f"An error occurred during analysis: {str(e)}") | |
return input_image, f"An error occurred during analysis: {str(e)}" | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=analyze_image, | |
inputs=gr.Image(type="pil", label="Upload X-ray Image"), | |
outputs=[ | |
gr.Image(type="pil", label="Output Image with Marked Fractures"), | |
gr.Textbox(label="Fracture Analysis Report") | |
], | |
title="Fracture Detection and Analysis System", | |
description="Upload an X-ray image to detect fractures, view marked areas, and receive a detailed analysis report." | |
) | |
# Launch the Gradio interface | |
if __name__ == "__main__": | |
iface.launch() | |