|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class JointsMSELoss(nn.Module): |
|
def __init__(self, use_target_weight): |
|
super(JointsMSELoss, self).__init__() |
|
self.criterion = nn.MSELoss(reduction='mean') |
|
self.use_target_weight = use_target_weight |
|
|
|
def forward(self, output, target, target_weight): |
|
batch_size = output.size(0) |
|
num_joints = output.size(1) |
|
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
loss = 0 |
|
|
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx].squeeze() |
|
heatmap_gt = heatmaps_gt[idx].squeeze() |
|
if self.use_target_weight: |
|
loss += 0.5 * self.criterion( |
|
heatmap_pred.mul(target_weight[:, idx]), |
|
heatmap_gt.mul(target_weight[:, idx]) |
|
) |
|
else: |
|
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) |
|
|
|
return loss / num_joints |
|
|
|
|
|
class JointsOHKMMSELoss(nn.Module): |
|
def __init__(self, use_target_weight, topk=8): |
|
super(JointsOHKMMSELoss, self).__init__() |
|
self.criterion = nn.MSELoss(reduction='none') |
|
self.use_target_weight = use_target_weight |
|
self.topk = topk |
|
|
|
def ohkm(self, loss): |
|
ohkm_loss = 0. |
|
for i in range(loss.size()[0]): |
|
sub_loss = loss[i] |
|
topk_val, topk_idx = torch.topk( |
|
sub_loss, k=self.topk, dim=0, sorted=False |
|
) |
|
tmp_loss = torch.gather(sub_loss, 0, topk_idx) |
|
ohkm_loss += torch.sum(tmp_loss) / self.topk |
|
ohkm_loss /= loss.size()[0] |
|
return ohkm_loss |
|
|
|
def forward(self, output, target, target_weight): |
|
batch_size = output.size(0) |
|
num_joints = output.size(1) |
|
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
|
|
loss = [] |
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx].squeeze() |
|
heatmap_gt = heatmaps_gt[idx].squeeze() |
|
if self.use_target_weight: |
|
loss.append(0.5 * self.criterion( |
|
heatmap_pred.mul(target_weight[:, idx]), |
|
heatmap_gt.mul(target_weight[:, idx]) |
|
)) |
|
else: |
|
loss.append( |
|
0.5 * self.criterion(heatmap_pred, heatmap_gt) |
|
) |
|
|
|
loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss] |
|
loss = torch.cat(loss, dim=1) |
|
|
|
return self.ohkm(loss) |
|
|