Spaces:
Sleeping
Sleeping
| ''' | |
| Adapted from here: https://github.com/ZitongYu/PhysFormer/TorchLossComputer.py | |
| Modifed based on the HR-CNN here: https://github.com/radimspetlik/hr-cnn | |
| ''' | |
| import math | |
| import torch | |
| from torch.autograd import Variable | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from evaluation.post_process import calculate_metric_per_video | |
| def normal_sampling(mean, label_k, std): | |
| return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std) | |
| def kl_loss(inputs, labels): | |
| criterion = nn.KLDivLoss(reduce=False) | |
| outputs = torch.log(inputs) | |
| loss = criterion(outputs, labels) | |
| #loss = loss.sum()/loss.shape[0] | |
| loss = loss.sum() | |
| return loss | |
| class Neg_Pearson(nn.Module): # Pearson range [-1, 1] so if < 0, abs|loss| ; if >0, 1- loss | |
| def __init__(self): | |
| super(Neg_Pearson,self).__init__() | |
| def forward(self, preds, labels): # all variable operation | |
| loss = 0 | |
| for i in range(preds.shape[0]): | |
| sum_x = torch.sum(preds[i]) # x | |
| sum_y = torch.sum(labels[i]) # y | |
| sum_xy = torch.sum(preds[i]*labels[i]) # xy | |
| sum_x2 = torch.sum(torch.pow(preds[i],2)) # x^2 | |
| sum_y2 = torch.sum(torch.pow(labels[i],2)) # y^2 | |
| N = preds.shape[1] | |
| pearson = (N*sum_xy - sum_x*sum_y)/(torch.sqrt((N*sum_x2 - torch.pow(sum_x,2))*(N*sum_y2 - torch.pow(sum_y,2)))) | |
| loss += 1 - pearson | |
| loss = loss/preds.shape[0] | |
| return loss | |
| class RhythmFormer_Loss(nn.Module): | |
| def __init__(self): | |
| super(RhythmFormer_Loss,self).__init__() | |
| self.criterion_Pearson = Neg_Pearson() | |
| def forward(self, pred_ppg, labels , epoch , FS , diff_flag): | |
| loss_time = self.criterion_Pearson(pred_ppg.view(1,-1) , labels.view(1,-1)) | |
| loss_CE , loss_distribution_kl = TorchLossComputer.Frequency_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0) | |
| loss_hr = TorchLossComputer.HR_loss(pred_ppg.squeeze(-1), labels.squeeze(-1), diff_flag=diff_flag, Fs=FS, std=3.0) | |
| if torch.isnan(loss_time) : | |
| loss_time = 0 | |
| loss = 0.2 * loss_time + 1.0 * loss_CE + 1.0 * loss_hr | |
| return loss | |
| class TorchLossComputer(object): | |
| def compute_complex_absolute_given_k(output, k, N): | |
| two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N | |
| hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1) | |
| k = k.type(torch.FloatTensor).cuda() | |
| two_pi_n_over_N = two_pi_n_over_N.cuda() | |
| hanning = hanning.cuda() | |
| output = output.view(1, -1) * hanning | |
| output = output.view(1, 1, -1).type(torch.cuda.FloatTensor) | |
| k = k.view(1, -1, 1) | |
| two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1) | |
| complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \ | |
| + torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2 | |
| return complex_absolute | |
| def complex_absolute(output, Fs, bpm_range=None): | |
| output = output.view(1, -1) | |
| N = output.size()[1] | |
| unit_per_hz = Fs / N | |
| feasible_bpm = bpm_range / 60.0 | |
| k = feasible_bpm / unit_per_hz | |
| # only calculate feasible PSD range [0.7,4]Hz | |
| complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N) | |
| return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator | |
| def cross_entropy_power_spectrum_loss(inputs, target, Fs): | |
| inputs = inputs.view(1, -1) | |
| target = target.view(1, -1) | |
| bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() | |
| #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() | |
| complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) | |
| whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) | |
| whole_max_idx = whole_max_idx.type(torch.float) | |
| #pdb.set_trace() | |
| #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2 | |
| return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) | |
| def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma): | |
| inputs = inputs.view(1, -1) | |
| target = target.view(1, -1) | |
| bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() | |
| #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() | |
| complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) | |
| whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) | |
| whole_max_idx = whole_max_idx.type(torch.float) | |
| #pdb.set_trace() | |
| criterion = FocalLoss(gamma=gamma) | |
| #return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2 | |
| return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx) | |
| def cross_entropy_power_spectrum_forward_pred(inputs, Fs): | |
| inputs = inputs.view(1, -1) | |
| bpm_range = torch.arange(40, 190, dtype=torch.float).cuda() | |
| #bpm_range = torch.arange(40, 180, dtype=torch.float).cuda() | |
| #bpm_range = torch.arange(40, 260, dtype=torch.float).cuda() | |
| complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) | |
| whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0) | |
| whole_max_idx = whole_max_idx.type(torch.float) | |
| return whole_max_idx | |
| def Frequency_loss(inputs, target, diff_flag , Fs, std): | |
| hr_gt, pred_hr_peak, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='FFT') | |
| inputs = inputs.view(1, -1) | |
| target = target.view(1, -1) | |
| bpm_range = torch.arange(45, 150, dtype=torch.float).to(torch.device('cuda')) | |
| ca = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range) | |
| sa = ca/torch.sum(ca) | |
| target_distribution = [normal_sampling(int(hr_gt), i, std) for i in range(45, 150)] | |
| target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution] | |
| target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda')) | |
| hr_gt = torch.tensor(hr_gt-45).view(1).type(torch.long).to(torch.device('cuda')) | |
| return F.cross_entropy(ca, hr_gt) , kl_loss(sa , target_distribution) | |
| def HR_loss(inputs, target, diff_flag , Fs, std): | |
| psd_gt, psd_pred, SNR, macc = calculate_metric_per_video(inputs.detach().cpu(), target.detach().cpu(), diff_flag = diff_flag, fs=Fs, hr_method='Peak') | |
| pred_distribution = [normal_sampling(np.argmax(psd_pred), i, std) for i in range(psd_pred.size)] | |
| pred_distribution = [i if i > 1e-15 else 1e-15 for i in pred_distribution] | |
| pred_distribution = torch.Tensor(pred_distribution).to(torch.device('cuda')) | |
| target_distribution = [normal_sampling(np.argmax(psd_gt), i, std) for i in range(psd_gt.size)] | |
| target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution] | |
| target_distribution = torch.Tensor(target_distribution).to(torch.device('cuda')) | |
| return kl_loss(pred_distribution , target_distribution) | |