File size: 597 Bytes
a80d6bb
 
 
 
c74a070
 
 
a80d6bb
 
 
c74a070
 
 
 
a80d6bb
c74a070
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn.functional as F
import torch


def l1_loss(output, target_rgb, target_raw, weight=1.0):
    raw_loss = F.l1_loss(output["reconstruct_raw"], target_raw)
    rgb_loss = F.l1_loss(output["reconstruct_rgb"], target_rgb)
    total_loss = raw_loss + weight * rgb_loss
    return total_loss, raw_loss, rgb_loss


def l2_loss(output, target_rgb, target_raw, weight=1.0):
    raw_loss = F.mse_loss(output["reconstruct_raw"], target_raw)
    rgb_loss = F.mse_loss(output["reconstruct_rgb"], target_rgb)
    total_loss = raw_loss + weight * rgb_loss
    return total_loss, raw_loss, rgb_loss