Spaces:
Runtime error
Runtime error
import functools | |
import random | |
from collections import defaultdict | |
import einops | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageDraw, ImageFont | |
from sklearn.cluster import SpectralClustering | |
from tqdm import tqdm | |
import flow_reconstruction | |
from utils import visualisation, log, grid | |
from utils.vit_extractor import ViTExtractor | |
label_colors = visualisation.create_label_colormap() | |
logger = log.getLogger('gwm') | |
def __default_font(fontsize): | |
try: | |
FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize) | |
except OSError: | |
FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize) | |
return FNT | |
# cache the result | |
def autosized_default_font(size_limit: float) -> ImageFont.ImageFont: | |
fontsize = 1 # starting font size | |
font = __default_font(fontsize) | |
while font.getsize('test123')[1] < size_limit: | |
fontsize += 1 | |
font = __default_font(fontsize) | |
fontsize -= 1 | |
font = __default_font(fontsize) | |
return font | |
def iou(masks, gt, thres=0.5): | |
masks = (masks > thres).float() | |
intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1])) | |
union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect | |
return intersect / union.clip(min=1e-12) | |
def get_unsup_image_viz(model, cfg, sample, criterion): | |
if model.training: | |
model.eval() | |
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) | |
model.train() | |
else: | |
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) | |
return get_image_vis(model, cfg, sample, preds, criterion) | |
def get_vis_header(header_size, image_size, header_texts, header_height=20): | |
W, H = (image_size, header_height) | |
header_labels = [] | |
font = autosized_default_font(0.8 * H) | |
for text in header_texts: | |
im = Image.new("RGB", (W, H), "white") | |
draw = ImageDraw.Draw(im) | |
w, h = draw.textsize(text, font=font) | |
draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font) | |
header_labels.append(torch.from_numpy(np.array(im))) | |
header_labels = torch.cat(header_labels, dim=1) | |
ret = (torch.ones((header_height, header_size, 3)) * 255) | |
ret[:, :header_labels.size(1)] = header_labels | |
return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8) | |
def get_image_vis(model, cfg, sample, preds, criterion): | |
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0) | |
with torch.no_grad(): | |
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20) | |
masks_softmaxed = torch.softmax(masks_pred, dim=1) | |
masks_pred = masks_softmaxed | |
rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed) | |
rec_headers = ['rec_flow'] | |
if len(rec_flows) > 1: | |
rec_headers.append('rec_bwd_flow') | |
rgb = torch.stack([x['rgb'] for x in sample]) | |
flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255 | |
rec_flows = [ | |
(criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows | |
] | |
gt_labels = torch.stack([x['sem_seg'] for x in sample]) | |
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2) | |
target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES | |
masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2) | |
masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255 | |
masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)') | |
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1]) | |
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K]) | |
if all('gwm_seg' in d for d in sample): | |
gwm_labels = torch.stack([x['gwm_seg'] for x in sample]) | |
mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2) | |
gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1]) | |
image_viz = torch.cat( | |
[rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255), | |
pred_seg.cpu(), *rec_flows], -1) | |
header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers] | |
else: | |
image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1) | |
header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers] | |
image_viz = torch.cat([image_viz, masks_each], -1) | |
header_text.extend(['slot'] * masks_softmaxed.shape[1]) | |
if 'flow_edges' in sample[0]: | |
flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample]) | |
if len(flow_edges.shape) >= 4: | |
flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4) | |
flow_edges = flow_edges.expand(-1, 3, -1, -1) | |
flow_edges = flow_edges * 255 | |
image_viz = torch.cat([image_viz, flow_edges], -1) | |
header_text.append('flow_edges') | |
image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8) | |
return image_viz, header_text | |
def get_frame_vis(model, cfg, sample, preds): | |
masks_pred = torch.stack([x['sem_seg'] for x in preds], 0) | |
flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20) | |
masks_softmaxed = torch.softmax(masks_pred, dim=1) | |
if cfg.GWM.SIMPLE_REC: | |
mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7 | |
means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom | |
rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means) | |
elif cfg.GWM.HOMOGRAPHY: | |
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow) | |
else: | |
grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device) | |
rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y) | |
rgb = torch.stack([x['rgb'] for x in sample]) | |
flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255 | |
rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255 | |
gt_labels = torch.stack([x['sem_seg'] for x in sample]) | |
gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2) | |
masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2) | |
gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1]) | |
pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES]) | |
frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1) | |
frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8) | |
return frame_vis | |
def is_2comp_dataset(dataset): | |
if '+' in dataset: | |
d = dataset.split('+')[0].strip() | |
else: | |
d = dataset.strip() | |
logger.info_once(f"Is 2comp dataset? {d}") | |
for s in ['DAVIS', 'FBMS', 'STv2']: | |
if s in d: | |
return True | |
return d in ['DAVIS', | |
'FBMS', | |
'STv2'] | |
def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False): | |
logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:') | |
logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}') | |
logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}') | |
merger = None | |
if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2: | |
merger = MaskMerger(cfg, model) | |
print_idxs = random.sample(range(len(val_loader)), k=10) | |
images_viz = [] | |
ious_davis_eval = defaultdict(list) | |
ious = defaultdict(list) | |
for idx, sample in enumerate(tqdm(val_loader)): | |
t = 1 | |
sample = [e for s in sample for e in s] | |
category = [s['category'] for s in sample] | |
preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) | |
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) | |
masks_softmaxed = torch.softmax(masks_raw, dim=1) | |
masks_dict = merger(sample, masks_softmaxed) | |
if writer and idx in print_idxs: | |
flow = torch.stack([x['flow'] for x in sample]).to(model.device) | |
img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion) | |
images_viz.append(img_viz) | |
masks = masks_dict['cos'] | |
gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu() | |
HW = gt_seg.shape[-2:] | |
if HW != masks.shape[-2:]: | |
logger.info_once(f"Upsampling predicted masks to {HW} for evaluation") | |
masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False) | |
else: | |
masks_softmaxed_sel = masks.detach().cpu() | |
masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach() | |
gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float() | |
for i in range(masks_.size(0)): | |
masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w | |
mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s | |
iou_max, slot_max = mask_iou.max(dim=1) | |
ious[category[i][0]].append(iou_max) | |
frame_id = category[i][1] | |
ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max)) | |
frameious = sum(ious.values(), []) | |
frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious) | |
if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]: | |
logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame") | |
seq_scores = dict() | |
for c in ious_davis_eval: | |
seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1]) | |
frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100 | |
if writer: | |
header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text) | |
images_viz = torch.cat(images_viz, dim=1) | |
images_viz = torch.cat([header, images_viz], dim=1) | |
writer.add_image('val/images', images_viz, writer_iteration) # C H W | |
writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration) | |
logger.info(f"mIoU: {frame_mean_iou:.3f} \n") | |
return frame_mean_iou | |
class MaskMerger: | |
def __init__(self, cfg, model, merger_model="dino_vits8"): | |
self.extractor = ViTExtractor(model_type=merger_model, device=model.device) | |
self.out_dim = 384 | |
self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1) | |
self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1) | |
self.start_idx = 0 | |
def get_feats(self, batch): | |
with torch.no_grad(): | |
feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False) | |
feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) | |
return F.interpolate(feat, batch.shape[-2:], mode='bilinear') | |
def spectral(self, A): | |
clustering = SpectralClustering(n_clusters=2, | |
affinity='precomputed', | |
random_state=0).fit(A.detach().cpu().numpy()) | |
return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1] | |
def cos_merge(self, basis, masks): | |
basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6) | |
A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6) | |
inda, indb = self.spectral(A) | |
return torch.stack([masks[:, inda].sum(1), | |
masks[:, indb].sum(1)], 1) | |
def __call__(self, sample, masks_softmaxed): | |
with torch.no_grad(): | |
masks_softmaxed = masks_softmaxed[:, self.start_idx:] | |
batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0 | |
features = self.get_feats((batch - self.mu) / self.sigma) | |
basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features) | |
basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12) | |
return { | |
'cos': self.cos_merge(basis, masks_softmaxed), | |
} | |