Nadine Rueegg
initial commit with code and data
753fd9a
raw
history blame
5.99 kB
import torch.nn as nn
import torch
from torch.nn.functional import mse_loss
# for NEW: losses when calculated on keypoint locations
# see https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/subpix/dsnt.html
# from kornia.geometry import dsnt # old kornia version
from kornia.geometry.subpix import dsnt # kornia 0.4.0
def joints_mse_loss_orig(output, target, target_weight=None):
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.view((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.view((batch_size, num_joints, -1)).split(1, 1)
loss = 0
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx]
heatmap_gt = heatmaps_gt[idx]
if target_weight is None:
loss += 0.5 * mse_loss(heatmap_pred, heatmap_gt, reduction='mean')
else:
loss += 0.5 * mse_loss(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx]),
reduction='mean'
)
return loss / num_joints
class JointsMSELoss(nn.Module):
def __init__(self, use_target_weight=True):
super().__init__()
self.use_target_weight = use_target_weight
raise NotImplementedError
def forward(self, output, target, target_weight):
if not self.use_target_weight:
target_weight = None
return joints_mse_loss_orig(output, target, target_weight)
# ----- NEW: losses when calculated on keypoint locations instead of keypoint heatmaps -----
def joints_mse_loss_onKPloc(output, target, meta, target_weight=None):
# debugging:
# for old kornia version
# output_softmax_2d = dsnt.spatial_softmax_2d(target, temperature=torch.tensor(100))
# output_kp = dsnt.spatial_softargmax_2d(output_softmax_2d, normalized_coordinates=False) + 1
# print(output_kp[0])
# print(meta['tpts'][0])
# render gaussian
# dsnt.render_gaussian_2d(meta['tpts'][0][0, :2].to('cpu'), torch.tensor(([5., 5.])).to('cpu'), [256, 256], False)
# output_softmax_2d = dsnt.spatial_softmax_2d(output, temperature=torch.tensor(100))
# target_norm = target / target.sum(axis=3).sum(axis=2)[:, :, None, None]
# output_softmax_2d = dsnt.spatial_softmax_2d(output*10) # (target, temperature=torch.tensor(10))
# output_kp = dsnt.spatial_softargmax_2d(target_norm, normalized_coordinates=False) + 1
# normalize target heatmap
'''target_sum = target.sum(axis=3).sum(axis=2)[:, :, None, None]
target_sum[target_sum==0] = 1e-2
target_norm = target / target_sum'''
target_norm = target # now we have normalized heatmaps
# normalize predictions -> from logits to probability distribution
output_norm = dsnt.spatial_softmax2d(output, temperature=torch.tensor(1))
# heatmap loss (for normalization)
heatmap_loss = joints_mse_loss_orig(output_norm, target_norm, target_weight)
# keypoint distance loss (average distance in pixels)
output_kp = dsnt.spatial_expectation2d(output_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
target_kp = meta['tpts'].to(output_kp.device) # (bs, 20, 3)
output_kp_resh = output_kp.reshape((-1, 2))
target_kp_resh = target_kp[:, :, :2].reshape((-1, 2))
weights_resh = target_kp[:, :, 2].reshape((-1))
# dist_loss = (((output_kp_resh - target_kp_resh)**2).sum(axis=1).sqrt()*weights_resh)[weights_resh>0].sum() / min(weights_resh[weights_resh>0].sum(), 1e-5)
dist_loss = (((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0]).sum() / max(weights_resh[weights_resh>0].sum(), 1e-5)
# return heatmap_loss*100 # + 0.0001*dist_loss
# import pdb; pdb.set_trace()
'''import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
img_np = output_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3]
img_np = img_np * 255./ img_np.max()
# plot image
plt.imshow(img_np)
plt.savefig('./debugging_output/test_output.png')
plt.close()
img_np = target_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3]
img_np = img_np * 255./ img_np.max()
# plot image
plt.imshow(img_np)
plt.savefig('./debugging_output/test_gt.png')
plt.close()'''
# print(heatmap_loss*100)
# print(dist_loss * 1e-4)
# distlossonly: return dist_loss * 1e-4
# both: return dist_loss * 1e-4 + heatmap_loss*100
return dist_loss * 1e-4 + heatmap_loss*100
class JointsMSELoss_onKPloc(nn.Module):
def __init__(self, use_target_weight=True):
super().__init__()
self.use_target_weight = use_target_weight
def forward(self, output, target, target_weight):
if not self.use_target_weight:
target_weight = None
return joints_mse_loss_onKPloc(output, target, meta, target_weight)
# ----- NEW: lsegmentation loss -----
import torch.nn.functional as F
'''def resize2d(img, size):
return (F.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data
# F.adaptive_avg_pool2d(meta['silh'], (64,64))).data'''
def segmentation_loss(output, meta):
# output: (6, 2, 64, 64)
# meta.keys(): ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'silh']
# prepare target silhouettes
target_silh = meta['silh']
target_silh_l = target_silh.to(torch.long)
criterion_ce = nn.CrossEntropyLoss()
if output.shape[2] == 64:
target_silh_64 = F.adaptive_avg_pool2d(target_silh, (64,64))
target_silh_64[target_silh_64>0.5] = 1
target_silh_64[target_silh_64<=0.5] = 0
target_silh_64_l = target_silh_64.to(torch.long)
loss_silh_64 = criterion_ce(output, target_silh_64_l) # 0.7
return loss_silh_64
else:
loss_silh_l = criterion_ce(output, target_silh_l) # 0.7
return loss_silh_l