ibrahim313's picture
Update app.py
5280a25 verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from huggingface_hub import hf_hub_download
import io
import requests
# Your UNET Model Definition
class UNET(nn.Module):
def __init__(self, dropout_rate=0.1, ch=32):
super(UNET, self).__init__()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout_rate),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout_rate)
)
self.encoder1 = conv_block(3, ch)
self.encoder2 = conv_block(ch, ch*2)
self.encoder3 = conv_block(ch*2, ch*4)
self.encoder4 = conv_block(ch*4, ch*8)
self.bottle_neck = conv_block(ch*8, ch*16)
self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2)
self.decoder1 = conv_block(ch*16, ch*8)
self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2)
self.decoder2 = conv_block(ch*8, ch*4)
self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2)
self.decoder3 = conv_block(ch*4, ch*2)
self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2)
self.decoder4 = conv_block(ch*2, ch)
self.final = nn.Conv2d(ch, 1, kernel_size=1)
def forward(self, x):
c1 = self.encoder1(x)
c2 = self.encoder2(self.pool(c1))
c3 = self.encoder3(self.pool(c2))
c4 = self.encoder4(self.pool(c3))
c5 = self.bottle_neck(self.pool(c4))
u6 = self.upsample1(c5)
u6 = torch.cat([c4, u6], dim=1)
c6 = self.decoder1(u6)
u7 = self.upsample2(c6)
u7 = torch.cat([c3, u7], dim=1)
c7 = self.decoder2(u7)
u8 = self.upsample3(c7)
u8 = torch.cat([c2, u8], dim=1)
c8 = self.decoder3(u8)
u9 = self.upsample4(c8)
u9 = torch.cat([c1, u9], dim=1)
c9 = self.decoder4(u9)
return self.final(c9)
# Global variables
model = None
device = torch.device('cpu') # HF Spaces use CPU
transform = A.Compose([
A.Resize(384, 384),
A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255),
ToTensorV2()
])
def load_model():
"""Load model from your HF repository"""
global model
try:
print("πŸ“₯ Downloading model from Hugging Face...")
# Download your model from HF
model_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="pytorch_model.bin"
)
# Load model
model = UNET(ch=32)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("βœ… Model loaded successfully!")
return "βœ… Model loaded from ibrahim313/unet-adam-diceloss"
except Exception as e:
print(f"❌ Error loading model: {e}")
return f"❌ Error: {e}"
def predict_polyp(image, threshold=0.5):
"""Predict polyp in uploaded image"""
if model is None:
return None, "❌ Model not loaded! Please wait for model to load.", None
if image is None:
return None, "❌ Please upload an image first!", None
try:
# Convert image to numpy array
if isinstance(image, Image.Image):
original_image = np.array(image.convert('RGB'))
else:
original_image = np.array(image)
# Preprocess image
transformed = transform(image=original_image)
input_tensor = transformed['image'].unsqueeze(0).float()
# Make prediction
with torch.no_grad():
prediction = model(input_tensor)
prediction = torch.sigmoid(prediction)
prediction = (prediction > threshold).float()
# Convert to numpy
pred_mask = prediction.squeeze().cpu().numpy()
# Calculate metrics
polyp_pixels = np.sum(pred_mask)
total_pixels = pred_mask.shape[0] * pred_mask.shape[1]
polyp_percentage = (polyp_pixels / total_pixels) * 100
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original image
axes[0].imshow(original_image)
axes[0].set_title('πŸ–ΌοΈ Original Image', fontsize=14)
axes[0].axis('off')
# Predicted mask
axes[1].imshow(pred_mask, cmap='gray')
axes[1].set_title('🎭 Predicted Mask', fontsize=14)
axes[1].axis('off')
# Overlay
axes[2].imshow(original_image)
axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6)
axes[2].set_title('πŸ” Detection Overlay', fontsize=14)
axes[2].axis('off')
# Add main title with results
if polyp_pixels > 100:
main_title = f"🚨 POLYP DETECTED! Coverage: {polyp_percentage:.2f}%"
title_color = 'red'
else:
main_title = f"βœ… No Polyp Detected - Coverage: {polyp_percentage:.2f}%"
title_color = 'green'
fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color)
plt.tight_layout()
# Save plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
result_image = Image.open(buf)
plt.close()
# Create detailed results text
if polyp_pixels > 100:
status_emoji = "🚨"
status_text = "POLYP DETECTED"
recommendation = "⚠️ **Recommendation:** Medical review recommended"
else:
status_emoji = "βœ…"
status_text = "NO POLYP DETECTED"
recommendation = "βœ… **Recommendation:** Continue routine monitoring"
results_text = f"""
## {status_emoji} **{status_text}**
### πŸ“Š **Analysis Results:**
- **Polyp Coverage:** {polyp_percentage:.3f}%
- **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,}
- **Detection Threshold:** {threshold}
### πŸ₯ **Clinical Assessment:**
{recommendation}
### πŸ”¬ **Technical Details:**
- **Model:** U-Net (32 channels)
- **Input Size:** 384Γ—384 pixels
- **Architecture:** Encoder-Decoder with skip connections
"""
return result_image, results_text, pred_mask
except Exception as e:
error_msg = f"❌ **Error processing image:** {str(e)}"
return None, error_msg, None
def load_example_image(image_num):
"""Load example images from your HF space"""
try:
if image_num == 1:
# Image 1: cju0qoxqj9q6s0835b43399p4.jpg
image_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="cju0qoxqj9q6s0835b43399p4.jpg",
repo_type="space"
)
else:
# Image 2: cju0roawvklrq0799vmjorwfv.jpg
image_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="cju0roawvklrq0799vmjorwfv.jpg",
repo_type="space"
)
# Load and return the image
image = Image.open(image_path)
return image
except Exception as e:
print(f"Error loading example image {image_num}: {e}")
return None
# Load model when app starts
print("πŸš€ Starting Polyp Detection App...")
load_status = load_model()
print(load_status)
# Create Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="πŸ₯ Polyp Detection AI") as demo:
# Header
gr.HTML("""
<div style="text-align: center; padding: 30px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2.5em;">πŸ₯ AI Polyp Detection System</h1>
<p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Medical Imaging with Deep Learning</p>
<p style="margin: 5px 0 0 0; opacity: 0.9;">Upload colonoscopy images for intelligent polyp detection</p>
</div>
""")
# Model info
gr.HTML(f"""
<div style="background: black; padding: 15px; border-radius: 8px; border-left: 4px solid #0ea5e9; margin-bottom: 20px;">
<strong>πŸ”¬ Model:</strong> ibrahim313/unet-adam-diceloss<br>
<strong>πŸ“ Architecture:</strong> U-Net with 32 base channels<br>
<strong>🎯 Dataset:</strong> Trained on Kvasir-SEG (1000 polyp images)<br>
<strong>πŸ“Έ Examples:</strong> 2 test colonoscopy images included<br>
<strong>⚑ Status:</strong> {load_status}
</div>
""")
# Main interface
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Upload Image</h3>")
input_image = gr.Image(
label="Drop colonoscopy image here",
type="pil",
height=300
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="🎯 Detection Sensitivity",
info="Higher = more sensitive detection"
)
analyze_btn = gr.Button(
"πŸ” Analyze for Polyps",
variant="primary",
size="lg"
)
gr.HTML("<br>")
# Quick examples
gr.HTML("<h4>πŸ“Έ Try Sample Images:</h4>")
gr.HTML("<p style='font-size: 0.9em; color: #666; margin: 5px 0;'>Click to load colonoscopy test images</p>")
with gr.Row():
example1_btn = gr.Button("πŸ–ΌοΈ Test Image 1", size="sm", variant="secondary")
example2_btn = gr.Button("πŸ–ΌοΈ Test Image 2", size="sm", variant="secondary")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Detection Results</h3>")
output_image = gr.Image(
label="Analysis Results",
height=400
)
results_text = gr.Markdown(
value="Upload an image and click 'Analyze for Polyps' to see results.",
label="Detailed Analysis"
)
# Event handlers
analyze_btn.click(
fn=predict_polyp,
inputs=[input_image, threshold_slider],
outputs=[output_image, results_text, gr.State()]
)
# Example button handlers
example1_btn.click(
fn=lambda: load_example_image(1),
inputs=[],
outputs=[input_image]
)
example2_btn.click(
fn=lambda: load_example_image(2),
inputs=[],
outputs=[input_image]
)
# Footer
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 40px; border-top: 2px solid #e5e7eb; background: #f9fafb;">
<p style="margin: 0; color: #dc2626; font-weight: bold;">
⚠️ MEDICAL DISCLAIMER
</p>
<p style="margin: 5px 0; color: #4b5563;">
This AI system is for research and educational purposes only.<br>
Always consult qualified medical professionals for clinical decisions.
</p>
<p style="margin: 10px 0 0 0; color: #6b7280; font-size: 0.9em;">
πŸ”¬ Powered by PyTorch | πŸ€— Hosted on Hugging Face | πŸ“Š Gradio Interface
</p>
</div>
""")
# Launch the app
if __name__ == "__main__":
demo.launch()