Spaces:
Sleeping
Sleeping
import logging | |
import numpy as np | |
import os | |
import time | |
from datetime import datetime, timedelta | |
from ultralytics import YOLO | |
import cv2 | |
import torch | |
import gradio as gr # Correctly import Gradio here | |
from gradio import Interface # Import Interface here | |
from io import BytesIO | |
import threading | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
filename='fracture_detection.log' | |
) | |
class FractureDetector: | |
def __init__(self, model_path: str, output_folder: str, file_lifetime_minutes: int = 3): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.output_folder = output_folder | |
self.file_lifetime = timedelta(minutes=file_lifetime_minutes) # Set file lifetime | |
# Ensure the output folder exists | |
os.makedirs(self.output_folder, exist_ok=True) | |
try: | |
self.model = YOLO(model_path) | |
self.model.to(self.device) | |
logging.info(f"Model loaded successfully from {model_path}") | |
except Exception as e: | |
logging.error(f"Error loading model: {e}") | |
raise | |
# Start the cleanup thread | |
self.cleanup_thread = threading.Thread(target=self.cleanup_old_files) | |
self.cleanup_thread.daemon = True | |
self.cleanup_thread.start() | |
def preprocess_image(self, image): | |
"""Preprocess the input image for fracture detection.""" | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Convert from RGB to BGR | |
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
img_clahe = clahe.apply(img_gray) | |
# Slight Gaussian blur to reduce noise | |
img_blurred = cv2.GaussianBlur(img_clahe, (3, 3), 0) | |
return cv2.merge([img_blurred, img_blurred, img_blurred]), img | |
def detect_fractures(self, image, conf_threshold: float = 0.25): | |
"""Detect fractures in the given image.""" | |
try: | |
processed_image, original_image = self.preprocess_image(image) | |
results = self.model(processed_image, device=self.device, conf=conf_threshold) | |
detections = [] | |
for r in results: | |
boxes = r.boxes | |
for box in boxes: | |
x1, y1, x2, y2 = box.xyxy[0] | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
class_name = self.model.names[cls] | |
detection = { | |
'coordinates': [float(x1), float(y1), float(x2), float(y2)], | |
'class': class_name, | |
'confidence': conf, | |
'name': f"{class_name} fracture" | |
} | |
detections.append(detection) | |
if detections: | |
self._save_visualization(original_image, detections) | |
return detections | |
except Exception as e: | |
logging.error(f"Error in fracture detection: {e}") | |
raise | |
def _save_visualization(self, image, detections): | |
"""Save visualization of the detection results.""" | |
img_viz = image.copy() | |
for det in detections: | |
x1, y1, x2, y2 = map(int, det['coordinates']) | |
cv2.rectangle(img_viz, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
label = f"{det['name']} ({det['confidence']:.2f})" | |
cv2.putText(img_viz, label, (x1, y1 - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
save_path = os.path.join(self.output_folder, f"detection_result_{timestamp}.jpg") | |
cv2.imwrite(save_path, img_viz) | |
logging.info(f"Visualization saved to {save_path}") | |
# Return the image in BytesIO format for Gradio | |
img_byte_arr = BytesIO() | |
_, img_encoded = cv2.imencode('.jpg', img_viz, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) | |
img_byte_arr.write(img_encoded) | |
img_byte_arr.seek(0) # Go to the beginning of the BytesIO buffer | |
return img_byte_arr.getvalue() # Return byte content of the image | |
def cleanup_old_files(self): | |
"""Periodically remove files older than the specified lifetime.""" | |
while True: | |
current_time = datetime.now() | |
for filename in os.listdir(self.output_folder): | |
file_path = os.path.join(self.output_folder, filename) | |
if os.path.isfile(file_path): | |
file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
if current_time - file_mod_time > self.file_lifetime: | |
os.remove(file_path) | |
logging.info(f"Removed old file: {filename}") | |
time.sleep(60) # Check every minute | |
def detect_fractures_in_image(image): | |
"""Function to handle the detection and return results.""" | |
model_path = 'yolov8n_custom_exported.pt' # Model path on Hugging Face Spaces | |
output_folder = '/app/output_images' # Output folder structure | |
detector = FractureDetector(model_path, output_folder) | |
results = detector.detect_fractures(image) | |
return results | |
# Define Gradio interface | |
iface = Interface( | |
fn=detect_fractures_in_image, | |
inputs=gr.Image(type="numpy"), # Use 'gr' to access Image class | |
outputs=gr.JSON(), # Use 'gr' to access JSON class | |
title="Fracture Detection using YOLO", | |
description="Upload an X-ray image to detect fractures." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |