dennistrujillo commited on
Commit
3cdf75f
·
verified ·
1 Parent(s): 2402255

changes for nrrd + bb

Browse files
Files changed (1) hide show
  1. app.py +8 -77
app.py CHANGED
@@ -11,11 +11,14 @@ from PIL import Image
11
  import torch.nn.functional as F
12
  import io
13
  from gradio_image_prompter import ImagePrompter
 
14
 
15
  def load_image(file_path):
16
  if file_path.endswith(".dcm"):
17
  ds = pydicom.dcmread(file_path)
18
  img = ds.pixel_array
 
 
19
  else:
20
  img = np.array(Image.open(file_path))
21
 
@@ -26,56 +29,7 @@ def load_image(file_path):
26
  H, W = img.shape[:2]
27
  return img, H, W
28
 
29
- @torch.no_grad()
30
- def medsam_inference(medsam_model, img_embed, box_1024, H, W):
31
- box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
32
- if len(box_torch.shape) == 2:
33
- box_torch = box_torch[:, None, :] # (B, 1, 4)
34
-
35
- box_torch=box_torch.reshape(1,4)
36
- sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
37
- points=None,
38
- boxes=box_torch,
39
- masks=None,
40
- )
41
-
42
- low_res_logits, _ = medsam_model.mask_decoder(
43
- image_embeddings=img_embed, # (B, 256, 64, 64)
44
- image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
45
- sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
46
- dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
47
- multimask_output=False,
48
- )
49
-
50
- low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
51
-
52
- low_res_pred = F.interpolate(
53
- low_res_pred,
54
- size=(H, W),
55
- mode="bilinear",
56
- align_corners=False,
57
- ) # (1, 1, gt.shape)
58
- low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
59
- medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
60
- return medsam_seg
61
-
62
- # Function for visualizing images with masks
63
- def visualize(image, mask, box):
64
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
65
- ax[0].imshow(image, cmap='gray')
66
- ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
67
- ax[1].imshow(image, cmap='gray')
68
- ax[1].imshow(mask, alpha=0.5, cmap="jet")
69
- plt.tight_layout()
70
-
71
- # Convert matplotlib figure to a PIL Image
72
- buf = io.BytesIO()
73
- fig.savefig(buf, format='png')
74
- plt.close(fig) # Close the figure to release memory
75
- buf.seek(0)
76
- pil_img = Image.open(buf)
77
-
78
- return pil_img
79
 
80
  # Main function for Gradio app
81
  def process_images(img_dict):
@@ -88,35 +42,12 @@ def process_images(img_dict):
88
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
89
  else:
90
  raise ValueError("Insufficient data for bounding box coordinates.")
91
- image, H, W = img, img.shape[0], img.shape[1] #
92
  if len(image.shape) == 2:
93
  image = np.repeat(image[:, :, None], 3, axis=-1)
94
  H, W, _ = image.shape
95
 
96
- image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
97
- image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
98
- image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
99
-
100
- # Initialize the MedSAM model and set the device
101
- model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
102
- medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
103
- medsam_model = medsam_model.to(device)
104
- medsam_model.eval()
105
-
106
- # Generate image embedding
107
- with torch.no_grad():
108
- img_embed = medsam_model.image_encoder(image_tensor)
109
-
110
- # Calculate resized box coordinates
111
- scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
112
- box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
113
-
114
- # Perform inference
115
- mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
116
-
117
- # Visualization
118
- visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
119
- return visualization
120
 
121
  # Set up Gradio interface
122
  iface = gr.Interface(
@@ -128,8 +59,8 @@ iface = gr.Interface(
128
  gr.Image(type="pil", label="Processed Image")
129
  ],
130
  title="ROI Selection with MEDSAM",
131
- description="Upload an image and select regions of interest for processing."
132
  )
133
 
134
  # Launch the interface
135
- iface.launch()
 
11
  import torch.nn.functional as F
12
  import io
13
  from gradio_image_prompter import ImagePrompter
14
+ import nrrd # Add this import for NRRD file support
15
 
16
  def load_image(file_path):
17
  if file_path.endswith(".dcm"):
18
  ds = pydicom.dcmread(file_path)
19
  img = ds.pixel_array
20
+ elif file_path.endswith(".nrrd"):
21
+ img, _ = nrrd.read(file_path) # Add this condition for NRRD files
22
  else:
23
  img = np.array(Image.open(file_path))
24
 
 
29
  H, W = img.shape[:2]
30
  return img, H, W
31
 
32
+ # The rest of the code remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Main function for Gradio app
35
  def process_images(img_dict):
 
42
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
43
  else:
44
  raise ValueError("Insufficient data for bounding box coordinates.")
45
+ image, H, W = img, img.shape[0], img.shape[1]
46
  if len(image.shape) == 2:
47
  image = np.repeat(image[:, :, None], 3, axis=-1)
48
  H, W, _ = image.shape
49
 
50
+ # The rest of the function remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Set up Gradio interface
53
  iface = gr.Interface(
 
59
  gr.Image(type="pil", label="Processed Image")
60
  ],
61
  title="ROI Selection with MEDSAM",
62
+ description="Upload an image (including NRRD files) and select regions of interest for processing."
63
  )
64
 
65
  # Launch the interface
66
+ iface.launch()