LHMPP / engine /BiRefNet /loss.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
from config import Config
class ContourLoss(torch.nn.Module):
def __init__(self):
super(ContourLoss, self).__init__()
def forward(self, pred, target, weight=10):
"""
target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
target[:,:,region_out_contour] == 0.
weight: scalar, length term weight.
"""
# length term
delta_r = (
pred[:, :, 1:, :] - pred[:, :, :-1, :]
) # horizontal gradient (B, C, H-1, W)
delta_c = (
pred[:, :, :, 1:] - pred[:, :, :, :-1]
) # vertical gradient (B, C, H, W-1)
delta_r = delta_r[:, :, 1:, :-2] ** 2 # (B, C, H-2, W-2)
delta_c = delta_c[:, :, :-2, 1:] ** 2 # (B, C, H-2, W-2)
delta_pred = torch.abs(delta_r + delta_c)
epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
length = torch.mean(
torch.sqrt(delta_pred + epsilon)
) # eq.(11) in the paper, mean is used instead of sum.
c_in = torch.ones_like(pred)
c_out = torch.zeros_like(pred)
region_in = torch.mean(
pred * (target - c_in) ** 2
) # equ.(12) in the paper, mean is used instead of sum.
region_out = torch.mean((1 - pred) * (target - c_out) ** 2)
region = region_in + region_out
loss = weight * length + region
return loss
class IoULoss(torch.nn.Module):
def __init__(self):
super(IoULoss, self).__init__()
def forward(self, pred, target):
b = pred.shape[0]
IoU = 0.0
for i in range(0, b):
# compute the IoU of the foreground
Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
IoU1 = Iand1 / Ior1
# IoU loss is (1-IoU1)
IoU = IoU + (1 - IoU1)
# return IoU/b
return IoU
class StructureLoss(torch.nn.Module):
def __init__(self):
super(StructureLoss, self).__init__()
def forward(self, pred, target):
weit = 1 + 5 * torch.abs(
F.avg_pool2d(target, kernel_size=31, stride=1, padding=15) - target
)
wbce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
pred = torch.sigmoid(pred)
inter = ((pred * target) * weit).sum(dim=(2, 3))
union = ((pred + target) * weit).sum(dim=(2, 3))
wiou = 1 - (inter + 1) / (union - inter + 1)
return (wbce + wiou).mean()
class PatchIoULoss(torch.nn.Module):
def __init__(self):
super(PatchIoULoss, self).__init__()
self.iou_loss = IoULoss()
def forward(self, pred, target):
win_y, win_x = 64, 64
iou_loss = 0.0
for anchor_y in range(0, target.shape[0], win_y):
for anchor_x in range(0, target.shape[1], win_y):
patch_pred = pred[
:, :, anchor_y : anchor_y + win_y, anchor_x : anchor_x + win_x
]
patch_target = target[
:, :, anchor_y : anchor_y + win_y, anchor_x : anchor_x + win_x
]
patch_iou_loss = self.iou_loss(patch_pred, patch_target)
iou_loss += patch_iou_loss
return iou_loss
class ThrReg_loss(torch.nn.Module):
def __init__(self):
super(ThrReg_loss, self).__init__()
def forward(self, pred, gt=None):
return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
class ClsLoss(nn.Module):
"""
Auxiliary classification loss for each refined class output.
"""
def __init__(self):
super(ClsLoss, self).__init__()
self.config = Config()
self.lambdas_cls = self.config.lambdas_cls
self.criterions_last = {"ce": nn.CrossEntropyLoss()}
def forward(self, preds, gt):
loss = 0.0
for _, pred_lvl in enumerate(preds):
if pred_lvl is None:
continue
for criterion_name, criterion in self.criterions_last.items():
loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
return loss
class PixLoss(nn.Module):
"""
Pixel loss for each refined map output.
"""
def __init__(self):
super(PixLoss, self).__init__()
self.config = Config()
self.lambdas_pix_last = self.config.lambdas_pix_last
self.criterions_last = {}
if "bce" in self.lambdas_pix_last and self.lambdas_pix_last["bce"]:
self.criterions_last["bce"] = nn.BCELoss()
if "iou" in self.lambdas_pix_last and self.lambdas_pix_last["iou"]:
self.criterions_last["iou"] = IoULoss()
if "iou_patch" in self.lambdas_pix_last and self.lambdas_pix_last["iou_patch"]:
self.criterions_last["iou_patch"] = PatchIoULoss()
if "ssim" in self.lambdas_pix_last and self.lambdas_pix_last["ssim"]:
self.criterions_last["ssim"] = SSIMLoss()
if "mae" in self.lambdas_pix_last and self.lambdas_pix_last["mae"]:
self.criterions_last["mae"] = nn.L1Loss()
if "mse" in self.lambdas_pix_last and self.lambdas_pix_last["mse"]:
self.criterions_last["mse"] = nn.MSELoss()
if "reg" in self.lambdas_pix_last and self.lambdas_pix_last["reg"]:
self.criterions_last["reg"] = ThrReg_loss()
if "cnt" in self.lambdas_pix_last and self.lambdas_pix_last["cnt"]:
self.criterions_last["cnt"] = ContourLoss()
if "structure" in self.lambdas_pix_last and self.lambdas_pix_last["structure"]:
self.criterions_last["structure"] = StructureLoss()
def forward(self, scaled_preds, gt):
loss = 0.0
for _, pred_lvl in enumerate(scaled_preds):
if pred_lvl.shape != gt.shape:
pred_lvl = nn.functional.interpolate(
pred_lvl, size=gt.shape[2:], mode="bilinear", align_corners=True
)
for criterion_name, criterion in self.criterions_last.items():
_loss = (
criterion(pred_lvl.sigmoid(), gt)
* self.lambdas_pix_last[criterion_name]
)
loss += _loss
# print(criterion_name, _loss.item())
return loss
class SSIMLoss(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIMLoss, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return (
1
- (
1
+ _ssim(
img1, img2, window, self.window_size, channel, self.size_average
)
)
/ 2
)
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
)
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
)
sigma2_sq = (
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
)
sigma12 = (
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
- mu1_mu2
)
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def SSIM(x, y):
C1 = 0.01**2
C2 = 0.03**2
mu_x = nn.AvgPool2d(3, 1, 1)(x)
mu_y = nn.AvgPool2d(3, 1, 1)(y)
mu_x_mu_y = mu_x * mu_y
mu_x_sq = mu_x.pow(2)
mu_y_sq = mu_y.pow(2)
sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
SSIM = SSIM_n / SSIM_d
return torch.clamp((1 - SSIM) / 2, 0, 1)
def saliency_structure_consistency(x, y):
ssim = torch.mean(SSIM(x, y))
return ssim