|
|
|
|
|
|
|
|
|
import pdb |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from nets.sampler import * |
|
from nets.repeatability_loss import * |
|
from nets.reliability_loss import * |
|
|
|
|
|
class MultiLoss(nn.Module): |
|
"""Combines several loss functions for convenience. |
|
*args: [loss weight (float), loss creator, ... ] |
|
|
|
Example: |
|
loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() ) |
|
""" |
|
|
|
def __init__(self, *args, dbg=()): |
|
nn.Module.__init__(self) |
|
assert len(args) % 2 == 0, "args must be a list of (float, loss)" |
|
self.weights = [] |
|
self.losses = nn.ModuleList() |
|
for i in range(len(args) // 2): |
|
weight = float(args[2 * i + 0]) |
|
loss = args[2 * i + 1] |
|
assert isinstance(loss, nn.Module), "%s is not a loss!" % loss |
|
self.weights.append(weight) |
|
self.losses.append(loss) |
|
|
|
def forward(self, select=None, **variables): |
|
assert not select or all(1 <= n <= len(self.losses) for n in select) |
|
d = dict() |
|
cum_loss = 0 |
|
for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses), 1): |
|
if select is not None and num not in select: |
|
continue |
|
l = loss_func(**{k: v for k, v in variables.items()}) |
|
if isinstance(l, tuple): |
|
assert len(l) == 2 and isinstance(l[1], dict) |
|
else: |
|
l = l, {loss_func.name: l} |
|
cum_loss = cum_loss + weight * l[0] |
|
for key, val in l[1].items(): |
|
d["loss_" + key] = float(val) |
|
d["loss"] = float(cum_loss) |
|
return cum_loss, d |
|
|