MedSAMTest / app.py
dennistrujillo's picture
Changed title and label names
03725aa
raw
history blame
4.84 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch=box_torch.reshape(1,4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
# Function for visualizing images with masks
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
# Convert matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to release memory
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
# Main function for Gradio app
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess image
img = img_dict['image']
points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
if len(points) >= 6:
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
else:
raise ValueError("Insufficient data for bounding box coordinates.")
image, H, W = img, img.shape[0], img.shape[1] #
if len(image.shape) == 2:
image = np.repeat(image[:, :, None], 3, axis=-1)
H, W, _ = image.shape
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Generate image embedding
with torch.no_grad():
img_embed = medsam_model.image_encoder(image_tensor)
# Calculate resized box coordinates
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
# Visualization
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
return visualization
def echo(x_min, y_min, x_max, y_max):
print(x_min, y_min, x_max, y_max)
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image and select regions of interest for processing."
)
# Launch the interface
iface.launch()