Spaces:
Runtime error
Runtime error
Uploaded 2 files
Browse files- automatic_mask_generator.py +173 -0
- requirements.txt +9 -0
automatic_mask_generator.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from streamlit_image_coordinates import streamlit_image_coordinates
|
5 |
+
import streamlit as st
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import SamModel, SamProcessor
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
# Define global constants
|
13 |
+
MAX_WIDTH = 700
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
# Define helpful functions
|
18 |
+
def show_anns(anns):
|
19 |
+
if len(anns) == 0:
|
20 |
+
return
|
21 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
22 |
+
ax = plt.gca()
|
23 |
+
ax.set_autoscale_on(False)
|
24 |
+
|
25 |
+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
|
26 |
+
img[:,:,3] = 0
|
27 |
+
for ann in sorted_anns:
|
28 |
+
m = ann['segmentation']
|
29 |
+
color_mask = np.concatenate([np.random.random(3), [0.35]])
|
30 |
+
img[m] = color_mask
|
31 |
+
ax.imshow(img)
|
32 |
+
|
33 |
+
def show_mask(mask, ax, random_color=False):
|
34 |
+
if random_color:
|
35 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
36 |
+
else:
|
37 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
38 |
+
h, w = mask.shape[-2:]
|
39 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
40 |
+
ax.imshow(mask_image)
|
41 |
+
|
42 |
+
def show_points(coords, labels, ax, marker_size=20):
|
43 |
+
pos_points = coords[labels==1]
|
44 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
|
45 |
+
|
46 |
+
def show_masks_on_image(raw_image, masks, scores):
|
47 |
+
if len(masks.shape) == 4:
|
48 |
+
masks = masks.squeeze()
|
49 |
+
if scores.shape[0] == 1:
|
50 |
+
scores = scores.squeeze()
|
51 |
+
|
52 |
+
nb_predictions = scores.shape[-1]
|
53 |
+
fig, ax = plt.subplots(1, nb_predictions)
|
54 |
+
|
55 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
56 |
+
mask = mask.cpu().detach()
|
57 |
+
ax[i].imshow(np.array(raw_image))
|
58 |
+
show_mask(mask, ax[i])
|
59 |
+
ax[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
|
60 |
+
ax[i].axis("off")
|
61 |
+
|
62 |
+
def show_points_on_image(raw_image, input_point, ax, input_labels=None):
|
63 |
+
ax.imshow(raw_image)
|
64 |
+
input_point = np.array(input_point)
|
65 |
+
if input_labels is None:
|
66 |
+
labels = np.ones_like(input_point[:, 0])
|
67 |
+
else:
|
68 |
+
labels = np.array(input_labels)
|
69 |
+
show_points(input_point, labels, ax)
|
70 |
+
ax.axis('on')
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# Get SAM
|
76 |
+
if torch.cuda.is_available():
|
77 |
+
device = 'cuda'
|
78 |
+
else:
|
79 |
+
device = 'cpu'
|
80 |
+
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
81 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# Get uploaded files from user
|
86 |
+
scale = st.file_uploader('Upload Scale Image')
|
87 |
+
image = st.file_uploader('Upload Particle Image')
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
# Runs when scale image is uploaded
|
92 |
+
if scale:
|
93 |
+
scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
|
94 |
+
scale_np = cv2.imdecode(scale_np, 1)
|
95 |
+
|
96 |
+
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
97 |
+
inputs = processor(scale_np, return_tensors="pt").to(device)
|
98 |
+
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
99 |
+
|
100 |
+
scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
101 |
+
clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
102 |
+
if clicked_point:
|
103 |
+
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
104 |
+
input_point_list = [input_point_np.astype(int).tolist()]
|
105 |
+
|
106 |
+
#inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device)
|
107 |
+
inputs = processor(scale_np, input_points=input_point_list, return_tensors="pt").to(device)
|
108 |
+
inputs.pop("pixel_values", None)
|
109 |
+
inputs.update({"image_embeddings": image_embeddings})
|
110 |
+
with torch.no_grad():
|
111 |
+
outputs = model(**inputs)
|
112 |
+
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
113 |
+
mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y)
|
114 |
+
|
115 |
+
mask = mask.to(torch.int)
|
116 |
+
input_label = np.array([1])
|
117 |
+
|
118 |
+
fig, ax = plt.subplots()
|
119 |
+
ax.imshow(scale_np)
|
120 |
+
show_mask(mask, ax)
|
121 |
+
#show_points_on_image(scale_np, input_point, input_label, ax)
|
122 |
+
show_points(input_point_np, input_label, ax)
|
123 |
+
ax.axis('off')
|
124 |
+
st.pyplot(fig)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
# Get pixels per millimeter
|
129 |
+
pixels_per_unit = torch.sum(mask, axis=1)
|
130 |
+
pixels_per_unit = pixels_per_unit[pixels_per_unit > 0]
|
131 |
+
pixels_per_unit = torch.mean(pixels_per_unit, dtype=torch.float).item()
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
# Runs when image is uploaded
|
136 |
+
if image:
|
137 |
+
image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
|
138 |
+
image_np = cv2.imdecode(image_np, 1)
|
139 |
+
|
140 |
+
#inputs = processor(raw_image, return_tensors="pt").to(device)
|
141 |
+
inputs = processor(image_np, return_tensors="pt").to(device)
|
142 |
+
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
143 |
+
|
144 |
+
scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
|
145 |
+
clicked_point = streamlit_image_coordinates(Image.open(image.name), height=image_np.shape[0] // scale_factor, width=MAX_WIDTH)
|
146 |
+
if clicked_point:
|
147 |
+
input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
|
148 |
+
input_point_list = [input_point_np.astype(int).tolist()]
|
149 |
+
|
150 |
+
#inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device)
|
151 |
+
inputs = processor(image_np, input_points=input_point_list, return_tensors="pt").to(device)
|
152 |
+
inputs.pop("pixel_values", None)
|
153 |
+
inputs.update({"image_embeddings": image_embeddings})
|
154 |
+
with torch.no_grad():
|
155 |
+
outputs = model(**inputs)
|
156 |
+
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
157 |
+
mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y)
|
158 |
+
|
159 |
+
mask = mask.to(torch.int)
|
160 |
+
input_label = np.array([1])
|
161 |
+
|
162 |
+
fig, ax = plt.subplots()
|
163 |
+
ax.imshow(image_np)
|
164 |
+
show_mask(mask, ax)
|
165 |
+
#show_points_on_image(scale_np, input_point, input_label, ax)
|
166 |
+
show_points(input_point_np, input_label, ax)
|
167 |
+
ax.axis('off')
|
168 |
+
st.pyplot(fig)
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
# Get the area in square millimeters
|
173 |
+
st.write(f'Area: {torch.sum(mask, dtype=torch.float).item() / pixels_per_unit ** 2} mm^2')
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.24.2
|
2 |
+
torch==1.13.1
|
3 |
+
torchvision==0.14.1
|
4 |
+
matplotlib==3.7.0
|
5 |
+
streamlit_image_coordinates==0.1.3
|
6 |
+
streamlit==1.22.0
|
7 |
+
Pillow==9.4.0
|
8 |
+
transformers==4.29.1
|
9 |
+
opencv-python-headless==4.5.4.60
|