Spaces:
Running
on
A10G
Running
on
A10G
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| def masked_l2_loss(pred, target, mask, weight_known, weight_missing): | |
| per_pixel_l2 = F.mse_loss(pred, target, reduction='none') | |
| pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |
| return (pixel_weights * per_pixel_l2).mean() | |
| def masked_l1_loss(pred, target, mask, weight_known, weight_missing): | |
| per_pixel_l1 = F.l1_loss(pred, target, reduction='none') | |
| pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |
| return (pixel_weights * per_pixel_l1).mean() | |
| def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): | |
| if mask is None: | |
| res = torch.stack([F.mse_loss(fake_feat, target_feat) | |
| for fake_feat, target_feat in zip(fake_features, target_features)]).mean() | |
| else: | |
| res = 0 | |
| norm = 0 | |
| for fake_feat, target_feat in zip(fake_features, target_features): | |
| cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) | |
| error_weights = 1 - cur_mask | |
| cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() | |
| res = res + cur_val | |
| norm += 1 | |
| res = res / norm | |
| return res | |