Spaces:
Sleeping
Sleeping
File size: 9,247 Bytes
6477883 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
import base64
from torchvision import transforms
import torch.nn.functional as F
# Load the pretrained model
@gr.utils.cache
def load_model():
"""Load the pretrained brain segmentation model"""
try:
model = torch.hub.load(
'mateuszbuda/brain-segmentation-pytorch',
'unet',
in_channels=3,
out_channels=1,
init_features=32,
pretrained=True,
force_reload=False
)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Initialize model
model = load_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model:
model = model.to(device)
def preprocess_image(image):
"""Preprocess the input image for the model"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize to 256x256 (model's expected input size)
image = image.resize((256, 256), Image.Resampling.LANCZOS)
# Convert to tensor and normalize
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
return image_tensor, image
def create_overlay_visualization(original_img, mask, alpha=0.6):
"""Create an overlay visualization of the segmentation"""
# Convert original image to numpy array
original_np = np.array(original_img)
# Create colored mask (red for tumor regions)
colored_mask = np.zeros_like(original_np)
colored_mask[:, :, 0] = mask * 255 # Red channel for tumor
# Create overlay
overlay = cv2.addWeighted(original_np, 1-alpha, colored_mask, alpha, 0)
return overlay
def predict_tumor(image):
"""Main prediction function"""
if model is None:
return None, "β Model failed to load. Please try again."
if image is None:
return None, "β οΈ Please upload an image first."
try:
# Preprocess the image
input_tensor, original_img = preprocess_image(image)
input_tensor = input_tensor.to(device)
# Make prediction
with torch.no_grad():
prediction = model(input_tensor)
# Apply sigmoid to get probability map
prediction = torch.sigmoid(prediction)
# Convert to numpy
prediction = prediction.squeeze().cpu().numpy()
# Threshold the prediction (you can adjust this threshold)
threshold = 0.5
binary_mask = (prediction > threshold).astype(np.uint8)
# Create visualizations
# 1. Original image
original_array = np.array(original_img)
# 2. Segmentation mask
mask_colored = np.zeros((256, 256, 3), dtype=np.uint8)
mask_colored[:, :, 0] = binary_mask * 255 # Red channel
# 3. Overlay
overlay = create_overlay_visualization(original_img, binary_mask, alpha=0.4)
# 4. Side-by-side comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(original_array)
axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
axes[0].axis('off')
axes[1].imshow(mask_colored)
axes[1].set_title('Tumor Segmentation', fontsize=14, fontweight='bold')
axes[1].axis('off')
axes[2].imshow(overlay)
axes[2].set_title('Overlay (Red = Tumor)', fontsize=14, fontweight='bold')
axes[2].axis('off')
plt.tight_layout()
# Save plot to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
plt.close()
# Convert to PIL Image
result_image = Image.open(buf)
# Calculate tumor statistics
total_pixels = 256 * 256
tumor_pixels = np.sum(binary_mask)
tumor_percentage = (tumor_pixels / total_pixels) * 100
# Create analysis report
analysis_text = f"""
## π§ Brain Tumor Segmentation Analysis
**π Tumor Statistics:**
- Total pixels analyzed: {total_pixels:,}
- Tumor pixels detected: {tumor_pixels:,}
- Tumor area percentage: {tumor_percentage:.2f}%
**π― Model Performance:**
- Model: U-Net with attention mechanism
- Input resolution: 256Γ256 pixels
- Detection threshold: {threshold}
**β οΈ Medical Disclaimer:**
This is an AI tool for research purposes only.
Always consult qualified medical professionals for diagnosis.
"""
return result_image, analysis_text
except Exception as e:
error_msg = f"β Error during prediction: {str(e)}"
return None, error_msg
def clear_all():
"""Clear all inputs and outputs"""
return None, None, ""
# Custom CSS for better styling
css = """
#main-container {
max-width: 1200px;
margin: 0 auto;
}
#title {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
#upload-box {
border: 2px dashed #ccc;
border-radius: 10px;
padding: 20px;
text-align: center;
margin: 10px 0;
}
.output-image {
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
"""
# Create Gradio interface
with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
# Header
gr.HTML("""
<div id="title">
<h1>π§ Brain Tumor Segmentation AI</h1>
<p>Upload an MRI brain scan to detect and visualize tumor regions using deep learning</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>π€ Input Image</h3>")
# Image input with camera option
image_input = gr.Image(
label="Upload Brain MRI Scan",
type="pil",
sources=["upload", "webcam"], # Allow both upload and camera
height=300
)
with gr.Row():
predict_btn = gr.Button("π Analyze Image", variant="primary", size="lg")
clear_btn = gr.Button("ποΈ Clear All", variant="secondary")
gr.HTML("""
<div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px;">
<h4>π Instructions:</h4>
<ul>
<li>Upload a brain MRI scan image</li>
<li>Supported formats: PNG, JPG, JPEG</li>
<li>For best results, use clear, high-contrast MRI images</li>
<li>You can also use the camera to capture an image from your device</li>
</ul>
</div>
""")
with gr.Column(scale=2):
gr.HTML("<h3>π Segmentation Results</h3>")
# Output image
output_image = gr.Image(
label="Segmentation Results",
type="pil",
height=400,
elem_classes=["output-image"]
)
# Analysis text
analysis_output = gr.Markdown(
label="Analysis Report",
value="Upload an image and click 'Analyze Image' to see results."
)
# Add footer with information
gr.HTML("""
<div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px;">
<h4>π¬ About This Tool</h4>
<p><strong>Model:</strong> Pre-trained U-Net architecture optimized for brain tumor segmentation</p>
<p><strong>Technology:</strong> PyTorch, Deep Learning, Computer Vision</p>
<p><strong>Dataset:</strong> Trained on medical MRI brain scans</p>
<h4>β οΈ Important Medical Disclaimer</h4>
<p style="color: #d73027; font-weight: bold;">
This AI tool is for research and educational purposes only. It should NOT be used for medical diagnosis.
Always consult qualified healthcare professionals for medical advice and diagnosis.
</p>
<p style="text-align: center; margin-top: 20px; color: #666;">
Made with β€οΈ using Gradio β’ Powered by PyTorch β’ Hosted on π€ Hugging Face Spaces
</p>
</div>
""")
# Event handlers
predict_btn.click(
fn=predict_tumor,
inputs=[image_input],
outputs=[output_image, analysis_output]
)
clear_btn.click(
fn=clear_all,
outputs=[image_input, output_image, analysis_output]
)
# Auto-predict when image is uploaded
image_input.change(
fn=predict_tumor,
inputs=[image_input],
outputs=[output_image, analysis_output]
)
# Launch the app
if __name__ == "__main__":
app.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
|