Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import pydicom | |
| import os | |
| from skimage import transform | |
| import torch | |
| from segment_anything import sam_model_registry | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import io | |
| from gradio_image_prompter import ImagePrompter | |
| import nrrd # Add this import for NRRD file support | |
| def load_image(file_path): | |
| if file_path.endswith(".dcm"): | |
| ds = pydicom.dcmread(file_path) | |
| img = ds.pixel_array | |
| elif file_path.endswith(".nrrd"): | |
| img, _ = nrrd.read(file_path) # Add this condition for NRRD files | |
| else: | |
| img = np.array(Image.open(file_path)) | |
| # Convert grayscale to 3-channel RGB by replicating channels | |
| if len(img.shape) == 2: # Grayscale image (height, width) | |
| img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3) | |
| H, W = img.shape[:2] | |
| return img, H, W | |
| def medsam_inference(medsam_model, img_embed, box_1024, H, W): | |
| box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) | |
| if len(box_torch.shape) == 2: | |
| box_torch = box_torch[:, None, :] # (B, 1, 4) | |
| box_torch=box_torch.reshape(1,4) | |
| sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( | |
| points=None, | |
| boxes=box_torch, | |
| masks=None, | |
| ) | |
| low_res_logits, _ = medsam_model.mask_decoder( | |
| image_embeddings=img_embed, # (B, 256, 64, 64) | |
| image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) | |
| sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) | |
| dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) | |
| multimask_output=False, | |
| ) | |
| low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) | |
| low_res_pred = F.interpolate( | |
| low_res_pred, | |
| size=(H, W), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) # (1, 1, gt.shape) | |
| low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) | |
| medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
| return medsam_seg | |
| # Function for visualizing images with masks | |
| def visualize(image, mask, box): | |
| fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | |
| ax[0].imshow(image, cmap='gray') | |
| ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none")) | |
| ax[1].imshow(image, cmap='gray') | |
| ax[1].imshow(mask, alpha=0.5, cmap="jet") | |
| plt.tight_layout() | |
| # Convert matplotlib figure to a PIL Image | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| plt.close(fig) # Close the figure to release memory | |
| buf.seek(0) | |
| pil_img = Image.open(buf) | |
| return pil_img | |
| # Main function for Gradio app | |
| def process_images(img_dict): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Load and preprocess image | |
| print(img_dict) | |
| img = img_dict['image'] | |
| points = img_dict['points'][0] # Accessing the first (and possibly only) set of points | |
| if len(points) >= 6: | |
| x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4] | |
| else: | |
| raise ValueError("Insufficient data for bounding box coordinates.") | |
| image, H, W = img, img.shape[0], img.shape[1] | |
| if len(image.shape) == 2: | |
| image = np.repeat(image[:, :, None], 3, axis=-1) | |
| H, W, _ = image.shape | |
| image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8) | |
| image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None) | |
| image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device) | |
| # Initialize the MedSAM model and set the device | |
| model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint | |
| medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path) | |
| medsam_model = medsam_model.to(device) | |
| medsam_model.eval() | |
| # Generate image embedding | |
| with torch.no_grad(): | |
| img_embed = medsam_model.image_encoder(image_tensor) | |
| # Calculate resized box coordinates | |
| scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H]) | |
| box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors | |
| # Perform inference | |
| mask = medsam_inference(medsam_model, img_embed, box_1024, H, W) | |
| # Visualization | |
| visualization = visualize(image, mask, [x_min, y_min, x_max, y_max]) | |
| return visualization | |
| # Set up Gradio interface | |
| iface = gr.Interface( | |
| fn=process_images, | |
| inputs=[ | |
| ImagePrompter(label="Image") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Processed Image") | |
| ], | |
| title="ROI Selection with MEDSAM", | |
| description="Upload an image (including NRRD files) and select regions of interest for processing." | |
| ) | |
| # Launch the interface | |
| iface.launch() | |