Spaces:
Running
Running
import gradio as gr | |
import torch | |
from unet import EnhancedUNet | |
import numpy as np | |
from PIL import Image | |
import io | |
import math | |
def initialize_model(model_path): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = EnhancedUNet(n_channels=1, n_classes=4).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
return model, device | |
def extract_patches(image, patch_size=256): | |
"""Extract patches from the input image.""" | |
width, height = image.size | |
patches = [] | |
positions = [] | |
# Calculate number of patches in each dimension | |
n_cols = math.ceil(width / patch_size) | |
n_rows = math.ceil(height / patch_size) | |
# Pad image if necessary | |
padded_width = n_cols * patch_size | |
padded_height = n_rows * patch_size | |
padded_image = Image.new('L', (padded_width, padded_height), 0) | |
padded_image.paste(image, (0, 0)) | |
# Extract patches | |
for i in range(n_rows): | |
for j in range(n_cols): | |
left = j * patch_size | |
top = i * patch_size | |
right = left + patch_size | |
bottom = top + patch_size | |
patch = padded_image.crop((left, top, right, bottom)) | |
patches.append(patch) | |
positions.append((left, top, right, bottom)) | |
return patches, positions, (padded_width, padded_height), (width, height) | |
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4): | |
"""Stitch patches back together into a complete mask.""" | |
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8) | |
for patch, (left, top, right, bottom) in zip(patches, positions): | |
full_mask[top:bottom, left:right] = patch | |
# Crop back to original size | |
full_mask = full_mask[:original_size[1], :original_size[0]] | |
return full_mask | |
def process_patch(patch, device): | |
# Convert to grayscale if it's not already | |
patch_gray = patch.convert("L") | |
# Convert to numpy array and normalize | |
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0 | |
# Add batch and channel dimensions | |
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0) | |
return patch_tensor.to(device) | |
def create_overlay(original_image, mask, alpha=0.5): | |
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class | |
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) | |
for i, color in enumerate(colors): | |
mask_rgb[mask == i] = color | |
# Resize original image to match mask size | |
original_image = original_image.resize((mask.shape[1], mask.shape[0])) | |
original_array = np.array(original_image.convert("RGB")) | |
# Create overlay | |
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8) | |
return Image.fromarray(overlay) | |
def predict(input_image, model_choice): | |
if input_image is None: | |
return None, None | |
model = models[model_choice] | |
patch_size = 256 | |
# Extract patches | |
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size) | |
# Process each patch | |
predicted_patches = [] | |
for patch in patches: | |
# Process patch | |
patch_tensor = process_patch(patch, device) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(patch_tensor) | |
# Get prediction mask for patch | |
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0] | |
predicted_patches.append(patch_mask) | |
# Stitch patches back together | |
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size) | |
# Create mask image | |
mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility | |
# Create overlay image | |
overlay_image = create_overlay(input_image, full_mask) | |
return mask_image, overlay_image | |
# Initialize model (do this outside the inference function for better performance) | |
w_noise_model_path = "./models/best_model_w_noise.pth" | |
wo_noise_model_path = "./models/best_model_wo_noise.pth" | |
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth" | |
w_noise_model, device = initialize_model(w_noise_model_path) | |
wo_noise_model, device = initialize_model(wo_noise_model_path) | |
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path) | |
models = { | |
"Without Noise": wo_noise_model, | |
"With Noise": w_noise_model, | |
"With Noise V2": w_noise_model_v2 | |
} | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation Mask"), | |
gr.Image(type="pil", label="Overlay"), | |
], | |
title="MoS2 Image Segmentation", | |
description="Upload an image to get the segmentation mask and overlay visualization.", | |
examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]], | |
) | |
# Launch the interface | |
iface.launch(share=True) | |