import gradio as gr import torch import torch.nn as nn import numpy as np import cv2 from PIL import Image import matplotlib.pyplot as plt import io from torchvision import transforms import torchvision.transforms.functional as TF import urllib.request import os import random import kagglehub device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = None # Download dataset dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation') image_path = os.path.join(dataset_path, 'images') mask_path = os.path.join(dataset_path, 'masks') test_imgs = sorted([f for f in os.listdir(image_path) if f.endswith('.jpg') or f.endswith('.png')]) test_masks = sorted([f for f in os.listdir(mask_path) if f.endswith('.jpg') or f.endswith('.png')]) # Define your Attention U-Net architecture (from your training code) class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class AttentionBlock(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionBlock, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi class AttentionUNET(nn.Module): def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]): super(AttentionUNET, self).__init__() self.out_channels = out_channels self.ups = nn.ModuleList() self.downs = nn.ModuleList() self.attentions = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Down part of UNET for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature # Bottleneck self.bottleneck = DoubleConv(features[-1], features[-1]*2) # Up part of UNET for feature in reversed(features): self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2)) self.ups.append(DoubleConv(feature*2, feature)) self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] #reverse list for idx in range(0, len(self.ups), 2): #do up and double_conv x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = TF.resize(x, size=skip_connection.shape[2:]) skip_connection = self.attentions[idx // 2](skip_connection, x) concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx+1](concat_skip) return self.final_conv(x) def download_model(): """Download trained model from HuggingFace""" model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar" model_path = "best_attention_model.pth.tar" if not os.path.exists(model_path): print("Downloading trained model...") try: urllib.request.urlretrieve(model_url, model_path) print("Model downloaded successfully!") except Exception as e: print(f"Failed to download model: {e}") return None else: print("Model already exists!") return model_path def load_attention_model(): """Load trained Attention U-Net model""" global model if model is None: try: print("Loading trained Attention U-Net model...") # Download model if needed model_path = download_model() if model_path is None: return None # Initialize model architecture model = AttentionUNET(in_channels=1, out_channels=1).to(device) # Load trained weights checkpoint = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint["state_dict"]) model.eval() print("Attention U-Net model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") model = None return model def preprocess_image(image): """Preprocessing for model input""" # Convert to grayscale if image.mode != 'L': image = image.convert('L') # Apply transforms val_test_transform = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) return val_test_transform(image).unsqueeze(0) # Add batch dimension def predict_tumor(image, mask=None): current_model = load_attention_model() if current_model is None: return None, "Failed to load trained model." if image is None: return None, "Please upload an image first." try: print("Processing with PerceptNet Attention U-Net...") # Preprocess image input_tensor = preprocess_image(image).to(device) # Model prediction with torch.no_grad(): pred_mask = torch.sigmoid(current_model(input_tensor)) pred_mask_binary = (pred_mask > 0.5).float() # Convert to numpy pred_mask_np = pred_mask_binary.cpu().squeeze().numpy() prob_mask_np = pred_mask.cpu().squeeze().numpy() # Probability for heatmap original_np = np.array(image.convert('L').resize((256, 256))) # Create inverted mask for visualization inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255) # Create tumor-only image tumor_only = np.where(pred_mask_np == 1, original_np, 255) # Handle ground truth if provided mask_np = None dice_score = None iou_score = None if mask is not None: mask_transform = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) mask_tensor = mask_transform(mask).squeeze().numpy() mask_np = (mask_tensor > 0.5).astype(float) intersection = np.logical_and(pred_mask_np, mask_np).sum() union = np.logical_or(pred_mask_np, mask_np).sum() iou_score = intersection / (union + 1e-7) dice_score = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7) # Create visualization (5-panel layout) fig, axes = plt.subplots(1, 5, figsize=(25, 5)) fig.suptitle('PerceptNet Analysis Results', fontsize=16, fontweight='bold') titles = ["Original Image", "Ground Truth", "Predicted Mask", "Tumor Only", "Heatmap"] images = [original_np, mask_np if mask_np is not None else np.zeros_like(original_np), inv_pred_mask_np, tumor_only, prob_mask_np] cmaps = ['gray', 'gray', 'gray', 'gray', 'hot'] for i, ax in enumerate(axes): ax.imshow(images[i], cmap=cmaps[i]) ax.set_title(titles[i], fontsize=12, fontweight='bold') ax.axis('off') plt.tight_layout() # Save result buf = io.BytesIO() plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white') buf.seek(0) plt.close() result_image = Image.open(buf) # Calculate statistics tumor_pixels = np.sum(pred_mask_np) total_pixels = pred_mask_np.size tumor_percentage = (tumor_pixels / total_pixels) * 100 # Calculate confidence metrics max_confidence = torch.max(pred_mask).item() mean_confidence = torch.mean(pred_mask).item() analysis_text = f""" ## PerceptNet Analysis Results ### Detection Summary: - **Status**: {'TUMOR DETECTED' if tumor_pixels > 50 else 'NO SIGNIFICANT TUMOR'} - **Tumor Area**: {tumor_percentage:.2f}% of brain region - **Tumor Pixels**: {tumor_pixels:,} pixels - **Max Confidence**: {max_confidence:.4f} - **Mean Confidence**: {mean_confidence:.4f} """ if dice_score is not None and iou_score is not None: analysis_text += f""" - **Dice Score**: {dice_score:.4f} - **IoU Score**: {iou_score:.4f} """ analysis_text += f""" ### Model Information: - **Architecture**: PerceptNet Attention U-Net - **Training Performance**: Dice: 0.8420, IoU: 0.7297 - **Input**: Grayscale (single channel) - **Output**: Binary segmentation mask - **Device**: {device.type.upper()} ### Processing Details: - **Preprocessing**: Resize(256×256) + ToTensor - **Threshold**: 0.5 (sigmoid > 0.5) - **Architecture**: Attention gates + Skip connections - **Features**: [32, 64, 128, 256] channels ### Medical Disclaimer: This AI model is for **research and educational purposes only**. Results should be validated by medical professionals. Not for clinical diagnosis. """ print(f"Model analysis completed! Tumor area: {tumor_percentage:.2f}%") return result_image, analysis_text except Exception as e: error_msg = f"Error with model: {str(e)}" print(error_msg) return None, error_msg def load_random_sample(): if not test_imgs: return None, None, "Dataset not available." rand_idx = random.randint(0, len(test_imgs) - 1) img_path = os.path.join(image_path, test_imgs[rand_idx]) msk_path = os.path.join(mask_path, test_masks[rand_idx]) image = Image.open(img_path).convert('L') mask = Image.open(msk_path).convert('L') return image, mask, "Loaded random sample from dataset." def clear_all(): return None, None, "Upload a brain MRI image to test PerceptNet model", None # Professional CSS styling css = """ .gradio-container { max-width: 1600px !important; margin: auto !important; background-color: #ffffff !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; } .gr-markdown p, .gr-markdown div, .gr-markdown span, .gr-markdown li { color: #1e293b !important; } .gr-markdown h1, .gr-markdown h2, .gr-markdown h3, .gr-markdown h4, .gr-markdown h5, .gr-markdown h6 { color: #1e293b !important; } .gr-markdown strong { color: #374151 !important; } #analysis-results * { color: #1e293b !important; } .info-panel * { color: #1e293b !important; } .footer-section * { color: #1e293b !important; } .footer-section h4 { color: #2563eb !important; } .footer-section p, .footer-section div { color: #374151 !important; } .info-panel h4 { color: #0ea5e9 !important; } .info-panel div { color: #374151 !important; } #title-header { background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%); color: white; padding: 40px 30px; border-radius: 12px; margin-bottom: 30px; box-shadow: 0 4px 20px rgba(37, 99, 235, 0.15); text-align: center; } .main-container { background-color: #ffffff; border-radius: 12px; box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); padding: 30px; margin-bottom: 20px; } .input-section { background-color: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 25px; } .info-panel { background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); border: 1px solid #0ea5e9; border-radius: 8px; padding: 20px; margin-top: 20px; } .footer-section { background-color: #f8fafc; border: 1px solid #e2e8f0; border-radius: 12px; padding: 30px; margin-top: 30px; } .stat-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin: 20px 0; } .disclaimer-text { color: #dc2626; font-weight: 600; line-height: 1.5; background-color: #fef2f2; padding: 15px; border-radius: 6px; border: 1px solid #fecaca; } h1, h2, h3, h4 { color: #1e293b !important; } .gr-button-primary { background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; border: none !important; color: white !important; font-weight: 600 !important; padding: 12px 24px !important; border-radius: 8px !important; transition: all 0.2s ease !important; } .gr-button-primary:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 12px rgba(37, 99, 235, 0.3) !important; } .gr-button-secondary { background: #6b7280 !important; border: none !important; color: white !important; font-weight: 600 !important; padding: 12px 24px !important; border-radius: 8px !important; } """ # Create Gradio interface with gr.Blocks(css=css, title="PerceptNet - Brain Tumor Segmentation", theme=gr.themes.Default()) as app: gr.HTML("""
Advanced Brain Tumor Segmentation System
Attention U-Net Architecture • Dice: 0.8420 • IoU: 0.7297