Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
from torchvision import transforms | |
import torch.nn.functional as F | |
# Load the pretrained model | |
def load_model(): | |
"""Load the pretrained brain segmentation model""" | |
try: | |
model = torch.hub.load( | |
'mateuszbuda/brain-segmentation-pytorch', | |
'unet', | |
in_channels=3, | |
out_channels=1, | |
init_features=32, | |
pretrained=True, | |
force_reload=False | |
) | |
model.eval() | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None | |
# Initialize model | |
model = load_model() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if model: | |
model = model.to(device) | |
def preprocess_image(image): | |
"""Preprocess the input image for the model""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Convert to RGB if not already | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Resize to 256x256 (model's expected input size) | |
image = image.resize((256, 256), Image.Resampling.LANCZOS) | |
# Convert to tensor and normalize | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
return image_tensor, image | |
def create_overlay_visualization(original_img, mask, alpha=0.6): | |
"""Create an overlay visualization of the segmentation""" | |
# Convert original image to numpy array | |
original_np = np.array(original_img) | |
# Create colored mask (red for tumor regions) | |
colored_mask = np.zeros_like(original_np) | |
colored_mask[:, :, 0] = mask * 255 # Red channel for tumor | |
# Create overlay | |
overlay = cv2.addWeighted(original_np, 1-alpha, colored_mask, alpha, 0) | |
return overlay | |
def predict_tumor(image): | |
"""Main prediction function""" | |
if model is None: | |
return None, "β Model failed to load. Please try again." | |
if image is None: | |
return None, "β οΈ Please upload an image first." | |
try: | |
# Preprocess the image | |
input_tensor, original_img = preprocess_image(image) | |
input_tensor = input_tensor.to(device) | |
# Make prediction | |
with torch.no_grad(): | |
prediction = model(input_tensor) | |
# Apply sigmoid to get probability map | |
prediction = torch.sigmoid(prediction) | |
# Convert to numpy | |
prediction = prediction.squeeze().cpu().numpy() | |
# Threshold the prediction (you can adjust this threshold) | |
threshold = 0.5 | |
binary_mask = (prediction > threshold).astype(np.uint8) | |
# Create visualizations | |
# 1. Original image | |
original_array = np.array(original_img) | |
# 2. Segmentation mask | |
mask_colored = np.zeros((256, 256, 3), dtype=np.uint8) | |
mask_colored[:, :, 0] = binary_mask * 255 # Red channel | |
# 3. Overlay | |
overlay = create_overlay_visualization(original_img, binary_mask, alpha=0.4) | |
# 4. Side-by-side comparison | |
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
axes[0].imshow(original_array) | |
axes[0].set_title('Original Image', fontsize=14, fontweight='bold') | |
axes[0].axis('off') | |
axes[1].imshow(mask_colored) | |
axes[1].set_title('Tumor Segmentation', fontsize=14, fontweight='bold') | |
axes[1].axis('off') | |
axes[2].imshow(overlay) | |
axes[2].set_title('Overlay (Red = Tumor)', fontsize=14, fontweight='bold') | |
axes[2].axis('off') | |
plt.tight_layout() | |
# Save plot to bytes | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
plt.close() | |
# Convert to PIL Image | |
result_image = Image.open(buf) | |
# Calculate tumor statistics | |
total_pixels = 256 * 256 | |
tumor_pixels = np.sum(binary_mask) | |
tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
# Create analysis report | |
analysis_text = f""" | |
## π§ Brain Tumor Segmentation Analysis | |
**π Tumor Statistics:** | |
- Total pixels analyzed: {total_pixels:,} | |
- Tumor pixels detected: {tumor_pixels:,} | |
- Tumor area percentage: {tumor_percentage:.2f}% | |
**π― Model Performance:** | |
- Model: U-Net with attention mechanism | |
- Input resolution: 256Γ256 pixels | |
- Detection threshold: {threshold} | |
**β οΈ Medical Disclaimer:** | |
This is an AI tool for research purposes only. | |
Always consult qualified medical professionals for diagnosis. | |
""" | |
return result_image, analysis_text | |
except Exception as e: | |
error_msg = f"β Error during prediction: {str(e)}" | |
return None, error_msg | |
def clear_all(): | |
"""Clear all inputs and outputs""" | |
return None, None, "" | |
# Custom CSS for better styling | |
css = """ | |
#main-container { | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
#title { | |
text-align: center; | |
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
#upload-box { | |
border: 2px dashed #ccc; | |
border-radius: 10px; | |
padding: 20px; | |
text-align: center; | |
margin: 10px 0; | |
} | |
.output-image { | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app: | |
# Header | |
gr.HTML(""" | |
<div id="title"> | |
<h1>π§ Brain Tumor Segmentation AI</h1> | |
<p>Upload an MRI brain scan to detect and visualize tumor regions using deep learning</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Input Image</h3>") | |
# Image input with camera option | |
image_input = gr.Image( | |
label="Upload Brain MRI Scan", | |
type="pil", | |
sources=["upload", "webcam"], # Allow both upload and camera | |
height=300 | |
) | |
with gr.Row(): | |
predict_btn = gr.Button("π Analyze Image", variant="primary", size="lg") | |
clear_btn = gr.Button("ποΈ Clear All", variant="secondary") | |
gr.HTML(""" | |
<div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px;"> | |
<h4>π Instructions:</h4> | |
<ul> | |
<li>Upload a brain MRI scan image</li> | |
<li>Supported formats: PNG, JPG, JPEG</li> | |
<li>For best results, use clear, high-contrast MRI images</li> | |
<li>You can also use the camera to capture an image from your device</li> | |
</ul> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
gr.HTML("<h3>π Segmentation Results</h3>") | |
# Output image | |
output_image = gr.Image( | |
label="Segmentation Results", | |
type="pil", | |
height=400, | |
elem_classes=["output-image"] | |
) | |
# Analysis text | |
analysis_output = gr.Markdown( | |
label="Analysis Report", | |
value="Upload an image and click 'Analyze Image' to see results." | |
) | |
# Add footer with information | |
gr.HTML(""" | |
<div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px;"> | |
<h4>π¬ About This Tool</h4> | |
<p><strong>Model:</strong> Pre-trained U-Net architecture optimized for brain tumor segmentation</p> | |
<p><strong>Technology:</strong> PyTorch, Deep Learning, Computer Vision</p> | |
<p><strong>Dataset:</strong> Trained on medical MRI brain scans</p> | |
<h4>β οΈ Important Medical Disclaimer</h4> | |
<p style="color: #d73027; font-weight: bold;"> | |
This AI tool is for research and educational purposes only. It should NOT be used for medical diagnosis. | |
Always consult qualified healthcare professionals for medical advice and diagnosis. | |
</p> | |
<p style="text-align: center; margin-top: 20px; color: #666;"> | |
Made with β€οΈ using Gradio β’ Powered by PyTorch β’ Hosted on π€ Hugging Face Spaces | |
</p> | |
</div> | |
""") | |
# Event handlers | |
predict_btn.click( | |
fn=predict_tumor, | |
inputs=[image_input], | |
outputs=[output_image, analysis_output] | |
) | |
clear_btn.click( | |
fn=clear_all, | |
outputs=[image_input, output_image, analysis_output] | |
) | |
# Auto-predict when image is uploaded | |
image_input.change( | |
fn=predict_tumor, | |
inputs=[image_input], | |
outputs=[output_image, analysis_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |