Guess-What-Moves / eval_utils.py
subhc's picture
Code Commit
5e88f62
import functools
import random
from collections import defaultdict
import einops
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from sklearn.cluster import SpectralClustering
from tqdm import tqdm
import flow_reconstruction
from utils import visualisation, log, grid
from utils.vit_extractor import ViTExtractor
label_colors = visualisation.create_label_colormap()
logger = log.getLogger('gwm')
def __default_font(fontsize):
try:
FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize)
except OSError:
FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize)
return FNT
@functools.lru_cache(None) # cache the result
def autosized_default_font(size_limit: float) -> ImageFont.ImageFont:
fontsize = 1 # starting font size
font = __default_font(fontsize)
while font.getsize('test123')[1] < size_limit:
fontsize += 1
font = __default_font(fontsize)
fontsize -= 1
font = __default_font(fontsize)
return font
def iou(masks, gt, thres=0.5):
masks = (masks > thres).float()
intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1]))
union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect
return intersect / union.clip(min=1e-12)
def get_unsup_image_viz(model, cfg, sample, criterion):
if model.training:
model.eval()
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
model.train()
else:
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
return get_image_vis(model, cfg, sample, preds, criterion)
def get_vis_header(header_size, image_size, header_texts, header_height=20):
W, H = (image_size, header_height)
header_labels = []
font = autosized_default_font(0.8 * H)
for text in header_texts:
im = Image.new("RGB", (W, H), "white")
draw = ImageDraw.Draw(im)
w, h = draw.textsize(text, font=font)
draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font)
header_labels.append(torch.from_numpy(np.array(im)))
header_labels = torch.cat(header_labels, dim=1)
ret = (torch.ones((header_height, header_size, 3)) * 255)
ret[:, :header_labels.size(1)] = header_labels
return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8)
def get_image_vis(model, cfg, sample, preds, criterion):
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
with torch.no_grad():
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
masks_softmaxed = torch.softmax(masks_pred, dim=1)
masks_pred = masks_softmaxed
rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed)
rec_headers = ['rec_flow']
if len(rec_flows) > 1:
rec_headers.append('rec_bwd_flow')
rgb = torch.stack([x['rgb'] for x in sample])
flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255
rec_flows = [
(criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows
]
gt_labels = torch.stack([x['sem_seg'] for x in sample])
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2)
masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255
masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)')
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K])
if all('gwm_seg' in d for d in sample):
gwm_labels = torch.stack([x['gwm_seg'] for x in sample])
mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2)
gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1])
image_viz = torch.cat(
[rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255),
pred_seg.cpu(), *rec_flows], -1)
header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers]
else:
image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1)
header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers]
image_viz = torch.cat([image_viz, masks_each], -1)
header_text.extend(['slot'] * masks_softmaxed.shape[1])
if 'flow_edges' in sample[0]:
flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample])
if len(flow_edges.shape) >= 4:
flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4)
flow_edges = flow_edges.expand(-1, 3, -1, -1)
flow_edges = flow_edges * 255
image_viz = torch.cat([image_viz, flow_edges], -1)
header_text.append('flow_edges')
image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8)
return image_viz, header_text
def get_frame_vis(model, cfg, sample, preds):
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
masks_softmaxed = torch.softmax(masks_pred, dim=1)
if cfg.GWM.SIMPLE_REC:
mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7
means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom
rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means)
elif cfg.GWM.HOMOGRAPHY:
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow)
else:
grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y)
rgb = torch.stack([x['rgb'] for x in sample])
flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255
rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255
gt_labels = torch.stack([x['sem_seg'] for x in sample])
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2)
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES])
frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1)
frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8)
return frame_vis
def is_2comp_dataset(dataset):
if '+' in dataset:
d = dataset.split('+')[0].strip()
else:
d = dataset.strip()
logger.info_once(f"Is 2comp dataset? {d}")
for s in ['DAVIS', 'FBMS', 'STv2']:
if s in d:
return True
return d in ['DAVIS',
'FBMS',
'STv2']
def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False):
logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:')
logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}')
logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}')
merger = None
if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2:
merger = MaskMerger(cfg, model)
print_idxs = random.sample(range(len(val_loader)), k=10)
images_viz = []
ious_davis_eval = defaultdict(list)
ious = defaultdict(list)
for idx, sample in enumerate(tqdm(val_loader)):
t = 1
sample = [e for s in sample for e in s]
category = [s['category'] for s in sample]
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
masks_softmaxed = torch.softmax(masks_raw, dim=1)
masks_dict = merger(sample, masks_softmaxed)
if writer and idx in print_idxs:
flow = torch.stack([x['flow'] for x in sample]).to(model.device)
img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion)
images_viz.append(img_viz)
masks = masks_dict['cos']
gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu()
HW = gt_seg.shape[-2:]
if HW != masks.shape[-2:]:
logger.info_once(f"Upsampling predicted masks to {HW} for evaluation")
masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False)
else:
masks_softmaxed_sel = masks.detach().cpu()
masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach()
gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float()
for i in range(masks_.size(0)):
masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w
mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s
iou_max, slot_max = mask_iou.max(dim=1)
ious[category[i][0]].append(iou_max)
frame_id = category[i][1]
ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max))
frameious = sum(ious.values(), [])
frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious)
if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]:
logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame")
seq_scores = dict()
for c in ious_davis_eval:
seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1])
frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100
if writer:
header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text)
images_viz = torch.cat(images_viz, dim=1)
images_viz = torch.cat([header, images_viz], dim=1)
writer.add_image('val/images', images_viz, writer_iteration) # C H W
writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration)
logger.info(f"mIoU: {frame_mean_iou:.3f} \n")
return frame_mean_iou
class MaskMerger:
def __init__(self, cfg, model, merger_model="dino_vits8"):
self.extractor = ViTExtractor(model_type=merger_model, device=model.device)
self.out_dim = 384
self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1)
self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1)
self.start_idx = 0
def get_feats(self, batch):
with torch.no_grad():
feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False)
feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
return F.interpolate(feat, batch.shape[-2:], mode='bilinear')
def spectral(self, A):
clustering = SpectralClustering(n_clusters=2,
affinity='precomputed',
random_state=0).fit(A.detach().cpu().numpy())
return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1]
def cos_merge(self, basis, masks):
basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6)
A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6)
inda, indb = self.spectral(A)
return torch.stack([masks[:, inda].sum(1),
masks[:, indb].sum(1)], 1)
def __call__(self, sample, masks_softmaxed):
with torch.no_grad():
masks_softmaxed = masks_softmaxed[:, self.start_idx:]
batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0
features = self.get_feats((batch - self.mu) / self.sigma)
basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features)
basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12)
return {
'cos': self.cos_merge(basis, masks_softmaxed),
}