Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
from torch.nn.functional import mse_loss | |
# for NEW: losses when calculated on keypoint locations | |
# see https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/subpix/dsnt.html | |
# from kornia.geometry import dsnt # old kornia version | |
from kornia.geometry.subpix import dsnt # kornia 0.4.0 | |
def joints_mse_loss_orig(output, target, target_weight=None): | |
batch_size = output.size(0) | |
num_joints = output.size(1) | |
heatmaps_pred = output.view((batch_size, num_joints, -1)).split(1, 1) | |
heatmaps_gt = target.view((batch_size, num_joints, -1)).split(1, 1) | |
loss = 0 | |
for idx in range(num_joints): | |
heatmap_pred = heatmaps_pred[idx] | |
heatmap_gt = heatmaps_gt[idx] | |
if target_weight is None: | |
loss += 0.5 * mse_loss(heatmap_pred, heatmap_gt, reduction='mean') | |
else: | |
loss += 0.5 * mse_loss( | |
heatmap_pred.mul(target_weight[:, idx]), | |
heatmap_gt.mul(target_weight[:, idx]), | |
reduction='mean' | |
) | |
return loss / num_joints | |
class JointsMSELoss(nn.Module): | |
def __init__(self, use_target_weight=True): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
raise NotImplementedError | |
def forward(self, output, target, target_weight): | |
if not self.use_target_weight: | |
target_weight = None | |
return joints_mse_loss_orig(output, target, target_weight) | |
# ----- NEW: losses when calculated on keypoint locations instead of keypoint heatmaps ----- | |
def joints_mse_loss_onKPloc(output, target, meta, target_weight=None): | |
# debugging: | |
# for old kornia version | |
# output_softmax_2d = dsnt.spatial_softmax_2d(target, temperature=torch.tensor(100)) | |
# output_kp = dsnt.spatial_softargmax_2d(output_softmax_2d, normalized_coordinates=False) + 1 | |
# print(output_kp[0]) | |
# print(meta['tpts'][0]) | |
# render gaussian | |
# dsnt.render_gaussian_2d(meta['tpts'][0][0, :2].to('cpu'), torch.tensor(([5., 5.])).to('cpu'), [256, 256], False) | |
# output_softmax_2d = dsnt.spatial_softmax_2d(output, temperature=torch.tensor(100)) | |
# target_norm = target / target.sum(axis=3).sum(axis=2)[:, :, None, None] | |
# output_softmax_2d = dsnt.spatial_softmax_2d(output*10) # (target, temperature=torch.tensor(10)) | |
# output_kp = dsnt.spatial_softargmax_2d(target_norm, normalized_coordinates=False) + 1 | |
# normalize target heatmap | |
'''target_sum = target.sum(axis=3).sum(axis=2)[:, :, None, None] | |
target_sum[target_sum==0] = 1e-2 | |
target_norm = target / target_sum''' | |
target_norm = target # now we have normalized heatmaps | |
# normalize predictions -> from logits to probability distribution | |
output_norm = dsnt.spatial_softmax2d(output, temperature=torch.tensor(1)) | |
# heatmap loss (for normalization) | |
heatmap_loss = joints_mse_loss_orig(output_norm, target_norm, target_weight) | |
# keypoint distance loss (average distance in pixels) | |
output_kp = dsnt.spatial_expectation2d(output_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) | |
target_kp = meta['tpts'].to(output_kp.device) # (bs, 20, 3) | |
output_kp_resh = output_kp.reshape((-1, 2)) | |
target_kp_resh = target_kp[:, :, :2].reshape((-1, 2)) | |
weights_resh = target_kp[:, :, 2].reshape((-1)) | |
# dist_loss = (((output_kp_resh - target_kp_resh)**2).sum(axis=1).sqrt()*weights_resh)[weights_resh>0].sum() / min(weights_resh[weights_resh>0].sum(), 1e-5) | |
dist_loss = (((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0]).sum() / max(weights_resh[weights_resh>0].sum(), 1e-5) | |
# return heatmap_loss*100 # + 0.0001*dist_loss | |
# import pdb; pdb.set_trace() | |
'''import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
img_np = output_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] | |
img_np = img_np * 255./ img_np.max() | |
# plot image | |
plt.imshow(img_np) | |
plt.savefig('./debugging_output/test_output.png') | |
plt.close() | |
img_np = target_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] | |
img_np = img_np * 255./ img_np.max() | |
# plot image | |
plt.imshow(img_np) | |
plt.savefig('./debugging_output/test_gt.png') | |
plt.close()''' | |
# print(heatmap_loss*100) | |
# print(dist_loss * 1e-4) | |
# distlossonly: return dist_loss * 1e-4 | |
# both: return dist_loss * 1e-4 + heatmap_loss*100 | |
return dist_loss * 1e-4 + heatmap_loss*100 | |
class JointsMSELoss_onKPloc(nn.Module): | |
def __init__(self, use_target_weight=True): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
def forward(self, output, target, target_weight): | |
if not self.use_target_weight: | |
target_weight = None | |
return joints_mse_loss_onKPloc(output, target, meta, target_weight) | |
# ----- NEW: lsegmentation loss ----- | |
import torch.nn.functional as F | |
'''def resize2d(img, size): | |
return (F.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data | |
# F.adaptive_avg_pool2d(meta['silh'], (64,64))).data''' | |
def segmentation_loss(output, meta): | |
# output: (6, 2, 64, 64) | |
# meta.keys(): ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'silh'] | |
# prepare target silhouettes | |
target_silh = meta['silh'] | |
target_silh_l = target_silh.to(torch.long) | |
criterion_ce = nn.CrossEntropyLoss() | |
if output.shape[2] == 64: | |
target_silh_64 = F.adaptive_avg_pool2d(target_silh, (64,64)) | |
target_silh_64[target_silh_64>0.5] = 1 | |
target_silh_64[target_silh_64<=0.5] = 0 | |
target_silh_64_l = target_silh_64.to(torch.long) | |
loss_silh_64 = criterion_ce(output, target_silh_64_l) # 0.7 | |
return loss_silh_64 | |
else: | |
loss_silh_l = criterion_ce(output, target_silh_l) # 0.7 | |
return loss_silh_l | |