Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from huggingface_hub import hf_hub_download | |
import io | |
import requests | |
# Your UNET Model Definition | |
class UNET(nn.Module): | |
def __init__(self, dropout_rate=0.1, ch=32): | |
super(UNET, self).__init__() | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
def conv_block(in_channels, out_channels): | |
return nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Dropout2d(p=dropout_rate), | |
nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Dropout2d(p=dropout_rate) | |
) | |
self.encoder1 = conv_block(3, ch) | |
self.encoder2 = conv_block(ch, ch*2) | |
self.encoder3 = conv_block(ch*2, ch*4) | |
self.encoder4 = conv_block(ch*4, ch*8) | |
self.bottle_neck = conv_block(ch*8, ch*16) | |
self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2) | |
self.decoder1 = conv_block(ch*16, ch*8) | |
self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2) | |
self.decoder2 = conv_block(ch*8, ch*4) | |
self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2) | |
self.decoder3 = conv_block(ch*4, ch*2) | |
self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2) | |
self.decoder4 = conv_block(ch*2, ch) | |
self.final = nn.Conv2d(ch, 1, kernel_size=1) | |
def forward(self, x): | |
c1 = self.encoder1(x) | |
c2 = self.encoder2(self.pool(c1)) | |
c3 = self.encoder3(self.pool(c2)) | |
c4 = self.encoder4(self.pool(c3)) | |
c5 = self.bottle_neck(self.pool(c4)) | |
u6 = self.upsample1(c5) | |
u6 = torch.cat([c4, u6], dim=1) | |
c6 = self.decoder1(u6) | |
u7 = self.upsample2(c6) | |
u7 = torch.cat([c3, u7], dim=1) | |
c7 = self.decoder2(u7) | |
u8 = self.upsample3(c7) | |
u8 = torch.cat([c2, u8], dim=1) | |
c8 = self.decoder3(u8) | |
u9 = self.upsample4(c8) | |
u9 = torch.cat([c1, u9], dim=1) | |
c9 = self.decoder4(u9) | |
return self.final(c9) | |
# Global variables | |
model = None | |
device = torch.device('cpu') # HF Spaces use CPU | |
transform = A.Compose([ | |
A.Resize(384, 384), | |
A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255), | |
ToTensorV2() | |
]) | |
def load_model(): | |
"""Load model from your HF repository""" | |
global model | |
try: | |
print("π₯ Downloading model from Hugging Face...") | |
# Download your model from HF | |
model_path = hf_hub_download( | |
repo_id="ibrahim313/unet-adam-diceloss", | |
filename="pytorch_model.bin" | |
) | |
# Load model | |
model = UNET(ch=32) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
print("β Model loaded successfully!") | |
return "β Model loaded from ibrahim313/unet-adam-diceloss" | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
return f"β Error: {e}" | |
def predict_polyp(image, threshold=0.5): | |
"""Predict polyp in uploaded image""" | |
if model is None: | |
return None, "β Model not loaded! Please wait for model to load.", None | |
if image is None: | |
return None, "β Please upload an image first!", None | |
try: | |
# Convert image to numpy array | |
if isinstance(image, Image.Image): | |
original_image = np.array(image.convert('RGB')) | |
else: | |
original_image = np.array(image) | |
# Preprocess image | |
transformed = transform(image=original_image) | |
input_tensor = transformed['image'].unsqueeze(0).float() | |
# Make prediction | |
with torch.no_grad(): | |
prediction = model(input_tensor) | |
prediction = torch.sigmoid(prediction) | |
prediction = (prediction > threshold).float() | |
# Convert to numpy | |
pred_mask = prediction.squeeze().cpu().numpy() | |
# Calculate metrics | |
polyp_pixels = np.sum(pred_mask) | |
total_pixels = pred_mask.shape[0] * pred_mask.shape[1] | |
polyp_percentage = (polyp_pixels / total_pixels) * 100 | |
# Create visualization | |
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
# Original image | |
axes[0].imshow(original_image) | |
axes[0].set_title('πΌοΈ Original Image', fontsize=14) | |
axes[0].axis('off') | |
# Predicted mask | |
axes[1].imshow(pred_mask, cmap='gray') | |
axes[1].set_title('π Predicted Mask', fontsize=14) | |
axes[1].axis('off') | |
# Overlay | |
axes[2].imshow(original_image) | |
axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6) | |
axes[2].set_title('π Detection Overlay', fontsize=14) | |
axes[2].axis('off') | |
# Add main title with results | |
if polyp_pixels > 100: | |
main_title = f"π¨ POLYP DETECTED! Coverage: {polyp_percentage:.2f}%" | |
title_color = 'red' | |
else: | |
main_title = f"β No Polyp Detected - Coverage: {polyp_percentage:.2f}%" | |
title_color = 'green' | |
fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color) | |
plt.tight_layout() | |
# Save plot to image | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
buf.seek(0) | |
result_image = Image.open(buf) | |
plt.close() | |
# Create detailed results text | |
if polyp_pixels > 100: | |
status_emoji = "π¨" | |
status_text = "POLYP DETECTED" | |
recommendation = "β οΈ **Recommendation:** Medical review recommended" | |
else: | |
status_emoji = "β " | |
status_text = "NO POLYP DETECTED" | |
recommendation = "β **Recommendation:** Continue routine monitoring" | |
results_text = f""" | |
## {status_emoji} **{status_text}** | |
### π **Analysis Results:** | |
- **Polyp Coverage:** {polyp_percentage:.3f}% | |
- **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,} | |
- **Detection Threshold:** {threshold} | |
### π₯ **Clinical Assessment:** | |
{recommendation} | |
### π¬ **Technical Details:** | |
- **Model:** U-Net (32 channels) | |
- **Input Size:** 384Γ384 pixels | |
- **Architecture:** Encoder-Decoder with skip connections | |
""" | |
return result_image, results_text, pred_mask | |
except Exception as e: | |
error_msg = f"β **Error processing image:** {str(e)}" | |
return None, error_msg, None | |
def load_example_image(image_num): | |
"""Load example images from your HF space""" | |
try: | |
if image_num == 1: | |
# Image 1: cju0qoxqj9q6s0835b43399p4.jpg | |
image_path = hf_hub_download( | |
repo_id="ibrahim313/unet-adam-diceloss", | |
filename="cju0qoxqj9q6s0835b43399p4.jpg", | |
repo_type="space" | |
) | |
else: | |
# Image 2: cju0roawvklrq0799vmjorwfv.jpg | |
image_path = hf_hub_download( | |
repo_id="ibrahim313/unet-adam-diceloss", | |
filename="cju0roawvklrq0799vmjorwfv.jpg", | |
repo_type="space" | |
) | |
# Load and return the image | |
image = Image.open(image_path) | |
return image | |
except Exception as e: | |
print(f"Error loading example image {image_num}: {e}") | |
return None | |
# Load model when app starts | |
print("π Starting Polyp Detection App...") | |
load_status = load_model() | |
print(load_status) | |
# Create Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft(), title="π₯ Polyp Detection AI") as demo: | |
# Header | |
gr.HTML(""" | |
<div style="text-align: center; padding: 30px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;"> | |
<h1 style="margin: 0; font-size: 2.5em;">π₯ AI Polyp Detection System</h1> | |
<p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Medical Imaging with Deep Learning</p> | |
<p style="margin: 5px 0 0 0; opacity: 0.9;">Upload colonoscopy images for intelligent polyp detection</p> | |
</div> | |
""") | |
# Model info | |
gr.HTML(f""" | |
<div style="background: black; padding: 15px; border-radius: 8px; border-left: 4px solid #0ea5e9; margin-bottom: 20px;"> | |
<strong>π¬ Model:</strong> ibrahim313/unet-adam-diceloss<br> | |
<strong>π Architecture:</strong> U-Net with 32 base channels<br> | |
<strong>π― Dataset:</strong> Trained on Kvasir-SEG (1000 polyp images)<br> | |
<strong>πΈ Examples:</strong> 2 test colonoscopy images included<br> | |
<strong>β‘ Status:</strong> {load_status} | |
</div> | |
""") | |
# Main interface | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Upload Image</h3>") | |
input_image = gr.Image( | |
label="Drop colonoscopy image here", | |
type="pil", | |
height=300 | |
) | |
threshold_slider = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=0.5, | |
step=0.1, | |
label="π― Detection Sensitivity", | |
info="Higher = more sensitive detection" | |
) | |
analyze_btn = gr.Button( | |
"π Analyze for Polyps", | |
variant="primary", | |
size="lg" | |
) | |
gr.HTML("<br>") | |
# Quick examples | |
gr.HTML("<h4>πΈ Try Sample Images:</h4>") | |
gr.HTML("<p style='font-size: 0.9em; color: #666; margin: 5px 0;'>Click to load colonoscopy test images</p>") | |
with gr.Row(): | |
example1_btn = gr.Button("πΌοΈ Test Image 1", size="sm", variant="secondary") | |
example2_btn = gr.Button("πΌοΈ Test Image 2", size="sm", variant="secondary") | |
with gr.Column(scale=2): | |
gr.HTML("<h3>π Detection Results</h3>") | |
output_image = gr.Image( | |
label="Analysis Results", | |
height=400 | |
) | |
results_text = gr.Markdown( | |
value="Upload an image and click 'Analyze for Polyps' to see results.", | |
label="Detailed Analysis" | |
) | |
# Event handlers | |
analyze_btn.click( | |
fn=predict_polyp, | |
inputs=[input_image, threshold_slider], | |
outputs=[output_image, results_text, gr.State()] | |
) | |
# Example button handlers | |
example1_btn.click( | |
fn=lambda: load_example_image(1), | |
inputs=[], | |
outputs=[input_image] | |
) | |
example2_btn.click( | |
fn=lambda: load_example_image(2), | |
inputs=[], | |
outputs=[input_image] | |
) | |
# Footer | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px; margin-top: 40px; border-top: 2px solid #e5e7eb; background: #f9fafb;"> | |
<p style="margin: 0; color: #dc2626; font-weight: bold;"> | |
β οΈ MEDICAL DISCLAIMER | |
</p> | |
<p style="margin: 5px 0; color: #4b5563;"> | |
This AI system is for research and educational purposes only.<br> | |
Always consult qualified medical professionals for clinical decisions. | |
</p> | |
<p style="margin: 10px 0 0 0; color: #6b7280; font-size: 0.9em;"> | |
π¬ Powered by PyTorch | π€ Hosted on Hugging Face | π Gradio Interface | |
</p> | |
</div> | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |