Guess-What-Moves / eval_utils.py
subhc's picture
Code Commit
5e88f62
raw
history blame contribute delete
No virus
12.8 kB
import functools
import random
from collections import defaultdict
import einops
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from sklearn.cluster import SpectralClustering
from tqdm import tqdm
import flow_reconstruction
from utils import visualisation, log, grid
from utils.vit_extractor import ViTExtractor
label_colors = visualisation.create_label_colormap()
logger = log.getLogger('gwm')
def __default_font(fontsize):
try:
FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize)
except OSError:
FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize)
return FNT
@functools.lru_cache(None) # cache the result
def autosized_default_font(size_limit: float) -> ImageFont.ImageFont:
fontsize = 1 # starting font size
font = __default_font(fontsize)
while font.getsize('test123')[1] < size_limit:
fontsize += 1
font = __default_font(fontsize)
fontsize -= 1
font = __default_font(fontsize)
return font
def iou(masks, gt, thres=0.5):
masks = (masks > thres).float()
intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1]))
union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect
return intersect / union.clip(min=1e-12)
def get_unsup_image_viz(model, cfg, sample, criterion):
if model.training:
model.eval()
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
model.train()
else:
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
return get_image_vis(model, cfg, sample, preds, criterion)
def get_vis_header(header_size, image_size, header_texts, header_height=20):
W, H = (image_size, header_height)
header_labels = []
font = autosized_default_font(0.8 * H)
for text in header_texts:
im = Image.new("RGB", (W, H), "white")
draw = ImageDraw.Draw(im)
w, h = draw.textsize(text, font=font)
draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font)
header_labels.append(torch.from_numpy(np.array(im)))
header_labels = torch.cat(header_labels, dim=1)
ret = (torch.ones((header_height, header_size, 3)) * 255)
ret[:, :header_labels.size(1)] = header_labels
return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8)
def get_image_vis(model, cfg, sample, preds, criterion):
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
with torch.no_grad():
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
masks_softmaxed = torch.softmax(masks_pred, dim=1)
masks_pred = masks_softmaxed
rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed)
rec_headers = ['rec_flow']
if len(rec_flows) > 1:
rec_headers.append('rec_bwd_flow')
rgb = torch.stack([x['rgb'] for x in sample])
flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255
rec_flows = [
(criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows
]
gt_labels = torch.stack([x['sem_seg'] for x in sample])
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2)
masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255
masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)')
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K])
if all('gwm_seg' in d for d in sample):
gwm_labels = torch.stack([x['gwm_seg'] for x in sample])
mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2)
gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1])
image_viz = torch.cat(
[rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255),
pred_seg.cpu(), *rec_flows], -1)
header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers]
else:
image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1)
header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers]
image_viz = torch.cat([image_viz, masks_each], -1)
header_text.extend(['slot'] * masks_softmaxed.shape[1])
if 'flow_edges' in sample[0]:
flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample])
if len(flow_edges.shape) >= 4:
flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4)
flow_edges = flow_edges.expand(-1, 3, -1, -1)
flow_edges = flow_edges * 255
image_viz = torch.cat([image_viz, flow_edges], -1)
header_text.append('flow_edges')
image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8)
return image_viz, header_text
def get_frame_vis(model, cfg, sample, preds):
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0)
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20)
masks_softmaxed = torch.softmax(masks_pred, dim=1)
if cfg.GWM.SIMPLE_REC:
mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7
means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom
rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means)
elif cfg.GWM.HOMOGRAPHY:
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow)
else:
grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device)
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y)
rgb = torch.stack([x['rgb'] for x in sample])
flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255
rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255
gt_labels = torch.stack([x['sem_seg'] for x in sample])
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2)
masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2)
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1])
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES])
frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1)
frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8)
return frame_vis
def is_2comp_dataset(dataset):
if '+' in dataset:
d = dataset.split('+')[0].strip()
else:
d = dataset.strip()
logger.info_once(f"Is 2comp dataset? {d}")
for s in ['DAVIS', 'FBMS', 'STv2']:
if s in d:
return True
return d in ['DAVIS',
'FBMS',
'STv2']
def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False):
logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:')
logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}')
logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}')
merger = None
if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2:
merger = MaskMerger(cfg, model)
print_idxs = random.sample(range(len(val_loader)), k=10)
images_viz = []
ious_davis_eval = defaultdict(list)
ious = defaultdict(list)
for idx, sample in enumerate(tqdm(val_loader)):
t = 1
sample = [e for s in sample for e in s]
category = [s['category'] for s in sample]
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True)
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
masks_softmaxed = torch.softmax(masks_raw, dim=1)
masks_dict = merger(sample, masks_softmaxed)
if writer and idx in print_idxs:
flow = torch.stack([x['flow'] for x in sample]).to(model.device)
img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion)
images_viz.append(img_viz)
masks = masks_dict['cos']
gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu()
HW = gt_seg.shape[-2:]
if HW != masks.shape[-2:]:
logger.info_once(f"Upsampling predicted masks to {HW} for evaluation")
masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False)
else:
masks_softmaxed_sel = masks.detach().cpu()
masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach()
gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float()
for i in range(masks_.size(0)):
masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w
mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s
iou_max, slot_max = mask_iou.max(dim=1)
ious[category[i][0]].append(iou_max)
frame_id = category[i][1]
ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max))
frameious = sum(ious.values(), [])
frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious)
if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]:
logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame")
seq_scores = dict()
for c in ious_davis_eval:
seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1])
frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100
if writer:
header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text)
images_viz = torch.cat(images_viz, dim=1)
images_viz = torch.cat([header, images_viz], dim=1)
writer.add_image('val/images', images_viz, writer_iteration) # C H W
writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration)
logger.info(f"mIoU: {frame_mean_iou:.3f} \n")
return frame_mean_iou
class MaskMerger:
def __init__(self, cfg, model, merger_model="dino_vits8"):
self.extractor = ViTExtractor(model_type=merger_model, device=model.device)
self.out_dim = 384
self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1)
self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1)
self.start_idx = 0
def get_feats(self, batch):
with torch.no_grad():
feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False)
feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2)
return F.interpolate(feat, batch.shape[-2:], mode='bilinear')
def spectral(self, A):
clustering = SpectralClustering(n_clusters=2,
affinity='precomputed',
random_state=0).fit(A.detach().cpu().numpy())
return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1]
def cos_merge(self, basis, masks):
basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6)
A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6)
inda, indb = self.spectral(A)
return torch.stack([masks[:, inda].sum(1),
masks[:, indb].sum(1)], 1)
def __call__(self, sample, masks_softmaxed):
with torch.no_grad():
masks_softmaxed = masks_softmaxed[:, self.start_idx:]
batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0
features = self.get_feats((batch - self.mu) / self.sigma)
basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features)
basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12)
return {
'cos': self.cos_merge(basis, masks_softmaxed),
}