Fracture_AI / yolo.py
Samanta Das
Update yolo.py
add3d03 verified
import logging
import numpy as np
import os
import time
from datetime import datetime, timedelta
from ultralytics import YOLO
import cv2
import torch
import gradio as gr # Correctly import Gradio here
from gradio import Interface # Import Interface here
from io import BytesIO
import threading
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='fracture_detection.log'
)
class FractureDetector:
def __init__(self, model_path: str, output_folder: str, file_lifetime_minutes: int = 3):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.output_folder = output_folder
self.file_lifetime = timedelta(minutes=file_lifetime_minutes) # Set file lifetime
# Ensure the output folder exists
os.makedirs(self.output_folder, exist_ok=True)
try:
self.model = YOLO(model_path)
self.model.to(self.device)
logging.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# Start the cleanup thread
self.cleanup_thread = threading.Thread(target=self.cleanup_old_files)
self.cleanup_thread.daemon = True
self.cleanup_thread.start()
def preprocess_image(self, image):
"""Preprocess the input image for fracture detection."""
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Convert from RGB to BGR
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
img_clahe = clahe.apply(img_gray)
# Slight Gaussian blur to reduce noise
img_blurred = cv2.GaussianBlur(img_clahe, (3, 3), 0)
return cv2.merge([img_blurred, img_blurred, img_blurred]), img
def detect_fractures(self, image, conf_threshold: float = 0.25):
"""Detect fractures in the given image."""
try:
processed_image, original_image = self.preprocess_image(image)
results = self.model(processed_image, device=self.device, conf=conf_threshold)
detections = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0]
conf = float(box.conf[0])
cls = int(box.cls[0])
class_name = self.model.names[cls]
detection = {
'coordinates': [float(x1), float(y1), float(x2), float(y2)],
'class': class_name,
'confidence': conf,
'name': f"{class_name} fracture"
}
detections.append(detection)
if detections:
self._save_visualization(original_image, detections)
return detections
except Exception as e:
logging.error(f"Error in fracture detection: {e}")
raise
def _save_visualization(self, image, detections):
"""Save visualization of the detection results."""
img_viz = image.copy()
for det in detections:
x1, y1, x2, y2 = map(int, det['coordinates'])
cv2.rectangle(img_viz, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{det['name']} ({det['confidence']:.2f})"
cv2.putText(img_viz, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
save_path = os.path.join(self.output_folder, f"detection_result_{timestamp}.jpg")
cv2.imwrite(save_path, img_viz)
logging.info(f"Visualization saved to {save_path}")
# Return the image in BytesIO format for Gradio
img_byte_arr = BytesIO()
_, img_encoded = cv2.imencode('.jpg', img_viz, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
img_byte_arr.write(img_encoded)
img_byte_arr.seek(0) # Go to the beginning of the BytesIO buffer
return img_byte_arr.getvalue() # Return byte content of the image
def cleanup_old_files(self):
"""Periodically remove files older than the specified lifetime."""
while True:
current_time = datetime.now()
for filename in os.listdir(self.output_folder):
file_path = os.path.join(self.output_folder, filename)
if os.path.isfile(file_path):
file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
if current_time - file_mod_time > self.file_lifetime:
os.remove(file_path)
logging.info(f"Removed old file: {filename}")
time.sleep(60) # Check every minute
def detect_fractures_in_image(image):
"""Function to handle the detection and return results."""
model_path = 'yolov8n_custom_exported.pt' # Model path on Hugging Face Spaces
output_folder = '/app/output_images' # Output folder structure
detector = FractureDetector(model_path, output_folder)
results = detector.detect_fractures(image)
return results
# Define Gradio interface
iface = Interface(
fn=detect_fractures_in_image,
inputs=gr.Image(type="numpy"), # Use 'gr' to access Image class
outputs=gr.JSON(), # Use 'gr' to access JSON class
title="Fracture Detection using YOLO",
description="Upload an X-ray image to detect fractures."
)
if __name__ == "__main__":
iface.launch()