dennistrujillo commited on
Commit
34cc7b2
·
1 Parent(s): 16fa719

changed interface to allow for bb selection

Browse files
Files changed (1) hide show
  1. app.py +53 -11
app.py CHANGED
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
10
  from PIL import Image
11
  import torch.nn.functional as F
12
  import io
 
13
 
14
  def load_image(file_path):
15
  if file_path.endswith(".dcm"):
@@ -38,10 +39,6 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
38
  masks=None,
39
  )
40
 
41
- #print shapes of tensors
42
- print("Shape of sparse_embeddings:", sparse_embeddings.shape)
43
- print("Shape of dense_embeddings:", dense_embeddings.shape)
44
-
45
  low_res_logits, _ = medsam_model.mask_decoder(
46
  image_embeddings=img_embed, # (B, 256, 64, 64)
47
  image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
@@ -70,14 +67,28 @@ def visualize(image, mask, box):
70
  ax[1].imshow(image, cmap='gray')
71
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
72
  plt.tight_layout()
73
- return fig
 
 
 
 
 
 
 
 
74
 
75
  # Main function for Gradio app
76
- def process_images(file, x_min, y_min, x_max, y_max):
77
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
 
79
  # Load and preprocess image
80
- image, H, W = load_image(file)
 
 
 
 
 
 
81
  if len(image.shape) == 2:
82
  image = np.repeat(image[:, :, None], 3, axis=-1)
83
  H, W, _ = image.shape
@@ -105,11 +116,44 @@ def process_images(file, x_min, y_min, x_max, y_max):
105
 
106
  # Visualization
107
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
108
- return visualization #.getvalue()
 
 
 
109
 
110
 
111
  # Set up Gradio interface
112
  iface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  fn=process_images,
114
  inputs=[
115
  gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
@@ -119,6 +163,4 @@ iface = gr.Interface(
119
  gr.Number(label="Y max")
120
  ],
121
  outputs="plot"
122
- )
123
-
124
- iface.launch()
 
10
  from PIL import Image
11
  import torch.nn.functional as F
12
  import io
13
+ from gradio_image_prompter import ImagePrompter
14
 
15
  def load_image(file_path):
16
  if file_path.endswith(".dcm"):
 
39
  masks=None,
40
  )
41
 
 
 
 
 
42
  low_res_logits, _ = medsam_model.mask_decoder(
43
  image_embeddings=img_embed, # (B, 256, 64, 64)
44
  image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
 
67
  ax[1].imshow(image, cmap='gray')
68
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
69
  plt.tight_layout()
70
+
71
+ # Convert matplotlib figure to a PIL Image
72
+ buf = io.BytesIO()
73
+ fig.savefig(buf, format='png')
74
+ plt.close(fig) # Close the figure to release memory
75
+ buf.seek(0)
76
+ pil_img = Image.open(buf)
77
+
78
+ return pil_img
79
 
80
  # Main function for Gradio app
81
+ def process_images(img_dict):
82
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
83
 
84
  # Load and preprocess image
85
+ img = img_dict['image']
86
+ points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
87
+ if len(points) >= 6:
88
+ x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
89
+ else:
90
+ raise ValueError("Insufficient data for bounding box coordinates.")
91
+ image, H, W = img, img.shape[0], img.shape[1] #
92
  if len(image.shape) == 2:
93
  image = np.repeat(image[:, :, None], 3, axis=-1)
94
  H, W, _ = image.shape
 
116
 
117
  # Visualization
118
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
119
+ return visualization
120
+
121
+ def echo(x_min, y_min, x_max, y_max):
122
+ print(x_min, y_min, x_max, y_max)
123
 
124
 
125
  # Set up Gradio interface
126
  iface = gr.Interface(
127
+ fn=process_images,
128
+ inputs=[
129
+ ImagePrompter(label="Select ROIs") # Custom image prompter for selecting regions of interest
130
+ ],
131
+ outputs=[
132
+ gr.Image(type="pil", label="Processed Image"), # Image output
133
+ ],
134
+ title="Image Processing with Custom Prompts",
135
+ description="Upload an image and select regions of interest for processing."
136
+ )
137
+
138
+ # Launch the interface
139
+ iface.launch()
140
+
141
+ '''iface= gr.Interface(fn=process_images,
142
+ inputs=[lambda prompts: (prompts["image"], prompts["points"]),
143
+ ImagePrompter(show_label=False)],
144
+ outputs="plot")'''
145
+
146
+
147
+
148
+ '''iface = gr.Interface(
149
+ lambda prompts: (prompts["image"], prompts["points"]),
150
+ ImagePrompter(show_label=False),
151
+ [gr.Image(show_label=False), gr.Dataframe(label="Points")],
152
+ )
153
+ '''
154
+
155
+
156
+ '''gr.Interface(
157
  fn=process_images,
158
  inputs=[
159
  gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
 
163
  gr.Number(label="Y max")
164
  ],
165
  outputs="plot"
166
+ )'''