Spaces:
Sleeping
Sleeping
File size: 5,170 Bytes
efd5df3 c895613 efd5df3 c895613 efd5df3 dc83630 c895613 efd5df3 c895613 efd5df3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math
def initialize_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, device
def extract_patches(image, patch_size=256):
"""Extract patches from the input image."""
width, height = image.size
patches = []
positions = []
# Calculate number of patches in each dimension
n_cols = math.ceil(width / patch_size)
n_rows = math.ceil(height / patch_size)
# Pad image if necessary
padded_width = n_cols * patch_size
padded_height = n_rows * patch_size
padded_image = Image.new('L', (padded_width, padded_height), 0)
padded_image.paste(image, (0, 0))
# Extract patches
for i in range(n_rows):
for j in range(n_cols):
left = j * patch_size
top = i * patch_size
right = left + patch_size
bottom = top + patch_size
patch = padded_image.crop((left, top, right, bottom))
patches.append(patch)
positions.append((left, top, right, bottom))
return patches, positions, (padded_width, padded_height), (width, height)
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
"""Stitch patches back together into a complete mask."""
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
for patch, (left, top, right, bottom) in zip(patches, positions):
full_mask[top:bottom, left:right] = patch
# Crop back to original size
full_mask = full_mask[:original_size[1], :original_size[0]]
return full_mask
def process_patch(patch, device):
# Convert to grayscale if it's not already
patch_gray = patch.convert("L")
# Convert to numpy array and normalize
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
# Add batch and channel dimensions
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
return patch_tensor.to(device)
def create_overlay(original_image, mask, alpha=0.5):
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i, color in enumerate(colors):
mask_rgb[mask == i] = color
# Resize original image to match mask size
original_image = original_image.resize((mask.shape[1], mask.shape[0]))
original_array = np.array(original_image.convert("RGB"))
# Create overlay
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
return Image.fromarray(overlay)
def predict(input_image, model_choice):
if input_image is None:
return None, None
model = models[model_choice]
patch_size = 256
# Extract patches
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
# Process each patch
predicted_patches = []
for patch in patches:
# Process patch
patch_tensor = process_patch(patch, device)
# Perform inference
with torch.no_grad():
output = model(patch_tensor)
# Get prediction mask for patch
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
predicted_patches.append(patch_mask)
# Stitch patches back together
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
# Create mask image
mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility
# Create overlay image
overlay_image = create_overlay(input_image, full_mask)
return mask_image, overlay_image
# Initialize model (do this outside the inference function for better performance)
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
models = {
"Without Noise": wo_noise_model,
"With Noise": w_noise_model,
"With Noise V2": w_noise_model_v2
}
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"),
],
outputs=[
gr.Image(type="pil", label="Segmentation Mask"),
gr.Image(type="pil", label="Overlay"),
],
title="MoS2 Image Segmentation",
description="Upload an image to get the segmentation mask and overlay visualization.",
examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
)
# Launch the interface
iface.launch(share=True)
|