Spaces:
Sleeping
Sleeping
Commit
·
34cc7b2
1
Parent(s):
16fa719
changed interface to allow for bb selection
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
|
|
10 |
from PIL import Image
|
11 |
import torch.nn.functional as F
|
12 |
import io
|
|
|
13 |
|
14 |
def load_image(file_path):
|
15 |
if file_path.endswith(".dcm"):
|
@@ -38,10 +39,6 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
38 |
masks=None,
|
39 |
)
|
40 |
|
41 |
-
#print shapes of tensors
|
42 |
-
print("Shape of sparse_embeddings:", sparse_embeddings.shape)
|
43 |
-
print("Shape of dense_embeddings:", dense_embeddings.shape)
|
44 |
-
|
45 |
low_res_logits, _ = medsam_model.mask_decoder(
|
46 |
image_embeddings=img_embed, # (B, 256, 64, 64)
|
47 |
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
@@ -70,14 +67,28 @@ def visualize(image, mask, box):
|
|
70 |
ax[1].imshow(image, cmap='gray')
|
71 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
72 |
plt.tight_layout()
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Main function for Gradio app
|
76 |
-
def process_images(
|
77 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
78 |
|
79 |
# Load and preprocess image
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
if len(image.shape) == 2:
|
82 |
image = np.repeat(image[:, :, None], 3, axis=-1)
|
83 |
H, W, _ = image.shape
|
@@ -105,11 +116,44 @@ def process_images(file, x_min, y_min, x_max, y_max):
|
|
105 |
|
106 |
# Visualization
|
107 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
108 |
-
return visualization
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
# Set up Gradio interface
|
112 |
iface = gr.Interface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
fn=process_images,
|
114 |
inputs=[
|
115 |
gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
|
@@ -119,6 +163,4 @@ iface = gr.Interface(
|
|
119 |
gr.Number(label="Y max")
|
120 |
],
|
121 |
outputs="plot"
|
122 |
-
)
|
123 |
-
|
124 |
-
iface.launch()
|
|
|
10 |
from PIL import Image
|
11 |
import torch.nn.functional as F
|
12 |
import io
|
13 |
+
from gradio_image_prompter import ImagePrompter
|
14 |
|
15 |
def load_image(file_path):
|
16 |
if file_path.endswith(".dcm"):
|
|
|
39 |
masks=None,
|
40 |
)
|
41 |
|
|
|
|
|
|
|
|
|
42 |
low_res_logits, _ = medsam_model.mask_decoder(
|
43 |
image_embeddings=img_embed, # (B, 256, 64, 64)
|
44 |
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
|
|
67 |
ax[1].imshow(image, cmap='gray')
|
68 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
69 |
plt.tight_layout()
|
70 |
+
|
71 |
+
# Convert matplotlib figure to a PIL Image
|
72 |
+
buf = io.BytesIO()
|
73 |
+
fig.savefig(buf, format='png')
|
74 |
+
plt.close(fig) # Close the figure to release memory
|
75 |
+
buf.seek(0)
|
76 |
+
pil_img = Image.open(buf)
|
77 |
+
|
78 |
+
return pil_img
|
79 |
|
80 |
# Main function for Gradio app
|
81 |
+
def process_images(img_dict):
|
82 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
83 |
|
84 |
# Load and preprocess image
|
85 |
+
img = img_dict['image']
|
86 |
+
points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
|
87 |
+
if len(points) >= 6:
|
88 |
+
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
|
89 |
+
else:
|
90 |
+
raise ValueError("Insufficient data for bounding box coordinates.")
|
91 |
+
image, H, W = img, img.shape[0], img.shape[1] #
|
92 |
if len(image.shape) == 2:
|
93 |
image = np.repeat(image[:, :, None], 3, axis=-1)
|
94 |
H, W, _ = image.shape
|
|
|
116 |
|
117 |
# Visualization
|
118 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
119 |
+
return visualization
|
120 |
+
|
121 |
+
def echo(x_min, y_min, x_max, y_max):
|
122 |
+
print(x_min, y_min, x_max, y_max)
|
123 |
|
124 |
|
125 |
# Set up Gradio interface
|
126 |
iface = gr.Interface(
|
127 |
+
fn=process_images,
|
128 |
+
inputs=[
|
129 |
+
ImagePrompter(label="Select ROIs") # Custom image prompter for selecting regions of interest
|
130 |
+
],
|
131 |
+
outputs=[
|
132 |
+
gr.Image(type="pil", label="Processed Image"), # Image output
|
133 |
+
],
|
134 |
+
title="Image Processing with Custom Prompts",
|
135 |
+
description="Upload an image and select regions of interest for processing."
|
136 |
+
)
|
137 |
+
|
138 |
+
# Launch the interface
|
139 |
+
iface.launch()
|
140 |
+
|
141 |
+
'''iface= gr.Interface(fn=process_images,
|
142 |
+
inputs=[lambda prompts: (prompts["image"], prompts["points"]),
|
143 |
+
ImagePrompter(show_label=False)],
|
144 |
+
outputs="plot")'''
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
'''iface = gr.Interface(
|
149 |
+
lambda prompts: (prompts["image"], prompts["points"]),
|
150 |
+
ImagePrompter(show_label=False),
|
151 |
+
[gr.Image(show_label=False), gr.Dataframe(label="Points")],
|
152 |
+
)
|
153 |
+
'''
|
154 |
+
|
155 |
+
|
156 |
+
'''gr.Interface(
|
157 |
fn=process_images,
|
158 |
inputs=[
|
159 |
gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
|
|
|
163 |
gr.Number(label="Y max")
|
164 |
],
|
165 |
outputs="plot"
|
166 |
+
)'''
|
|
|
|