FcF-Inpainting / training /losses /high_receptive_pl.py
praeclarumjj3's picture
:zap: Build App
9eae6e7
raw
history blame
1.64 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from training.losses.ade20k import ModelBuilder
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
class HRFPL(nn.Module):
def __init__(self, weight=1,
weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
super().__init__()
self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
arch_encoder=arch_encoder,
arch_decoder='ppm_deepsup',
fc_dim=2048,
segmentation=segmentation)
self.impl.eval()
for w in self.impl.parameters():
w.requires_grad_(False)
self.weight = weight
def forward(self, pred, target):
target = (target + 1) / 2
pred = (pred + 1) / 2
pred = torch.clamp(pred, 0., 1.)
pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
self.impl = self.impl.to(pred.device)
pred_feats = self.impl(pred, return_feature_maps=True)
target_feats = self.impl(target, return_feature_maps=True)
result = torch.stack([F.mse_loss(cur_pred, cur_target)
for cur_pred, cur_target
in zip(pred_feats, target_feats)]).sum() * self.weight
return result