rsrao1729 commited on
Commit
7d61c0e
·
1 Parent(s): 1e53d3d

Uploaded 2 files

Browse files
Files changed (2) hide show
  1. automatic_mask_generator.py +173 -0
  2. requirements.txt +9 -0
automatic_mask_generator.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from streamlit_image_coordinates import streamlit_image_coordinates
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from transformers import SamModel, SamProcessor
8
+ import cv2
9
+
10
+
11
+
12
+ # Define global constants
13
+ MAX_WIDTH = 700
14
+
15
+
16
+
17
+ # Define helpful functions
18
+ def show_anns(anns):
19
+ if len(anns) == 0:
20
+ return
21
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
22
+ ax = plt.gca()
23
+ ax.set_autoscale_on(False)
24
+
25
+ img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
26
+ img[:,:,3] = 0
27
+ for ann in sorted_anns:
28
+ m = ann['segmentation']
29
+ color_mask = np.concatenate([np.random.random(3), [0.35]])
30
+ img[m] = color_mask
31
+ ax.imshow(img)
32
+
33
+ def show_mask(mask, ax, random_color=False):
34
+ if random_color:
35
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
36
+ else:
37
+ color = np.array([30/255, 144/255, 255/255, 0.6])
38
+ h, w = mask.shape[-2:]
39
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
40
+ ax.imshow(mask_image)
41
+
42
+ def show_points(coords, labels, ax, marker_size=20):
43
+ pos_points = coords[labels==1]
44
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white', linewidth=0.2)
45
+
46
+ def show_masks_on_image(raw_image, masks, scores):
47
+ if len(masks.shape) == 4:
48
+ masks = masks.squeeze()
49
+ if scores.shape[0] == 1:
50
+ scores = scores.squeeze()
51
+
52
+ nb_predictions = scores.shape[-1]
53
+ fig, ax = plt.subplots(1, nb_predictions)
54
+
55
+ for i, (mask, score) in enumerate(zip(masks, scores)):
56
+ mask = mask.cpu().detach()
57
+ ax[i].imshow(np.array(raw_image))
58
+ show_mask(mask, ax[i])
59
+ ax[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
60
+ ax[i].axis("off")
61
+
62
+ def show_points_on_image(raw_image, input_point, ax, input_labels=None):
63
+ ax.imshow(raw_image)
64
+ input_point = np.array(input_point)
65
+ if input_labels is None:
66
+ labels = np.ones_like(input_point[:, 0])
67
+ else:
68
+ labels = np.array(input_labels)
69
+ show_points(input_point, labels, ax)
70
+ ax.axis('on')
71
+
72
+
73
+
74
+
75
+ # Get SAM
76
+ if torch.cuda.is_available():
77
+ device = 'cuda'
78
+ else:
79
+ device = 'cpu'
80
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
81
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
82
+
83
+
84
+
85
+ # Get uploaded files from user
86
+ scale = st.file_uploader('Upload Scale Image')
87
+ image = st.file_uploader('Upload Particle Image')
88
+
89
+
90
+
91
+ # Runs when scale image is uploaded
92
+ if scale:
93
+ scale_np = np.asarray(bytearray(scale.read()), dtype=np.uint8)
94
+ scale_np = cv2.imdecode(scale_np, 1)
95
+
96
+ #inputs = processor(raw_image, return_tensors="pt").to(device)
97
+ inputs = processor(scale_np, return_tensors="pt").to(device)
98
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
99
+
100
+ scale_factor = scale_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
101
+ clicked_point = streamlit_image_coordinates(Image.open(scale.name), height=scale_np.shape[0] // scale_factor, width=MAX_WIDTH)
102
+ if clicked_point:
103
+ input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
104
+ input_point_list = [input_point_np.astype(int).tolist()]
105
+
106
+ #inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device)
107
+ inputs = processor(scale_np, input_points=input_point_list, return_tensors="pt").to(device)
108
+ inputs.pop("pixel_values", None)
109
+ inputs.update({"image_embeddings": image_embeddings})
110
+ with torch.no_grad():
111
+ outputs = model(**inputs)
112
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
113
+ mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y)
114
+
115
+ mask = mask.to(torch.int)
116
+ input_label = np.array([1])
117
+
118
+ fig, ax = plt.subplots()
119
+ ax.imshow(scale_np)
120
+ show_mask(mask, ax)
121
+ #show_points_on_image(scale_np, input_point, input_label, ax)
122
+ show_points(input_point_np, input_label, ax)
123
+ ax.axis('off')
124
+ st.pyplot(fig)
125
+
126
+
127
+
128
+ # Get pixels per millimeter
129
+ pixels_per_unit = torch.sum(mask, axis=1)
130
+ pixels_per_unit = pixels_per_unit[pixels_per_unit > 0]
131
+ pixels_per_unit = torch.mean(pixels_per_unit, dtype=torch.float).item()
132
+
133
+
134
+
135
+ # Runs when image is uploaded
136
+ if image:
137
+ image_np = np.asarray(bytearray(image.read()), dtype=np.uint8)
138
+ image_np = cv2.imdecode(image_np, 1)
139
+
140
+ #inputs = processor(raw_image, return_tensors="pt").to(device)
141
+ inputs = processor(image_np, return_tensors="pt").to(device)
142
+ image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
143
+
144
+ scale_factor = image_np.shape[1] / MAX_WIDTH # how many times larger scale_np is than the image shown for each dimension
145
+ clicked_point = streamlit_image_coordinates(Image.open(image.name), height=image_np.shape[0] // scale_factor, width=MAX_WIDTH)
146
+ if clicked_point:
147
+ input_point_np = np.array([[clicked_point['x'], clicked_point['y']]]) * scale_factor
148
+ input_point_list = [input_point_np.astype(int).tolist()]
149
+
150
+ #inputs = processor(raw_image, input_points=input_point, return_tensors="pt").to(device)
151
+ inputs = processor(image_np, input_points=input_point_list, return_tensors="pt").to(device)
152
+ inputs.pop("pixel_values", None)
153
+ inputs.update({"image_embeddings": image_embeddings})
154
+ with torch.no_grad():
155
+ outputs = model(**inputs)
156
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
157
+ mask = torch.squeeze(masks[0])[0] # mask.shape: (1,x,y) --> (x,y)
158
+
159
+ mask = mask.to(torch.int)
160
+ input_label = np.array([1])
161
+
162
+ fig, ax = plt.subplots()
163
+ ax.imshow(image_np)
164
+ show_mask(mask, ax)
165
+ #show_points_on_image(scale_np, input_point, input_label, ax)
166
+ show_points(input_point_np, input_label, ax)
167
+ ax.axis('off')
168
+ st.pyplot(fig)
169
+
170
+
171
+
172
+ # Get the area in square millimeters
173
+ st.write(f'Area: {torch.sum(mask, dtype=torch.float).item() / pixels_per_unit ** 2} mm^2')
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.24.2
2
+ torch==1.13.1
3
+ torchvision==0.14.1
4
+ matplotlib==3.7.0
5
+ streamlit_image_coordinates==0.1.3
6
+ streamlit==1.22.0
7
+ Pillow==9.4.0
8
+ transformers==4.29.1
9
+ opencv-python-headless==4.5.4.60