Spaces:
Runtime error
Runtime error
File size: 5,994 Bytes
753fd9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch.nn as nn
import torch
from torch.nn.functional import mse_loss
# for NEW: losses when calculated on keypoint locations
# see https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/subpix/dsnt.html
# from kornia.geometry import dsnt # old kornia version
from kornia.geometry.subpix import dsnt # kornia 0.4.0
def joints_mse_loss_orig(output, target, target_weight=None):
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.view((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.view((batch_size, num_joints, -1)).split(1, 1)
loss = 0
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx]
heatmap_gt = heatmaps_gt[idx]
if target_weight is None:
loss += 0.5 * mse_loss(heatmap_pred, heatmap_gt, reduction='mean')
else:
loss += 0.5 * mse_loss(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx]),
reduction='mean'
)
return loss / num_joints
class JointsMSELoss(nn.Module):
def __init__(self, use_target_weight=True):
super().__init__()
self.use_target_weight = use_target_weight
raise NotImplementedError
def forward(self, output, target, target_weight):
if not self.use_target_weight:
target_weight = None
return joints_mse_loss_orig(output, target, target_weight)
# ----- NEW: losses when calculated on keypoint locations instead of keypoint heatmaps -----
def joints_mse_loss_onKPloc(output, target, meta, target_weight=None):
# debugging:
# for old kornia version
# output_softmax_2d = dsnt.spatial_softmax_2d(target, temperature=torch.tensor(100))
# output_kp = dsnt.spatial_softargmax_2d(output_softmax_2d, normalized_coordinates=False) + 1
# print(output_kp[0])
# print(meta['tpts'][0])
# render gaussian
# dsnt.render_gaussian_2d(meta['tpts'][0][0, :2].to('cpu'), torch.tensor(([5., 5.])).to('cpu'), [256, 256], False)
# output_softmax_2d = dsnt.spatial_softmax_2d(output, temperature=torch.tensor(100))
# target_norm = target / target.sum(axis=3).sum(axis=2)[:, :, None, None]
# output_softmax_2d = dsnt.spatial_softmax_2d(output*10) # (target, temperature=torch.tensor(10))
# output_kp = dsnt.spatial_softargmax_2d(target_norm, normalized_coordinates=False) + 1
# normalize target heatmap
'''target_sum = target.sum(axis=3).sum(axis=2)[:, :, None, None]
target_sum[target_sum==0] = 1e-2
target_norm = target / target_sum'''
target_norm = target # now we have normalized heatmaps
# normalize predictions -> from logits to probability distribution
output_norm = dsnt.spatial_softmax2d(output, temperature=torch.tensor(1))
# heatmap loss (for normalization)
heatmap_loss = joints_mse_loss_orig(output_norm, target_norm, target_weight)
# keypoint distance loss (average distance in pixels)
output_kp = dsnt.spatial_expectation2d(output_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
target_kp = meta['tpts'].to(output_kp.device) # (bs, 20, 3)
output_kp_resh = output_kp.reshape((-1, 2))
target_kp_resh = target_kp[:, :, :2].reshape((-1, 2))
weights_resh = target_kp[:, :, 2].reshape((-1))
# dist_loss = (((output_kp_resh - target_kp_resh)**2).sum(axis=1).sqrt()*weights_resh)[weights_resh>0].sum() / min(weights_resh[weights_resh>0].sum(), 1e-5)
dist_loss = (((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0]).sum() / max(weights_resh[weights_resh>0].sum(), 1e-5)
# return heatmap_loss*100 # + 0.0001*dist_loss
# import pdb; pdb.set_trace()
'''import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
img_np = output_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3]
img_np = img_np * 255./ img_np.max()
# plot image
plt.imshow(img_np)
plt.savefig('./debugging_output/test_output.png')
plt.close()
img_np = target_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3]
img_np = img_np * 255./ img_np.max()
# plot image
plt.imshow(img_np)
plt.savefig('./debugging_output/test_gt.png')
plt.close()'''
# print(heatmap_loss*100)
# print(dist_loss * 1e-4)
# distlossonly: return dist_loss * 1e-4
# both: return dist_loss * 1e-4 + heatmap_loss*100
return dist_loss * 1e-4 + heatmap_loss*100
class JointsMSELoss_onKPloc(nn.Module):
def __init__(self, use_target_weight=True):
super().__init__()
self.use_target_weight = use_target_weight
def forward(self, output, target, target_weight):
if not self.use_target_weight:
target_weight = None
return joints_mse_loss_onKPloc(output, target, meta, target_weight)
# ----- NEW: lsegmentation loss -----
import torch.nn.functional as F
'''def resize2d(img, size):
return (F.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data
# F.adaptive_avg_pool2d(meta['silh'], (64,64))).data'''
def segmentation_loss(output, meta):
# output: (6, 2, 64, 64)
# meta.keys(): ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'silh']
# prepare target silhouettes
target_silh = meta['silh']
target_silh_l = target_silh.to(torch.long)
criterion_ce = nn.CrossEntropyLoss()
if output.shape[2] == 64:
target_silh_64 = F.adaptive_avg_pool2d(target_silh, (64,64))
target_silh_64[target_silh_64>0.5] = 1
target_silh_64[target_silh_64<=0.5] = 0
target_silh_64_l = target_silh_64.to(torch.long)
loss_silh_64 = criterion_ce(output, target_silh_64_l) # 0.7
return loss_silh_64
else:
loss_silh_l = criterion_ce(output, target_silh_l) # 0.7
return loss_silh_l
|