WSCL / losses /multi_view_consistency_loss.py
yhzhai's picture
release code
482ab8a
raw history blame
No virus
4.91 kB
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from skimage import segmentation
def get_multi_view_consistency_loss(opt):
loss = MultiViewConsistencyLoss(
opt.mvc_soft,
opt.mvc_zeros_on_au,
opt.mvc_single_weight,
opt.modality,
opt.mvc_spixel,
opt.mvc_num_spixel,
)
return loss
class MultiViewConsistencyLoss(nn.Module):
def __init__(
self,
soft: bool,
zeros_on_au: bool,
single_weight: Dict,
modality: List,
spixel: bool = False,
num_spixel: int = 100,
eps: float = 1e-4,
):
super().__init__()
self.soft = soft
self.zeros_on_au = zeros_on_au
self.single_weight = single_weight
self.modality = modality
self.spixel = spixel
self.num_spixel = num_spixel
self.eps = eps
self.mse_loss = nn.MSELoss(reduction="mean")
def forward(self, output: Dict, label, spixel=None, image=None, mask=None):
tgt_map = torch.zeros_like(
output[self.modality[0]]["out_map"], requires_grad=False
)
with torch.no_grad():
for modality in self.modality:
weight = self.single_weight[modality.lower()]
tgt_map = tgt_map + weight * output[modality]["out_map"]
if self.spixel:
# raw_tgt_map = tgt_map.clone()
tgt_map = get_spixel_tgt_map(tgt_map, spixel)
if not self.soft:
for b in range(tgt_map.shape[0]):
if tgt_map[b, ...].max() <= 0.5 and label[b] == 1.0:
tgt_map[b, ...][
torch.where(tgt_map[b, ...] == torch.max(tgt_map[b, ...]))
] = 1.0
tgt_map[torch.where(tgt_map > 0.5)] = 1
tgt_map[torch.where(tgt_map <= 0.5)] = 0
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0
if self.zeros_on_au:
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0
total_loss = 0.0
loss_dict = {}
for modality in self.modality:
loss = self.mse_loss(output[modality]["out_map"], tgt_map)
loss_dict[f"multi_view_consistency_loss_{modality}"] = loss
total_loss = total_loss + loss
return {**loss_dict, "tgt_map": tgt_map, "total_loss": total_loss}
def _save(
self,
spixel: torch.Tensor,
image: torch.Tensor,
mask: torch.Tensor,
tgt_map: torch.Tensor,
raw_tgt_map: torch.Tensor,
out_path: str = "tmp/spixel_tgt_map.png",
):
spixel = spixel.permute(0, 2, 3, 1).detach().cpu().numpy()
image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
mask = mask.permute(0, 2, 3, 1).detach().cpu().numpy() * 255.0
tgt_map = tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
raw_tgt_map = (
raw_tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
)
bn = spixel.shape[0]
i = 1
for b in range(bn):
plt.subplot(bn, 5, i)
i += 1
plt.imshow(image[b])
plt.axis("off")
plt.title("image")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(mask[b])
plt.axis("off")
plt.title("mask")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(spixel[b])
plt.axis("off")
plt.title("superpixel")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(raw_tgt_map[b])
plt.axis("off")
plt.title("raw target map")
plt.subplot(bn, 5, i)
i += 1
plt.imshow(tgt_map[b])
plt.axis("off")
plt.title("target map")
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def get_spixel_tgt_map(weighted_sum, spixel):
b, _, h, w = weighted_sum.shape
spixel_tgt_map = torch.zeros_like(weighted_sum, requires_grad=False)
for bidx in range(b):
spixel_indices = spixel[bidx, ...].unique()
# num_spixel = spixel_idx.shape[0]
for spixel_idx in spixel_indices.tolist():
area = (spixel[bidx, ...] == spixel_idx).sum()
weighted_sum_in_area = weighted_sum[bidx, ...][
torch.where(spixel[bidx, ...] == spixel_idx)
].sum()
avg_area = weighted_sum_in_area / area
# this is soft map, and the threshold process will be conducted in the forward function
spixel_tgt_map[bidx][
torch.where(spixel[bidx, ...] == spixel_idx)
] = avg_area
return spixel_tgt_map
if __name__ == "__main__":
mvc_loss = MultiViewConsistencyLoss(True, True, [1, 1, 2])
print("a")