Add files using upload-large-folder tool
Browse files- data/audio_embed.tar.gz +3 -0
- data/gt_mask.tar.gz +3 -0
- data/image_embed.tar +3 -0
- data/media.tar +3 -0
- models/segment_anything/utils/amg.py +346 -0
- models/segment_anything/utils/onnx.py +157 -0
- models/segment_anything/utils/transforms.py +113 -0
- models/tf/modeling_outputs.py +43 -0
- setup_simtoken.md +152 -0
- utils/__init__.py +2 -0
- utils/metric/pyutils.py +160 -0
- utils/metric/utility.py +125 -0
data/audio_embed.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7cfb23c805a62df237904736403226b60329075a3147d2f93fcc78f2a163843
|
| 3 |
+
size 17678459
|
data/gt_mask.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2fd1d8d37ceea45d0b8b7714c06afa0218710061a7da158652617d9107e197a
|
| 3 |
+
size 203344455
|
data/image_embed.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b0f5c8ae133bbddbfa558b2052b3aeb757492ffe310650988103d07e24135bb
|
| 3 |
+
size 167486740480
|
data/media.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e4d2f70333148331b736f2e65eea4d0d8e38a030fc70b49bee90b63ed5ba7f5
|
| 3 |
+
size 19026513920
|
models/segment_anything/utils/amg.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from itertools import product
|
| 10 |
+
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MaskData:
|
| 17 |
+
"""
|
| 18 |
+
A structure for storing masks and their related data in batched format.
|
| 19 |
+
Implements basic filtering and concatenation.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, **kwargs) -> None:
|
| 23 |
+
for v in kwargs.values():
|
| 24 |
+
assert isinstance(
|
| 25 |
+
v, (list, np.ndarray, torch.Tensor)
|
| 26 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
| 27 |
+
self._stats = dict(**kwargs)
|
| 28 |
+
|
| 29 |
+
def __setitem__(self, key: str, item: Any) -> None:
|
| 30 |
+
assert isinstance(
|
| 31 |
+
item, (list, np.ndarray, torch.Tensor)
|
| 32 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
| 33 |
+
self._stats[key] = item
|
| 34 |
+
|
| 35 |
+
def __delitem__(self, key: str) -> None:
|
| 36 |
+
del self._stats[key]
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, key: str) -> Any:
|
| 39 |
+
return self._stats[key]
|
| 40 |
+
|
| 41 |
+
def items(self) -> ItemsView[str, Any]:
|
| 42 |
+
return self._stats.items()
|
| 43 |
+
|
| 44 |
+
def filter(self, keep: torch.Tensor) -> None:
|
| 45 |
+
for k, v in self._stats.items():
|
| 46 |
+
if v is None:
|
| 47 |
+
self._stats[k] = None
|
| 48 |
+
elif isinstance(v, torch.Tensor):
|
| 49 |
+
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
| 50 |
+
elif isinstance(v, np.ndarray):
|
| 51 |
+
self._stats[k] = v[keep.detach().cpu().numpy()]
|
| 52 |
+
elif isinstance(v, list) and keep.dtype == torch.bool:
|
| 53 |
+
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
| 54 |
+
elif isinstance(v, list):
|
| 55 |
+
self._stats[k] = [v[i] for i in keep]
|
| 56 |
+
else:
|
| 57 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
| 58 |
+
|
| 59 |
+
def cat(self, new_stats: "MaskData") -> None:
|
| 60 |
+
for k, v in new_stats.items():
|
| 61 |
+
if k not in self._stats or self._stats[k] is None:
|
| 62 |
+
self._stats[k] = deepcopy(v)
|
| 63 |
+
elif isinstance(v, torch.Tensor):
|
| 64 |
+
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
| 65 |
+
elif isinstance(v, np.ndarray):
|
| 66 |
+
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
| 67 |
+
elif isinstance(v, list):
|
| 68 |
+
self._stats[k] = self._stats[k] + deepcopy(v)
|
| 69 |
+
else:
|
| 70 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
| 71 |
+
|
| 72 |
+
def to_numpy(self) -> None:
|
| 73 |
+
for k, v in self._stats.items():
|
| 74 |
+
if isinstance(v, torch.Tensor):
|
| 75 |
+
self._stats[k] = v.detach().cpu().numpy()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def is_box_near_crop_edge(
|
| 79 |
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
| 82 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
| 83 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
| 84 |
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
| 85 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
| 86 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
| 87 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
| 88 |
+
return torch.any(near_crop_edge, dim=1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
box_xywh = deepcopy(box_xyxy)
|
| 93 |
+
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
| 94 |
+
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
| 95 |
+
return box_xywh
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
| 99 |
+
assert len(args) > 0 and all(
|
| 100 |
+
len(a) == len(args[0]) for a in args
|
| 101 |
+
), "Batched iteration must have inputs of all the same size."
|
| 102 |
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
| 103 |
+
for b in range(n_batches):
|
| 104 |
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
| 108 |
+
"""
|
| 109 |
+
Encodes masks to an uncompressed RLE, in the format expected by
|
| 110 |
+
pycoco tools.
|
| 111 |
+
"""
|
| 112 |
+
# Put in fortran order and flatten h,w
|
| 113 |
+
b, h, w = tensor.shape
|
| 114 |
+
tensor = tensor.permute(0, 2, 1).flatten(1)
|
| 115 |
+
|
| 116 |
+
# Compute change indices
|
| 117 |
+
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
| 118 |
+
change_indices = diff.nonzero()
|
| 119 |
+
|
| 120 |
+
# Encode run length
|
| 121 |
+
out = []
|
| 122 |
+
for i in range(b):
|
| 123 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
| 124 |
+
cur_idxs = torch.cat(
|
| 125 |
+
[
|
| 126 |
+
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
| 127 |
+
cur_idxs + 1,
|
| 128 |
+
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
| 132 |
+
counts = [] if tensor[i, 0] == 0 else [0]
|
| 133 |
+
counts.extend(btw_idxs.detach().cpu().tolist())
|
| 134 |
+
out.append({"size": [h, w], "counts": counts})
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
| 139 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
| 140 |
+
h, w = rle["size"]
|
| 141 |
+
mask = np.empty(h * w, dtype=bool)
|
| 142 |
+
idx = 0
|
| 143 |
+
parity = False
|
| 144 |
+
for count in rle["counts"]:
|
| 145 |
+
mask[idx : idx + count] = parity
|
| 146 |
+
idx += count
|
| 147 |
+
parity ^= True
|
| 148 |
+
mask = mask.reshape(w, h)
|
| 149 |
+
return mask.transpose() # Put in C order
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def area_from_rle(rle: Dict[str, Any]) -> int:
|
| 153 |
+
return sum(rle["counts"][1::2])
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def calculate_stability_score(
|
| 157 |
+
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Computes the stability score for a batch of masks. The stability
|
| 161 |
+
score is the IoU between the binary masks obtained by thresholding
|
| 162 |
+
the predicted mask logits at high and low values.
|
| 163 |
+
"""
|
| 164 |
+
# One mask is always contained inside the other.
|
| 165 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
| 166 |
+
intersections = (
|
| 167 |
+
(masks > (mask_threshold + threshold_offset))
|
| 168 |
+
.sum(-1, dtype=torch.int16)
|
| 169 |
+
.sum(-1, dtype=torch.int32)
|
| 170 |
+
)
|
| 171 |
+
unions = (
|
| 172 |
+
(masks > (mask_threshold - threshold_offset))
|
| 173 |
+
.sum(-1, dtype=torch.int16)
|
| 174 |
+
.sum(-1, dtype=torch.int32)
|
| 175 |
+
)
|
| 176 |
+
return intersections / unions
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
| 180 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
| 181 |
+
offset = 1 / (2 * n_per_side)
|
| 182 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
| 183 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
| 184 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
| 185 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
| 186 |
+
return points
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def build_all_layer_point_grids(
|
| 190 |
+
n_per_side: int, n_layers: int, scale_per_layer: int
|
| 191 |
+
) -> List[np.ndarray]:
|
| 192 |
+
"""Generates point grids for all crop layers."""
|
| 193 |
+
points_by_layer = []
|
| 194 |
+
for i in range(n_layers + 1):
|
| 195 |
+
n_points = int(n_per_side / (scale_per_layer**i))
|
| 196 |
+
points_by_layer.append(build_point_grid(n_points))
|
| 197 |
+
return points_by_layer
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def generate_crop_boxes(
|
| 201 |
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
| 202 |
+
) -> Tuple[List[List[int]], List[int]]:
|
| 203 |
+
"""
|
| 204 |
+
Generates a list of crop boxes of different sizes. Each layer
|
| 205 |
+
has (2**i)**2 boxes for the ith layer.
|
| 206 |
+
"""
|
| 207 |
+
crop_boxes, layer_idxs = [], []
|
| 208 |
+
im_h, im_w = im_size
|
| 209 |
+
short_side = min(im_h, im_w)
|
| 210 |
+
|
| 211 |
+
# Original image
|
| 212 |
+
crop_boxes.append([0, 0, im_w, im_h])
|
| 213 |
+
layer_idxs.append(0)
|
| 214 |
+
|
| 215 |
+
def crop_len(orig_len, n_crops, overlap):
|
| 216 |
+
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
| 217 |
+
|
| 218 |
+
for i_layer in range(n_layers):
|
| 219 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
| 220 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
| 221 |
+
|
| 222 |
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
| 223 |
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
| 224 |
+
|
| 225 |
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
| 226 |
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
| 227 |
+
|
| 228 |
+
# Crops in XYWH format
|
| 229 |
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
| 230 |
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
| 231 |
+
crop_boxes.append(box)
|
| 232 |
+
layer_idxs.append(i_layer + 1)
|
| 233 |
+
|
| 234 |
+
return crop_boxes, layer_idxs
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
| 238 |
+
x0, y0, _, _ = crop_box
|
| 239 |
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
| 240 |
+
# Check if boxes has a channel dimension
|
| 241 |
+
if len(boxes.shape) == 3:
|
| 242 |
+
offset = offset.unsqueeze(1)
|
| 243 |
+
return boxes + offset
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
| 247 |
+
x0, y0, _, _ = crop_box
|
| 248 |
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
| 249 |
+
# Check if points has a channel dimension
|
| 250 |
+
if len(points.shape) == 3:
|
| 251 |
+
offset = offset.unsqueeze(1)
|
| 252 |
+
return points + offset
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def uncrop_masks(
|
| 256 |
+
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
| 257 |
+
) -> torch.Tensor:
|
| 258 |
+
x0, y0, x1, y1 = crop_box
|
| 259 |
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
| 260 |
+
return masks
|
| 261 |
+
# Coordinate transform masks
|
| 262 |
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
| 263 |
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
| 264 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def remove_small_regions(
|
| 268 |
+
mask: np.ndarray, area_thresh: float, mode: str
|
| 269 |
+
) -> Tuple[np.ndarray, bool]:
|
| 270 |
+
"""
|
| 271 |
+
Removes small disconnected regions and holes in a mask. Returns the
|
| 272 |
+
mask and an indicator of if the mask has been modified.
|
| 273 |
+
"""
|
| 274 |
+
import cv2 # type: ignore
|
| 275 |
+
|
| 276 |
+
assert mode in ["holes", "islands"]
|
| 277 |
+
correct_holes = mode == "holes"
|
| 278 |
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
| 279 |
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
| 280 |
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
| 281 |
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
| 282 |
+
if len(small_regions) == 0:
|
| 283 |
+
return mask, False
|
| 284 |
+
fill_labels = [0] + small_regions
|
| 285 |
+
if not correct_holes:
|
| 286 |
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
| 287 |
+
# If every region is below threshold, keep largest
|
| 288 |
+
if len(fill_labels) == 0:
|
| 289 |
+
fill_labels = [int(np.argmax(sizes)) + 1]
|
| 290 |
+
mask = np.isin(regions, fill_labels)
|
| 291 |
+
return mask, True
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
| 295 |
+
from pycocotools import mask as mask_utils # type: ignore
|
| 296 |
+
|
| 297 |
+
h, w = uncompressed_rle["size"]
|
| 298 |
+
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
| 299 |
+
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
| 300 |
+
return rle
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
| 304 |
+
"""
|
| 305 |
+
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
| 306 |
+
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
| 307 |
+
"""
|
| 308 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
| 309 |
+
if torch.numel(masks) == 0:
|
| 310 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
| 311 |
+
|
| 312 |
+
# Normalize shape to CxHxW
|
| 313 |
+
shape = masks.shape
|
| 314 |
+
h, w = shape[-2:]
|
| 315 |
+
if len(shape) > 2:
|
| 316 |
+
masks = masks.flatten(0, -3)
|
| 317 |
+
else:
|
| 318 |
+
masks = masks.unsqueeze(0)
|
| 319 |
+
|
| 320 |
+
# Get top and bottom edges
|
| 321 |
+
in_height, _ = torch.max(masks, dim=-1)
|
| 322 |
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
| 323 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
| 324 |
+
in_height_coords = in_height_coords + h * (~in_height)
|
| 325 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
| 326 |
+
|
| 327 |
+
# Get left and right edges
|
| 328 |
+
in_width, _ = torch.max(masks, dim=-2)
|
| 329 |
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
| 330 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
| 331 |
+
in_width_coords = in_width_coords + w * (~in_width)
|
| 332 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
| 333 |
+
|
| 334 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
| 335 |
+
# Replace these boxes with [0, 0, 0, 0]
|
| 336 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
| 337 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
| 338 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
| 339 |
+
|
| 340 |
+
# Return to original shape
|
| 341 |
+
if len(shape) > 2:
|
| 342 |
+
out = out.reshape(*shape[:-2], 4)
|
| 343 |
+
else:
|
| 344 |
+
out = out[0]
|
| 345 |
+
|
| 346 |
+
return out
|
models/segment_anything/utils/onnx.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
from ..modeling import Sam
|
| 14 |
+
from .amg import calculate_stability_score
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SamOnnxModel(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
This model should not be called directly, but is used in ONNX export.
|
| 20 |
+
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
| 21 |
+
with some functions modified to enable model tracing. Also supports extra
|
| 22 |
+
options controlling what information. See the ONNX export script for details.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
model: Sam,
|
| 28 |
+
return_single_mask: bool,
|
| 29 |
+
use_stability_score: bool = False,
|
| 30 |
+
return_extra_metrics: bool = False,
|
| 31 |
+
) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mask_decoder = model.mask_decoder
|
| 34 |
+
self.model = model
|
| 35 |
+
self.img_size = model.image_encoder.img_size
|
| 36 |
+
self.return_single_mask = return_single_mask
|
| 37 |
+
self.use_stability_score = use_stability_score
|
| 38 |
+
self.stability_score_offset = 1.0
|
| 39 |
+
self.return_extra_metrics = return_extra_metrics
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def resize_longest_image_size(
|
| 43 |
+
input_image_size: torch.Tensor, longest_side: int
|
| 44 |
+
) -> torch.Tensor:
|
| 45 |
+
input_image_size = input_image_size.to(torch.float32)
|
| 46 |
+
scale = longest_side / torch.max(input_image_size)
|
| 47 |
+
transformed_size = scale * input_image_size
|
| 48 |
+
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
| 49 |
+
return transformed_size
|
| 50 |
+
|
| 51 |
+
def _embed_points(
|
| 52 |
+
self, point_coords: torch.Tensor, point_labels: torch.Tensor
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
point_coords = point_coords + 0.5
|
| 55 |
+
point_coords = point_coords / self.img_size
|
| 56 |
+
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
| 57 |
+
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
| 58 |
+
|
| 59 |
+
point_embedding = point_embedding * (point_labels != -1)
|
| 60 |
+
point_embedding = (
|
| 61 |
+
point_embedding
|
| 62 |
+
+ self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
| 66 |
+
point_embedding = (
|
| 67 |
+
point_embedding
|
| 68 |
+
+ self.model.prompt_encoder.point_embeddings[i].weight
|
| 69 |
+
* (point_labels == i)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return point_embedding
|
| 73 |
+
|
| 74 |
+
def _embed_masks(
|
| 75 |
+
self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
|
| 78 |
+
input_mask
|
| 79 |
+
)
|
| 80 |
+
mask_embedding = mask_embedding + (
|
| 81 |
+
1 - has_mask_input
|
| 82 |
+
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
| 83 |
+
return mask_embedding
|
| 84 |
+
|
| 85 |
+
def mask_postprocessing(
|
| 86 |
+
self, masks: torch.Tensor, orig_im_size: torch.Tensor
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
masks = F.interpolate(
|
| 89 |
+
masks,
|
| 90 |
+
size=(self.img_size, self.img_size),
|
| 91 |
+
mode="bilinear",
|
| 92 |
+
align_corners=False,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
|
| 96 |
+
torch.int64
|
| 97 |
+
)
|
| 98 |
+
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
| 99 |
+
|
| 100 |
+
orig_im_size = orig_im_size.to(torch.int64)
|
| 101 |
+
h, w = orig_im_size[0], orig_im_size[1]
|
| 102 |
+
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
| 103 |
+
return masks
|
| 104 |
+
|
| 105 |
+
def select_masks(
|
| 106 |
+
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
| 107 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 108 |
+
# Determine if we should return the multiclick mask or not from the number of points.
|
| 109 |
+
# The reweighting is used to avoid control flow.
|
| 110 |
+
score_reweight = torch.tensor(
|
| 111 |
+
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
| 112 |
+
).to(iou_preds.device)
|
| 113 |
+
score = iou_preds + (num_points - 2.5) * score_reweight
|
| 114 |
+
best_idx = torch.argmax(score, dim=1)
|
| 115 |
+
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
| 116 |
+
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
| 117 |
+
|
| 118 |
+
return masks, iou_preds
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
image_embeddings: torch.Tensor,
|
| 124 |
+
point_coords: torch.Tensor,
|
| 125 |
+
point_labels: torch.Tensor,
|
| 126 |
+
mask_input: torch.Tensor,
|
| 127 |
+
has_mask_input: torch.Tensor,
|
| 128 |
+
orig_im_size: torch.Tensor,
|
| 129 |
+
):
|
| 130 |
+
sparse_embedding = self._embed_points(point_coords, point_labels)
|
| 131 |
+
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
| 132 |
+
|
| 133 |
+
masks, scores = self.model.mask_decoder.predict_masks(
|
| 134 |
+
image_embeddings=image_embeddings,
|
| 135 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
| 136 |
+
sparse_prompt_embeddings=sparse_embedding,
|
| 137 |
+
dense_prompt_embeddings=dense_embedding,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if self.use_stability_score:
|
| 141 |
+
scores = calculate_stability_score(
|
| 142 |
+
masks, self.model.mask_threshold, self.stability_score_offset
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if self.return_single_mask:
|
| 146 |
+
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
| 147 |
+
|
| 148 |
+
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
| 149 |
+
|
| 150 |
+
if self.return_extra_metrics:
|
| 151 |
+
stability_scores = calculate_stability_score(
|
| 152 |
+
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
| 153 |
+
)
|
| 154 |
+
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
| 155 |
+
return upscaled_masks, scores, stability_scores, areas, masks
|
| 156 |
+
|
| 157 |
+
return upscaled_masks, scores, masks
|
models/segment_anything/utils/transforms.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torchvision.transforms.functional import resize # type: ignore
|
| 14 |
+
from torchvision.transforms.functional import to_pil_image
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResizeLongestSide:
|
| 18 |
+
"""
|
| 19 |
+
Resizes images to the longest side 'target_length', as well as provides
|
| 20 |
+
methods for resizing coordinates and boxes. Provides methods for
|
| 21 |
+
transforming both numpy array and batched torch tensors.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, target_length: int) -> None:
|
| 25 |
+
self.target_length = target_length
|
| 26 |
+
|
| 27 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
| 28 |
+
"""
|
| 29 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 30 |
+
"""
|
| 31 |
+
target_size = self.get_preprocess_shape(
|
| 32 |
+
image.shape[0], image.shape[1], self.target_length
|
| 33 |
+
)
|
| 34 |
+
return np.array(resize(to_pil_image(image), target_size))
|
| 35 |
+
|
| 36 |
+
def apply_coords(
|
| 37 |
+
self, coords: np.ndarray, original_size: Tuple[int, ...]
|
| 38 |
+
) -> np.ndarray:
|
| 39 |
+
"""
|
| 40 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
| 41 |
+
original image size in (H, W) format.
|
| 42 |
+
"""
|
| 43 |
+
old_h, old_w = original_size
|
| 44 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 45 |
+
original_size[0], original_size[1], self.target_length
|
| 46 |
+
)
|
| 47 |
+
coords = deepcopy(coords).astype(float)
|
| 48 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 49 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 50 |
+
return coords
|
| 51 |
+
|
| 52 |
+
def apply_boxes(
|
| 53 |
+
self, boxes: np.ndarray, original_size: Tuple[int, ...]
|
| 54 |
+
) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
| 57 |
+
in (H, W) format.
|
| 58 |
+
"""
|
| 59 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
| 60 |
+
return boxes.reshape(-1, 4)
|
| 61 |
+
|
| 62 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Expects batched images with shape BxCxHxW and float format. This
|
| 65 |
+
transformation may not exactly match apply_image. apply_image is
|
| 66 |
+
the transformation expected by the model.
|
| 67 |
+
"""
|
| 68 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
| 69 |
+
target_size = self.get_preprocess_shape(
|
| 70 |
+
image.shape[0], image.shape[1], self.target_length
|
| 71 |
+
)
|
| 72 |
+
return F.interpolate(
|
| 73 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def apply_coords_torch(
|
| 77 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
| 81 |
+
original image size in (H, W) format.
|
| 82 |
+
"""
|
| 83 |
+
old_h, old_w = original_size
|
| 84 |
+
new_h, new_w = self.get_preprocess_shape(
|
| 85 |
+
original_size[0], original_size[1], self.target_length
|
| 86 |
+
)
|
| 87 |
+
coords = deepcopy(coords).to(torch.float)
|
| 88 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 89 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 90 |
+
return coords
|
| 91 |
+
|
| 92 |
+
def apply_boxes_torch(
|
| 93 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
| 97 |
+
size in (H, W) format.
|
| 98 |
+
"""
|
| 99 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
| 100 |
+
return boxes.reshape(-1, 4)
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def get_preprocess_shape(
|
| 104 |
+
oldh: int, oldw: int, long_side_length: int
|
| 105 |
+
) -> Tuple[int, int]:
|
| 106 |
+
"""
|
| 107 |
+
Compute the output size given input size and target long side length.
|
| 108 |
+
"""
|
| 109 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 110 |
+
newh, neww = oldh * scale, oldw * scale
|
| 111 |
+
neww = int(neww + 0.5)
|
| 112 |
+
newh = int(newh + 0.5)
|
| 113 |
+
return (newh, neww)
|
models/tf/modeling_outputs.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import warnings
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple, Dict, List
|
| 5 |
+
from transformers.utils import ModelOutput
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class CausalLMOutputWithPastAndLabel(ModelOutput):
|
| 9 |
+
"""
|
| 10 |
+
Base class for causal language model (or autoregressive) outputs.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 14 |
+
Language modeling loss (for next-token prediction).
|
| 15 |
+
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, returned when `labels` is provided):
|
| 16 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 17 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 18 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 19 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 20 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 21 |
+
|
| 22 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 23 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 24 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 25 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 26 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 27 |
+
|
| 28 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 29 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 30 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 31 |
+
sequence_length)`.
|
| 32 |
+
|
| 33 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 34 |
+
heads.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
loss: Optional[torch.FloatTensor] = None
|
| 38 |
+
labels: Optional[torch.FloatTensor] = None
|
| 39 |
+
logits: torch.FloatTensor = None
|
| 40 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 41 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 42 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 43 |
+
bs2imgs_token_list: List[List[int]] = None
|
setup_simtoken.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SimToken Setup
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
## 1. Create Environment
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
conda create -n simtoken python=3.10 -y
|
| 9 |
+
conda activate simtoken
|
| 10 |
+
|
| 11 |
+
python -m pip install --upgrade pip wheel "setuptools<81"
|
| 12 |
+
|
| 13 |
+
pip install \
|
| 14 |
+
torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 \
|
| 15 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 16 |
+
|
| 17 |
+
pip install \
|
| 18 |
+
transformers==4.30.2 \
|
| 19 |
+
peft==0.2.0 \
|
| 20 |
+
accelerate==0.21.0 \
|
| 21 |
+
sentencepiece \
|
| 22 |
+
protobuf \
|
| 23 |
+
safetensors \
|
| 24 |
+
numpy==1.26.4 \
|
| 25 |
+
pandas \
|
| 26 |
+
matplotlib \
|
| 27 |
+
opencv-python \
|
| 28 |
+
pillow \
|
| 29 |
+
tqdm \
|
| 30 |
+
einops \
|
| 31 |
+
timm \
|
| 32 |
+
requests \
|
| 33 |
+
towhee \
|
| 34 |
+
huggingface_hub
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## 2. Download from HuggingFace(新机器初始化)
|
| 40 |
+
|
| 41 |
+
登录 HuggingFace(token 在 https://huggingface.co/settings/tokens 生成):
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
huggingface-cli login
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
下载完整 repo(代码 + 权重 + 压缩数据包,共约 190G):
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
mkdir -p /workspace/SimToken
|
| 51 |
+
cd /workspace/SimToken
|
| 52 |
+
|
| 53 |
+
huggingface-cli download yfan07/SimToken \
|
| 54 |
+
--repo-type model \
|
| 55 |
+
--local-dir . \
|
| 56 |
+
--local-dir-use-symlinks False
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
下载完成后解压数据包:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
cd /workspace/SimToken/data
|
| 63 |
+
|
| 64 |
+
tar -xf image_embed.tar # ~5–10 分钟
|
| 65 |
+
tar -xzf gt_mask.tar.gz
|
| 66 |
+
tar -xzf audio_embed.tar.gz
|
| 67 |
+
tar -xf media.tar
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## 3. Pre-download Model Weights(首次使用必做)
|
| 74 |
+
|
| 75 |
+
`transformers==4.30.2` 与新版 `huggingface_hub` 存在 API 不兼容(`use_auth_token` 已移除)。
|
| 76 |
+
解决方案:先用 CLI 将模型下载到本地缓存,之后运行实验时加 `TRANSFORMERS_OFFLINE=1`,跳过所有网络请求。
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# Chat-UniVi-7B(~14G)
|
| 80 |
+
huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
|
| 81 |
+
|
| 82 |
+
# CLIP ViT-L(~1.6G)
|
| 83 |
+
huggingface-cli download openai/clip-vit-large-patch14
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
下载完成后即永久缓存,新 session 无需重复下载。
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## 4. Example Evaluation
|
| 91 |
+
|
| 92 |
+
所有评测命令统一加 `TRANSFORMERS_OFFLINE=1`:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
cd /workspace/SimToken
|
| 96 |
+
|
| 97 |
+
# Unseen split(全量 1656 样本)
|
| 98 |
+
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u
|
| 99 |
+
|
| 100 |
+
# Seen split
|
| 101 |
+
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_s
|
| 102 |
+
|
| 103 |
+
# Null split(S metric,越低越好)
|
| 104 |
+
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_n
|
| 105 |
+
|
| 106 |
+
# 限制样本数(快速验证)
|
| 107 |
+
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 50
|
| 108 |
+
|
| 109 |
+
# Stage 0 梯度连通性 + bypass 等价性检查(仅诊断)
|
| 110 |
+
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 0
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
每次评估依次输出:Baseline + q-LTPO Stage 1 两组结果及诊断统计。
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## 5. Upload to HuggingFace(实验结束后)
|
| 118 |
+
|
| 119 |
+
数据目录以压缩包形式存储,可大幅减少文件数量,避免 HuggingFace commit 频率限制。
|
| 120 |
+
|
| 121 |
+
**第一步:将数据目录压缩为归档文件(如尚未压缩)**
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
cd /workspace/SimToken/data
|
| 125 |
+
|
| 126 |
+
tar -cf image_embed.tar image_embed/ # 不压缩(.pt 已是二进制)
|
| 127 |
+
tar -czf gt_mask.tar.gz gt_mask/
|
| 128 |
+
tar -czf audio_embed.tar.gz audio_embed/
|
| 129 |
+
tar -cf media.tar media/
|
| 130 |
+
|
| 131 |
+
# 确认压缩包存在后删除原始目录
|
| 132 |
+
ls -lh *.tar*
|
| 133 |
+
rm -rf image_embed/ gt_mask/ audio_embed/ media/
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
**第二步:清理缓存并上传**
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
find /workspace/SimToken -name "__pycache__" -exec rm -rf {} + 2>/dev/null
|
| 140 |
+
find /workspace/SimToken -name "*.pyc" -delete
|
| 141 |
+
|
| 142 |
+
huggingface-cli login # token 在 https://huggingface.co/settings/tokens 生成(需 Write 权限)
|
| 143 |
+
|
| 144 |
+
cd /workspace/SimToken
|
| 145 |
+
python upload_hf.py --repo yfan07/SimToken
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
**注意事项:**
|
| 149 |
+
- 建议在 `tmux` 里运行,防止 SSH 断开:`tmux new -s upload`,完成后 `Ctrl+B D` detach
|
| 150 |
+
- 支持断点续传:中断后重新执行同一命令会自动跳过已上传文件
|
| 151 |
+
- 遇到 rate limit(HTTP 429)时脚本会自动等待约 1 小时后重试
|
| 152 |
+
- 监控进度:`tail -f /workspace/SimToken/upload.log`
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .metric import pyutils
|
| 2 |
+
from .metric import utility
|
utils/metric/pyutils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
class Logger(object):
|
| 7 |
+
def __init__(self, outfile):
|
| 8 |
+
self.terminal = sys.stdout
|
| 9 |
+
self.log = open(outfile, "w")
|
| 10 |
+
sys.stdout = self
|
| 11 |
+
|
| 12 |
+
def write(self, message):
|
| 13 |
+
self.terminal.write(message)
|
| 14 |
+
self.log.write(message)
|
| 15 |
+
|
| 16 |
+
def flush(self):
|
| 17 |
+
self.terminal.flush()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AverageMeter:
|
| 21 |
+
def __init__(self, *keys):
|
| 22 |
+
self.__data = dict()
|
| 23 |
+
for k in keys:
|
| 24 |
+
self.__data[k] = [0.0, 0]
|
| 25 |
+
|
| 26 |
+
def add(self, dict):
|
| 27 |
+
for k, v in dict.items():
|
| 28 |
+
self.__data[k][0] += v
|
| 29 |
+
self.__data[k][1] += 1
|
| 30 |
+
|
| 31 |
+
def get(self, *keys):
|
| 32 |
+
if len(keys) == 1:
|
| 33 |
+
return self.__data[keys[0]][0] / self.__data[keys[0]][1]
|
| 34 |
+
else:
|
| 35 |
+
v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
|
| 36 |
+
return tuple(v_list)
|
| 37 |
+
|
| 38 |
+
def pop(self, key=None):
|
| 39 |
+
if key is None:
|
| 40 |
+
for k in self.__data.keys():
|
| 41 |
+
self.__data[k] = [0.0, 0]
|
| 42 |
+
else:
|
| 43 |
+
v = self.get(key)
|
| 44 |
+
self.__data[key] = [0.0, 0]
|
| 45 |
+
return v
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Timer:
|
| 49 |
+
def __init__(self, starting_msg = None):
|
| 50 |
+
self.start = time.time()
|
| 51 |
+
self.stage_start = self.start
|
| 52 |
+
|
| 53 |
+
if starting_msg is not None:
|
| 54 |
+
print(starting_msg, time.ctime(time.time()))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def update_progress(self, progress):
|
| 58 |
+
self.elapsed = time.time() - self.start
|
| 59 |
+
self.est_total = self.elapsed / progress
|
| 60 |
+
self.est_remaining = self.est_total - self.elapsed
|
| 61 |
+
self.est_finish = int(self.start + self.est_total)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def str_est_finish(self):
|
| 65 |
+
return str(time.ctime(self.est_finish))
|
| 66 |
+
|
| 67 |
+
def get_stage_elapsed(self):
|
| 68 |
+
return time.time() - self.stage_start
|
| 69 |
+
|
| 70 |
+
def reset_stage(self):
|
| 71 |
+
self.stage_start = time.time()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
from multiprocessing.pool import ThreadPool
|
| 75 |
+
|
| 76 |
+
class BatchThreader:
|
| 77 |
+
|
| 78 |
+
def __init__(self, func, args_list, batch_size, prefetch_size=4, processes=12):
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.prefetch_size = prefetch_size
|
| 81 |
+
|
| 82 |
+
self.pool = ThreadPool(processes=processes)
|
| 83 |
+
self.async_result = []
|
| 84 |
+
|
| 85 |
+
self.func = func
|
| 86 |
+
self.left_args_list = args_list
|
| 87 |
+
self.n_tasks = len(args_list)
|
| 88 |
+
|
| 89 |
+
# initial work
|
| 90 |
+
self.__start_works(self.__get_n_pending_works())
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def __start_works(self, times):
|
| 94 |
+
for _ in range(times):
|
| 95 |
+
args = self.left_args_list.pop(0)
|
| 96 |
+
self.async_result.append(
|
| 97 |
+
self.pool.apply_async(self.func, args))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def __get_n_pending_works(self):
|
| 101 |
+
return min((self.prefetch_size + 1) * self.batch_size - len(self.async_result)
|
| 102 |
+
, len(self.left_args_list))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def pop_results(self):
|
| 107 |
+
|
| 108 |
+
n_inwork = len(self.async_result)
|
| 109 |
+
|
| 110 |
+
n_fetch = min(n_inwork, self.batch_size)
|
| 111 |
+
rtn = [self.async_result.pop(0).get()
|
| 112 |
+
for _ in range(n_fetch)]
|
| 113 |
+
|
| 114 |
+
to_fill = self.__get_n_pending_works()
|
| 115 |
+
if to_fill == 0:
|
| 116 |
+
self.pool.close()
|
| 117 |
+
else:
|
| 118 |
+
self.__start_works(to_fill)
|
| 119 |
+
|
| 120 |
+
return rtn
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_indices_of_pairs(radius, size):
|
| 126 |
+
|
| 127 |
+
search_dist = []
|
| 128 |
+
|
| 129 |
+
for x in range(1, radius):
|
| 130 |
+
search_dist.append((0, x))
|
| 131 |
+
|
| 132 |
+
for y in range(1, radius):
|
| 133 |
+
for x in range(-radius + 1, radius):
|
| 134 |
+
if x * x + y * y < radius * radius:
|
| 135 |
+
search_dist.append((y, x))
|
| 136 |
+
|
| 137 |
+
radius_floor = radius - 1
|
| 138 |
+
|
| 139 |
+
full_indices = np.reshape(np.arange(0, size[0]*size[1], dtype=np.int64),
|
| 140 |
+
(size[0], size[1]))
|
| 141 |
+
|
| 142 |
+
cropped_height = size[0] - radius_floor
|
| 143 |
+
cropped_width = size[1] - 2 * radius_floor
|
| 144 |
+
|
| 145 |
+
indices_from = np.reshape(full_indices[:-radius_floor, radius_floor:-radius_floor],
|
| 146 |
+
[-1])
|
| 147 |
+
|
| 148 |
+
indices_to_list = []
|
| 149 |
+
|
| 150 |
+
for dy, dx in search_dist:
|
| 151 |
+
indices_to = full_indices[dy:dy + cropped_height,
|
| 152 |
+
radius_floor + dx:radius_floor + dx + cropped_width]
|
| 153 |
+
indices_to = np.reshape(indices_to, [-1])
|
| 154 |
+
|
| 155 |
+
indices_to_list.append(indices_to)
|
| 156 |
+
|
| 157 |
+
concat_indices_to = np.concatenate(indices_to_list, axis=0)
|
| 158 |
+
|
| 159 |
+
return indices_from, concat_indices_to
|
| 160 |
+
|
utils/metric/utility.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
# import logging
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import pdb
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
|
| 18 |
+
def metric_s_for_null(pred):
|
| 19 |
+
num_seg, T, H, W = pred.shape
|
| 20 |
+
pred = pred.view(num_seg*T, H, W)
|
| 21 |
+
assert len(pred.shape) == 3
|
| 22 |
+
|
| 23 |
+
N = pred.size(0)
|
| 24 |
+
num_pixels = pred.view(-1).shape[0]
|
| 25 |
+
|
| 26 |
+
temp_pred = torch.sigmoid(pred)
|
| 27 |
+
pred = (temp_pred > 0.5).int()
|
| 28 |
+
|
| 29 |
+
x = torch.sum(pred.view(-1))
|
| 30 |
+
s = torch.sqrt(x / num_pixels)
|
| 31 |
+
|
| 32 |
+
return s
|
| 33 |
+
|
| 34 |
+
def mask_iou(pred, target, eps=1e-7, size_average=True):
|
| 35 |
+
r"""
|
| 36 |
+
param:
|
| 37 |
+
pred: size [N x H x W]
|
| 38 |
+
target: size [N x H x W]
|
| 39 |
+
output:
|
| 40 |
+
iou: size [1] (size_average=True) or [N] (size_average=False)
|
| 41 |
+
"""
|
| 42 |
+
# return mask_iou_224(pred, target, eps=1e-7)
|
| 43 |
+
num_ref, T, H, W = pred.shape
|
| 44 |
+
pred = pred.view(num_ref*T, H, W)
|
| 45 |
+
target = target.view(num_ref*T, H, W)
|
| 46 |
+
assert len(pred.shape) == 3 and pred.shape == target.shape
|
| 47 |
+
|
| 48 |
+
N = pred.size(0)
|
| 49 |
+
# 像素数
|
| 50 |
+
num_pixels = pred.size(-1) * pred.size(-2)
|
| 51 |
+
# gt是否是纯黑
|
| 52 |
+
no_obj_flag = (target.sum(2).sum(1) == 0)
|
| 53 |
+
|
| 54 |
+
# 会把pred进行sigmoid变为01之间的概率
|
| 55 |
+
temp_pred = torch.sigmoid(pred)
|
| 56 |
+
# 通过阈值变成01矩阵
|
| 57 |
+
pred = (temp_pred > 0.4).int()
|
| 58 |
+
# 交集
|
| 59 |
+
inter = (pred * target).sum(2).sum(1)
|
| 60 |
+
# 并集
|
| 61 |
+
union = torch.max(pred, target).sum(2).sum(1)
|
| 62 |
+
|
| 63 |
+
inter_no_obj = ((1 - target) * (1 - pred)).sum(2).sum(1)
|
| 64 |
+
inter[no_obj_flag] = inter_no_obj[no_obj_flag]
|
| 65 |
+
union[no_obj_flag] = num_pixels
|
| 66 |
+
|
| 67 |
+
iou = torch.sum(inter / (union + eps)) / N
|
| 68 |
+
|
| 69 |
+
return iou.item()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _eval_pr(y_pred, y, num, device='cuda'):
|
| 73 |
+
if device.startswith('cuda'):
|
| 74 |
+
prec, recall = torch.zeros(num).to(y_pred.device), torch.zeros(num).to(y_pred.device)
|
| 75 |
+
# 0到1 生成num个阈值
|
| 76 |
+
thlist = torch.linspace(0, 1 - 1e-10, num).to(y_pred.device)
|
| 77 |
+
else:
|
| 78 |
+
prec, recall = torch.zeros(num), torch.zeros(num)
|
| 79 |
+
thlist = torch.linspace(0, 1 - 1e-10, num)
|
| 80 |
+
|
| 81 |
+
for i in range(num):
|
| 82 |
+
y_temp = (y_pred >= thlist[i]).float()
|
| 83 |
+
|
| 84 |
+
# 计算 True Positives(TP)
|
| 85 |
+
tp = (y_temp * y).sum()
|
| 86 |
+
# 一个是交集除以 预测面积 一个是除以真实面积
|
| 87 |
+
prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20)
|
| 88 |
+
|
| 89 |
+
return prec, recall
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def Eval_Fmeasure(pred, gt, measure_path, pr_num=255, device='cuda'):
|
| 93 |
+
r"""
|
| 94 |
+
param:
|
| 95 |
+
pred: size [N x H x W]
|
| 96 |
+
gt: size [N x H x W]
|
| 97 |
+
output:
|
| 98 |
+
iou: size [1] (size_average=True) or [N] (size_average=False)
|
| 99 |
+
"""
|
| 100 |
+
num_ref, T, H, W = pred.shape
|
| 101 |
+
pred = pred.view(num_ref*T, H, W)
|
| 102 |
+
gt = gt.view(num_ref*T, H, W)
|
| 103 |
+
assert len(pred.shape) == 3
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# sigmoid转为01之间的
|
| 107 |
+
pred = torch.sigmoid(pred)
|
| 108 |
+
N = pred.size(0)
|
| 109 |
+
beta2 = 0.3
|
| 110 |
+
avg_f, img_num = 0.0, 0
|
| 111 |
+
score = torch.zeros(pr_num)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
for img_id in range(N):
|
| 115 |
+
if torch.mean(gt[img_id]) == 0.0:
|
| 116 |
+
continue
|
| 117 |
+
prec, recall = _eval_pr(pred[img_id], gt[img_id], pr_num, device=device)
|
| 118 |
+
f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
|
| 119 |
+
f_score[f_score != f_score] = 0 # for Nan
|
| 120 |
+
avg_f += f_score
|
| 121 |
+
img_num += 1
|
| 122 |
+
score = avg_f / img_num
|
| 123 |
+
|
| 124 |
+
return score.max().item()
|
| 125 |
+
|