Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy.io import loadmat | |
def init_spixel_grid(args, b_train=True, ratio = 1, downsize = 16): | |
curr_img_height = args.crop_size | |
curr_img_width = args.crop_size | |
# pixel coord | |
all_h_coords = np.arange(0, curr_img_height, 1) | |
all_w_coords = np.arange(0, curr_img_width, 1) | |
curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij')) | |
coord_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]]) | |
all_XY_feat = (torch.from_numpy( | |
np.tile(coord_tensor, (1, 1, 1, 1)).astype(np.float32)).cuda()) | |
return all_XY_feat | |
def label2one_hot_torch(labels, C=14): | |
""" Converts an integer label torch.autograd.Variable to a one-hot Variable. | |
Args: | |
labels(tensor) : segmentation label | |
C (integer) : number of classes in labels | |
Returns: | |
target (tensor) : one-hot vector of the input label | |
Shape: | |
labels: (B, 1, H, W) | |
target: (B, N, H, W) | |
""" | |
b,_, h, w = labels.shape | |
one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels) | |
target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type | |
return target.type(torch.float32) | |
colors = loadmat('data/color150.mat')['colors'] | |
colors = np.concatenate((colors, colors, colors, colors)) | |
def unique(ar, return_index=False, return_inverse=False, return_counts=False): | |
ar = np.asanyarray(ar).flatten() | |
optional_indices = return_index or return_inverse | |
optional_returns = optional_indices or return_counts | |
if ar.size == 0: | |
if not optional_returns: | |
ret = ar | |
else: | |
ret = (ar,) | |
if return_index: | |
ret += (np.empty(0, np.bool),) | |
if return_inverse: | |
ret += (np.empty(0, np.bool),) | |
if return_counts: | |
ret += (np.empty(0, np.intp),) | |
return ret | |
if optional_indices: | |
perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') | |
aux = ar[perm] | |
else: | |
ar.sort() | |
aux = ar | |
flag = np.concatenate(([True], aux[1:] != aux[:-1])) | |
if not optional_returns: | |
ret = aux[flag] | |
else: | |
ret = (aux[flag],) | |
if return_index: | |
ret += (perm[flag],) | |
if return_inverse: | |
iflag = np.cumsum(flag) - 1 | |
inv_idx = np.empty(ar.shape, dtype=np.intp) | |
inv_idx[perm] = iflag | |
ret += (inv_idx,) | |
if return_counts: | |
idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) | |
ret += (np.diff(idx),) | |
return ret | |
def colorEncode(labelmap, mode='RGB'): | |
labelmap = labelmap.astype('int') | |
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), | |
dtype=np.uint8) | |
for label in unique(labelmap): | |
if label < 0: | |
continue | |
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ | |
np.tile(colors[label], | |
(labelmap.shape[0], labelmap.shape[1], 1)) | |
if mode == 'BGR': | |
return labelmap_rgb[:, :, ::-1] | |
else: | |
return labelmap_rgb | |
def get_edges(sp_label, sp_num): | |
# This function returns a (hw) * (hw) matrix N. | |
# If Nij = 1, then superpixel i and j are neighbors | |
# Otherwise Nij = 0. | |
top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] | |
left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] | |
top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] | |
top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] | |
n_affs = [] | |
edge_indices = [] | |
for i in range(sp_label.shape[0]): | |
# change to torch.ones below to include self-loop in graph | |
n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).cuda() | |
# top/bottom | |
top_i = top[i].squeeze() | |
x, y = torch.nonzero(top_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# left/right | |
left_i = left[i].squeeze() | |
try: | |
x, y = torch.nonzero(left_i, as_tuple = True) | |
except: | |
import pdb; pdb.set_trace() | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x, y+1].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# top left | |
top_left_i = top_left[i].squeeze() | |
x, y = torch.nonzero(top_left_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y+1].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# top right | |
top_right_i = top_right[i].squeeze() | |
x, y = torch.nonzero(top_right_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y+1].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
n_affs.append(n_aff) | |
edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) | |
edge_indices.append(edge_index.cuda()) | |
return edge_indices | |
def draw_color_seg(seg): | |
seg = seg.detach().cpu().numpy() | |
color_ = [] | |
for i in range(seg.shape[0]): | |
colori = colorEncode(seg[i].squeeze()) | |
colori = torch.from_numpy(colori / 255.0).float().permute(2, 0, 1) | |
color_.append(colori) | |
color_ = torch.stack(color_) | |
return color_ | |