import torch import torch.nn as nn import torch.nn.functional as F from training.losses.ade20k import ModelBuilder IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] class HRFPL(nn.Module): def __init__(self, weight=1, weights_path=None, arch_encoder='resnet50dilated', segmentation=True): super().__init__() self.impl = ModelBuilder.get_encoder(weights_path=weights_path, arch_encoder=arch_encoder, arch_decoder='ppm_deepsup', fc_dim=2048, segmentation=segmentation) self.impl.eval() for w in self.impl.parameters(): w.requires_grad_(False) self.weight = weight def forward(self, pred, target): target = (target + 1) / 2 pred = (pred + 1) / 2 pred = torch.clamp(pred, 0., 1.) pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) self.impl = self.impl.to(pred.device) pred_feats = self.impl(pred, return_feature_maps=True) target_feats = self.impl(target, return_feature_maps=True) result = torch.stack([F.mse_loss(cur_pred, cur_target) for cur_pred, cur_target in zip(pred_feats, target_feats)]).sum() * self.weight return result