File size: 5,776 Bytes
61db09b
9396a3e
61db09b
9396a3e
 
 
 
 
b283c7d
908304b
9396a3e
 
61db09b
 
 
 
 
 
 
 
 
9396a3e
61db09b
 
9396a3e
61db09b
 
 
 
 
 
 
 
 
 
 
 
9396a3e
 
 
 
 
61db09b
9396a3e
 
61db09b
 
 
9396a3e
61db09b
 
 
 
 
 
 
 
9396a3e
61db09b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9396a3e
61db09b
 
 
 
 
 
9396a3e
 
61db09b
 
 
 
 
 
9396a3e
 
7ff4307
 
9396a3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61db09b
9396a3e
add3d03
9396a3e
61db09b
 
 
 
 
 
 
7ff4307
 
61db09b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
import numpy as np
import os
import time
from datetime import datetime, timedelta
from ultralytics import YOLO
import cv2
import torch
import gradio as gr  # Correctly import Gradio here
from gradio import Interface  # Import Interface here
from io import BytesIO
import threading

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='fracture_detection.log'
)

class FractureDetector:
    def __init__(self, model_path: str, output_folder: str, file_lifetime_minutes: int = 3):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.output_folder = output_folder
        self.file_lifetime = timedelta(minutes=file_lifetime_minutes)  # Set file lifetime
        
        # Ensure the output folder exists
        os.makedirs(self.output_folder, exist_ok=True)
        
        try:
            self.model = YOLO(model_path)
            self.model.to(self.device)
            logging.info(f"Model loaded successfully from {model_path}")
        except Exception as e:
            logging.error(f"Error loading model: {e}")
            raise

        # Start the cleanup thread
        self.cleanup_thread = threading.Thread(target=self.cleanup_old_files)
        self.cleanup_thread.daemon = True
        self.cleanup_thread.start()

    def preprocess_image(self, image):
        """Preprocess the input image for fracture detection."""
        img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)  # Convert from RGB to BGR
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img_clahe = clahe.apply(img_gray)
        
        # Slight Gaussian blur to reduce noise
        img_blurred = cv2.GaussianBlur(img_clahe, (3, 3), 0)
        
        return cv2.merge([img_blurred, img_blurred, img_blurred]), img

    def detect_fractures(self, image, conf_threshold: float = 0.25):
        """Detect fractures in the given image."""
        try:
            processed_image, original_image = self.preprocess_image(image)
            results = self.model(processed_image, device=self.device, conf=conf_threshold)
            detections = []
            
            for r in results:
                boxes = r.boxes
                for box in boxes:
                    x1, y1, x2, y2 = box.xyxy[0]
                    conf = float(box.conf[0])
                    cls = int(box.cls[0])
                    class_name = self.model.names[cls]
                    
                    detection = {
                        'coordinates': [float(x1), float(y1), float(x2), float(y2)],
                        'class': class_name,
                        'confidence': conf,
                        'name': f"{class_name} fracture"
                    }
                    detections.append(detection)

            if detections:
                self._save_visualization(original_image, detections)

            return detections

        except Exception as e:
            logging.error(f"Error in fracture detection: {e}")
            raise

    def _save_visualization(self, image, detections):
        """Save visualization of the detection results."""
        img_viz = image.copy()
        
        for det in detections:
            x1, y1, x2, y2 = map(int, det['coordinates'])
            cv2.rectangle(img_viz, (x1, y1), (x2, y2), (0, 255, 0), 2)
            label = f"{det['name']} ({det['confidence']:.2f})"
            cv2.putText(img_viz, label, (x1, y1 - 10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        save_path = os.path.join(self.output_folder, f"detection_result_{timestamp}.jpg")
        cv2.imwrite(save_path, img_viz)
        logging.info(f"Visualization saved to {save_path}")

        # Return the image in BytesIO format for Gradio
        img_byte_arr = BytesIO()
        _, img_encoded = cv2.imencode('.jpg', img_viz, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
        img_byte_arr.write(img_encoded)
        img_byte_arr.seek(0)  # Go to the beginning of the BytesIO buffer
        return img_byte_arr.getvalue()  # Return byte content of the image

    def cleanup_old_files(self):
        """Periodically remove files older than the specified lifetime."""
        while True:
            current_time = datetime.now()
            for filename in os.listdir(self.output_folder):
                file_path = os.path.join(self.output_folder, filename)
                if os.path.isfile(file_path):
                    file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
                    if current_time - file_mod_time > self.file_lifetime:
                        os.remove(file_path)
                        logging.info(f"Removed old file: {filename}")
            time.sleep(60)  # Check every minute

def detect_fractures_in_image(image):
    """Function to handle the detection and return results."""
    model_path = 'yolov8n_custom_exported.pt'  # Model path on Hugging Face Spaces
    output_folder = '/app/output_images'  # Output folder structure
    detector = FractureDetector(model_path, output_folder)
    results = detector.detect_fractures(image)
    return results

# Define Gradio interface
iface = Interface(
    fn=detect_fractures_in_image,
    inputs=gr.Image(type="numpy"),  # Use 'gr' to access Image class
    outputs=gr.JSON(),  # Use 'gr' to access JSON class
    title="Fracture Detection using YOLO",
    description="Upload an X-ray image to detect fractures."
)

if __name__ == "__main__":
    iface.launch()