import gradio as gr import torch from unet import EnhancedUNet import numpy as np from PIL import Image import io import math def initialize_model(model_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = EnhancedUNet(n_channels=1, n_classes=4).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() return model, device def extract_patches(image, patch_size=256): """Extract patches from the input image.""" width, height = image.size patches = [] positions = [] # Calculate number of patches in each dimension n_cols = math.ceil(width / patch_size) n_rows = math.ceil(height / patch_size) # Pad image if necessary padded_width = n_cols * patch_size padded_height = n_rows * patch_size padded_image = Image.new('L', (padded_width, padded_height), 0) padded_image.paste(image, (0, 0)) # Extract patches for i in range(n_rows): for j in range(n_cols): left = j * patch_size top = i * patch_size right = left + patch_size bottom = top + patch_size patch = padded_image.crop((left, top, right, bottom)) patches.append(patch) positions.append((left, top, right, bottom)) return patches, positions, (padded_width, padded_height), (width, height) def stitch_patches(patches, positions, padded_size, original_size, n_classes=4): """Stitch patches back together into a complete mask.""" full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8) for patch, (left, top, right, bottom) in zip(patches, positions): full_mask[top:bottom, left:right] = patch # Crop back to original size full_mask = full_mask[:original_size[1], :original_size[0]] return full_mask def process_patch(patch, device): # Convert to grayscale if it's not already patch_gray = patch.convert("L") # Convert to numpy array and normalize patch_np = np.array(patch_gray, dtype=np.float32) / 255.0 # Add batch and channel dimensions patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0) return patch_tensor.to(device) def create_overlay(original_image, mask, alpha=0.5): colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8) for i, color in enumerate(colors): mask_rgb[mask == i] = color # Resize original image to match mask size original_image = original_image.resize((mask.shape[1], mask.shape[0])) original_array = np.array(original_image.convert("RGB")) # Create overlay overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8) return Image.fromarray(overlay) def predict(input_image, model_choice): if input_image is None: return None, None model = models[model_choice] patch_size = 256 # Extract patches patches, positions, padded_size, original_size = extract_patches(input_image, patch_size) # Process each patch predicted_patches = [] for patch in patches: # Process patch patch_tensor = process_patch(patch, device) # Perform inference with torch.no_grad(): output = model(patch_tensor) # Get prediction mask for patch patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0] predicted_patches.append(patch_mask) # Stitch patches back together full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size) # Create mask image mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility # Create overlay image overlay_image = create_overlay(input_image, full_mask) return mask_image, overlay_image # Initialize model (do this outside the inference function for better performance) w_noise_model_path = "./models/best_model_w_noise.pth" wo_noise_model_path = "./models/best_model_wo_noise.pth" w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth" w_noise_model, device = initialize_model(w_noise_model_path) wo_noise_model, device = initialize_model(wo_noise_model_path) w_noise_model_v2, device = initialize_model(w_noise_model_v2_path) models = { "Without Noise": wo_noise_model, "With Noise": w_noise_model, "With Noise V2": w_noise_model_v2 } # Create Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"), ], outputs=[ gr.Image(type="pil", label="Segmentation Mask"), gr.Image(type="pil", label="Overlay"), ], title="MoS2 Image Segmentation", description="Upload an image to get the segmentation mask and overlay visualization.", examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]], ) # Launch the interface iface.launch(share=True)