umairahmad1789's picture
Update app.py
dc83630 verified
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math
def initialize_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, device
def extract_patches(image, patch_size=256):
"""Extract patches from the input image."""
width, height = image.size
patches = []
positions = []
# Calculate number of patches in each dimension
n_cols = math.ceil(width / patch_size)
n_rows = math.ceil(height / patch_size)
# Pad image if necessary
padded_width = n_cols * patch_size
padded_height = n_rows * patch_size
padded_image = Image.new('L', (padded_width, padded_height), 0)
padded_image.paste(image, (0, 0))
# Extract patches
for i in range(n_rows):
for j in range(n_cols):
left = j * patch_size
top = i * patch_size
right = left + patch_size
bottom = top + patch_size
patch = padded_image.crop((left, top, right, bottom))
patches.append(patch)
positions.append((left, top, right, bottom))
return patches, positions, (padded_width, padded_height), (width, height)
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
"""Stitch patches back together into a complete mask."""
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
for patch, (left, top, right, bottom) in zip(patches, positions):
full_mask[top:bottom, left:right] = patch
# Crop back to original size
full_mask = full_mask[:original_size[1], :original_size[0]]
return full_mask
def process_patch(patch, device):
# Convert to grayscale if it's not already
patch_gray = patch.convert("L")
# Convert to numpy array and normalize
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
# Add batch and channel dimensions
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
return patch_tensor.to(device)
def create_overlay(original_image, mask, alpha=0.5):
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i, color in enumerate(colors):
mask_rgb[mask == i] = color
# Resize original image to match mask size
original_image = original_image.resize((mask.shape[1], mask.shape[0]))
original_array = np.array(original_image.convert("RGB"))
# Create overlay
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
return Image.fromarray(overlay)
def predict(input_image, model_choice):
if input_image is None:
return None, None
model = models[model_choice]
patch_size = 256
# Extract patches
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
# Process each patch
predicted_patches = []
for patch in patches:
# Process patch
patch_tensor = process_patch(patch, device)
# Perform inference
with torch.no_grad():
output = model(patch_tensor)
# Get prediction mask for patch
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
predicted_patches.append(patch_mask)
# Stitch patches back together
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
# Create mask image
mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility
# Create overlay image
overlay_image = create_overlay(input_image, full_mask)
return mask_image, overlay_image
# Initialize model (do this outside the inference function for better performance)
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
models = {
"Without Noise": wo_noise_model,
"With Noise": w_noise_model,
"With Noise V2": w_noise_model_v2
}
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"),
],
outputs=[
gr.Image(type="pil", label="Segmentation Mask"),
gr.Image(type="pil", label="Overlay"),
],
title="MoS2 Image Segmentation",
description="Upload an image to get the segmentation mask and overlay visualization.",
examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
)
# Launch the interface
iface.launch(share=True)