yfan07 commited on
Commit
51e0ad4
·
verified ·
1 Parent(s): 34283aa

Add files using upload-large-folder tool

Browse files
data/audio_embed.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7cfb23c805a62df237904736403226b60329075a3147d2f93fcc78f2a163843
3
+ size 17678459
data/gt_mask.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2fd1d8d37ceea45d0b8b7714c06afa0218710061a7da158652617d9107e197a
3
+ size 203344455
data/image_embed.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b0f5c8ae133bbddbfa558b2052b3aeb757492ffe310650988103d07e24135bb
3
+ size 167486740480
data/media.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e4d2f70333148331b736f2e65eea4d0d8e38a030fc70b49bee90b63ed5ba7f5
3
+ size 19026513920
models/segment_anything/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from copy import deepcopy
9
+ from itertools import product
10
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
models/segment_anything/utils/onnx.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ from ..modeling import Sam
14
+ from .amg import calculate_stability_score
15
+
16
+
17
+ class SamOnnxModel(nn.Module):
18
+ """
19
+ This model should not be called directly, but is used in ONNX export.
20
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21
+ with some functions modified to enable model tracing. Also supports extra
22
+ options controlling what information. See the ONNX export script for details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ return_single_mask: bool,
29
+ use_stability_score: bool = False,
30
+ return_extra_metrics: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.mask_decoder = model.mask_decoder
34
+ self.model = model
35
+ self.img_size = model.image_encoder.img_size
36
+ self.return_single_mask = return_single_mask
37
+ self.use_stability_score = use_stability_score
38
+ self.stability_score_offset = 1.0
39
+ self.return_extra_metrics = return_extra_metrics
40
+
41
+ @staticmethod
42
+ def resize_longest_image_size(
43
+ input_image_size: torch.Tensor, longest_side: int
44
+ ) -> torch.Tensor:
45
+ input_image_size = input_image_size.to(torch.float32)
46
+ scale = longest_side / torch.max(input_image_size)
47
+ transformed_size = scale * input_image_size
48
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49
+ return transformed_size
50
+
51
+ def _embed_points(
52
+ self, point_coords: torch.Tensor, point_labels: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ point_coords = point_coords + 0.5
55
+ point_coords = point_coords / self.img_size
56
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
57
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
58
+
59
+ point_embedding = point_embedding * (point_labels != -1)
60
+ point_embedding = (
61
+ point_embedding
62
+ + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
63
+ )
64
+
65
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
66
+ point_embedding = (
67
+ point_embedding
68
+ + self.model.prompt_encoder.point_embeddings[i].weight
69
+ * (point_labels == i)
70
+ )
71
+
72
+ return point_embedding
73
+
74
+ def _embed_masks(
75
+ self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
78
+ input_mask
79
+ )
80
+ mask_embedding = mask_embedding + (
81
+ 1 - has_mask_input
82
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
83
+ return mask_embedding
84
+
85
+ def mask_postprocessing(
86
+ self, masks: torch.Tensor, orig_im_size: torch.Tensor
87
+ ) -> torch.Tensor:
88
+ masks = F.interpolate(
89
+ masks,
90
+ size=(self.img_size, self.img_size),
91
+ mode="bilinear",
92
+ align_corners=False,
93
+ )
94
+
95
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
96
+ torch.int64
97
+ )
98
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
99
+
100
+ orig_im_size = orig_im_size.to(torch.int64)
101
+ h, w = orig_im_size[0], orig_im_size[1]
102
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
103
+ return masks
104
+
105
+ def select_masks(
106
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
107
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ # Determine if we should return the multiclick mask or not from the number of points.
109
+ # The reweighting is used to avoid control flow.
110
+ score_reweight = torch.tensor(
111
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
112
+ ).to(iou_preds.device)
113
+ score = iou_preds + (num_points - 2.5) * score_reweight
114
+ best_idx = torch.argmax(score, dim=1)
115
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
116
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
117
+
118
+ return masks, iou_preds
119
+
120
+ @torch.no_grad()
121
+ def forward(
122
+ self,
123
+ image_embeddings: torch.Tensor,
124
+ point_coords: torch.Tensor,
125
+ point_labels: torch.Tensor,
126
+ mask_input: torch.Tensor,
127
+ has_mask_input: torch.Tensor,
128
+ orig_im_size: torch.Tensor,
129
+ ):
130
+ sparse_embedding = self._embed_points(point_coords, point_labels)
131
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
132
+
133
+ masks, scores = self.model.mask_decoder.predict_masks(
134
+ image_embeddings=image_embeddings,
135
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
136
+ sparse_prompt_embeddings=sparse_embedding,
137
+ dense_prompt_embeddings=dense_embedding,
138
+ )
139
+
140
+ if self.use_stability_score:
141
+ scores = calculate_stability_score(
142
+ masks, self.model.mask_threshold, self.stability_score_offset
143
+ )
144
+
145
+ if self.return_single_mask:
146
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
147
+
148
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
149
+
150
+ if self.return_extra_metrics:
151
+ stability_scores = calculate_stability_score(
152
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
153
+ )
154
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
155
+ return upscaled_masks, scores, stability_scores, areas, masks
156
+
157
+ return upscaled_masks, scores, masks
models/segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from copy import deepcopy
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torchvision.transforms.functional import resize # type: ignore
14
+ from torchvision.transforms.functional import to_pil_image
15
+
16
+
17
+ class ResizeLongestSide:
18
+ """
19
+ Resizes images to the longest side 'target_length', as well as provides
20
+ methods for resizing coordinates and boxes. Provides methods for
21
+ transforming both numpy array and batched torch tensors.
22
+ """
23
+
24
+ def __init__(self, target_length: int) -> None:
25
+ self.target_length = target_length
26
+
27
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
28
+ """
29
+ Expects a numpy array with shape HxWxC in uint8 format.
30
+ """
31
+ target_size = self.get_preprocess_shape(
32
+ image.shape[0], image.shape[1], self.target_length
33
+ )
34
+ return np.array(resize(to_pil_image(image), target_size))
35
+
36
+ def apply_coords(
37
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
38
+ ) -> np.ndarray:
39
+ """
40
+ Expects a numpy array of length 2 in the final dimension. Requires the
41
+ original image size in (H, W) format.
42
+ """
43
+ old_h, old_w = original_size
44
+ new_h, new_w = self.get_preprocess_shape(
45
+ original_size[0], original_size[1], self.target_length
46
+ )
47
+ coords = deepcopy(coords).astype(float)
48
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
49
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
50
+ return coords
51
+
52
+ def apply_boxes(
53
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
54
+ ) -> np.ndarray:
55
+ """
56
+ Expects a numpy array shape Bx4. Requires the original image size
57
+ in (H, W) format.
58
+ """
59
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
60
+ return boxes.reshape(-1, 4)
61
+
62
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Expects batched images with shape BxCxHxW and float format. This
65
+ transformation may not exactly match apply_image. apply_image is
66
+ the transformation expected by the model.
67
+ """
68
+ # Expects an image in BCHW format. May not exactly match apply_image.
69
+ target_size = self.get_preprocess_shape(
70
+ image.shape[0], image.shape[1], self.target_length
71
+ )
72
+ return F.interpolate(
73
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
74
+ )
75
+
76
+ def apply_coords_torch(
77
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
78
+ ) -> torch.Tensor:
79
+ """
80
+ Expects a torch tensor with length 2 in the last dimension. Requires the
81
+ original image size in (H, W) format.
82
+ """
83
+ old_h, old_w = original_size
84
+ new_h, new_w = self.get_preprocess_shape(
85
+ original_size[0], original_size[1], self.target_length
86
+ )
87
+ coords = deepcopy(coords).to(torch.float)
88
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
89
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
90
+ return coords
91
+
92
+ def apply_boxes_torch(
93
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
94
+ ) -> torch.Tensor:
95
+ """
96
+ Expects a torch tensor with shape Bx4. Requires the original image
97
+ size in (H, W) format.
98
+ """
99
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
100
+ return boxes.reshape(-1, 4)
101
+
102
+ @staticmethod
103
+ def get_preprocess_shape(
104
+ oldh: int, oldw: int, long_side_length: int
105
+ ) -> Tuple[int, int]:
106
+ """
107
+ Compute the output size given input size and target long side length.
108
+ """
109
+ scale = long_side_length * 1.0 / max(oldh, oldw)
110
+ newh, neww = oldh * scale, oldw * scale
111
+ neww = int(neww + 0.5)
112
+ newh = int(newh + 0.5)
113
+ return (newh, neww)
models/tf/modeling_outputs.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple, Dict, List
5
+ from transformers.utils import ModelOutput
6
+
7
+ @dataclass
8
+ class CausalLMOutputWithPastAndLabel(ModelOutput):
9
+ """
10
+ Base class for causal language model (or autoregressive) outputs.
11
+
12
+ Args:
13
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
14
+ Language modeling loss (for next-token prediction).
15
+ labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, returned when `labels` is provided):
16
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
17
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
18
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
19
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
20
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
21
+
22
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
23
+ `past_key_values` input) to speed up sequential decoding.
24
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
25
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
26
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
27
+
28
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
29
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
30
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
31
+ sequence_length)`.
32
+
33
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
34
+ heads.
35
+ """
36
+
37
+ loss: Optional[torch.FloatTensor] = None
38
+ labels: Optional[torch.FloatTensor] = None
39
+ logits: torch.FloatTensor = None
40
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
41
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
42
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
43
+ bs2imgs_token_list: List[List[int]] = None
setup_simtoken.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SimToken Setup
2
+
3
+ ---
4
+
5
+ ## 1. Create Environment
6
+
7
+ ```bash
8
+ conda create -n simtoken python=3.10 -y
9
+ conda activate simtoken
10
+
11
+ python -m pip install --upgrade pip wheel "setuptools<81"
12
+
13
+ pip install \
14
+ torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 \
15
+ --index-url https://download.pytorch.org/whl/cu121
16
+
17
+ pip install \
18
+ transformers==4.30.2 \
19
+ peft==0.2.0 \
20
+ accelerate==0.21.0 \
21
+ sentencepiece \
22
+ protobuf \
23
+ safetensors \
24
+ numpy==1.26.4 \
25
+ pandas \
26
+ matplotlib \
27
+ opencv-python \
28
+ pillow \
29
+ tqdm \
30
+ einops \
31
+ timm \
32
+ requests \
33
+ towhee \
34
+ huggingface_hub
35
+ ```
36
+
37
+ ---
38
+
39
+ ## 2. Download from HuggingFace(新机器初始化)
40
+
41
+ 登录 HuggingFace(token 在 https://huggingface.co/settings/tokens 生成):
42
+
43
+ ```bash
44
+ huggingface-cli login
45
+ ```
46
+
47
+ 下载完整 repo(代码 + 权重 + 压缩数据包,共约 190G):
48
+
49
+ ```bash
50
+ mkdir -p /workspace/SimToken
51
+ cd /workspace/SimToken
52
+
53
+ huggingface-cli download yfan07/SimToken \
54
+ --repo-type model \
55
+ --local-dir . \
56
+ --local-dir-use-symlinks False
57
+ ```
58
+
59
+ 下载完成后解压数据包:
60
+
61
+ ```bash
62
+ cd /workspace/SimToken/data
63
+
64
+ tar -xf image_embed.tar # ~5–10 分钟
65
+ tar -xzf gt_mask.tar.gz
66
+ tar -xzf audio_embed.tar.gz
67
+ tar -xf media.tar
68
+ ```
69
+
70
+
71
+ ---
72
+
73
+ ## 3. Pre-download Model Weights(首次使用必做)
74
+
75
+ `transformers==4.30.2` 与新版 `huggingface_hub` 存在 API 不兼容(`use_auth_token` 已移除)。
76
+ 解决方案:先用 CLI 将模型下载到本地缓存,之后运行实验时加 `TRANSFORMERS_OFFLINE=1`,跳过所有网络请求。
77
+
78
+ ```bash
79
+ # Chat-UniVi-7B(~14G)
80
+ huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
81
+
82
+ # CLIP ViT-L(~1.6G)
83
+ huggingface-cli download openai/clip-vit-large-patch14
84
+ ```
85
+
86
+ 下载完成后即永久缓存,新 session 无需重复下载。
87
+
88
+ ---
89
+
90
+ ## 4. Example Evaluation
91
+
92
+ 所有评测命令统一加 `TRANSFORMERS_OFFLINE=1`:
93
+
94
+ ```bash
95
+ cd /workspace/SimToken
96
+
97
+ # Unseen split(全量 1656 样本)
98
+ TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u
99
+
100
+ # Seen split
101
+ TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_s
102
+
103
+ # Null split(S metric,越低越好)
104
+ TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_n
105
+
106
+ # 限制样本数(快速验证)
107
+ TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 50
108
+
109
+ # Stage 0 梯度连通性 + bypass 等价性检查(仅诊断)
110
+ TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 0
111
+ ```
112
+
113
+ 每次评估依次输出:Baseline + q-LTPO Stage 1 两组结果及诊断统计。
114
+
115
+ ---
116
+
117
+ ## 5. Upload to HuggingFace(实验结束后)
118
+
119
+ 数据目录以压缩包形式存储,可大幅减少文件数量,避免 HuggingFace commit 频率限制。
120
+
121
+ **第一步:将数据目录压缩为归档文件(如尚未压缩)**
122
+
123
+ ```bash
124
+ cd /workspace/SimToken/data
125
+
126
+ tar -cf image_embed.tar image_embed/ # 不压缩(.pt 已是二进制)
127
+ tar -czf gt_mask.tar.gz gt_mask/
128
+ tar -czf audio_embed.tar.gz audio_embed/
129
+ tar -cf media.tar media/
130
+
131
+ # 确认压缩包存在后删除原始目录
132
+ ls -lh *.tar*
133
+ rm -rf image_embed/ gt_mask/ audio_embed/ media/
134
+ ```
135
+
136
+ **第二步:清理缓存并上传**
137
+
138
+ ```bash
139
+ find /workspace/SimToken -name "__pycache__" -exec rm -rf {} + 2>/dev/null
140
+ find /workspace/SimToken -name "*.pyc" -delete
141
+
142
+ huggingface-cli login # token 在 https://huggingface.co/settings/tokens 生成(需 Write 权限)
143
+
144
+ cd /workspace/SimToken
145
+ python upload_hf.py --repo yfan07/SimToken
146
+ ```
147
+
148
+ **注意事项:**
149
+ - 建议在 `tmux` 里运行,防止 SSH 断开:`tmux new -s upload`,完成后 `Ctrl+B D` detach
150
+ - 支持断点续传:中断后重新执行同一命令会自动跳过已上传文件
151
+ - 遇到 rate limit(HTTP 429)时脚本会自动等待约 1 小时后重试
152
+ - 监控进度:`tail -f /workspace/SimToken/upload.log`
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .metric import pyutils
2
+ from .metric import utility
utils/metric/pyutils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import time
4
+ import sys
5
+
6
+ class Logger(object):
7
+ def __init__(self, outfile):
8
+ self.terminal = sys.stdout
9
+ self.log = open(outfile, "w")
10
+ sys.stdout = self
11
+
12
+ def write(self, message):
13
+ self.terminal.write(message)
14
+ self.log.write(message)
15
+
16
+ def flush(self):
17
+ self.terminal.flush()
18
+
19
+
20
+ class AverageMeter:
21
+ def __init__(self, *keys):
22
+ self.__data = dict()
23
+ for k in keys:
24
+ self.__data[k] = [0.0, 0]
25
+
26
+ def add(self, dict):
27
+ for k, v in dict.items():
28
+ self.__data[k][0] += v
29
+ self.__data[k][1] += 1
30
+
31
+ def get(self, *keys):
32
+ if len(keys) == 1:
33
+ return self.__data[keys[0]][0] / self.__data[keys[0]][1]
34
+ else:
35
+ v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
36
+ return tuple(v_list)
37
+
38
+ def pop(self, key=None):
39
+ if key is None:
40
+ for k in self.__data.keys():
41
+ self.__data[k] = [0.0, 0]
42
+ else:
43
+ v = self.get(key)
44
+ self.__data[key] = [0.0, 0]
45
+ return v
46
+
47
+
48
+ class Timer:
49
+ def __init__(self, starting_msg = None):
50
+ self.start = time.time()
51
+ self.stage_start = self.start
52
+
53
+ if starting_msg is not None:
54
+ print(starting_msg, time.ctime(time.time()))
55
+
56
+
57
+ def update_progress(self, progress):
58
+ self.elapsed = time.time() - self.start
59
+ self.est_total = self.elapsed / progress
60
+ self.est_remaining = self.est_total - self.elapsed
61
+ self.est_finish = int(self.start + self.est_total)
62
+
63
+
64
+ def str_est_finish(self):
65
+ return str(time.ctime(self.est_finish))
66
+
67
+ def get_stage_elapsed(self):
68
+ return time.time() - self.stage_start
69
+
70
+ def reset_stage(self):
71
+ self.stage_start = time.time()
72
+
73
+
74
+ from multiprocessing.pool import ThreadPool
75
+
76
+ class BatchThreader:
77
+
78
+ def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12):
79
+ self.batch_size = batch_size
80
+ self.prefetch_size = prefetch_size
81
+
82
+ self.pool = ThreadPool(processes=processes)
83
+ self.async_result = []
84
+
85
+ self.func = func
86
+ self.left_args_list = args_list
87
+ self.n_tasks = len(args_list)
88
+
89
+ # initial work
90
+ self.__start_works(self.__get_n_pending_works())
91
+
92
+
93
+ def __start_works(self, times):
94
+ for _ in range(times):
95
+ args = self.left_args_list.pop(0)
96
+ self.async_result.append(
97
+ self.pool.apply_async(self.func, args))
98
+
99
+
100
+ def __get_n_pending_works(self):
101
+ return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result)
102
+ , len(self.left_args_list))
103
+
104
+
105
+
106
+ def pop_results(self):
107
+
108
+ n_inwork = len(self.async_result)
109
+
110
+ n_fetch = min(n_inwork, self.batch_size)
111
+ rtn = [self.async_result.pop(0).get()
112
+ for _ in range(n_fetch)]
113
+
114
+ to_fill = self.__get_n_pending_works()
115
+ if to_fill == 0:
116
+ self.pool.close()
117
+ else:
118
+ self.__start_works(to_fill)
119
+
120
+ return rtn
121
+
122
+
123
+
124
+
125
+ def get_indices_of_pairs(radius, size):
126
+
127
+ search_dist = []
128
+
129
+ for x in range(1, radius):
130
+ search_dist.append((0, x))
131
+
132
+ for y in range(1, radius):
133
+ for x in range(-radius + 1, radius):
134
+ if x * x + y * y < radius * radius:
135
+ search_dist.append((y, x))
136
+
137
+ radius_floor = radius - 1
138
+
139
+ full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64),
140
+ (size[0], size[1]))
141
+
142
+ cropped_height = size[0] - radius_floor
143
+ cropped_width = size[1] - 2 * radius_floor
144
+
145
+ indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor],
146
+ [-1])
147
+
148
+ indices_to_list = []
149
+
150
+ for dy, dx in search_dist:
151
+ indices_to = full_indices[dy:dy + cropped_height,
152
+ radius_floor + dx:radius_floor + dx + cropped_width]
153
+ indices_to = np.reshape(indices_to, [-1])
154
+
155
+ indices_to_list.append(indices_to)
156
+
157
+ concat_indices_to = np.concatenate(indices_to_list, axis=0)
158
+
159
+ return indices_from, concat_indices_to
160
+
utils/metric/utility.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ import os
6
+ import shutil
7
+ # import logging
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ import sys
13
+ import time
14
+ import pandas as pd
15
+ import pdb
16
+ from torchvision import transforms
17
+
18
+ def metric_s_for_null(pred):
19
+ num_seg, T, H, W = pred.shape
20
+ pred = pred.view(num_seg*T, H, W)
21
+ assert len(pred.shape) == 3
22
+
23
+ N = pred.size(0)
24
+ num_pixels = pred.view(-1).shape[0]
25
+
26
+ temp_pred = torch.sigmoid(pred)
27
+ pred = (temp_pred > 0.5).int()
28
+
29
+ x = torch.sum(pred.view(-1))
30
+ s = torch.sqrt(x / num_pixels)
31
+
32
+ return s
33
+
34
+ def mask_iou(pred, target, eps=1e-7, size_average=True):
35
+ r"""
36
+ param:
37
+ pred: size [N x H x W]
38
+ target: size [N x H x W]
39
+ output:
40
+ iou: size [1] (size_average=True) or [N] (size_average=False)
41
+ """
42
+ # return mask_iou_224(pred, target, eps=1e-7)
43
+ num_ref, T, H, W = pred.shape
44
+ pred = pred.view(num_ref*T, H, W)
45
+ target = target.view(num_ref*T, H, W)
46
+ assert len(pred.shape) == 3 and pred.shape == target.shape
47
+
48
+ N = pred.size(0)
49
+ # 像素数
50
+ num_pixels = pred.size(-1) * pred.size(-2)
51
+ # gt是否是纯黑
52
+ no_obj_flag = (target.sum(2).sum(1) == 0)
53
+
54
+ # 会把pred进行sigmoid变为01之间的概率
55
+ temp_pred = torch.sigmoid(pred)
56
+ # 通过阈值变成01矩阵
57
+ pred = (temp_pred > 0.4).int()
58
+ # 交集
59
+ inter = (pred * target).sum(2).sum(1)
60
+ # 并集
61
+ union = torch.max(pred, target).sum(2).sum(1)
62
+
63
+ inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
64
+ inter[no_obj_flag] = inter_no_obj[no_obj_flag]
65
+ union[no_obj_flag] = num_pixels
66
+
67
+ iou = torch.sum(inter / (union + eps)) / N
68
+
69
+ return iou.item()
70
+
71
+
72
+ def _eval_pr(y_pred, y, num, device='cuda'):
73
+ if device.startswith('cuda'):
74
+ prec, recall = torch.zeros(num).to(y_pred.device), torch.zeros(num).to(y_pred.device)
75
+ # 0到1 生成num个阈值
76
+ thlist = torch.linspace(0, 1 - 1e-10, num).to(y_pred.device)
77
+ else:
78
+ prec, recall = torch.zeros(num), torch.zeros(num)
79
+ thlist = torch.linspace(0, 1 - 1e-10, num)
80
+
81
+ for i in range(num):
82
+ y_temp = (y_pred >= thlist[i]).float()
83
+
84
+ # 计算 True Positives(TP)
85
+ tp = (y_temp * y).sum()
86
+ # 一个是交集除以 预测面积 一个是除以真实面积
87
+ prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
88
+
89
+ return prec, recall
90
+
91
+
92
+ def Eval_Fmeasure(pred, gt, measure_path, pr_num=255, device='cuda'):
93
+ r"""
94
+ param:
95
+ pred: size [N x H x W]
96
+ gt: size [N x H x W]
97
+ output:
98
+ iou: size [1] (size_average=True) or [N] (size_average=False)
99
+ """
100
+ num_ref, T, H, W = pred.shape
101
+ pred = pred.view(num_ref*T, H, W)
102
+ gt = gt.view(num_ref*T, H, W)
103
+ assert len(pred.shape) == 3
104
+
105
+
106
+ # sigmoid转为01之间的
107
+ pred = torch.sigmoid(pred)
108
+ N = pred.size(0)
109
+ beta2 = 0.3
110
+ avg_f, img_num = 0.0, 0
111
+ score = torch.zeros(pr_num)
112
+
113
+
114
+ for img_id in range(N):
115
+ if torch.mean(gt[img_id]) == 0.0:
116
+ continue
117
+ prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num, device=device)
118
+ f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
119
+ f_score[f_score != f_score] = 0 # for Nan
120
+ avg_f += f_score
121
+ img_num += 1
122
+ score = avg_f / img_num
123
+
124
+ return score.max().item()
125
+