Spaces:
Runtime error
Runtime error
# Modified from: | |
# https://github.com/anibali/pytorch-stacked-hourglass | |
# https://github.com/bearpaw/pytorch-pose | |
import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import cv2 | |
import torch | |
# import stacked_hourglass.datasets.utils_stanext as utils_stanext | |
# COLORS, labels = utils_stanext.load_keypoint_labels_and_colours() | |
COLORS = ['#d82400', '#d82400', '#d82400', '#fcfc00', '#fcfc00', '#fcfc00', '#48b455', '#48b455', '#48b455', '#0090aa', '#0090aa', '#0090aa', '#d848ff', '#d848ff', '#fc90aa', '#006caa', '#d89000', '#d89000', '#fc90aa', '#006caa', '#ededed', '#ededed', '#a9d08e', '#a9d08e'] | |
RGB_MEAN = [0.4404, 0.4440, 0.4327] | |
RGB_STD = [0.2458, 0.2410, 0.2468] | |
def get_img_from_fig(fig, dpi=180): | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png", dpi=dpi) | |
buf.seek(0) | |
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) | |
buf.close() | |
img = cv2.imdecode(img_arr, 1) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return img | |
def save_input_image_with_keypoints(img, tpts, out_path='./test_input.png', colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD, ratio_in_out=4., threshold=0.3, print_scores=False): | |
""" | |
img has shape (3, 256, 256) and is a torch tensor | |
pts has shape (20, 3) and is a torch tensor | |
-> this function is tested with the mpii dataset and the results look ok | |
""" | |
# reverse color normalization | |
for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() | |
img_np = img.detach().cpu().numpy().transpose(1, 2, 0) | |
# tpts_np = tpts.detach().cpu().numpy() | |
# plot image | |
fig, ax = plt.subplots() | |
plt.imshow(img_np) # plt.imshow(im) | |
plt.gca().set_axis_off() | |
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0) | |
plt.margins(0,0) | |
# plot all visible keypoints | |
#import pdb; pdb.set_trace() | |
for idx, (x, y, v) in enumerate(tpts): | |
if v > threshold: | |
x = int(x*ratio_in_out) | |
y = int(y*ratio_in_out) | |
plt.scatter([x], [y], c=[colors[idx]], marker="x", s=50) | |
if print_scores: | |
txt = '{:2.2f}'.format(v.item()) | |
plt.annotate(txt, (x, y)) # , c=colors[idx]) | |
plt.savefig(out_path, bbox_inches='tight', pad_inches=0) | |
plt.close() | |
return | |
def save_input_image(img, out_path, colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD): | |
for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize() | |
img_np = img.detach().cpu().numpy().transpose(1, 2, 0) | |
plt.imsave(out_path, img_np) | |
return | |
###################################################################### | |
def get_bodypart_colors(): | |
# body colors | |
n_body = 8 | |
c = np.arange(1, n_body + 1) | |
norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max()) | |
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.gist_rainbow) | |
cmap.set_array([]) | |
body_cols = [] | |
for i in range(0, n_body): | |
body_cols.append(cmap.to_rgba(i + 1)) | |
# head colors | |
n_blue = 5 | |
c = np.arange(1, n_blue + 1) | |
norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) | |
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues) | |
cmap.set_array([]) | |
head_cols = [] | |
for i in range(0, n_body): | |
head_cols.append(cmap.to_rgba(i + 1)) | |
# torso colors | |
n_blue = 2 | |
c = np.arange(1, n_blue + 1) | |
norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1) | |
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens) | |
cmap.set_array([]) | |
torso_cols = [] | |
for i in range(0, n_body): | |
torso_cols.append(cmap.to_rgba(i + 1)) | |
return body_cols, head_cols, torso_cols | |
body_cols, head_cols, torso_cols = get_bodypart_colors() | |
tbp_dict = {'full_body': [0, 8], | |
'head': [8, 13], | |
'torso': [13, 15]} | |
def save_image_with_part_segmentation(partseg_big, seg_big, input_image_np, ind_img, out_path_seg=None, out_path_seg_overlay=None, thr=0.3): | |
soft_max = torch.nn.Softmax(dim=0) | |
# create dit with results | |
tbp_dict_res = {} | |
for ind_tbp, part in enumerate(['full_body', 'head', 'torso']): | |
partseg_tbp = partseg_big[:, tbp_dict[part][0]:tbp_dict[part][1], :, :] | |
segm_img_pred = soft_max((partseg_tbp[ind_img, :, :, :])) # [1, :, :] | |
m_v, m_i = segm_img_pred.max(axis=0) | |
tbp_dict_res[part] = { | |
'inds': tbp_dict[part], | |
'seg_probs': segm_img_pred, | |
'seg_max_inds': m_i, | |
'seg_max_values': m_v} | |
# create output_image | |
partseg_image = np.zeros((256, 256, 3)) | |
for ind_sp in range(0, 5): | |
# partseg_image[tbp_dict_res['head']['seg_max_inds']==ind_sp, :] = head_cols[ind_sp][0:3] | |
mask_a = tbp_dict_res['full_body']['seg_max_inds']==1 | |
mask_b = tbp_dict_res['head']['seg_max_inds']==ind_sp | |
partseg_image[mask_a*mask_b, :] = head_cols[ind_sp][0:3] | |
for ind_sp in range(0, 2): | |
# partseg_image[tbp_dict_res['torso']['seg_max_inds']==ind_sp, :] = torso_cols[ind_sp][0:3] | |
mask_a = tbp_dict_res['full_body']['seg_max_inds']==2 | |
mask_b = tbp_dict_res['torso']['seg_max_inds']==ind_sp | |
partseg_image[mask_a*mask_b, :] = torso_cols[ind_sp][0:3] | |
for ind_sp in range(0, 8): | |
if (not ind_sp == 1) and (not ind_sp == 2): # head and torso | |
partseg_image[tbp_dict_res['full_body']['seg_max_inds']==ind_sp, :] = body_cols[ind_sp][0:3] | |
partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] = 0 | |
# save images | |
if out_path_seg is not None: | |
plt.imsave(out_path_seg, partseg_image) | |
if out_path_seg_overlay is not None: | |
partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] = input_image_np[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]<thr, :] | |
im_masked_partseg = cv2.addWeighted(input_image_np.astype(np.float32),0.5,partseg_image.astype(np.float32),0.5,0) | |
plt.imsave(out_path_seg_overlay, im_masked_partseg) | |
return | |
def save_image_with_part_segmentation_from_gt_annotation(partseg_annots, out_path, ind_img=0): | |
# partseg_annots: (bs, 3, 256, 256) | |
# import pdb; pdb.set_trace() | |
annots = partseg_annots[ind_img, :, :, :] | |
partseg_image = np.zeros((256, 256, 3)) | |
for ind_sp in range(0, 8): | |
partseg_image[annots[0, :, :]==ind_sp, :] = body_cols[ind_sp][0:3] | |
for ind_sp in range(0, 5): | |
partseg_image[annots[1, :, :]==ind_sp, :] = head_cols[ind_sp][0:3] | |
for ind_sp in range(0, 2): | |
partseg_image[annots[2, :, :]==ind_sp, :] = torso_cols[ind_sp][0:3] | |
plt.imsave(out_path, partseg_image.astype(np.float32)) | |
return | |
def save_image_from_prepared_partseg(partseg_init, out_path): | |
# partseg_init: (256, 256, 11) | |
# partseg_init = output_reproj['partseg_images_hg_nograd'][0, :, :, :].detach().cpu().numpy() | |
# out_path = '/ps/scratch/nrueegg/new_projects/Animals/dog_project/pytorch-stacked-hourglass/debugging_output/partseg_hg_0.png' | |
partseg = np.argmax(partseg_init, axis=2) | |
partseg_image = np.zeros((256, 256, 3)) | |
for ind in range(partseg_init.shape[2]): | |
if ind == 0: # head | |
partseg_image[partseg==ind, :] = np.asarray(head_cols[0][0:3]) | |
elif ind < 7: | |
partseg_image[partseg==ind, :] = np.asarray(body_cols[ind+1][0:3]) | |
else: # 7 to 10 | |
partseg_image[partseg==ind, :] = np.asarray(head_cols[ind-6][0:3]) | |
partseg_image[partseg_init.sum(axis=2)==0, :] = 0 | |
plt.imsave(out_path, partseg_image) | |
return | |