Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| import urllib.request | |
| import os | |
| import random | |
| import kagglehub | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = None | |
| # Download dataset | |
| dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') | |
| image_path = os.path.join(dataset_path, 'images') | |
| mask_path = os.path.join(dataset_path, 'masks') | |
| test_imgs = sorted([f for f in os.listdir(image_path) if f.endswith('.jpg') or f.endswith('.png')]) | |
| test_masks = sorted([f for f in os.listdir(mask_path) if f.endswith('.jpg') or f.endswith('.png')]) | |
| # Define your Attention U-Net architecture (from your training code) | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(DoubleConv, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionBlock, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi | |
| class AttentionUNET(nn.Module): | |
| def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): | |
| super(AttentionUNET, self).__init__() | |
| self.out_channels = out_channels | |
| self.ups = nn.ModuleList() | |
| self.downs = nn.ModuleList() | |
| self.attentions = nn.ModuleList() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| # Down part of UNET | |
| for feature in features: | |
| self.downs.append(DoubleConv(in_channels, feature)) | |
| in_channels = feature | |
| # Bottleneck | |
| self.bottleneck = DoubleConv(features[-1], features[-1]*2) | |
| # Up part of UNET | |
| for feature in reversed(features): | |
| self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) | |
| self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) | |
| self.ups.append(DoubleConv(feature*2, feature)) | |
| self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) | |
| def forward(self, x): | |
| skip_connections = [] | |
| for down in self.downs: | |
| x = down(x) | |
| skip_connections.append(x) | |
| x = self.pool(x) | |
| x = self.bottleneck(x) | |
| skip_connections = skip_connections[::-1] #reverse list | |
| for idx in range(0, len(self.ups), 2): #do up and double_conv | |
| x = self.ups[idx](x) | |
| skip_connection = skip_connections[idx//2] | |
| if x.shape != skip_connection.shape: | |
| x = TF.resize(x, size=skip_connection.shape[2:]) | |
| skip_connection = self.attentions[idx // 2](skip_connection, x) | |
| concat_skip = torch.cat((skip_connection, x), dim=1) | |
| x = self.ups[idx+1](concat_skip) | |
| return self.final_conv(x) | |
| def download_model(): | |
| """Download trained model from HuggingFace""" | |
| model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" | |
| model_path = "best_attention_model.pth.tar" | |
| if not os.path.exists(model_path): | |
| print("Downloading trained model...") | |
| try: | |
| urllib.request.urlretrieve(model_url, model_path) | |
| print("Model downloaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to download model: {e}") | |
| return None | |
| else: | |
| print("Model already exists!") | |
| return model_path | |
| def load_attention_model(): | |
| """Load trained Attention U-Net model""" | |
| global model | |
| if model is None: | |
| try: | |
| print("Loading trained Attention U-Net model...") | |
| # Download model if needed | |
| model_path = download_model() | |
| if model_path is None: | |
| return None | |
| # Initialize model architecture | |
| model = AttentionUNET(in_channels=1, out_channels=1).to(device) | |
| # Load trained weights | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print("Attention U-Net model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| return model | |
| def preprocess_image(image): | |
| """Preprocessing for model input""" | |
| # Convert to grayscale | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| # Apply transforms | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| return val_test_transform(image).unsqueeze(0) # Add batch dimension | |
| def predict_tumor(image, mask=None): | |
| current_model = load_attention_model() | |
| if current_model is None: | |
| return None, "Failed to load trained model." | |
| if image is None: | |
| return None, "Please upload an image first." | |
| try: | |
| print("Processing with PerceptNet Attention U-Net...") | |
| # Preprocess image | |
| input_tensor = preprocess_image(image).to(device) | |
| # Model prediction | |
| with torch.no_grad(): | |
| pred_mask = torch.sigmoid(current_model(input_tensor)) | |
| pred_mask_binary = (pred_mask > 0.5).float() | |
| # Convert to numpy | |
| pred_mask_np = pred_mask_binary.cpu().squeeze().numpy() | |
| prob_mask_np = pred_mask.cpu().squeeze().numpy() # Probability for heatmap | |
| original_np = np.array(image.convert('L').resize((256, 256))) | |
| # Create inverted mask for visualization | |
| inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) | |
| # Create tumor-only image | |
| tumor_only = np.where(pred_mask_np == 1, original_np, 255) | |
| # Handle ground truth if provided | |
| mask_np = None | |
| dice_score = None | |
| iou_score = None | |
| if mask is not None: | |
| mask_transform = transforms.Compose([ | |
| transforms.Resize((256,256)), | |
| transforms.ToTensor() | |
| ]) | |
| mask_tensor = mask_transform(mask).squeeze().numpy() | |
| mask_np = (mask_tensor > 0.5).astype(float) | |
| intersection = np.logical_and(pred_mask_np, mask_np).sum() | |
| union = np.logical_or(pred_mask_np, mask_np).sum() | |
| iou_score = intersection / (union + 1e-7) | |
| dice_score = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7) | |
| # Create visualization (5-panel layout) | |
| fig, axes = plt.subplots(1, 5, figsize=(25, 5)) | |
| fig.suptitle('PerceptNet Analysis Results', fontsize=16, fontweight='bold') | |
| titles = ["Original Image", "Ground Truth", "Predicted Mask", "Tumor Only", "Heatmap"] | |
| images = [original_np, mask_np if mask_np is not None else np.zeros_like(original_np), inv_pred_mask_np, tumor_only, prob_mask_np] | |
| cmaps = ['gray', 'gray', 'gray', 'gray', 'hot'] | |
| for i, ax in enumerate(axes): | |
| ax.imshow(images[i], cmap=cmaps[i]) | |
| ax.set_title(titles[i], fontsize=12, fontweight='bold') | |
| ax.axis('off') | |
| plt.tight_layout() | |
| # Save result | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') | |
| buf.seek(0) | |
| plt.close() | |
| result_image = Image.open(buf) | |
| # Calculate statistics | |
| tumor_pixels = np.sum(pred_mask_np) | |
| total_pixels = pred_mask_np.size | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| # Calculate confidence metrics | |
| max_confidence = torch.max(pred_mask).item() | |
| mean_confidence = torch.mean(pred_mask).item() | |
| analysis_text = f""" | |
| ## PerceptNet Analysis Results | |
| ### Detection Summary: | |
| - **Status**: {'TUMOR DETECTED' if tumor_pixels > 50 else 'NO SIGNIFICANT TUMOR'} | |
| - **Tumor Area**: {tumor_percentage:.2f}% of brain region | |
| - **Tumor Pixels**: {tumor_pixels:,} pixels | |
| - **Max Confidence**: {max_confidence:.4f} | |
| - **Mean Confidence**: {mean_confidence:.4f} | |
| """ | |
| if dice_score is not None and iou_score is not None: | |
| analysis_text += f""" | |
| - **Dice Score**: {dice_score:.4f} | |
| - **IoU Score**: {iou_score:.4f} | |
| """ | |
| analysis_text += f""" | |
| ### Model Information: | |
| - **Architecture**: PerceptNet Attention U-Net | |
| - **Training Performance**: Dice: 0.8420, IoU: 0.7297 | |
| - **Input**: Grayscale (single channel) | |
| - **Output**: Binary segmentation mask | |
| - **Device**: {device.type.upper()} | |
| ### Processing Details: | |
| - **Preprocessing**: Resize(256×256) + ToTensor | |
| - **Threshold**: 0.5 (sigmoid > 0.5) | |
| - **Architecture**: Attention gates + Skip connections | |
| - **Features**: [32, 64, 128, 256] channels | |
| ### Medical Disclaimer: | |
| This AI model is for **research and educational purposes only**. | |
| Results should be validated by medical professionals. Not for clinical diagnosis. | |
| """ | |
| print(f"Model analysis completed! Tumor area: {tumor_percentage:.2f}%") | |
| return result_image, analysis_text | |
| except Exception as e: | |
| error_msg = f"Error with model: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def load_random_sample(): | |
| if not test_imgs: | |
| return None, None, "Dataset not available." | |
| rand_idx = random.randint(0, len(test_imgs) - 1) | |
| img_path = os.path.join(image_path, test_imgs[rand_idx]) | |
| msk_path = os.path.join(mask_path, test_masks[rand_idx]) | |
| image = Image.open(img_path).convert('L') | |
| mask = Image.open(msk_path).convert('L') | |
| return image, mask, "Loaded random sample from dataset." | |
| def clear_all(): | |
| return None, None, "Upload a brain MRI image to test PerceptNet model", None | |
| # Professional CSS styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| margin: auto !important; | |
| background-color: #ffffff !important; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
| } | |
| #title-header { | |
| background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%); | |
| color: white; | |
| padding: 40px 30px; | |
| border-radius: 12px; | |
| margin-bottom: 30px; | |
| box-shadow: 0 4px 20px rgba(37, 99, 235, 0.15); | |
| text-align: center; | |
| } | |
| .main-container { | |
| background-color: #ffffff; | |
| border-radius: 12px; | |
| box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |
| padding: 30px; | |
| margin-bottom: 20px; | |
| } | |
| .input-section { | |
| background-color: #f8fafc; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 8px; | |
| padding: 25px; | |
| } | |
| .info-panel { | |
| background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); | |
| border: 1px solid #0ea5e9; | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin-top: 20px; | |
| } | |
| .footer-section { | |
| background-color: #f8fafc; | |
| border: 1px solid #e2e8f0; | |
| border-radius: 12px; | |
| padding: 30px; | |
| margin-top: 30px; | |
| } | |
| .stat-grid { | |
| display: grid; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 30px; | |
| margin: 20px 0; | |
| } | |
| .disclaimer-text { | |
| color: #dc2626; | |
| font-weight: 600; | |
| line-height: 1.5; | |
| background-color: #fef2f2; | |
| padding: 15px; | |
| border-radius: 6px; | |
| border: 1px solid #fecaca; | |
| } | |
| h1, h2, h3, h4 { | |
| color: #1e293b !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| padding: 12px 24px !important; | |
| border-radius: 8px !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-1px) !important; | |
| box-shadow: 0 4px 12px rgba(37, 99, 235, 0.3) !important; | |
| } | |
| .gr-button-secondary { | |
| background: #6b7280 !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| padding: 12px 24px !important; | |
| border-radius: 8px !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="PerceptNet - Brain Tumor Segmentation", theme=gr.themes.Default()) as app: | |
| gr.HTML(""" | |
| <div id="title-header"> | |
| <h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">PerceptNet</h1> | |
| <p style="font-size: 1.2rem; margin: 15px 0 5px 0; opacity: 0.95;"> | |
| Advanced Brain Tumor Segmentation System | |
| </p> | |
| <p style="font-size: 1rem; margin: 5px 0 0 0; opacity: 0.8;"> | |
| Attention U-Net Architecture • Dice: 0.8420 • IoU: 0.7297 | |
| </p> | |
| </div> | |
| """) | |
| mask_state = gr.State(None) | |
| with gr.Row(elem_classes="main-container"): | |
| with gr.Column(scale=1, elem_classes="input-section"): | |
| gr.Markdown("### Upload Brain MRI Scan", elem_classes="section-title") | |
| image_input = gr.Image( | |
| label="Brain MRI Image", | |
| type="pil", | |
| sources=["upload", "webcam"], | |
| height=380 | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button( | |
| "Analyze Image", | |
| variant="primary", | |
| scale=2, | |
| size="lg" | |
| ) | |
| random_btn = gr.Button( | |
| "Load Sample", | |
| variant="secondary", | |
| scale=1, | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear", | |
| variant="secondary", | |
| scale=1 | |
| ) | |
| gr.HTML(""" | |
| <div class="info-panel"> | |
| <h4 style="color: #0ea5e9; margin-bottom: 15px; font-size: 1.1rem;">Model Specifications</h4> | |
| <div style="line-height: 1.8; font-size: 0.95rem;"> | |
| <div><strong>Architecture:</strong> Attention U-Net with Skip Connections</div> | |
| <div><strong>Performance:</strong> 84.2% Dice Score, 72.97% IoU</div> | |
| <div><strong>Input Format:</strong> Grayscale MRI Scans (256×256)</div> | |
| <div><strong>Output:</strong> Binary Segmentation + Confidence Heatmap</div> | |
| <div><strong>Features:</strong> Attention Mechanisms, Multi-scale Analysis</div> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Analysis Results", elem_classes="section-title") | |
| output_image = gr.Image( | |
| label="PerceptNet Analysis Output", | |
| type="pil", | |
| height=520 | |
| ) | |
| analysis_output = gr.Markdown( | |
| value="Upload a brain MRI image to begin analysis with PerceptNet.", | |
| elem_id="analysis-results" | |
| ) | |
| # Footer section | |
| gr.HTML(""" | |
| <div class="footer-section"> | |
| <div class="stat-grid"> | |
| <div> | |
| <h4 style="color: #2563eb; margin-bottom: 15px;">Technical Specifications</h4> | |
| <div style="line-height: 1.6;"> | |
| <p><strong>Model Architecture:</strong> Attention U-Net with Gating Mechanisms</p> | |
| <p><strong>Training Dataset:</strong> Brain Tumor Segmentation Dataset</p> | |
| <p><strong>Image Processing:</strong> 256×256 Grayscale Normalization</p> | |
| <p><strong>Inference Speed:</strong> Real-time Processing on GPU/CPU</p> | |
| <p><strong>Output Formats:</strong> Binary Masks, Probability Maps, Heatmaps</p> | |
| </div> | |
| </div> | |
| <div> | |
| <h4 style="color: #dc2626; margin-bottom: 15px;">Important Disclaimer</h4> | |
| <div class="disclaimer-text"> | |
| PerceptNet is an AI research tool designed for <strong>educational and research purposes only</strong>. | |
| This system is not intended for clinical diagnosis or medical decision-making. | |
| All results must be validated by qualified medical professionals before any medical application. | |
| </div> | |
| </div> | |
| </div> | |
| <hr style="margin: 25px 0; border: none; border-top: 1px solid #e2e8f0;"> | |
| <p style="text-align: center; color: #64748b; margin: 15px 0; font-weight: 500;"> | |
| PerceptNet v1.0 • Advanced Medical Image Analysis • Research Grade Performance | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=predict_tumor, | |
| inputs=[image_input, mask_state], | |
| outputs=[output_image, analysis_output], | |
| show_progress=True | |
| ) | |
| random_btn.click( | |
| fn=load_random_sample, | |
| inputs=[], | |
| outputs=[image_input, mask_state, analysis_output] | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[image_input, output_image, analysis_output, mask_state] | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting PerceptNet Brain Tumor Segmentation System...") | |
| print("Loading Attention U-Net architecture...") | |
| print("Auto-downloading model weights...") | |
| print("Expected performance: Dice 0.8420, IoU 0.7297") | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| share=False | |
| ) | |