|
import math |
|
|
|
|
|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from timm.models.layers import trunc_normal_ |
|
|
|
import torch |
|
|
|
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', |
|
'has_instances', 'ignore_in_eval', 'color']) |
|
|
|
classes = [ |
|
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), |
|
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), |
|
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), |
|
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), |
|
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), |
|
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), |
|
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), |
|
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), |
|
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), |
|
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), |
|
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), |
|
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), |
|
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), |
|
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), |
|
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), |
|
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), |
|
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), |
|
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), |
|
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), |
|
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), |
|
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), |
|
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), |
|
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), |
|
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), |
|
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), |
|
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), |
|
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), |
|
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), |
|
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), |
|
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), |
|
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), |
|
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), |
|
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), |
|
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), |
|
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), |
|
] |
|
|
|
cityscapes_id_to_trainID = {cls.id: cls.train_id for cls in classes} |
|
cityscapes_trainID_to_testID = {cls.train_id: cls.id for cls in classes} |
|
cityscapes_trainID_to_color = {cls.train_id: cls.color for cls in classes} |
|
cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes} |
|
cityscapes_trainID_to_color[255] = (0, 0, 0) |
|
cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes} |
|
cityscapes_trainID_to_name[255] = 'ignore' |
|
cityscapes_trainID_to_name[19] = 'ignore' |
|
|
|
|
|
def map2cs(seg): |
|
while len(seg.shape) > 2: |
|
seg = seg[0] |
|
colors = cityscapes_trainID_to_color |
|
|
|
rgb = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
for l in np.unique(seg): |
|
rgb[seg == l, :] = colors[l] |
|
return rgb |
|
|
|
|
|
def get_colors(num_colors): |
|
from PIL import ImageColor |
|
import matplotlib |
|
hex_colors = [ |
|
|
|
"#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059", |
|
"#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", "#997D87", |
|
"#5A0007", "#809693", "#FEFFE6", "#1B4400", "#4FC601", "#3B5DFF", "#4A3B53", "#FF2F80", |
|
"#61615A", "#BA0900", "#6B7900", "#00C2A0", "#FFAA92", "#FF90C9", "#B903AA", "#D16100", |
|
"#DDEFFF", "#000035", "#7B4F4B", "#A1C299", "#300018", "#0AA6D8", "#013349", "#00846F", |
|
"#372101", "#FFB500", "#C2FFED", "#A079BF", "#CC0744", "#C0B9B2", "#C2FF99", "#001E09", |
|
"#00489C", "#6F0062", "#0CBD66", "#EEC3FF", "#456D75", "#B77B68", "#7A87A1", "#788D66", |
|
"#885578", "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0086ED", "#886F4C", |
|
"#34362D", "#B4A8BD", "#00A6AA", "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81", |
|
"#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", "#04F757", "#C8A1A1", "#1E6E00", |
|
"#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700", |
|
"#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329", |
|
"#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C", |
|
"#83AB58", "#001C1E", "#D1F7CE", "#004B28", "#C8D0F6", "#A3A489", "#806C66", "#222800", |
|
"#BF5650", "#E83000", "#66796D", "#DA007C", "#FF1A59", "#8ADBB4", "#1E0200", "#5B4E51", |
|
"#C895C5", "#320033", "#FF6832", "#66E1D3", "#CFCDAC", "#D0AC94", "#7ED379", "#012C58", |
|
] |
|
hex_colors_mlib = list(matplotlib.colors.cnames.values()) |
|
for hcm in hex_colors_mlib: |
|
if hcm not in hex_colors: |
|
hex_colors.append(hcm) |
|
colors = [ImageColor.getrgb(hex) for hex in hex_colors] |
|
return colors[:num_colors] |
|
|
|
|
|
def colorize_one(seg, ignore=255, colors=None, ncolors=32): |
|
unq = np.unique(seg) |
|
if ncolors is not None: |
|
ncolors = max(ncolors, max(unq)) |
|
else: |
|
ncolors = max(unq) |
|
colors = get_colors(ncolors) if colors is None else colors |
|
h, w = seg.shape |
|
c = 3 |
|
rgb = np.zeros((h, w, c), dtype=np.uint8) |
|
for l in unq: |
|
if ignore is not None and l == ignore: |
|
continue |
|
try: |
|
rgb[seg == l, :] = colors[l] |
|
except: |
|
raise Exception(l) |
|
return rgb |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): |
|
|
|
|
|
posemb_tok, posemb_grid = ( |
|
posemb[:, :num_extra_tokens], |
|
posemb[0, num_extra_tokens:], |
|
) |
|
if grid_old_shape is None: |
|
gs_old_h = int(math.sqrt(len(posemb_grid))) |
|
gs_old_w = gs_old_h |
|
else: |
|
gs_old_h, gs_old_w = grid_old_shape |
|
|
|
gs_h, gs_w = grid_new_shape |
|
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) |
|
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") |
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) |
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) |
|
return posemb |
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model): |
|
""" convert patch embedding weight from manual patchify + linear proj to conv""" |
|
out_dict = {} |
|
if "model" in state_dict: |
|
|
|
state_dict = state_dict["model"] |
|
num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) |
|
patch_size = model.patch_size |
|
image_size = model.patch_embed.image_size |
|
for k, v in state_dict.items(): |
|
if k == "pos_embed" and v.shape != model.pos_embed.shape: |
|
|
|
v = resize_pos_embed( |
|
v, |
|
None, |
|
(image_size[0] // patch_size, image_size[1] // patch_size), |
|
num_extra_tokens, |
|
) |
|
out_dict[k] = v |
|
return out_dict |
|
|
|
|
|
def padding(im, patch_size, fill_value=0): |
|
|
|
H, W = im.size(2), im.size(3) |
|
pad_h, pad_w = 0, 0 |
|
if H % patch_size > 0: |
|
pad_h = patch_size - (H % patch_size) |
|
if W % patch_size > 0: |
|
pad_w = patch_size - (W % patch_size) |
|
im_padded = im |
|
if pad_h > 0 or pad_w > 0: |
|
im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value) |
|
return im_padded |
|
|
|
|
|
def unpadding(y, target_size): |
|
H, W = target_size |
|
H_pad, W_pad = y.size(2), y.size(3) |
|
|
|
extra_h = H_pad - H |
|
extra_w = W_pad - W |
|
if extra_h > 0: |
|
y = y[:, :, :-extra_h] |
|
if extra_w > 0: |
|
y = y[:, :, :, :-extra_w] |
|
return y |
|
|
|
|
|
def resize(im, smaller_size): |
|
h, w = im.shape[2:] |
|
if h < w: |
|
ratio = w / h |
|
h_res, w_res = smaller_size, ratio * smaller_size |
|
else: |
|
ratio = h / w |
|
h_res, w_res = ratio * smaller_size, smaller_size |
|
if min(h, w) < smaller_size: |
|
im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear") |
|
else: |
|
im_res = im |
|
return im_res |
|
|
|
|
|
def sliding_window(im, flip, window_size, window_stride, channels_first=True): |
|
if channels_first: |
|
B, C, H, W = im.shape |
|
else: |
|
B, H, W, C = im.shape |
|
ws = window_size |
|
|
|
windows = {"crop": [], "anchors": []} |
|
h_anchors = torch.arange(0, H, window_stride) |
|
w_anchors = torch.arange(0, W, window_stride) |
|
h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws] |
|
w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws] |
|
for ha in h_anchors: |
|
for wa in w_anchors: |
|
if channels_first: |
|
window = im[:, :, ha: ha + ws, wa: wa + ws] |
|
else: |
|
window = im[:, ha: ha + ws, wa: wa + ws] |
|
windows["crop"].append(window) |
|
windows["anchors"].append((ha, wa)) |
|
windows["flip"] = flip |
|
windows["shape"] = (H, W) |
|
return windows |
|
|
|
|
|
def merge_windows(windows, window_size, ori_shape, no_softmax=False, no_upsample=False, patch_size=None): |
|
ws = window_size |
|
im_windows = windows["seg_maps"] |
|
anchors = windows["anchors"] |
|
C = im_windows[0].shape[0] |
|
H, W = windows["shape"] |
|
flip = windows["flip"] |
|
|
|
if no_upsample: |
|
H, W = H // patch_size, W // patch_size |
|
|
|
logit = torch.zeros((C, H, W), device=im_windows.device) |
|
count = torch.zeros((1, H, W), device=im_windows.device) |
|
for window, (ha, wa) in zip(im_windows, anchors): |
|
if no_upsample: |
|
ha = ha // patch_size |
|
wa = wa // patch_size |
|
logit[:, ha: ha + ws, wa: wa + ws] += window |
|
count[:, ha: ha + ws, wa: wa + ws] += 1 |
|
logit /= count |
|
|
|
if not no_upsample: |
|
logit = F.interpolate( |
|
logit.unsqueeze(0), |
|
ori_shape, |
|
mode="bilinear", |
|
)[0] |
|
if flip: |
|
logit = torch.flip(logit, (2,)) |
|
if not no_softmax: |
|
|
|
result = F.softmax(logit, 0) |
|
else: |
|
|
|
result = logit |
|
return result |
|
|
|
|
|
def debug_windows(windows, debug_file): |
|
pass |
|
|
|
|
|
def inference_picie( |
|
model, |
|
classifier, |
|
metric_test, |
|
ims, |
|
ori_shape, |
|
window_size, |
|
window_stride, |
|
batch_size, |
|
decoder_features=False, |
|
no_upsample=False, |
|
debug_file=None, |
|
im_rgb=None, |
|
channel_first=False |
|
): |
|
try: |
|
C = model.n_cls |
|
except: |
|
C = classifier.module.bias.shape[0] |
|
|
|
|
|
|
|
|
|
for im in ims: |
|
im = im.to('cuda') |
|
if len(im.shape) == 3: |
|
im = im.unsqueeze(0) |
|
flip = False |
|
windows = sliding_window(im, flip, window_size, window_stride) |
|
crops = torch.stack(windows.pop("crop"))[:, 0] |
|
num_crops = len(crops) |
|
|
|
WB = batch_size if batch_size > 0 else num_crops |
|
if no_upsample: |
|
window_size = window_size // model.patch_size |
|
seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device) |
|
with torch.no_grad(): |
|
for i in range(0, num_crops, WB): |
|
|
|
feats = model.forward(crops[i: i + WB]) |
|
if metric_test == 'cosine': |
|
feats = F.normalize(feats, dim=1, p=2) |
|
probs = classifier(feats) |
|
probs = F.interpolate(probs, crops[i: i + WB].shape[-2:], mode='bilinear', align_corners=False) |
|
seg_maps[i: i + WB] = probs |
|
windows["seg_maps"] = seg_maps |
|
|
|
im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features, |
|
no_upsample=no_upsample, patch_size=None) |
|
|
|
seg_map = im_seg_map |
|
if no_upsample and not decoder_features: |
|
pass |
|
else: |
|
seg_map = F.interpolate( |
|
seg_map.unsqueeze(0), |
|
ori_shape, |
|
mode="bilinear", |
|
) |
|
|
|
return seg_map |
|
|
|
|
|
def inference( |
|
model, |
|
ims, |
|
ori_shape, |
|
window_size, |
|
window_stride, |
|
batch_size, |
|
decoder_features=False, |
|
encoder_features=False, |
|
no_upsample=False, |
|
): |
|
C = model.n_cls |
|
patch_size = model.patch_size |
|
|
|
|
|
|
|
|
|
for im in ims: |
|
|
|
if len(im.shape) == 3: |
|
im = im.unsqueeze(0) |
|
|
|
flip = False |
|
|
|
windows = sliding_window(im, flip, window_size, window_stride) |
|
|
|
crops = torch.stack(windows.pop("crop"))[:, 0] |
|
num_crops = len(crops) |
|
|
|
WB = batch_size if batch_size > 0 else num_crops |
|
if no_upsample: |
|
window_size = window_size // model.patch_size |
|
|
|
seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device) |
|
|
|
with torch.no_grad(): |
|
for i in range(0, num_crops, WB): |
|
|
|
|
|
seg_maps[i: i + WB] = model.forward(crops[i: i + WB], decoder_features=decoder_features, |
|
encoder_features=encoder_features, |
|
no_upsample=no_upsample) |
|
windows["seg_maps"] = seg_maps |
|
|
|
im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features, |
|
no_upsample=no_upsample, patch_size=model.patch_size) |
|
|
|
seg_map = im_seg_map |
|
if no_upsample and not decoder_features: |
|
pass |
|
else: |
|
seg_map = F.interpolate( |
|
seg_map.unsqueeze(0), |
|
ori_shape, |
|
mode="bilinear", |
|
) |
|
|
|
|
|
|
|
|
|
return seg_map |
|
|
|
|
|
def inference_features( |
|
model, |
|
ims, |
|
ori_shape, |
|
window_size, |
|
window_stride, |
|
batch_size, |
|
decoder_features=False, |
|
encoder_features=False, |
|
save2cpu=False, |
|
no_upsample=True, |
|
encoder_only=False |
|
): |
|
C = model.n_cls if decoder_features else model.encoder.d_model |
|
patch_size = model.patch_size |
|
|
|
|
|
|
|
|
|
for im in ims: |
|
im = im.to('cuda') |
|
if len(im.shape) == 3: |
|
im = im.unsqueeze(0) |
|
|
|
flip = False |
|
|
|
windows = sliding_window(im, flip, window_size, window_stride) |
|
|
|
crops = torch.stack(windows.pop("crop"))[:, 0] |
|
num_crops = len(crops) |
|
|
|
WB = batch_size if batch_size > 0 else num_crops |
|
if no_upsample: |
|
window_size = window_size // model.patch_size |
|
|
|
enc_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device) |
|
if decoder_features: |
|
dec_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device) |
|
|
|
with torch.no_grad(): |
|
for i in range(0, num_crops, WB): |
|
enc_fts = model.forward(crops[i: i + WB], decoder_features=decoder_features, |
|
encoder_features=True, |
|
no_upsample=no_upsample, encoder_only=encoder_only) |
|
if decoder_features: |
|
enc_fts, dec_fts = enc_fts |
|
dec_maps[i: i + WB] = dec_fts |
|
elif isinstance(enc_fts, tuple): |
|
enc_fts = enc_fts[0] |
|
enc_maps[i: i + WB] = enc_fts |
|
|
|
windows["seg_maps"] = enc_maps |
|
im_enc_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features, |
|
no_upsample=no_upsample, patch_size=model.patch_size) |
|
|
|
if decoder_features: |
|
windows["seg_maps"] = dec_maps |
|
im_dec_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features, |
|
no_upsample=no_upsample, patch_size=model.patch_size) |
|
|
|
if no_upsample: |
|
pass |
|
else: |
|
im_enc_map = F.interpolate( |
|
im_enc_map.unsqueeze(0), |
|
ori_shape, |
|
mode="bilinear", |
|
) |
|
if decoder_features: |
|
im_dec_map = F.interpolate( |
|
im_dec_map.unsqueeze(0), |
|
ori_shape, |
|
mode="bilinear", |
|
) |
|
|
|
im_enc_map = im_enc_map.cpu().numpy() |
|
if decoder_features: |
|
im_dec_map = im_dec_map.cpu().numpy() |
|
return im_enc_map, im_dec_map |
|
|
|
return im_enc_map |
|
|
|
|
|
def inference_conv( |
|
model, |
|
ims, |
|
ims_metas, |
|
ori_shape |
|
): |
|
assert len(ims) == 1 |
|
for im, im_metas in zip(ims, ims_metas): |
|
im = im.to(ptu.device) |
|
if len(im.shape) < 4: |
|
im = im.unsqueeze(0) |
|
logits = model(im) |
|
if ori_shape[:2] != logits.shape[-2:]: |
|
|
|
logits = F.interpolate( |
|
logits, |
|
ori_shape[-2:], |
|
mode="bilinear", |
|
) |
|
|
|
result = F.softmax(logits.squeeze(), 0) |
|
|
|
return result |
|
|
|
|
|
def num_params(model): |
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
|
n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters]) |
|
if not type(n_params) == int: |
|
n_params = n_params.item() |
|
return n_params |
|
|