import logging import numpy as np import os import time from datetime import datetime, timedelta from ultralytics import YOLO import cv2 import torch import gradio as gr # Correctly import Gradio here from gradio import Interface # Import Interface here from io import BytesIO import threading # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='fracture_detection.log' ) class FractureDetector: def __init__(self, model_path: str, output_folder: str, file_lifetime_minutes: int = 3): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.output_folder = output_folder self.file_lifetime = timedelta(minutes=file_lifetime_minutes) # Set file lifetime # Ensure the output folder exists os.makedirs(self.output_folder, exist_ok=True) try: self.model = YOLO(model_path) self.model.to(self.device) logging.info(f"Model loaded successfully from {model_path}") except Exception as e: logging.error(f"Error loading model: {e}") raise # Start the cleanup thread self.cleanup_thread = threading.Thread(target=self.cleanup_old_files) self.cleanup_thread.daemon = True self.cleanup_thread.start() def preprocess_image(self, image): """Preprocess the input image for fracture detection.""" img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Convert from RGB to BGR img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) img_clahe = clahe.apply(img_gray) # Slight Gaussian blur to reduce noise img_blurred = cv2.GaussianBlur(img_clahe, (3, 3), 0) return cv2.merge([img_blurred, img_blurred, img_blurred]), img def detect_fractures(self, image, conf_threshold: float = 0.25): """Detect fractures in the given image.""" try: processed_image, original_image = self.preprocess_image(image) results = self.model(processed_image, device=self.device, conf=conf_threshold) detections = [] for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = box.xyxy[0] conf = float(box.conf[0]) cls = int(box.cls[0]) class_name = self.model.names[cls] detection = { 'coordinates': [float(x1), float(y1), float(x2), float(y2)], 'class': class_name, 'confidence': conf, 'name': f"{class_name} fracture" } detections.append(detection) if detections: self._save_visualization(original_image, detections) return detections except Exception as e: logging.error(f"Error in fracture detection: {e}") raise def _save_visualization(self, image, detections): """Save visualization of the detection results.""" img_viz = image.copy() for det in detections: x1, y1, x2, y2 = map(int, det['coordinates']) cv2.rectangle(img_viz, (x1, y1), (x2, y2), (0, 255, 0), 2) label = f"{det['name']} ({det['confidence']:.2f})" cv2.putText(img_viz, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') save_path = os.path.join(self.output_folder, f"detection_result_{timestamp}.jpg") cv2.imwrite(save_path, img_viz) logging.info(f"Visualization saved to {save_path}") # Return the image in BytesIO format for Gradio img_byte_arr = BytesIO() _, img_encoded = cv2.imencode('.jpg', img_viz, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) img_byte_arr.write(img_encoded) img_byte_arr.seek(0) # Go to the beginning of the BytesIO buffer return img_byte_arr.getvalue() # Return byte content of the image def cleanup_old_files(self): """Periodically remove files older than the specified lifetime.""" while True: current_time = datetime.now() for filename in os.listdir(self.output_folder): file_path = os.path.join(self.output_folder, filename) if os.path.isfile(file_path): file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) if current_time - file_mod_time > self.file_lifetime: os.remove(file_path) logging.info(f"Removed old file: {filename}") time.sleep(60) # Check every minute def detect_fractures_in_image(image): """Function to handle the detection and return results.""" model_path = 'yolov8n_custom_exported.pt' # Model path on Hugging Face Spaces output_folder = '/app/output_images' # Output folder structure detector = FractureDetector(model_path, output_folder) results = detector.detect_fractures(image) return results # Define Gradio interface iface = Interface( fn=detect_fractures_in_image, inputs=gr.Image(type="numpy"), # Use 'gr' to access Image class outputs=gr.JSON(), # Use 'gr' to access JSON class title="Fracture Detection using YOLO", description="Upload an X-ray image to detect fractures." ) if __name__ == "__main__": iface.launch()