diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..24745a1c043b8f335e01eed42ede150d0b9b3a4c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +checkpoints/checkpoint_best.pth filter=lfs diff=lfs merge=lfs -text +samples/1920px-Woman_at_work,_Gujarat.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index bcc43c092a18b1e7f850794bc21731cf5b2ebd52..51ff5bcc16cf3fb8a56974fddade2f973eeb471e 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: Guess What Moves -emoji: 🐨 -colorFrom: red -colorTo: gray +title: GWM +emoji: 🏄 +colorFrom: purple +colorTo: red sdk: gradio sdk_version: 3.17.0 app_file: app.py @@ -10,4 +10,4 @@ pinned: false license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +This is a demo for https://www.robots.ox.ac.uk/~vgg/research/gwm/. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..392193f44e17d094eaba76f177e6627f9b5f991e --- /dev/null +++ b/app.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python + +import os +try: + import detectron2 +except: + os.system('pip install git+https://github.com/facebookresearch/detectron2.git') + +import logging +logging.disable(logging.CRITICAL) # comment out to enable verbose logging + +######################################################### +import pathlib +import gradio as gr +import numpy as np +import PIL.Image as Image +import os +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +from PIL import Image +from collections import defaultdict +from pathlib import Path +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms as T +from tqdm import tqdm +from types import SimpleNamespace +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.data import MetadataCatalog +from detectron2.utils.visualizer import Visualizer + +import config +import utils as ut +from eval_utils import MaskMerger +from mask_former_trainer import setup, Trainer + + +def load_model_cfg(dataset=None): + + args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True) + cfg = setup(args) + cfg.defrost() + cfg.MODEL.DEVICE = 'cpu' + cfg.freeze() + random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE)) + + model = Trainer.build_model(cfg) + checkpointer = DetectionCheckpointer(model, + random_state=random_state, + save_dir=None) + + checkpoint_path = 'checkpoints/checkpoint_best.pth' + checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False) + model.eval() + + return model, cfg + +def edgeness(masks): + + em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device) + lm = em.clone() + lm[..., :2] = 1. + rm = em.clone() + rm[...,-2:] = 1. + tm = em.clone() + tm[..., :2, :] = 1. + bm = em.clone() + bm[..., -2:,:] = 1. + + one = torch.tensor(1.,dtype= masks.dtype, device=masks.device) + + l = (masks * lm).flatten(-2).sum(-1) / lm.sum() + l = torch.where(l > 0.3, one, l) + r = (masks * rm).flatten(-2).sum(-1) / rm.sum() + r = torch.where(r > 0.3, one, r) + t = (masks * tm).flatten(-2).sum(-1) / tm.sum() + t = torch.where(t > 0.3, one, t) + b = (masks * bm).flatten(-2).sum(-1) / bm.sum() + b = torch.where(b > 0.3, one, b) + return (l + r + t + b ) + +def expand2sizedivisible(pil_img, background_color, size_divisibility): + width, height = pil_img.size + if width % size_divisibility == 0 and height % size_divisibility == 0: + return pil_img + result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color) + result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2)) + + return result + +def cropfromsizedivisible(img, size_divisibility, orig_size): + height, width = img.shape[:2] + owidth, oheight = orig_size + result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2] + + return result + + +def evaluate_image(image_path): + binary_threshold = 0.5 + metadata = MetadataCatalog.get("__unused") + + model, cfg = load_model_cfg("DAVIS") + + merger = MaskMerger(cfg, model, merger_model="dino_vitb8") + + + image_pil = Image.open(image_path).convert('RGB') + + image_pil.thumbnail((384, 384)) + + osize = image_pil.size + if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0: + image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY) + + image = np.asarray(image_pil) + image_pt = torch.from_numpy(np.array(image)).permute(2,0,1) + + with torch.no_grad(): + sample = [{'rgb': image_pt}] + preds = model.forward_base(sample, keys=['rgb'], get_eval=True) + masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) + + K = masks_raw.shape[1] + if K > 2: + masks_softmaxed = torch.softmax(masks_raw, dim=1) + masks_dict = merger(sample, masks_softmaxed) + K = 2 + masks = masks_dict['cos'] + else: + print(K) + masks = masks_raw.softmax(1) + masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w + bg = edgeness(masks_raw)[0].argmax().item() + + masks = masks_raw[0] > binary_threshold + frame_visualizer = Visualizer(image, metadata) + out = frame_visualizer.overlay_instances( + masks=masks[[int(bg==0)]], + alpha=0.3, + assigned_colors=[(1,0,1)] + ).get_image() + + return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize) + + +paths = sorted(pathlib.Path('samples').glob('*.jpg')) +css = ".component-1 {height: 256px !important;}" +demo = gr.Interface( + fn=evaluate_image, + inputs=gr.Image(label='Image', type='filepath'), + outputs=gr.Image(label='Annotated Image', type='numpy'), + examples=[[path.as_posix(), 0.15, 6] for path in paths], + title="Guess What Moves", + description="#### Unsupervised Image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)", + css=css) +demo.queue().launch() diff --git a/checkpoints/checkpoint_best.pth b/checkpoints/checkpoint_best.pth new file mode 100644 index 0000000000000000000000000000000000000000..9ae0f6b16b1afd463f1819c47fbd4a602d342617 --- /dev/null +++ b/checkpoints/checkpoint_best.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a573e43ffc78e84dbf8b4f2c9e31195bed3c44d8e7be942daba92c7437b8b17d +size 63672283 diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8632b0a6886856bf7910fda1509f2db6314b32 --- /dev/null +++ b/config.py @@ -0,0 +1,377 @@ +import copy +import itertools +import logging +import os +from pathlib import Path + +import numpy as np +import torch.utils.data +from detectron2.config import CfgNode as CN + +import utils +from datasets import FlowPairDetectron, FlowEvalDetectron + +logger = logging.getLogger('gwm') + +def scan_train_flow(folders, res, pairs, basepath): + pair_list = [p for p in itertools.combinations(pairs, 2)] + + flow_dir = {} + for pair in pair_list: + p1, p2 = pair + flowpairs = [] + for f in folders: + path1 = basepath / f'Flows_gap{p1}' / res / f + path2 = basepath / f'Flows_gap{p2}' / res / f + + flows1 = [p.name for p in path1.glob('*.flo')] + flows2 = [p.name for p in path2.glob('*.flo')] + + flows1 = sorted(flows1) + flows2 = sorted(flows2) + + intersect = list(set(flows1).intersection(flows2)) + intersect.sort() + + flowpair = np.array([[path1 / i, path2 / i] for i in intersect]) + flowpairs += [flowpair] + flow_dir['gap_{}_{}'.format(p1, p2)] = flowpairs + + # flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names, + # each item then is an array with Nx2, N indicates the number of available pairs. + return flow_dir + + +def setup_dataset(cfg=None, multi_val=False): + dataset_str = cfg.GWM.DATASET + if '+' in dataset_str: + datasets = dataset_str.split('+') + logger.info(f'Multiple datasets detected: {datasets}') + train_datasets = [] + val_datasets = [] + for ds in datasets: + proxy_cfg = copy.deepcopy(cfg) + proxy_cfg.merge_from_list(['GWM.DATASET', ds]), + train_ds, val_ds = setup_dataset(proxy_cfg, multi_val=multi_val) + train_datasets.append(train_ds) + val_datasets.append(val_ds) + logger.info(f'Multiple datasets detected: {datasets}') + logger.info(f'Validation is still : {datasets[0]}') + return torch.utils.data.ConcatDataset(train_datasets), val_datasets[0] + + resolution = cfg.GWM.RESOLUTION # h,w + res = "" + with_gt = True + pairs = [1, 2, -1, -2] + trainval_data_dir = None + + if cfg.GWM.DATASET == 'DAVIS': + basepath = '/DAVIS2016' + img_dir = '/DAVIS2016/JPEGImages/480p' + gt_dir = '/DAVIS2016/Annotations/480p' + + val_flow_dir = '/DAVIS2016/Flows_gap1/1080p' + val_seq = ['dog', 'cows', 'goat', 'camel', 'libby', 'parkour', 'soapbox', 'blackswan', 'bmx-trees', + 'kite-surf', 'car-shadow', 'breakdance', 'dance-twirl', 'scooter-black', 'drift-chicane', + 'motocross-jump', 'horsejump-high', 'drift-straight', 'car-roundabout', 'paragliding-launch'] + val_data_dir = [val_flow_dir, img_dir, gt_dir] + res = "1080p" + + elif cfg.GWM.DATASET in ['FBMS']: + basepath = '/FBMS_clean' + img_dir = '/FBMS_clean/JPEGImages/' + gt_dir = '/FBMS_clean/Annotations/' + + val_flow_dir = '/FBMS_val/Flows_gap1/' + val_seq = ['camel01', 'cars1', 'cars10', 'cars4', 'cars5', 'cats01', 'cats03', 'cats06', + 'dogs01', 'dogs02', 'farm01', 'giraffes01', 'goats01', 'horses02', 'horses04', + 'horses05', 'lion01', 'marple12', 'marple2', 'marple4', 'marple6', 'marple7', 'marple9', + 'people03', 'people1', 'people2', 'rabbits02', 'rabbits03', 'rabbits04', 'tennis'] + val_img_dir = '/FBMS_val/JPEGImages/' + val_gt_dir = '/FBMS_val/Annotations/' + val_data_dir = [val_flow_dir, val_img_dir, val_gt_dir] + with_gt = False + pairs = [3, 6, -3, -6] + + elif cfg.GWM.DATASET in ['STv2']: + basepath = '/SegTrackv2' + img_dir = '/SegTrackv2/JPEGImages' + gt_dir = '/SegTrackv2/Annotations' + + val_flow_dir = '/SegTrackv2/Flows_gap1/' + val_seq = ['drift', 'birdfall', 'girl', 'cheetah', 'worm', 'parachute', 'monkeydog', + 'hummingbird', 'soldier', 'bmx', 'frog', 'penguin', 'monkey', 'bird_of_paradise'] + val_data_dir = [val_flow_dir, img_dir, gt_dir] + + else: + raise ValueError('Unknown Setting/Dataset.') + + # Switching this section to pathlib, which should prevent double // errors in paths and dict keys + + root_path_str = cfg.GWM.DATA_ROOT + logger.info(f"Found DATA_ROOT in config: {root_path_str}") + root_path_str = '../data' + + if root_path_str.startswith('/'): + root_path = Path(f"/{root_path_str.lstrip('/').rstrip('/')}") + else: + root_path = Path(f"{root_path_str.lstrip('/').rstrip('/')}") + + logger.info(f"Loading dataset from: {root_path}") + + basepath = root_path / basepath.lstrip('/').rstrip('/') + img_dir = root_path / img_dir.lstrip('/').rstrip('/') + gt_dir = root_path / gt_dir.lstrip('/').rstrip('/') + val_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in val_data_dir] + + folders = [p.name for p in (basepath / f'Flows_gap{pairs[0]}' / res).iterdir() if p.is_dir()] + folders = sorted(folders) + + # flow_dir is a dictionary, with keys indicating the flow gap, and each value is a list of sequence names, + # each item then is an array with Nx2, N indicates the number of available pairs. + + flow_dir = scan_train_flow(folders, res, pairs, basepath) + data_dir = [flow_dir, img_dir, gt_dir] + + force1080p = ('DAVIS' not in cfg.GWM.DATASET) and 'RGB_BIG' in cfg.GWM.SAMPLE_KEYS + + enable_photometric_augmentations = cfg.FLAGS.INF_TPS + + train_dataset = FlowPairDetectron(data_dir=data_dir, + resolution=resolution, + to_rgb=cfg.GWM.FLOW2RGB, + size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1, + enable_photo_aug=enable_photometric_augmentations, + flow_clip=cfg.GWM.FLOW_CLIP, + norm=cfg.GWM.FLOW_NORM, + force1080p=force1080p, + flow_res=cfg.GWM.FLOW_RES, ) + if multi_val: + print(f"Using multiple validation datasets from {val_data_dir}") + val_dataset = [FlowEvalDetectron(data_dir=val_data_dir, + resolution=resolution, + pair_list=pairs, + val_seq=[vs], + to_rgb=cfg.GWM.FLOW2RGB, + with_rgb=False, + size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1, + flow_clip=cfg.GWM.FLOW_CLIP, + norm=cfg.GWM.FLOW_NORM, + force1080p=force1080p) for vs in val_seq] + for vs, vds in zip(val_seq, val_dataset): + print(f"Validation dataset for {vs}: {len(vds)}") + if len(vds) == 0: + raise ValueError(f"Empty validation dataset for {vs}") + + if cfg.GWM.TTA_AS_TRAIN: + if trainval_data_dir is None: + trainval_data_dir = val_data_dir + else: + trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir] + trainval_dataset = [] + tvd_basepath = root_path / str(trainval_data_dir[0].relative_to(root_path)).split('/')[0] + print("TVD BASE DIR", tvd_basepath) + for vs in val_seq: + tvd_data_dir = [scan_train_flow([vs], res, pairs, tvd_basepath), *trainval_data_dir[1:]] + tvd = FlowPairDetectron(data_dir=tvd_data_dir, + resolution=resolution, + to_rgb=cfg.GWM.FLOW2RGB, + size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1, + enable_photo_aug=cfg.GWM.LOSS_MULT.EQV is not None, + flow_clip=cfg.GWM.FLOW_CLIP, + norm=cfg.GWM.FLOW_NORM, + force1080p=force1080p, + flow_res=cfg.GWM.FLOW_RES, ) + trainval_dataset.append(tvd) + print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}') + else: + if trainval_data_dir is None: + trainval_dataset = val_dataset + else: + trainval_data_dir = [root_path / path.lstrip('/').rstrip('/') for path in trainval_data_dir] + trainval_dataset = [] + for vs in val_seq: + tvd = FlowEvalDetectron(data_dir=trainval_data_dir, + resolution=resolution, + pair_list=pairs, + val_seq=[vs], + to_rgb=cfg.GWM.FLOW2RGB, + with_rgb=False, + size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1, + flow_clip=cfg.GWM.FLOW_CLIP, + norm=cfg.GWM.FLOW_NORM, + force1080p=force1080p) + trainval_dataset.append(tvd) + print(f'Seq {trainval_data_dir[0]}/{vs} dataset: {len(tvd)}') + return train_dataset, val_dataset, trainval_dataset + val_dataset = FlowEvalDetectron(data_dir=val_data_dir, + resolution=resolution, + pair_list=pairs, + val_seq=val_seq, + to_rgb=cfg.GWM.FLOW2RGB, + with_rgb=False, + size_divisibility=cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY if not cfg.FLAGS.IGNORE_SIZE_DIV else -1, + flow_clip=cfg.GWM.FLOW_CLIP, + norm=cfg.GWM.FLOW_NORM, + force1080p=force1080p) + + return train_dataset, val_dataset + + +def loaders(cfg): + train_dataset, val_dataset = setup_dataset(cfg) + logger.info(f"Sourcing data from {val_dataset.data_dir[0]}") + + if cfg.FLAGS.DEV_DATA: + subset = cfg.SOLVER.IMS_PER_BATCH * 3 + train_dataset = torch.utils.data.Subset(train_dataset, list(range(subset))) + val_dataset = torch.utils.data.Subset(val_dataset, list(range(subset))) + + g = torch.Generator() + data_generator_seed = int(torch.randint(int(1e6), (1,)).item()) + logger.info(f"Dataloaders generator seed {data_generator_seed}") + g.manual_seed(data_generator_seed) + + train_loader = torch.utils.data.DataLoader(train_dataset, + num_workers=cfg.DATALOADER.NUM_WORKERS, + batch_size=cfg.SOLVER.IMS_PER_BATCH, + collate_fn=lambda x: x, + shuffle=True, + pin_memory=True, + drop_last=True, + persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0, + worker_init_fn=utils.random_state.worker_init_function, + generator=g + ) + val_loader = torch.utils.data.DataLoader(val_dataset, + num_workers=cfg.DATALOADER.NUM_WORKERS, + batch_size=1, + shuffle=False, + pin_memory=True, + collate_fn=lambda x: x, + drop_last=False, + persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0, + worker_init_fn=utils.random_state.worker_init_function, + generator=g) + return train_loader, val_loader + + +def multi_loaders(cfg): + train_dataset, val_datasets, train_val_datasets = setup_dataset(cfg, multi_val=True) + logger.info(f"Sourcing multiple loaders from {len(val_datasets)}") + logger.info(f"Sourcing data from {val_datasets[0].data_dir[0]}") + + g = torch.Generator() + data_generator_seed = int(torch.randint(int(1e6), (1,)).item()) + logger.info(f"Dataloaders generator seed {data_generator_seed}") + g.manual_seed(data_generator_seed) + + train_loader = torch.utils.data.DataLoader(train_dataset, + num_workers=cfg.DATALOADER.NUM_WORKERS, + batch_size=cfg.SOLVER.IMS_PER_BATCH, + collate_fn=lambda x: x, + shuffle=True, + pin_memory=True, + drop_last=True, + persistent_workers=cfg.DATALOADER.NUM_WORKERS > 0, + worker_init_fn=utils.random_state.worker_init_function, + generator=g + ) + + val_loaders = [(torch.utils.data.DataLoader(val_dataset, + num_workers=0, + batch_size=1, + shuffle=False, + pin_memory=True, + collate_fn=lambda x: x, + drop_last=False, + persistent_workers=False, + worker_init_fn=utils.random_state.worker_init_function, + generator=g), + torch.utils.data.DataLoader(tv_dataset, + num_workers=0, + batch_size=cfg.SOLVER.IMS_PER_BATCH, + shuffle=True, + pin_memory=False, + collate_fn=lambda x: x, + drop_last=False, + persistent_workers=False, + worker_init_fn=utils.random_state.worker_init_function, + generator=g)) + for val_dataset, tv_dataset in zip(val_datasets, train_val_datasets)] + + return train_loader, val_loaders + + +def add_gwm_config(cfg): + cfg.GWM = CN() + cfg.GWM.MODEL = "MASKFORMER" + cfg.GWM.RESOLUTION = (128, 224) + cfg.GWM.FLOW_RES = (480, 854) + cfg.GWM.SAMPLE_KEYS = ["rgb"] + cfg.GWM.ADD_POS_EMB = False + cfg.GWM.CRITERION = "L2" + cfg.GWM.L1_OPTIMIZE = False + cfg.GWM.HOMOGRAPHY = 'quad' # False + cfg.GWM.HOMOGRAPHY_SUBSAMPLE = 8 + cfg.GWM.HOMOGRAPHY_SKIP = 0.4 + cfg.GWM.DATASET = 'DAVIS' + cfg.GWM.DATA_ROOT = None + cfg.GWM.FLOW2RGB = False + cfg.GWM.SIMPLE_REC = False + cfg.GWM.DAVIS_SINGLE_VID = None + cfg.GWM.USE_MULT_FLOW = False + cfg.GWM.FLOW_COLORSPACE_REC = None + + cfg.GWM.FLOW_CLIP_U_LOW = float('-inf') + cfg.GWM.FLOW_CLIP_U_HIGH = float('inf') + cfg.GWM.FLOW_CLIP_V_LOW = float('-inf') + cfg.GWM.FLOW_CLIP_V_HIGH = float('inf') + + cfg.GWM.FLOW_CLIP = float('inf') + cfg.GWM.FLOW_NORM = False + + cfg.GWM.LOSS_MULT = CN() + cfg.GWM.LOSS_MULT.REC = 0.03 + cfg.GWM.LOSS_MULT.HEIR_W = [0.1, 0.3, 0.6] + + + cfg.GWM.TTA = 100 # Test-time-adaptation + cfg.GWM.TTA_AS_TRAIN = False # Use train-like data logic for test-time-adaptation + + cfg.GWM.LOSS = 'OG' + + cfg.FLAGS = CN() + cfg.FLAGS.MAKE_VIS_VIDEOS = False # Making videos is kinda slow + cfg.FLAGS.EXTENDED_FLOW_RECON_VIS = False # Does not cost much + cfg.FLAGS.COMP_NLL_FOR_GT = False # Should we log loss against ground truth? + cfg.FLAGS.DEV_DATA = False + cfg.FLAGS.KEEP_ALL = True # Keep all checkoints + cfg.FLAGS.ORACLE_CHECK = False # Use oracle check to estimate max performance when grouping multiple components + + cfg.FLAGS.INF_TPS = False + + # cfg.FLAGS.UNFREEZE_AT = [(1, 10000), (0, 20000), (-1, 30000)] + cfg.FLAGS.UNFREEZE_AT = [(4, 0), (2, 500), (1, 1000), (-1, 10000)] + + cfg.FLAGS.IGNORE_SIZE_DIV = False + + cfg.FLAGS.IGNORE_TMP = True + + cfg.WANDB = CN() + cfg.WANDB.ENABLE = False + cfg.WANDB.BASEDIR = '../' + + cfg.DEBUG = False + + cfg.LOG_ID = 'exp' + cfg.LOG_FREQ = 250 + cfg.OUTPUT_BASEDIR = '../outputs' + cfg.SLURM = False + cfg.SKIP_TB = False + cfg.TOTAL_ITER = 20000 + cfg.CONFIG_FILE = None + + if os.environ.get('SLURM_JOB_ID', None): + cfg.LOG_ID = os.environ.get('SLURM_JOB_NAME', cfg.LOG_ID) + logger.info(f"Setting name {cfg.LOG_ID} based on SLURM job name") diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b449419905c5bb6f5a7016e99f8c9fac81c892bc --- /dev/null +++ b/configs/README.md @@ -0,0 +1,16 @@ +## Available configs: + + +Use `main.py --config-file=` + +No need to specify `GWM.MODEL`. It is already defined inside the config files + + +### Available configs: +``` +maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml +maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml +maskformer/swin/maskformer_swin_small_bs16_160k.yaml +maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml +maskformer/maskformer_R50_bs16_160k.yaml +``` diff --git a/configs/maskformer/Base-ADE20K-150.yaml b/configs/maskformer/Base-ADE20K-150.yaml new file mode 100644 index 0000000000000000000000000000000000000000..451fca76667bf9af3aa5986cd2b94523bbb23ba2 --- /dev/null +++ b/configs/maskformer/Base-ADE20K-150.yaml @@ -0,0 +1,60 @@ +_BASE_: Base-unsup-vidseg.yaml +MODEL: + BACKBONE: + FREEZE_AT: 0 + NAME: "build_resnet_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STEM_TYPE: "basic" # not used + STEM_OUT_CHANNELS: 64 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 1, 1] # not used +DATASETS: + TRAIN: ("ade20k_sem_seg_train",) + TEST: ("ade20k_sem_seg_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.0001 + MAX_ITER: 160000 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 0 + WEIGHT_DECAY: 0.0001 + OPTIMIZER: "ADAMW" + LR_SCHEDULER_NAME: "WarmupPolyLR" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 512 + MAX_SIZE_TRAIN: 2048 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (512, 512) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 512 # used in dataset mapper + FORMAT: "RGB" + DATASET_MAPPER_NAME: "mask_former_semantic" +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [256, 384, 512, 640, 768, 896] + MAX_SIZE: 3584 + FLIP: True +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: True + NUM_WORKERS: 10 +VERSION: 2 diff --git a/configs/maskformer/Base-unsup-vidseg.yaml b/configs/maskformer/Base-unsup-vidseg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43c3bb496bbd52d4a48f1583e2d4a34d6ecc3c2c --- /dev/null +++ b/configs/maskformer/Base-unsup-vidseg.yaml @@ -0,0 +1,3 @@ +SEED: 42 +GWM: + MODEL: "MASKFORMER" diff --git a/configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml b/configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b52542657bc7170432cac0890858e5178ec4677 --- /dev/null +++ b/configs/maskformer/cityscapes-19/Base-Cityscapes-19.yaml @@ -0,0 +1,59 @@ +MODEL: + BACKBONE: + FREEZE_AT: 0 + NAME: "build_resnet_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STEM_TYPE: "basic" # not used + STEM_OUT_CHANNELS: 64 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 1, 1] # not used +DATASETS: + TRAIN: ("cityscapes_fine_sem_seg_train",) + TEST: ("cityscapes_fine_sem_seg_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.0001 + MAX_ITER: 90000 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 0 + WEIGHT_DECAY: 0.0001 + OPTIMIZER: "ADAMW" + LR_SCHEDULER_NAME: "WarmupPolyLR" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 1024 + MAX_SIZE_TRAIN: 4096 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (512, 1024) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: -1 + FORMAT: "RGB" + DATASET_MAPPER_NAME: "mask_former_semantic" +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792] + MAX_SIZE: 4096 + FLIP: True +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: True + NUM_WORKERS: 4 +VERSION: 2 diff --git a/configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml b/configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1017a7b3586939525b754e6010cf66e64c2ff93 --- /dev/null +++ b/configs/maskformer/cityscapes-19/maskformer_R101_bs16_90k.yaml @@ -0,0 +1,36 @@ +_BASE_: Base-Cityscapes-19.yaml +MODEL: + WEIGHTS: "R-101.pkl" + RESNETS: + DEPTH: 101 + STEM_TYPE: "basic" # not used + STEM_OUT_CHANNELS: 64 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 1, 1] # not used + META_ARCHITECTURE: "MaskFormer" + SEM_SEG_HEAD: + NAME: "MaskFormerHead" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + IGNORE_VALUE: 255 + NUM_CLASSES: 19 + COMMON_STRIDE: 4 # not used, hard-coded + LOSS_WEIGHT: 1.0 + CONVS_DIM: 256 + MASK_DIM: 256 + NORM: "GN" + MASK_FORMER: + TRANSFORMER_IN_FEATURE: "res5" + DEEP_SUPERVISION: True + NO_OBJECT_WEIGHT: 0.1 + DICE_WEIGHT: 1.0 + MASK_WEIGHT: 20.0 + HIDDEN_DIM: 256 + NUM_OBJECT_QUERIES: 100 + NHEADS: 8 + DROPOUT: 0.1 + DIM_FEEDFORWARD: 2048 + ENC_LAYERS: 0 + DEC_LAYERS: 6 + PRE_NORM: False diff --git a/configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml b/configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e07bbeea78419af7404dfd01d11cc7e3cd2f10f6 --- /dev/null +++ b/configs/maskformer/cityscapes-19/maskformer_R101c_bs16_90k.yaml @@ -0,0 +1,16 @@ +_BASE_: maskformer_R101_bs16_90k.yaml +MODEL: + BACKBONE: + FREEZE_AT: 0 + NAME: "build_resnet_deeplab_backbone" + WEIGHTS: "detectron2://DeepLab/R-103.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 101 + STEM_TYPE: "deeplab" + STEM_OUT_CHANNELS: 128 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 2, 4] diff --git a/configs/maskformer/maskformer_R50_bs16_160k.yaml b/configs/maskformer/maskformer_R50_bs16_160k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07319d6e95ad9bd3bdb1e55a7e70a4615335d945 --- /dev/null +++ b/configs/maskformer/maskformer_R50_bs16_160k.yaml @@ -0,0 +1,27 @@ +_BASE_: Base-ADE20K-150.yaml +MODEL: + META_ARCHITECTURE: "MaskFormer" + SEM_SEG_HEAD: + NAME: "MaskFormerHead" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + IGNORE_VALUE: 255 + NUM_CLASSES: 2 + COMMON_STRIDE: 4 # not used, hard-coded + LOSS_WEIGHT: 1.0 + CONVS_DIM: 256 + MASK_DIM: 256 + NORM: "GN" + MASK_FORMER: + TRANSFORMER_IN_FEATURE: "res5" + DEEP_SUPERVISION: False + NO_OBJECT_WEIGHT: 0.1 + DICE_WEIGHT: 1.0 + MASK_WEIGHT: 20.0 + HIDDEN_DIM: 256 + NUM_OBJECT_QUERIES: 2 + NHEADS: 8 + DROPOUT: 0.1 + DIM_FEEDFORWARD: 2048 + ENC_LAYERS: 0 + DEC_LAYERS: 6 + PRE_NORM: False diff --git a/configs/maskformer/maskformer_R50_bs16_160k_dino.yaml b/configs/maskformer/maskformer_R50_bs16_160k_dino.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3dc93d1a97da9e9393a17a050a8e2c35d3bc92a --- /dev/null +++ b/configs/maskformer/maskformer_R50_bs16_160k_dino.yaml @@ -0,0 +1,31 @@ +_BASE_: ./maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2ViTTransformer" + FREEZE_AT: -1 + SWIN: + EMBED_DIM: 768 + DEPTHS: [2, 2, 6, 2] + NUM_HEADS: [3, 6, 12, 24] + WINDOW_SIZE: 7 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + WEIGHTS: None + MASK_FORMER: + NUM_OBJECT_QUERIES: 4 + SEM_SEG_HEAD: + PIXEL_DECODER_NAME: BigPixelDecoder +SOLVER: + BASE_LR: 0.00015 + IMS_PER_BATCH: 8 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 +FLAGS: + UNFREEZE_AT: [] diff --git a/configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml b/configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml new file mode 100644 index 0000000000000000000000000000000000000000..794a2f5b66d22c2060cf50e6ff8d3741ef79fd3e --- /dev/null +++ b/configs/maskformer/swin/maskformer_swin_base_IN21k_384_bs16_160k_res640.yaml @@ -0,0 +1,45 @@ +_BASE_: ../maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 128 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [4, 8, 16, 32] + WINDOW_SIZE: 12 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PRETRAIN_IMG_SIZE: 384 + WEIGHTS: "pretrained_weights/swin_base_patch4_window12_384_22k.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] +SOLVER: + BASE_LR: 0.00006 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 640 + MAX_SIZE_TRAIN: 2560 + MAX_SIZE_TEST: 2560 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (640, 640) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 640 # used in dataset mapper + FORMAT: "RGB" +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [320, 480, 640, 800, 960, 1120] + MAX_SIZE: 4480 + FLIP: True diff --git a/configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml b/configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c2e2337204a1097f7ba94fa3a1c97296c1695fb --- /dev/null +++ b/configs/maskformer/swin/maskformer_swin_large_IN21k_384_bs16_160k_res640.yaml @@ -0,0 +1,45 @@ +_BASE_: ../maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 192 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [6, 12, 24, 48] + WINDOW_SIZE: 12 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PRETRAIN_IMG_SIZE: 384 + WEIGHTS: "pretrained_weights/swin_large_patch4_window12_384_22k.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] +SOLVER: + BASE_LR: 0.00006 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 640 + MAX_SIZE_TRAIN: 2560 + MAX_SIZE_TEST: 2560 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (640, 640) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 640 # used in dataset mapper + FORMAT: "RGB" +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [320, 480, 640, 800, 960, 1120] + MAX_SIZE: 4480 + FLIP: True diff --git a/configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml b/configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..843f96881a9dca4186c222da4b923b83705d2ca9 --- /dev/null +++ b/configs/maskformer/swin/maskformer_swin_small_bs16_160k.yaml @@ -0,0 +1,23 @@ +_BASE_: ../maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 96 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [3, 6, 12, 24] + WINDOW_SIZE: 7 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + WEIGHTS: "pretrained_weights/swin_small_patch4_window7_224.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] +SOLVER: + BASE_LR: 0.00006 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 diff --git a/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml b/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08b533a7991efe3c2f7f54d1231e89b5da0d8dc0 --- /dev/null +++ b/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k.yaml @@ -0,0 +1,23 @@ +_BASE_: ../maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 96 + DEPTHS: [2, 2, 6, 2] + NUM_HEADS: [3, 6, 12, 24] + WINDOW_SIZE: 7 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + WEIGHTS: "pretrained_weights/swin_tiny_patch4_window7_224.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] +SOLVER: + BASE_LR: 0.00006 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 diff --git a/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml b/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ea9c210dac35850aa00042b1890a44d8ead4f98 --- /dev/null +++ b/configs/maskformer/swin/maskformer_swin_tiny_bs16_160k_moby.yaml @@ -0,0 +1,23 @@ +_BASE_: ../maskformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 96 + DEPTHS: [2, 2, 6, 2] + NUM_HEADS: [3, 6, 12, 24] + WINDOW_SIZE: 7 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + WEIGHTS: "pretrained_weights/moby_swin_t_300ep_pretrained.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] +SOLVER: + BASE_LR: 0.00006 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb50111224f11e4f63ca2e8f4df059a4637dda9f --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,2 @@ +from .flow_eval_detectron import FlowEvalDetectron +from .flow_pair_detectron import FlowPairDetectron diff --git a/datasets/flow_eval_detectron.py b/datasets/flow_eval_detectron.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8dbc4a96695113d5e4af23033e48c8c644cd42 --- /dev/null +++ b/datasets/flow_eval_detectron.py @@ -0,0 +1,209 @@ +import math +import os +from pathlib import Path + +import detectron2.data.transforms as DT +import einops +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from detectron2.data import detection_utils as d2_utils +from detectron2.structures import Instances, BitMasks +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset + +from utils.data import read_flow + + +class FlowEvalDetectron(Dataset): + def __init__(self, data_dir, resolution, pair_list, val_seq, to_rgb=False, with_rgb=False, size_divisibility=None, + small_val=0, flow_clip=1., norm=True, read_big=True, eval_size=True, force1080p=False): + self.val_seq = val_seq + self.to_rgb = to_rgb + self.with_rgb = with_rgb + self.data_dir = data_dir + self.pair_list = pair_list + self.resolution = resolution + + self.eval_size = eval_size + + self.samples = [] + self.samples_fid = {} + for v in self.val_seq: + seq_dir = Path(self.data_dir[0]) / v + frames_paths = sorted(seq_dir.glob('*.flo')) + self.samples_fid[str(seq_dir)] = {fp: i for i, fp in enumerate(frames_paths)} + self.samples.extend(frames_paths) + self.samples = [os.path.join(x.parent.name, x.name) for x in self.samples] + if small_val > 0: + _, self.samples = train_test_split(self.samples, test_size=small_val, random_state=42) + self.gaps = ['gap{}'.format(i) for i in pair_list] + self.neg_gaps = ['gap{}'.format(-i) for i in pair_list] + self.size_divisibility = size_divisibility + self.ignore_label = -1 + self.transforms = DT.AugmentationList([ + DT.Resize(self.resolution, interp=Image.BICUBIC), + ]) + self.flow_clip=flow_clip + self.norm_flow=norm + self.read_big=read_big + self.force1080p_transforms=None + if force1080p: + self.force1080p_transforms = DT.AugmentationList([ + DT.Resize((1088, 1920), interp=Image.BICUBIC), + ]) + + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + dataset_dicts = [] + + dataset_dict = {} + flow_dir = Path(self.data_dir[0]) / self.samples[idx] + fid = self.samples_fid[str(flow_dir.parent)][flow_dir] + flo = einops.rearrange(read_flow(str(flow_dir), self.resolution, self.to_rgb), 'c h w -> h w c') + dataset_dict["gap"] = 'gap1' + + suffix = '.png' if 'CLEVR' in self.samples[idx] else '.jpg' + rgb_dir = (self.data_dir[1] / self.samples[idx]).with_suffix(suffix) + gt_dir = (self.data_dir[2] / self.samples[idx]).with_suffix('.png') + + rgb = d2_utils.read_image(str(rgb_dir)).astype(np.float32) + original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float() + if self.read_big: + rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32) + rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.) + if self.force1080p_transforms is not None: + rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0] + + input = DT.AugInput(rgb) + + # Apply the augmentation: + preprocessing_transforms = self.transforms(input) # type: DT.Transform + rgb = input.image + rgb = np.transpose(rgb, (2, 0, 1)) + rgb = rgb.clip(0., 255.) + d2_utils.check_image_size(dataset_dict, flo) + + if gt_dir.exists(): + sem_seg_gt_ori = d2_utils.read_image(gt_dir) + sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt_ori) + if sem_seg_gt.ndim == 3: + sem_seg_gt = sem_seg_gt[:, :, 0] + sem_seg_gt_ori = sem_seg_gt_ori[:, :, 0] + if sem_seg_gt.max() == 255: + sem_seg_gt = (sem_seg_gt > 128).astype(int) + sem_seg_gt_ori = (sem_seg_gt_ori > 128).astype(int) + else: + sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1])) + sem_seg_gt_ori = np.zeros((original_rgb.shape[-2], original_rgb.shape[-1])) + + gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / self.samples[idx]).with_suffix( + '.png') + if gwm_dir.exists(): + gwm_seg_gt = preprocessing_transforms.apply_segmentation(d2_utils.read_image(str(gwm_dir))) + gwm_seg_gt = np.array(gwm_seg_gt) + if gwm_seg_gt.ndim == 3: + gwm_seg_gt = gwm_seg_gt[:, :, 0] + if gwm_seg_gt.max() == 255: + gwm_seg_gt[gwm_seg_gt == 255] = 1 + else: + gwm_seg_gt = None + + if sem_seg_gt is None: + raise ValueError( + "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( + dataset_dict["file_name"] + ) + ) + + # Pad image and segmentation label here! + if self.to_rgb: + flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) / 2 + .5 + flo = flo * 255 + else: + flo = torch.as_tensor(np.ascontiguousarray(flo.transpose(2, 0, 1))) + if self.norm_flow: + flo = flo/(flo ** 2).sum(0).max().sqrt() + flo = flo.clip(-self.flow_clip, self.flow_clip) + + rgb = torch.as_tensor(np.ascontiguousarray(rgb)).float() + if sem_seg_gt is not None: + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + sem_seg_gt_ori = torch.as_tensor(sem_seg_gt_ori.astype("long")) + if gwm_seg_gt is not None: + gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long")) + + if self.size_divisibility > 0: + image_size = (flo.shape[-2], flo.shape[-1]) + padding_size = [ + 0, + int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1], + 0, + int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0], + ] + flo = F.pad(flo, padding_size, value=0).contiguous() + rgb = F.pad(rgb, padding_size, value=128).contiguous() + if sem_seg_gt is not None: + sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() + if gwm_seg_gt is not None: + gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous() + + image_shape = (flo.shape[-2], flo.shape[-1]) # h, w + if self.eval_size: + image_shape = (sem_seg_gt_ori.shape[-2], sem_seg_gt_ori.shape[-1]) + + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["flow"] = flo + dataset_dict["rgb"] = rgb + + + dataset_dict["original_rgb"] = F.interpolate(original_rgb[None], mode='bicubic', size=sem_seg_gt_ori.shape[-2:], align_corners=False).clip(0.,255.)[0] + if self.read_big: + dataset_dict["RGB_BIG"] = rgb_big + + dataset_dict["category"] = str(gt_dir).split('/')[-2:] + dataset_dict['frame_id'] = fid + + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = sem_seg_gt.long() + dataset_dict["sem_seg_ori"] = sem_seg_gt_ori.long() + + if gwm_seg_gt is not None: + dataset_dict["gwm_seg"] = gwm_seg_gt.long() + + if "annotations" in dataset_dict: + raise ValueError("Semantic segmentation dataset should not have 'annotations'.") + + # Prepare per-category binary masks + if sem_seg_gt is not None: + sem_seg_gt = sem_seg_gt.numpy() + instances = Instances(image_shape) + classes = np.unique(sem_seg_gt) + # remove ignored region + classes = classes[classes != self.ignore_label] + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + + masks = [] + for class_id in classes: + masks.append(sem_seg_gt == class_id) + + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) + else: + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + dataset_dicts.append(dataset_dict) + + return dataset_dicts diff --git a/datasets/flow_pair_detectron.py b/datasets/flow_pair_detectron.py new file mode 100644 index 0000000000000000000000000000000000000000..cac5a4c1b073dbfffc947a13fc369bbeebce871d --- /dev/null +++ b/datasets/flow_pair_detectron.py @@ -0,0 +1,275 @@ +import math +from pathlib import Path +import random + +import detectron2.data.transforms as DT +import einops +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as T +from PIL import Image +from detectron2.data import detection_utils as d2_utils +from detectron2.structures import Instances, BitMasks +from torch.utils.data import Dataset + +from utils.data import read_flow, read_flo + + +def load_flow_tensor(path, resize=None, normalize=True, align_corners=True): + """ + Load flow, scale the pixel values according to the resized scale. + If normalize is true, return rescaled in normalized pixel coordinates + where pixel coordinates are in range [-1, 1]. + NOTE: RAFT USES ALIGN_CORNERS=TRUE SO WE NEED TO ACCOUNT FOR THIS + Returns (2, H, W) float32 + """ + flow = read_flo(path).astype(np.float32) + H, W, _ = flow.shape + h, w = (H, W) if resize is None else resize + u, v = flow[..., 0], flow[..., 1] + if normalize: + if align_corners: + u = 2.0 * u / (W - 1) + v = 2.0 * v / (H - 1) + else: + u = 2.0 * u / W + v = 2.0 * v / H + else: + h, w = resize + u = w * u / W + v = h * v / H + + if h != H or w !=W: + u = Image.fromarray(u).resize((w, h), Image.ANTIALIAS) + v = Image.fromarray(v).resize((w, h), Image.ANTIALIAS) + u, v = np.array(u), np.array(v) + return torch.from_numpy(np.stack([u, v], axis=0)) + + +class FlowPairDetectron(Dataset): + def __init__(self, data_dir, resolution, to_rgb=False, size_divisibility=None, enable_photo_aug=False, flow_clip=1., norm=True, read_big=True, force1080p=False, flow_res=None): + self.eval = eval + self.to_rgb = to_rgb + self.data_dir = data_dir + self.flow_dir = {k: [e for e in v if e.shape[0] > 0] for k, v in data_dir[0].items()} + self.flow_dir = {k: v for k, v in self.flow_dir.items() if len(v) > 0} + self.resolution = resolution + self.size_divisibility = size_divisibility + self.ignore_label = -1 + self.transforms = DT.AugmentationList([ + DT.Resize(self.resolution, interp=Image.BICUBIC), + ]) + self.photometric_aug = T.Compose([ + T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)]), + p=0.8), + T.RandomGrayscale(p=0.2), + ]) if enable_photo_aug else None + self.flow_clip=flow_clip + self.norm_flow=norm + self.read_big = read_big + self.force1080p_transforms = None + if force1080p: + self.force1080p_transforms = DT.AugmentationList([ + DT.Resize((1088, 1920), interp=Image.BICUBIC), + ]) + self.big_flow_resolution = flow_res + + def __len__(self): + return sum([cat.shape[0] for cat in next(iter(self.flow_dir.values()))]) if len( + self.flow_dir.values()) > 0 else 0 + + def __getitem__(self, idx): + + dataset_dicts = [] + + random_gap = random.choice(list(self.flow_dir.keys())) + flowgaps = self.flow_dir[random_gap] + vid = random.choice(flowgaps) + flos = random.choice(vid) + dataset_dict = {} + + fname = Path(flos[0]).stem + dname = Path(flos[0]).parent.name + suffix = '.png' if 'CLEVR' in fname else '.jpg' + rgb_dir = (self.data_dir[1] / dname / fname).with_suffix(suffix) + gt_dir = (self.data_dir[2] / dname / fname).with_suffix('.png') + + flo0 = einops.rearrange(read_flow(str(flos[0]), self.resolution, self.to_rgb), 'c h w -> h w c') + flo1 = einops.rearrange(read_flow(str(flos[1]), self.resolution, self.to_rgb), 'c h w -> h w c') + if self.big_flow_resolution is not None: + flo0_big = einops.rearrange(read_flow(str(flos[0]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c') + flo1_big = einops.rearrange(read_flow(str(flos[1]), self.big_flow_resolution, self.to_rgb), 'c h w -> h w c') + rgb = d2_utils.read_image(rgb_dir).astype(np.float32) + original_rgb = torch.as_tensor(np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)).clip(0., 255.))).float() + if self.read_big: + rgb_big = d2_utils.read_image(str(rgb_dir).replace('480p', '1080p')).astype(np.float32) + rgb_big = (torch.as_tensor(np.ascontiguousarray(rgb_big))[:, :, :3]).permute(2, 0, 1).clamp(0., 255.) + if self.force1080p_transforms is not None: + rgb_big = F.interpolate(rgb_big[None], size=(1080, 1920), mode='bicubic').clamp(0., 255.)[0] + + # print('not here', rgb.min(), rgb.max()) + input = DT.AugInput(rgb) + + # Apply the augmentation: + preprocessing_transforms = self.transforms(input) # type: DT.Transform + rgb = input.image + if self.photometric_aug: + rgb_aug = Image.fromarray(rgb.astype(np.uint8)) + rgb_aug = self.photometric_aug(rgb_aug) + rgb_aug = d2_utils.convert_PIL_to_numpy(rgb_aug, 'RGB') + rgb_aug = np.transpose(rgb_aug, (2, 0, 1)).astype(np.float32) + rgb = np.transpose(rgb, (2, 0, 1)) + rgb = rgb.clip(0., 255.) + # print('here', rgb.min(), rgb.max()) + d2_utils.check_image_size(dataset_dict, flo0) + if gt_dir.exists(): + sem_seg_gt = d2_utils.read_image(str(gt_dir)) + sem_seg_gt = preprocessing_transforms.apply_segmentation(sem_seg_gt) + # sem_seg_gt = cv2.resize(sem_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST) + if sem_seg_gt.ndim == 3: + sem_seg_gt = sem_seg_gt[:, :, 0] + if sem_seg_gt.max() == 255: + sem_seg_gt = (sem_seg_gt > 128).astype(int) + else: + sem_seg_gt = np.zeros((self.resolution[0], self.resolution[1])) + + + gwm_dir = (Path(str(self.data_dir[2]).replace('Annotations', 'gwm')) / dname / fname).with_suffix('.png') + if gwm_dir.exists(): + gwm_seg_gt = d2_utils.read_image(str(gwm_dir)) + gwm_seg_gt = preprocessing_transforms.apply_segmentation(gwm_seg_gt) + gwm_seg_gt = np.array(gwm_seg_gt) + # gwm_seg_gt = cv2.resize(gwm_seg_gt, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_NEAREST) + if gwm_seg_gt.ndim == 3: + gwm_seg_gt = gwm_seg_gt[:, :, 0] + if gwm_seg_gt.max() == 255: + gwm_seg_gt[gwm_seg_gt == 255] = 1 + else: + gwm_seg_gt = None + + if sem_seg_gt is None: + raise ValueError( + "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( + dataset_dict["file_name"] + ) + ) + + # Pad image and segmentation label here! + if self.to_rgb: + flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) / 2 + .5 + flo0 = flo0 * 255 + flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) / 2 + .5 + flo1 = flo1 * 255 + if self.big_flow_resolution is not None: + flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) / 2 + .5 + flo0_big = flo0_big * 255 + flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) / 2 + .5 + flo1_big = flo1_big * 255 + else: + flo0 = torch.as_tensor(np.ascontiguousarray(flo0.transpose(2, 0, 1))) + flo1 = torch.as_tensor(np.ascontiguousarray(flo1.transpose(2, 0, 1))) + + if self.norm_flow: + flo0 = flo0 / (flo0 ** 2).sum(0).max().sqrt() + flo1 = flo1 / (flo1 ** 2).sum(0).max().sqrt() + + flo0 = flo0.clip(-self.flow_clip, self.flow_clip) + flo1 = flo1.clip(-self.flow_clip, self.flow_clip) + + if self.big_flow_resolution is not None: + flo0_big = torch.as_tensor(np.ascontiguousarray(flo0_big.transpose(2, 0, 1))) + flo1_big = torch.as_tensor(np.ascontiguousarray(flo1_big.transpose(2, 0, 1))) + if self.norm_flow: + flo0_big = flo0_big / (flo0_big ** 2).sum(0).max().sqrt() + flo1_big = flo1_big / (flo1_big ** 2).sum(0).max().sqrt() + flo0_big = flo0_big.clip(-self.flow_clip, self.flow_clip) + flo1_big = flo1_big.clip(-self.flow_clip, self.flow_clip) + + rgb = torch.as_tensor(np.ascontiguousarray(rgb)) + if self.photometric_aug: + rgb_aug = torch.as_tensor(np.ascontiguousarray(rgb_aug)) + + if sem_seg_gt is not None: + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + if gwm_seg_gt is not None: + gwm_seg_gt = torch.as_tensor(gwm_seg_gt.astype("long")) + + if self.size_divisibility > 0: + image_size = (flo0.shape[-2], flo0.shape[-1]) + padding_size = [ + 0, + int(self.size_divisibility * math.ceil(image_size[1] // self.size_divisibility)) - image_size[1], + 0, + int(self.size_divisibility * math.ceil(image_size[0] // self.size_divisibility)) - image_size[0], + ] + flo0 = F.pad(flo0, padding_size, value=0).contiguous() + flo1 = F.pad(flo1, padding_size, value=0).contiguous() + rgb = F.pad(rgb, padding_size, value=128).contiguous() + if self.photometric_aug: + rgb_aug = F.pad(rgb_aug, padding_size, value=128).contiguous() + if sem_seg_gt is not None: + sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() + if gwm_seg_gt is not None: + gwm_seg_gt = F.pad(gwm_seg_gt, padding_size, value=self.ignore_label).contiguous() + + image_shape = (rgb.shape[-2], rgb.shape[-1]) # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["flow"] = flo0 + dataset_dict["flow_2"] = flo1 + + # dataset_dict["flow_fwd"] = flo_norm_fwd + # dataset_dict["flow_bwd"] = flo_norm_bwd + # dataset_dict["flow_rgb"] = rgb_flo0 + # dataset_dict["flow_gap"] = gap + + dataset_dict["rgb"] = rgb + dataset_dict["original_rgb"] = original_rgb + if self.read_big: + dataset_dict["RGB_BIG"] = rgb_big + if self.photometric_aug: + dataset_dict["rgb_aug"] = rgb_aug + + if self.big_flow_resolution is not None: + dataset_dict["flow_big"] = flo0_big + dataset_dict["flow_big_2"] = flo1_big + + + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = sem_seg_gt.long() + + if gwm_seg_gt is not None: + dataset_dict["gwm_seg"] = gwm_seg_gt.long() + + if "annotations" in dataset_dict: + raise ValueError("Semantic segmentation dataset should not have 'annotations'.") + + # Prepare per-category binary masks + if sem_seg_gt is not None: + sem_seg_gt = sem_seg_gt.numpy() + instances = Instances(image_shape) + classes = np.unique(sem_seg_gt) + # remove ignored region + classes = classes[classes != self.ignore_label] + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + + masks = [] + for class_id in classes: + masks.append(sem_seg_gt == class_id) + + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) + else: + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + dataset_dicts.append(dataset_dict) + + return dataset_dicts diff --git a/determinism.py b/determinism.py new file mode 100644 index 0000000000000000000000000000000000000000..2056e9c6d0d76d7c50aef3bf116b7aaf10d0d26a --- /dev/null +++ b/determinism.py @@ -0,0 +1,24 @@ +import os +lvl = int(os.environ.get('TRY_DETERMISM_LVL', '0')) +if lvl > 0: + print(f'Attempting to enable deterministic cuDNN and cuBLAS operations to lvl {lvl}') +if lvl >= 2: + # turn on deterministic operations + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" #Need to set before torch gets loaded + import torch + # Since using unstable torch version, it looks like 1.12.0.devXXXXXXX + if torch.version.__version__ >= '1.12.0': + torch.use_deterministic_algorithms(True, warn_only=(lvl < 3)) + elif lvl >= 3: + torch.use_deterministic_algorithms(True) # This will throw errors if implementations are missing + else: + print(f"Torch verions is only {torch.version.__version__}, which will cause errors on lvl {lvl}") +if lvl >= 1: + import torch + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = False + + +def i_do_nothing_but_dont_remove_me_otherwise_things_break(): + """This exists to prevent formatters from treating this file as dead code""" + pass diff --git a/dist.py b/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..ca142146452262683f88c6106b115fe8538d9bbf --- /dev/null +++ b/dist.py @@ -0,0 +1,34 @@ +import functools + +import torch +import torch.distributions + +import utils + +LOGGER = utils.log.getLogger(__name__) + +__defined_kl = False + +EPS = 1e-5 + + +def clamp_probs(probs): + probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1 + return probs / probs.sum(-1, keepdim=True) # to simplex + + +def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False): + hr = torch.arange(h + 2 * pad, device=device) - pad + wr = torch.arange(w + 2 * pad, device=device) - pad + if norm: + hr = hr / (h + 2 * pad - 1) + wr = wr / (w + 2 * pad - 1) + ig, jg = torch.meshgrid(hr, wr) + g = torch.stack([jg, ig]).to(dtype)[None] + return g + + +@functools.lru_cache(2) +def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False): + return grid(h, w, pad, device, dtype, norm) + diff --git a/eval_utils.py b/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5aebcb686d9e94d070e3514f627ac48efeb77b7e --- /dev/null +++ b/eval_utils.py @@ -0,0 +1,282 @@ +import functools +import random +from collections import defaultdict + +import einops +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image, ImageDraw, ImageFont +from sklearn.cluster import SpectralClustering +from tqdm import tqdm + +import flow_reconstruction +from utils import visualisation, log, grid +from utils.vit_extractor import ViTExtractor + +label_colors = visualisation.create_label_colormap() +logger = log.getLogger('gwm') + + +def __default_font(fontsize): + try: + FNT = ImageFont.truetype("dejavu/DejaVuSansMono.ttf", fontsize) + except OSError: + FNT = ImageFont.truetype("dejavu/DejaVuSans.ttf", fontsize) + return FNT + + +@functools.lru_cache(None) # cache the result +def autosized_default_font(size_limit: float) -> ImageFont.ImageFont: + fontsize = 1 # starting font size + font = __default_font(fontsize) + while font.getsize('test123')[1] < size_limit: + fontsize += 1 + font = __default_font(fontsize) + fontsize -= 1 + font = __default_font(fontsize) + return font + + +def iou(masks, gt, thres=0.5): + masks = (masks > thres).float() + intersect = torch.tensordot(masks, gt, dims=([-2, -1], [0, 1])) + union = masks.sum(dim=[-2, -1]) + gt.sum(dim=[-2, -1]) - intersect + return intersect / union.clip(min=1e-12) + + +def get_unsup_image_viz(model, cfg, sample, criterion): + if model.training: + model.eval() + preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) + model.train() + else: + preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) + return get_image_vis(model, cfg, sample, preds, criterion) + +def get_vis_header(header_size, image_size, header_texts, header_height=20): + W, H = (image_size, header_height) + header_labels = [] + font = autosized_default_font(0.8 * H) + + for text in header_texts: + im = Image.new("RGB", (W, H), "white") + draw = ImageDraw.Draw(im) + w, h = draw.textsize(text, font=font) + draw.text(((W - w) / 2, (H - h) / 2), text, fill="black", font=font) + header_labels.append(torch.from_numpy(np.array(im))) + header_labels = torch.cat(header_labels, dim=1) + ret = (torch.ones((header_height, header_size, 3)) * 255) + ret[:, :header_labels.size(1)] = header_labels + + return ret.permute(2, 0, 1).clip(0, 255).to(torch.uint8) + +def get_image_vis(model, cfg, sample, preds, criterion): + masks_pred = torch.stack([x['sem_seg'] for x in preds], 0) + + with torch.no_grad(): + flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20) + + masks_softmaxed = torch.softmax(masks_pred, dim=1) + masks_pred = masks_softmaxed + rec_flows = criterion.flow_reconstruction(sample, criterion.process_flow(sample, flow), masks_softmaxed) + rec_headers = ['rec_flow'] + if len(rec_flows) > 1: + rec_headers.append('rec_bwd_flow') + + rgb = torch.stack([x['rgb'] for x in sample]) + flow = criterion.viz_flow(criterion.process_flow(sample, flow).cpu()) * 255 + rec_flows = [ + (criterion.viz_flow(rec_flow_.detach().cpu().cpu()) * 255).clip(0, 255).to(torch.uint8) for rec_flow_ in rec_flows + ] + + + gt_labels = torch.stack([x['sem_seg'] for x in sample]) + gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2) + target_K = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES + masks = F.one_hot(masks_pred.argmax(1).cpu(), target_K).permute(0, 3, 1, 2) + masks_each = torch.stack([masks_softmaxed, masks_softmaxed, masks_softmaxed], 2) * 255 + masks_each = einops.rearrange(F.pad(masks_each.cpu(), pad=[0, 1], value=255), 'b n c h w -> b c h (n w)') + + gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1]) + pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:target_K]) + if all('gwm_seg' in d for d in sample): + gwm_labels = torch.stack([x['gwm_seg'] for x in sample]) + mg = F.one_hot(gwm_labels, gwm_labels.max().item() + 1).permute(0, 3, 1, 2) + gwm_seg = torch.einsum('b k h w, k c -> b c h w', mg, label_colors[:gwm_labels.max().item() + 1]) + image_viz = torch.cat( + [rgb, flow, F.pad(gt_seg.cpu(), pad=[0, 1], value=255), F.pad(gwm_seg, pad=[0, 1], value=255), + pred_seg.cpu(), *rec_flows], -1) + header_text = ['rgb', 'gt_flow', 'gt_seg', 'GWM', 'pred_seg', *rec_headers] + else: + image_viz = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), *rec_flows], -1) + header_text = ['rgb', 'gt_flow', 'gt_seg', 'pred_seg', *rec_headers] + + image_viz = torch.cat([image_viz, masks_each], -1) + header_text.extend(['slot'] * masks_softmaxed.shape[1]) + if 'flow_edges' in sample[0]: + flow_edges = torch.stack([x['flow_edges'].to(image_viz.device) for x in sample]) + if len(flow_edges.shape) >= 4: + flow_edges = flow_edges.sum(1, keepdim=len(flow_edges.shape) == 4) + flow_edges = flow_edges.expand(-1, 3, -1, -1) + flow_edges = flow_edges * 255 + image_viz = torch.cat([image_viz, flow_edges], -1) + header_text.append('flow_edges') + image_viz = einops.rearrange(image_viz[:8], 'b c h w -> c (b h) w').detach().clip(0, 255).to(torch.uint8) + + return image_viz, header_text + + +def get_frame_vis(model, cfg, sample, preds): + masks_pred = torch.stack([x['sem_seg'] for x in preds], 0) + flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20) + + masks_softmaxed = torch.softmax(masks_pred, dim=1) + if cfg.GWM.SIMPLE_REC: + mask_denom = einops.reduce(masks_softmaxed, 'b k h w -> b k 1', 'sum') + 1e-7 + means = torch.einsum('brhw, bchw -> brc', masks_softmaxed, flow) / mask_denom + rec_flow = torch.einsum('bkhw, bkc-> bchw', masks_softmaxed, means) + elif cfg.GWM.HOMOGRAPHY: + rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow) + else: + grid_x, grid_y = grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device) + rec_flow = flow_reconstruction.get_quad_flow(masks_softmaxed, flow, grid_x, grid_y) + + rgb = torch.stack([x['rgb'] for x in sample]) + flow = torch.stack([visualisation.flow2rgb_torch(x) for x in flow.cpu()]) * 255 + rec_flow = torch.stack([visualisation.flow2rgb_torch(x) for x in rec_flow.detach().cpu()]) * 255 + + gt_labels = torch.stack([x['sem_seg'] for x in sample]) + gt = F.one_hot(gt_labels, gt_labels.max().item() + 1).permute(0, 3, 1, 2) + + masks = F.one_hot(masks_pred.argmax(1).cpu(), cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES).permute(0, 3, 1, 2) + + gt_seg = torch.einsum('b k h w, k c -> b c h w', gt, label_colors[:gt_labels.max().item() + 1]) + pred_seg = torch.einsum('b k h w, k c -> b c h w', masks, label_colors[:cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES]) + frame_vis = torch.cat([rgb, flow, gt_seg.cpu(), pred_seg.cpu(), rec_flow.clip(0, 255).to(torch.uint8)], -1) + frame_vis = einops.rearrange(frame_vis, 'b c h w -> b c h w').detach().clip(0, 255).to(torch.uint8) + return frame_vis + + +def is_2comp_dataset(dataset): + if '+' in dataset: + d = dataset.split('+')[0].strip() + else: + d = dataset.strip() + logger.info_once(f"Is 2comp dataset? {d}") + for s in ['DAVIS', 'FBMS', 'STv2']: + if s in d: + return True + return d in ['DAVIS', + 'FBMS', + 'STv2'] + +def eval_unsupmf(cfg, val_loader, model, criterion, writer=None, writer_iteration=0, use_wandb=False): + logger.info(f'Running Evaluation: {cfg.LOG_ID} {"Simple" if cfg.GWM.SIMPLE_REC else "Gradient"}:') + logger.info(f'Model mode: {"train" if model.training else "eval"}, wandb: {use_wandb}') + logger.info(f'Dataset: {cfg.GWM.DATASET} # components: {cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES}') + + merger = None + if cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES > 2: + merger = MaskMerger(cfg, model) + + print_idxs = random.sample(range(len(val_loader)), k=10) + + images_viz = [] + ious_davis_eval = defaultdict(list) + ious = defaultdict(list) + + for idx, sample in enumerate(tqdm(val_loader)): + t = 1 + sample = [e for s in sample for e in s] + category = [s['category'] for s in sample] + preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True) + masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) + + masks_softmaxed = torch.softmax(masks_raw, dim=1) + masks_dict = merger(sample, masks_softmaxed) + + if writer and idx in print_idxs: + flow = torch.stack([x['flow'] for x in sample]).to(model.device) + img_viz, header_text = get_image_vis(model, cfg, sample, preds, criterion) + images_viz.append(img_viz) + + masks = masks_dict['cos'] + gt_seg = torch.stack([x['sem_seg_ori'] for x in sample]).cpu() + HW = gt_seg.shape[-2:] + if HW != masks.shape[-2:]: + logger.info_once(f"Upsampling predicted masks to {HW} for evaluation") + masks_softmaxed_sel = F.interpolate(masks.detach().cpu(), size=HW, mode='bilinear', align_corners=False) + else: + masks_softmaxed_sel = masks.detach().cpu() + masks_ = einops.rearrange(masks_softmaxed_sel, '(b t) s h w -> b t s 1 h w', t=t).detach() + gt_seg = einops.rearrange(gt_seg, 'b h w -> b 1 h w').float() + for i in range(masks_.size(0)): + masks_k = F.interpolate(masks_[i], size=(1, gt_seg.shape[-2], gt_seg.shape[-1])) # t s 1 h w + mask_iou = iou(masks_k[:, :, 0], gt_seg[i, 0], thres=0.5) # t s + iou_max, slot_max = mask_iou.max(dim=1) + + ious[category[i][0]].append(iou_max) + frame_id = category[i][1] + ious_davis_eval[category[i][0]].append((frame_id.strip().replace('.png', ''), iou_max)) + + frameious = sum(ious.values(), []) + frame_mean_iou = torch.cat(frameious).sum().item() * 100 / len(frameious) + if 'DAVIS' in cfg.GWM.DATASET.split('+')[0]: + logger.info_once("Using DAVIS evaluator methods for evaluting IoU -- mean of mean of sequences without first frame") + seq_scores = dict() + for c in ious_davis_eval: + seq_scores[c] = np.nanmean([v.item() for n, v in ious_davis_eval[c] if int(n) > 1]) + + frame_mean_iou = np.nanmean(list(seq_scores.values())) * 100 + + if writer: + header = get_vis_header(images_viz[0].size(2), flow.size(3), header_text) + images_viz = torch.cat(images_viz, dim=1) + images_viz = torch.cat([header, images_viz], dim=1) + writer.add_image('val/images', images_viz, writer_iteration) # C H W + writer.add_scalar('eval/mIoU', frame_mean_iou, writer_iteration) + + logger.info(f"mIoU: {frame_mean_iou:.3f} \n") + return frame_mean_iou + + +class MaskMerger: + def __init__(self, cfg, model, merger_model="dino_vits8"): + self.extractor = ViTExtractor(model_type=merger_model, device=model.device) + self.out_dim = 384 + + self.mu = torch.tensor(self.extractor.mean).to(model.device).view(1, -1, 1, 1) + self.sigma = torch.tensor(self.extractor.std).to(model.device).view(1, -1, 1, 1) + self.start_idx = 0 + + def get_feats(self, batch): + with torch.no_grad(): + feat = self.extractor.extract_descriptors(batch, facet='key', layer=11, bin=False) + feat = feat.reshape(feat.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) + return F.interpolate(feat, batch.shape[-2:], mode='bilinear') + + def spectral(self, A): + clustering = SpectralClustering(n_clusters=2, + affinity='precomputed', + random_state=0).fit(A.detach().cpu().numpy()) + return np.arange(A.shape[-1])[clustering.labels_ == 0], np.arange(A.shape[-1])[clustering.labels_ == 1] + + def cos_merge(self, basis, masks): + basis = basis / torch.linalg.vector_norm(basis, dim=-1, keepdim=True).clamp(min=1e-6) + A = torch.einsum('brc, blc -> brl', basis, basis)[0].clamp(min=1e-6) + inda, indb = self.spectral(A) + return torch.stack([masks[:, inda].sum(1), + masks[:, indb].sum(1)], 1) + + def __call__(self, sample, masks_softmaxed): + with torch.no_grad(): + masks_softmaxed = masks_softmaxed[:, self.start_idx:] + batch = torch.stack([x['rgb'].to(masks_softmaxed.device) for x in sample], 0) / 255.0 + features = self.get_feats((batch - self.mu) / self.sigma) + basis = torch.einsum('brhw, bchw -> brc', masks_softmaxed, features) + basis /= einops.reduce(masks_softmaxed, 'b r h w -> b r 1', 'sum').clamp_min(1e-12) + + return { + 'cos': self.cos_merge(basis, masks_softmaxed), + } diff --git a/flow_reconstruction.py b/flow_reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..4c71cec1bfe3b94293e00d29e640fe838dbafcd7 --- /dev/null +++ b/flow_reconstruction.py @@ -0,0 +1,54 @@ +import sys + +import torch + +from dist import LOGGER + + +def lstq(A, F_u, F_v, lamda=0.01): + # cols = A.shape[2] + # assert all(cols == torch.linalg.matrix_rank(A)) # something better? + try: + Q, R = torch.linalg.qr(A) + theta_x = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_u) + theta_y = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_v) + except: + LOGGER.exception("Least Squares failed") + sys.exit(-1) + return theta_x, theta_y + +def get_quad_flow(masks_softmaxed, flow, grid_x, grid_y): + rec_flow = 0 + for k in range(masks_softmaxed.size(1)): + mask = masks_softmaxed[:, k].unsqueeze(1) + _F = flow * mask + M = mask.flatten(1) + bs = _F.shape[0] + x = grid_x.unsqueeze(0).flatten(1) + y = grid_y.unsqueeze(0).flatten(1) + + F_u = _F[:, 0].flatten(1).unsqueeze(2) # B x L x 1 + F_v = _F[:, 1].flatten(1).unsqueeze(2) # B x L x 1 + A = torch.stack([x * M, y * M, x*x *M, y*y*M, x*y*M, torch.ones_like(y) * M], 2) # B x L x 2 + + theta_x, theta_y = lstq(A, F_u, F_v, lamda=.01) + rec_flow_m = torch.stack([torch.einsum('bln,bnk->blk', A, theta_x).view(bs, *grid_x.shape), + torch.einsum('bln,bnk->blk', A, theta_y).view(bs, *grid_y.shape)], 1) + + rec_flow += rec_flow_m + return rec_flow + + +SUBSAMPLE = 8 +SKIP = 0.4 +SIZE = 0.3 +NITER = 50 +METHOD = 'inv_score' + +def set_subsample_skip(sub=None, skip=None, size=None, niter=None, method=None): + global SUBSAMPLE, SKIP, SIZE, NITER, METHOD + if sub is not None: SUBSAMPLE=sub + if skip is not None: SKIP=skip + if size is not None: SIZE=size + if niter is not None: NITER=niter + if method is not None: METHOD=method diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a3e2140dca7e152f3282cfb6788d2541ba26a6 --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1,28 @@ +from .reconstruction_loss import ReconstructionLoss +import torch + + +class CriterionDict: + def __init__(self, dict): + self.criterions = dict + + def __call__(self, sample, flow, masks_softmaxed, iteration, train=True, prefix=''): + loss = torch.tensor(0., device=masks_softmaxed.device, dtype=masks_softmaxed.dtype) + log_dict = {} + for name_i, (criterion_i, loss_multiplier_i, anneal_fn_i) in self.criterions.items(): + loss_i = loss_multiplier_i * anneal_fn_i(iteration) * criterion_i(sample, flow, masks_softmaxed, iteration, train=train) + loss += loss_i + log_dict[f'loss_{name_i}'] = loss_i.item() + + log_dict['loss_total'] = loss.item() + return loss, log_dict + + def flow_reconstruction(self, sample, flow, masks_softmaxed): + return self.criterions['reconstruction'][0].rec_flow(sample, flow, masks_softmaxed) + + def process_flow(self, sample, flow): + return self.criterions['reconstruction'][0].process_flow(sample, flow) + + def viz_flow(self, flow): + return self.criterions['reconstruction'][0].viz_flow(flow) + diff --git a/losses/reconstruction_loss.py b/losses/reconstruction_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..49ff2db70acce4a66d0d91ebe547bf80a625b0dc --- /dev/null +++ b/losses/reconstruction_loss.py @@ -0,0 +1,85 @@ +import torch +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import flow_reconstruction +import utils +from utils.visualisation import flow2rgb_torch + +logger = utils.log.getLogger(__name__) + +class ReconstructionLoss: + def __init__(self, cfg, model): + self.criterion = nn.MSELoss() if cfg.GWM.CRITERION == 'L2' else nn.L1Loss() + self.l1_optimize = cfg.GWM.L1_OPTIMIZE + self.homography = cfg.GWM.HOMOGRAPHY + self.device=model.device + self.cfg = cfg + self.grid_x, self.grid_y = utils.grid.get_meshgrid(cfg.GWM.RESOLUTION, model.device) + # self.mult_flow = cfg.GWM.USE_MULT_FLOW + self.flow_colorspace_rec = cfg.GWM.FLOW_COLORSPACE_REC + flow_reconstruction.set_subsample_skip(cfg.GWM.HOMOGRAPHY_SUBSAMPLE, cfg.GWM.HOMOGRAPHY_SKIP) + self.flow_u_low = cfg.GWM.FLOW_CLIP_U_LOW + self.flow_u_high = cfg.GWM.FLOW_CLIP_U_HIGH + self.flow_v_low = cfg.GWM.FLOW_CLIP_V_LOW + self.flow_v_high = cfg.GWM.FLOW_CLIP_V_HIGH + + self._recon_fn = self.flow_quad + logger.info(f'Using reconstruction method {self._recon_fn.__name__}') + self.it = 0 + self._extra_losses = [] + + def __call__(self, sample, flow, masks_softmaxed, it, train=True): + return self.loss(sample, flow, masks_softmaxed, it, train=train) + + def loss(self, sample, flow, mask_softmaxed, it, train=True): + self.training = train + flow = self.process_flow(sample, flow) + self.it = it + self._extra_losses = [] + + if self.cfg.GWM.FLOW_RES is not None: + if flow.shape[-2:] != mask_softmaxed.shape[-2:]: + logger.debug_once(f'Resizing predicted masks to {self.cfg.GWM.FLOW_RES}') + mask_softmaxed = F.interpolate(mask_softmaxed, flow.shape[-2:], mode='bilinear', align_corners=False) + + rec_flows = self.rec_flow(sample, flow, mask_softmaxed) + if not isinstance(rec_flows, (list, tuple)): + rec_flows = (rec_flows,) + k = len(rec_flows) + loss = sum(self.criterion(flow, rec_flow) / k for rec_flow in rec_flows) + if len(self._extra_losses): + loss = loss + sum(self._extra_losses, 0.) / len(self._extra_losses) + self._extra_losses = [] + return loss + + def flow_quad(self, sample, flow, masks_softmaxed, it, **_): + logger.debug_once(f'Reconstruction using quadratic. Masks shape {masks_softmaxed.shape} | ' + f'Flow shape {flow.shape} | ' + f'Grid shape {self.grid_x.shape, self.grid_y.shape}') + return flow_reconstruction.get_quad_flow(masks_softmaxed, flow, self.grid_x, self.grid_y) + + def _clipped_recon_fn(self, *args, **kwargs): + flow = self._recon_fn(*args, **kwargs) + flow_o = flow[:, :-2] + flow_u = flow[:, -2:-1].clip(self.flow_u_low, self.flow_u_high) + flow_v = flow[:, -1:].clip(self.flow_v_low, self.flow_v_high) + return torch.cat([flow_o, flow_u, flow_v], dim=1) + + def rec_flow(self, sample, flow, masks_softmaxed): + it = self.it + if self.cfg.GWM.FLOW_RES is not None and flow.shape[-2:] != self.grid_x.shape[-2:]: + logger.debug_once(f'Generating new grid predicted masks of {flow.shape[-2:]}') + self.grid_x, self.grid_y = utils.grid.get_meshgrid(flow.shape[-2:], self.device) + return [self._clipped_recon_fn(sample, flow, masks_softmaxed, it)] + + def process_flow(self, sample, flow_cuda): + return flow_cuda + + def viz_flow(self, flow): + return torch.stack([flow2rgb_torch(x) for x in flow]) + diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b3835209acc93adcc4a01229cb8e8dc353cc34 --- /dev/null +++ b/main.py @@ -0,0 +1,270 @@ +import determinism # noqa + +determinism.i_do_nothing_but_dont_remove_me_otherwise_things_break() # noqa + +import argparse +import bisect +import copy +import os +import sys +import time +from argparse import ArgumentParser + +import torch +import wandb +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.engine import PeriodicCheckpointer +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +import config +import losses +import utils +from eval_utils import eval_unsupmf, get_unsup_image_viz, get_vis_header +from mask_former_trainer import setup, Trainer + + +logger = utils.log.getLogger('gwm') + +def freeze(module, set=False): + for param in module.parameters(): + param.requires_grad = set + + +def main(args): + cfg = setup(args) + logger.info(f"Called as {' '.join(sys.argv)}") + logger.info(f'Output dir {cfg.OUTPUT_DIR}') + + random_state = utils.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE)) + random_state.seed_everything() + utils.log.checkpoint_code(cfg.OUTPUT_DIR) + + if not cfg.SKIP_TB: + writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) + else: + writer = None + + # initialize model + model = Trainer.build_model(cfg) + optimizer = Trainer.build_optimizer(cfg, model) + scheduler = Trainer.build_lr_scheduler(cfg, optimizer) + + logger.info(f'Optimiser is {type(optimizer)}') + + + checkpointer = DetectionCheckpointer(model, + save_dir=os.path.join(cfg.OUTPUT_DIR, 'checkpoints'), + random_state=random_state, + optimizer=optimizer, + scheduler=scheduler) + periodic_checkpointer = PeriodicCheckpointer(checkpointer=checkpointer, + period=cfg.SOLVER.CHECKPOINT_PERIOD, + max_iter=cfg.SOLVER.MAX_ITER, + max_to_keep=None if cfg.FLAGS.KEEP_ALL else 5, + file_prefix='checkpoint') + checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume_path is not None) + iteration = 0 if args.resume_path is None else checkpoint['iteration'] + + train_loader, val_loader = config.loaders(cfg) + + # overfit single batch for debug + # sample = next(iter(loader)) + + criterions = { + 'reconstruction': (losses.ReconstructionLoss(cfg, model), cfg.GWM.LOSS_MULT.REC, lambda x: 1)} + + criterion = losses.CriterionDict(criterions) + + if args.eval_only: + if len(val_loader.dataset) == 0: + logger.error("Training dataset: empty") + sys.exit(0) + model.eval() + iou = eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion, writer=writer, + writer_iteration=iteration) + logger.info(f"Results: iteration: {iteration} IOU = {iou}") + return + if len(train_loader.dataset) == 0: + logger.error("Training dataset: empty") + sys.exit(0) + + logger.info( + f'Start of training: dataset {cfg.GWM.DATASET},' + f' train {len(train_loader.dataset)}, val {len(val_loader.dataset)},' + f' device {model.device}, keys {cfg.GWM.SAMPLE_KEYS}, ' + f'multiple flows {cfg.GWM.USE_MULT_FLOW}') + + iou_best = 0 + timestart = time.time() + dilate_kernel = torch.ones((2, 2), device=model.device) + + total_iter = cfg.TOTAL_ITER if cfg.TOTAL_ITER else cfg.SOLVER.MAX_ITER # early stop + with torch.autograd.set_detect_anomaly(cfg.DEBUG) and \ + tqdm(initial=iteration, total=total_iter, disable=utils.environment.is_slurm()) as pbar: + while iteration < total_iter: + for sample in train_loader: + + if cfg.MODEL.META_ARCHITECTURE != 'UNET' and cfg.FLAGS.UNFREEZE_AT: + if hasattr(model.backbone, 'frozen_stages'): + assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen" + stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT] + milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT] + pos = bisect.bisect_right(milest, iteration) - 1 + if pos >= 0: + curr_setting = model.backbone.frozen_stages + if curr_setting != stages[pos]: + logger.info(f"Updating backbone freezing stages from {curr_setting} to {stages[pos]}") + model.backbone.frozen_stages = stages[pos] + model.train() + else: + assert cfg.MODEL.BACKBONE.FREEZE_AT == -1, f"MODEL initial parameters forced frozen" + stages = [s for s, m in cfg.FLAGS.UNFREEZE_AT] + milest = [m for s, m in cfg.FLAGS.UNFREEZE_AT] + pos = bisect.bisect_right(milest, iteration) - 1 + freeze(model, set=False) + freeze(model.sem_seg_head.predictor, set=True) + if pos >= 0: + stage = stages[pos] + if stage <= 2: + freeze(model.sem_seg_head, set=True) + if stage <= 1: + freeze(model.backbone, set=True) + model.train() + + else: + logger.debug_once(f'Unfreezing disabled schedule: {cfg.FLAGS.UNFREEZE_AT}') + + sample = [e for s in sample for e in s] + flow_key = 'flow' + raw_sem_seg = False + if cfg.GWM.FLOW_RES is not None: + flow_key = 'flow_big' + raw_sem_seg = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME == 'MegaBigPixelDecoder' + + flow = torch.stack([x[flow_key].to(model.device) for x in sample]).clip(-20, 20) + logger.debug_once(f'flow shape: {flow.shape}') + preds = model.forward_base(sample, keys=cfg.GWM.SAMPLE_KEYS, get_eval=True, raw_sem_seg=raw_sem_seg) + masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) + logger.debug_once(f'mask shape: {masks_raw.shape}') + masks_softmaxed_list = [torch.softmax(masks_raw, dim=1)] + + + total_losses = [] + log_dicts = [] + for mask_idx, masks_softmaxed in enumerate(masks_softmaxed_list): + + loss, log_dict = criterion(sample, flow, masks_softmaxed, iteration) + + if cfg.GWM.USE_MULT_FLOW: + flow2 = torch.stack([x[flow_key + '_2'].to(model.device) for x in sample]).clip(-20, 20) + other_loss, other_log_dict = criterion(sample, flow2, masks_softmaxed, iteration) + loss = loss / 2 + other_loss / 2 + for k, v in other_log_dict.items(): + log_dict[k] = other_log_dict[k] / 2 + v / 2 + total_losses.append(loss) + log_dicts.append(log_dict) + + loss_ws = cfg.GWM.LOSS_MULT.HEIR_W + total_w = float(sum(loss_ws[:len(total_losses)])) + log_dict = {} + if len(total_losses) == 1: + log_dict = log_dicts[0] + loss = total_losses[0] + else: + loss = 0 + for i, (tl, w, ld) in enumerate(zip(total_losses, loss_ws, log_dicts)): + for k, v in ld.items(): + log_dict[f'{k}_{i}'] = v * w / total_w + loss += tl * w / total_w + + train_log_dict = {f'train/{k}': v for k, v in log_dict.items()} + del log_dict + train_log_dict['train/learning_rate'] = optimizer.param_groups[-1]['lr'] + train_log_dict['train/loss_total'] = loss.item() + + + optimizer.zero_grad() + + + loss.backward() + optimizer.step() + scheduler.step() + + pbar.set_postfix(loss=loss.item()) + pbar.update() + + # Sanity check for RNG state + if (iteration + 1) % 1000 == 0 or iteration + 1 in {1, 50}: + logger.info( + f'Iteration {iteration + 1}. RNG outputs {utils.random_state.get_randstate_magic_numbers(model.device)}') + + if cfg.DEBUG or (iteration + 1) % 100 == 0: + logger.info( + f'Iteration: {iteration + 1}, time: {time.time() - timestart:.01f}s, loss: {loss.item():.02f}.') + + for k, v in train_log_dict.items(): + if writer: + writer.add_scalar(k, v, iteration + 1) + + if cfg.WANDB.ENABLE: + wandb.log(train_log_dict, step=iteration + 1) + + if (iteration + 1) % cfg.LOG_FREQ == 0 or (iteration + 1) in [1, 50, 500]: + model.eval() + if writer: + flow = torch.stack([x['flow'].to(model.device) for x in sample]).clip(-20, 20) + image_viz, header_text = get_unsup_image_viz(model, cfg, sample, criterion) + header = get_vis_header(image_viz.size(2), flow.size(3), header_text) + image_viz = torch.cat([header, image_viz], dim=1) + writer.add_image('train/images', image_viz, iteration + 1) + if cfg.WANDB.ENABLE and (iteration + 1) % 2500 == 0: + image_viz = get_unsup_image_viz(model, cfg, sample) + wandb.log({'train/viz': wandb.Image(image_viz.float())}, step=iteration + 1) + + if iou := eval_unsupmf(cfg=cfg, val_loader=val_loader, model=model, criterion=criterion, + writer=writer, writer_iteration=iteration + 1, use_wandb=cfg.WANDB.ENABLE): + if cfg.SOLVER.CHECKPOINT_PERIOD and iou > iou_best: + iou_best = iou + if not args.wandb_sweep_mode: + checkpointer.save(name='checkpoint_best', iteration=iteration + 1, loss=loss, + iou=iou_best) + logger.info(f'New best IoU {iou_best:.02f} after iteration {iteration + 1}') + if cfg.WANDB.ENABLE: + wandb.log({'eval/IoU_best': iou_best}, step=iteration + 1) + if writer: + writer.add_scalar('eval/IoU_best', iou_best, iteration + 1) + + + model.train() + + periodic_checkpointer.step(iteration=iteration + 1, loss=loss) + + iteration += 1 + timestart = time.time() + + +def get_argparse_args(): + parser = ArgumentParser() + parser.add_argument('--resume_path', type=str, default=None) + parser.add_argument('--use_wandb', dest='wandb_sweep_mode', action='store_true') # for sweep + parser.add_argument('--config-file', type=str, + default='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml') + parser.add_argument('--eval_only', action='store_true') + parser.add_argument( + "opts", + help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " + "See config references at " + "https://detectron2.readthedocs.io/modules/config.html#config-references", + default=None, + nargs=argparse.REMAINDER, + ) + return parser + + +if __name__ == "__main__": + args = get_argparse_args().parse_args() + if args.resume_path: + args.config_file = "/".join(args.resume_path.split('/')[:-2]) + '/config.yaml' + print(args.config_file) + main(args) diff --git a/mask_former/__init__.py b/mask_former/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef4d245cfcf49cf13a5195b60d57d92e0af016e --- /dev/null +++ b/mask_former/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import data # register all new datasets +from . import modeling + +# config +from .config import add_mask_former_config + +# dataset loading +from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper +from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import ( + MaskFormerPanopticDatasetMapper, +) +from .data.dataset_mappers.mask_former_semantic_dataset_mapper import ( + MaskFormerSemanticDatasetMapper, +) + +# models +from .mask_former_model import MaskFormer +from .test_time_augmentation import SemanticSegmentorWithTTA diff --git a/mask_former/config.py b/mask_former/config.py new file mode 100644 index 0000000000000000000000000000000000000000..90ba988efd4bad35106b44f2c17de4bf71382a04 --- /dev/null +++ b/mask_former/config.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import CfgNode as CN + + +def add_mask_former_config(cfg): + """ + Add config for MASK_FORMER. + """ + # data config + # select the dataset mapper + cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" + # Color augmentation + cfg.INPUT.COLOR_AUG_SSD = False + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Pad image and segmentation GT in dataset mapper. + cfg.INPUT.SIZE_DIVISIBILITY = -1 + + # solver config + # weight decay on embedding + cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 + # optimizer + cfg.SOLVER.OPTIMIZER = "ADAMW" + cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 + + # mask_former model config + cfg.MODEL.MASK_FORMER = CN() + + # loss + cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True + cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 + cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 + cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 + + # transformer config + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 + + cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" + cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False + + # mask_former inference config + cfg.MODEL.MASK_FORMER.TEST = CN() + cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False + cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False + + # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) + # you can use this config to override + cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 + + # pixel decoder config + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + # adding transformer in pixel decoder + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 + # pixel decoder + cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" + + # swin transformer backbone + cfg.MODEL.SWIN = CN() + cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 + cfg.MODEL.SWIN.PATCH_SIZE = 4 + cfg.MODEL.SWIN.EMBED_DIM = 96 + cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] + cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] + cfg.MODEL.SWIN.WINDOW_SIZE = 7 + cfg.MODEL.SWIN.MLP_RATIO = 4.0 + cfg.MODEL.SWIN.QKV_BIAS = True + cfg.MODEL.SWIN.QK_SCALE = None + cfg.MODEL.SWIN.DROP_RATE = 0.0 + cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 + cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 + cfg.MODEL.SWIN.APE = False + cfg.MODEL.SWIN.PATCH_NORM = True + cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] diff --git a/mask_former/data/__init__.py b/mask_former/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63ba265b1effc69f1eef16e57a04db8902ee347e --- /dev/null +++ b/mask_former/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import datasets diff --git a/mask_former/data/dataset_mappers/__init__.py b/mask_former/data/dataset_mappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/mask_former/data/dataset_mappers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py b/mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4a296f2fbbd24b190b312b464ce2d4c1957b221c --- /dev/null +++ b/mask_former/data/dataset_mappers/detr_panoptic_dataset_mapper.py @@ -0,0 +1,180 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py +import copy +import logging + +import numpy as np +import torch + +from detectron2.config import configurable +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.data.transforms import TransformGen +from detectron2.structures import BitMasks, Instances + +__all__ = ["DETRPanopticDatasetMapper"] + + +def build_transform_gen(cfg, is_train): + """ + Create a list of :class:`TransformGen` from config. + Returns: + list[TransformGen] + """ + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + if sample_style == "range": + assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( + len(min_size) + ) + + logger = logging.getLogger(__name__) + tfm_gens = [] + if is_train: + tfm_gens.append(T.RandomFlip()) + tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) + if is_train: + logger.info("TransformGens used in training: " + str(tfm_gens)) + return tfm_gens + + +# This is specifically designed for the COCO dataset. +class DETRPanopticDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by MaskFormer. + + This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + @configurable + def __init__( + self, + is_train=True, + *, + crop_gen, + tfm_gens, + image_format, + ): + """ + NOTE: this interface is experimental. + Args: + is_train: for training or inference + augmentations: a list of augmentations or deterministic transforms to apply + crop_gen: crop augmentation + tfm_gens: data augmentation + image_format: an image format supported by :func:`detection_utils.read_image`. + """ + self.crop_gen = crop_gen + self.tfm_gens = tfm_gens + logging.getLogger(__name__).info( + "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format( + str(self.tfm_gens), str(self.crop_gen) + ) + ) + + self.img_format = image_format + self.is_train = is_train + + @classmethod + def from_config(cls, cfg, is_train=True): + # Build augmentation + if cfg.INPUT.CROP.ENABLED and is_train: + crop_gen = [ + T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), + T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), + ] + else: + crop_gen = None + + tfm_gens = build_transform_gen(cfg, is_train) + + ret = { + "is_train": is_train, + "crop_gen": crop_gen, + "tfm_gens": tfm_gens, + "image_format": cfg.INPUT.FORMAT, + } + return ret + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + if self.crop_gen is None: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + if np.random.rand() > 0.5: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + image, transforms = T.apply_transform_gens( + self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image + ) + + image_shape = image.shape[:2] # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + return dataset_dict + + if "pan_seg_file_name" in dataset_dict: + pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") + segments_info = dataset_dict["segments_info"] + + # apply the same transformation to panoptic segmentation + pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) + + from panopticapi.utils import rgb2id + + pan_seg_gt = rgb2id(pan_seg_gt) + + instances = Instances(image_shape) + classes = [] + masks = [] + for segment_info in segments_info: + class_id = segment_info["category_id"] + if not segment_info["iscrowd"]: + classes.append(class_id) + masks.append(pan_seg_gt == segment_info["id"]) + + classes = np.array(classes) + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) + else: + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + + return dataset_dict diff --git a/mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py b/mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbc2bd77fb1b17540dd5272cfc6534ee2b6e2df --- /dev/null +++ b/mask_former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging + +import numpy as np +import torch +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.structures import BitMasks, Instances + +from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper + +__all__ = ["MaskFormerPanopticDatasetMapper"] + + +class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper): + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by MaskFormer for panoptic segmentation. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + @configurable + def __init__( + self, + is_train=True, + *, + augmentations, + image_format, + ignore_label, + size_divisibility, + ): + """ + NOTE: this interface is experimental. + Args: + is_train: for training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + ignore_label: the label that is ignored to evaluation + size_divisibility: pad image size to be divisible by this value + """ + super().__init__( + is_train, + augmentations=augmentations, + image_format=image_format, + ignore_label=ignore_label, + size_divisibility=size_divisibility, + ) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!" + + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + # semantic segmentation + if "sem_seg_file_name" in dataset_dict: + # PyTorch transformation not implemented for uint16, so converting it to double first + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") + else: + sem_seg_gt = None + + # panoptic segmentation + if "pan_seg_file_name" in dataset_dict: + pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") + segments_info = dataset_dict["segments_info"] + else: + pan_seg_gt = None + segments_info = None + + if pan_seg_gt is None: + raise ValueError( + "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format( + dataset_dict["file_name"] + ) + ) + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) + image = aug_input.image + if sem_seg_gt is not None: + sem_seg_gt = aug_input.sem_seg + + # apply the same transformation to panoptic segmentation + pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) + + from panopticapi.utils import rgb2id + + pan_seg_gt = rgb2id(pan_seg_gt) + + # Pad image and segmentation label here! + image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long")) + + if self.size_divisibility > 0: + image_size = (image.shape[-2], image.shape[-1]) + padding_size = [ + 0, + self.size_divisibility - image_size[1], + 0, + self.size_divisibility - image_size[0], + ] + image = F.pad(image, padding_size, value=128).contiguous() + if sem_seg_gt is not None: + sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() + pan_seg_gt = F.pad( + pan_seg_gt, padding_size, value=0 + ).contiguous() # 0 is the VOID panoptic label + + image_shape = (image.shape[-2], image.shape[-1]) # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = image + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = sem_seg_gt.long() + + if "annotations" in dataset_dict: + raise ValueError("Pemantic segmentation dataset should not have 'annotations'.") + + # Prepare per-category binary masks + pan_seg_gt = pan_seg_gt.numpy() + instances = Instances(image_shape) + classes = [] + masks = [] + for segment_info in segments_info: + class_id = segment_info["category_id"] + if not segment_info["iscrowd"]: + classes.append(class_id) + masks.append(pan_seg_gt == segment_info["id"]) + + classes = np.array(classes) + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) + else: + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + + return dataset_dict diff --git a/mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py b/mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..36ff3153b0c84462ea14f1bf3273668217f14678 --- /dev/null +++ b/mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging + +import numpy as np +import torch +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.projects.point_rend import ColorAugSSDTransform +from detectron2.structures import BitMasks, Instances + +__all__ = ["MaskFormerSemanticDatasetMapper"] + + +class MaskFormerSemanticDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by MaskFormer for semantic segmentation. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + @configurable + def __init__( + self, + is_train=True, + *, + augmentations, + image_format, + ignore_label, + size_divisibility, + ): + """ + NOTE: this interface is experimental. + Args: + is_train: for training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + ignore_label: the label that is ignored to evaluation + size_divisibility: pad image size to be divisible by this value + """ + self.is_train = is_train + self.tfm_gens = augmentations + self.img_format = image_format + self.ignore_label = ignore_label + self.size_divisibility = size_divisibility + + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train=True): + # Build augmentation + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, + cfg.INPUT.MAX_SIZE_TRAIN, + cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append( + T.RandomCrop_CategoryAreaConstraint( + cfg.INPUT.CROP.TYPE, + cfg.INPUT.CROP.SIZE, + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, + cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + ) + ) + if cfg.INPUT.COLOR_AUG_SSD: + augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) + augs.append(T.RandomFlip()) + + # Assume always applies to the training set. + dataset_names = cfg.DATASETS.TRAIN + meta = MetadataCatalog.get(dataset_names[0]) + ignore_label = meta.ignore_label + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "ignore_label": ignore_label, + "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY, + } + return ret + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!" + + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + if "sem_seg_file_name" in dataset_dict: + # PyTorch transformation not implemented for uint16, so converting it to double first + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double") + else: + sem_seg_gt = None + + if sem_seg_gt is None: + raise ValueError( + "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( + dataset_dict["file_name"] + ) + ) + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) + image = aug_input.image + sem_seg_gt = aug_input.sem_seg + + # Pad image and segmentation label here! + image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + + if self.size_divisibility > 0: + image_size = (image.shape[-2], image.shape[-1]) + padding_size = [ + 0, + self.size_divisibility - image_size[1], + 0, + self.size_divisibility - image_size[0], + ] + image = F.pad(image, padding_size, value=128).contiguous() + if sem_seg_gt is not None: + sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous() + + image_shape = (image.shape[-2], image.shape[-1]) # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = image + + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = sem_seg_gt.long() + + if "annotations" in dataset_dict: + raise ValueError("Semantic segmentation dataset should not have 'annotations'.") + + # Prepare per-category binary masks + if sem_seg_gt is not None: + sem_seg_gt = sem_seg_gt.numpy() + instances = Instances(image_shape) + classes = np.unique(sem_seg_gt) + # remove ignored region + classes = classes[classes != self.ignore_label] + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + + masks = [] + for class_id in classes: + masks.append(sem_seg_gt == class_id) + + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])) + else: + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + + return dataset_dict diff --git a/mask_former/data/datasets/__init__.py b/mask_former/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2492d388415250569ac01ae79108a6de05e57528 --- /dev/null +++ b/mask_former/data/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import ( + register_ade20k_full, + register_ade20k_panoptic, + register_coco_stuff_10k, + register_mapillary_vistas, +) diff --git a/mask_former/data/datasets/register_ade20k_full.py b/mask_former/data/datasets/register_ade20k_full.py new file mode 100644 index 0000000000000000000000000000000000000000..7121a22227583b29a6e167b560703e33371f1081 --- /dev/null +++ b/mask_former/data/datasets/register_ade20k_full.py @@ -0,0 +1,964 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +ADE20K_SEM_SEG_FULL_CATEGORIES = [ + {"name": "wall", "id": 2978, "trainId": 0}, + {"name": "building, edifice", "id": 312, "trainId": 1}, + {"name": "sky", "id": 2420, "trainId": 2}, + {"name": "tree", "id": 2855, "trainId": 3}, + {"name": "road, route", "id": 2131, "trainId": 4}, + {"name": "floor, flooring", "id": 976, "trainId": 5}, + {"name": "ceiling", "id": 447, "trainId": 6}, + {"name": "bed", "id": 165, "trainId": 7}, + {"name": "sidewalk, pavement", "id": 2377, "trainId": 8}, + {"name": "earth, ground", "id": 838, "trainId": 9}, + {"name": "cabinet", "id": 350, "trainId": 10}, + {"name": "person, individual, someone, somebody, mortal, soul", "id": 1831, "trainId": 11}, + {"name": "grass", "id": 1125, "trainId": 12}, + {"name": "windowpane, window", "id": 3055, "trainId": 13}, + {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14}, + {"name": "mountain, mount", "id": 1610, "trainId": 15}, + {"name": "plant, flora, plant life", "id": 1910, "trainId": 16}, + {"name": "table", "id": 2684, "trainId": 17}, + {"name": "chair", "id": 471, "trainId": 18}, + {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19}, + {"name": "door", "id": 774, "trainId": 20}, + {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21}, + {"name": "sea", "id": 2264, "trainId": 22}, + {"name": "painting, picture", "id": 1735, "trainId": 23}, + {"name": "water", "id": 2994, "trainId": 24}, + {"name": "mirror", "id": 1564, "trainId": 25}, + {"name": "house", "id": 1276, "trainId": 26}, + {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27}, + {"name": "shelf", "id": 2329, "trainId": 28}, + {"name": "armchair", "id": 57, "trainId": 29}, + {"name": "fence, fencing", "id": 907, "trainId": 30}, + {"name": "field", "id": 913, "trainId": 31}, + {"name": "lamp", "id": 1395, "trainId": 32}, + {"name": "rock, stone", "id": 2138, "trainId": 33}, + {"name": "seat", "id": 2272, "trainId": 34}, + {"name": "river", "id": 2128, "trainId": 35}, + {"name": "desk", "id": 724, "trainId": 36}, + {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37}, + {"name": "railing, rail", "id": 2053, "trainId": 38}, + {"name": "signboard, sign", "id": 2380, "trainId": 39}, + {"name": "cushion", "id": 689, "trainId": 40}, + {"name": "path", "id": 1788, "trainId": 41}, + {"name": "work surface", "id": 3087, "trainId": 42}, + {"name": "stairs, steps", "id": 2530, "trainId": 43}, + {"name": "column, pillar", "id": 581, "trainId": 44}, + {"name": "sink", "id": 2388, "trainId": 45}, + {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46}, + {"name": "snow", "id": 2454, "trainId": 47}, + {"name": "refrigerator, icebox", "id": 2096, "trainId": 48}, + {"name": "base, pedestal, stand", "id": 137, "trainId": 49}, + {"name": "bridge, span", "id": 294, "trainId": 50}, + {"name": "blind, screen", "id": 212, "trainId": 51}, + {"name": "runway", "id": 2185, "trainId": 52}, + {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53}, + {"name": "sand", "id": 2212, "trainId": 54}, + {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55}, + {"name": "pillow", "id": 1869, "trainId": 56}, + {"name": "screen door, screen", "id": 2251, "trainId": 57}, + {"name": "toilet, can, commode, crapper, pot, potty, stool, throne", "id": 2793, "trainId": 58}, + {"name": "skyscraper", "id": 2423, "trainId": 59}, + {"name": "grandstand, covered stand", "id": 1121, "trainId": 60}, + {"name": "box", "id": 266, "trainId": 61}, + {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62}, + {"name": "palm, palm tree", "id": 1744, "trainId": 63}, + {"name": "double door", "id": 783, "trainId": 64}, + {"name": "coffee table, cocktail table", "id": 571, "trainId": 65}, + {"name": "counter", "id": 627, "trainId": 66}, + {"name": "countertop", "id": 629, "trainId": 67}, + {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68}, + {"name": "kitchen island", "id": 1374, "trainId": 69}, + {"name": "boat", "id": 223, "trainId": 70}, + {"name": "waterfall, falls", "id": 3016, "trainId": 71}, + { + "name": "stove, kitchen stove, range, kitchen range, cooking stove", + "id": 2598, + "trainId": 72, + }, + {"name": "flower", "id": 978, "trainId": 73}, + {"name": "bookcase", "id": 239, "trainId": 74}, + {"name": "controls", "id": 608, "trainId": 75}, + {"name": "book", "id": 236, "trainId": 76}, + {"name": "stairway, staircase", "id": 2531, "trainId": 77}, + {"name": "streetlight, street lamp", "id": 2616, "trainId": 78}, + { + "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system", + "id": 591, + "trainId": 79, + }, + { + "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle", + "id": 327, + "trainId": 80, + }, + {"name": "swivel chair", "id": 2679, "trainId": 81}, + {"name": "light, light source", "id": 1451, "trainId": 82}, + {"name": "bench", "id": 181, "trainId": 83}, + {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84}, + {"name": "towel", "id": 2821, "trainId": 85}, + {"name": "fountain", "id": 1023, "trainId": 86}, + {"name": "embankment", "id": 855, "trainId": 87}, + { + "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box", + "id": 2733, + "trainId": 88, + }, + {"name": "van", "id": 2928, "trainId": 89}, + {"name": "hill", "id": 1240, "trainId": 90}, + {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91}, + {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92}, + {"name": "truck, motortruck", "id": 2880, "trainId": 93}, + {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94}, + {"name": "pole", "id": 1936, "trainId": 95}, + {"name": "tower", "id": 2828, "trainId": 96}, + {"name": "court", "id": 631, "trainId": 97}, + {"name": "ball", "id": 103, "trainId": 98}, + { + "name": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "id": 3144, + "trainId": 99, + }, + {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100}, + {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101}, + {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102}, + {"name": "minibike, motorbike", "id": 1563, "trainId": 103}, + {"name": "animal, animate being, beast, brute, creature, fauna", "id": 29, "trainId": 104}, + {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105}, + {"name": "step, stair", "id": 2569, "trainId": 106}, + {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107}, + {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108}, + {"name": "doorframe, doorcase", "id": 778, "trainId": 109}, + {"name": "sconce", "id": 2243, "trainId": 110}, + {"name": "pond", "id": 1941, "trainId": 111}, + {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112}, + {"name": "bannister, banister, balustrade, balusters, handrail", "id": 120, "trainId": 113}, + {"name": "bag", "id": 95, "trainId": 114}, + {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115}, + {"name": "gazebo", "id": 1087, "trainId": 116}, + {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117}, + {"name": "land, ground, soil", "id": 1401, "trainId": 118}, + {"name": "board, plank", "id": 220, "trainId": 119}, + {"name": "arcade machine", "id": 47, "trainId": 120}, + {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121}, + {"name": "bar", "id": 123, "trainId": 122}, + {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123}, + {"name": "playground", "id": 1927, "trainId": 124}, + {"name": "ship", "id": 2337, "trainId": 125}, + {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126}, + { + "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "id": 64, + "trainId": 127, + }, + {"name": "bottle", "id": 249, "trainId": 128}, + {"name": "cradle", "id": 642, "trainId": 129}, + {"name": "pot, flowerpot", "id": 1981, "trainId": 130}, + { + "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", + "id": 609, + "trainId": 131, + }, + {"name": "train, railroad train", "id": 2840, "trainId": 132}, + {"name": "stool", "id": 2586, "trainId": 133}, + {"name": "lake", "id": 1393, "trainId": 134}, + {"name": "tank, storage tank", "id": 2704, "trainId": 135}, + {"name": "ice, water ice", "id": 1304, "trainId": 136}, + {"name": "basket, handbasket", "id": 146, "trainId": 137}, + {"name": "manhole", "id": 1494, "trainId": 138}, + {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139}, + {"name": "canopy", "id": 389, "trainId": 140}, + {"name": "microwave, microwave oven", "id": 1551, "trainId": 141}, + {"name": "barrel, cask", "id": 131, "trainId": 142}, + {"name": "dirt track", "id": 738, "trainId": 143}, + {"name": "beam", "id": 161, "trainId": 144}, + {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145}, + {"name": "plate", "id": 1919, "trainId": 146}, + {"name": "screen, crt screen", "id": 3109, "trainId": 147}, + {"name": "ruins", "id": 2179, "trainId": 148}, + {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149}, + {"name": "blanket, cover", "id": 206, "trainId": 150}, + {"name": "plaything, toy", "id": 1930, "trainId": 151}, + {"name": "food, solid food", "id": 1002, "trainId": 152}, + {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153}, + {"name": "oven", "id": 1708, "trainId": 154}, + {"name": "stage", "id": 2526, "trainId": 155}, + {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156}, + {"name": "umbrella", "id": 2901, "trainId": 157}, + {"name": "sculpture", "id": 2262, "trainId": 158}, + {"name": "aqueduct", "id": 44, "trainId": 159}, + {"name": "container", "id": 597, "trainId": 160}, + {"name": "scaffolding, staging", "id": 2235, "trainId": 161}, + {"name": "hood, exhaust hood", "id": 1260, "trainId": 162}, + {"name": "curb, curbing, kerb", "id": 682, "trainId": 163}, + {"name": "roller coaster", "id": 2151, "trainId": 164}, + {"name": "horse, equus caballus", "id": 3107, "trainId": 165}, + {"name": "catwalk", "id": 432, "trainId": 166}, + {"name": "glass, drinking glass", "id": 1098, "trainId": 167}, + {"name": "vase", "id": 2932, "trainId": 168}, + {"name": "central reservation", "id": 461, "trainId": 169}, + {"name": "carousel", "id": 410, "trainId": 170}, + {"name": "radiator", "id": 2046, "trainId": 171}, + {"name": "closet", "id": 533, "trainId": 172}, + {"name": "machine", "id": 1481, "trainId": 173}, + {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174}, + {"name": "fan", "id": 894, "trainId": 175}, + {"name": "inflatable bounce game", "id": 1322, "trainId": 176}, + {"name": "pitch", "id": 1891, "trainId": 177}, + {"name": "paper", "id": 1756, "trainId": 178}, + {"name": "arcade, colonnade", "id": 49, "trainId": 179}, + {"name": "hot tub", "id": 1272, "trainId": 180}, + {"name": "helicopter", "id": 1229, "trainId": 181}, + {"name": "tray", "id": 2850, "trainId": 182}, + {"name": "partition, divider", "id": 1784, "trainId": 183}, + {"name": "vineyard", "id": 2962, "trainId": 184}, + {"name": "bowl", "id": 259, "trainId": 185}, + {"name": "bullring", "id": 319, "trainId": 186}, + {"name": "flag", "id": 954, "trainId": 187}, + {"name": "pot", "id": 1974, "trainId": 188}, + {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189}, + {"name": "shower", "id": 2356, "trainId": 190}, + {"name": "bag, traveling bag, travelling bag, grip, suitcase", "id": 97, "trainId": 191}, + {"name": "bulletin board, notice board", "id": 318, "trainId": 192}, + {"name": "confessional booth", "id": 592, "trainId": 193}, + {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194}, + {"name": "forest", "id": 1017, "trainId": 195}, + {"name": "elevator door", "id": 851, "trainId": 196}, + {"name": "laptop, laptop computer", "id": 1407, "trainId": 197}, + {"name": "instrument panel", "id": 1332, "trainId": 198}, + {"name": "bucket, pail", "id": 303, "trainId": 199}, + {"name": "tapestry, tapis", "id": 2714, "trainId": 200}, + {"name": "platform", "id": 1924, "trainId": 201}, + {"name": "jacket", "id": 1346, "trainId": 202}, + {"name": "gate", "id": 1081, "trainId": 203}, + {"name": "monitor, monitoring device", "id": 1583, "trainId": 204}, + { + "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk", + "id": 2727, + "trainId": 205, + }, + {"name": "spotlight, spot", "id": 2509, "trainId": 206}, + {"name": "ring", "id": 2123, "trainId": 207}, + {"name": "control panel", "id": 602, "trainId": 208}, + {"name": "blackboard, chalkboard", "id": 202, "trainId": 209}, + {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210}, + {"name": "chest", "id": 490, "trainId": 211}, + {"name": "clock", "id": 530, "trainId": 212}, + {"name": "sand dune", "id": 2213, "trainId": 213}, + {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214}, + {"name": "vault", "id": 2934, "trainId": 215}, + {"name": "table football", "id": 2687, "trainId": 216}, + {"name": "cannon", "id": 387, "trainId": 217}, + {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218}, + {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219}, + {"name": "statue", "id": 2547, "trainId": 220}, + { + "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "id": 1474, + "trainId": 221, + }, + {"name": "exhibitor", "id": 877, "trainId": 222}, + {"name": "ladder", "id": 1391, "trainId": 223}, + {"name": "carport", "id": 414, "trainId": 224}, + {"name": "dam", "id": 698, "trainId": 225}, + {"name": "pulpit", "id": 2019, "trainId": 226}, + {"name": "skylight, fanlight", "id": 2422, "trainId": 227}, + {"name": "water tower", "id": 3010, "trainId": 228}, + {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229}, + {"name": "display board", "id": 753, "trainId": 230}, + {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231}, + {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232}, + {"name": "ice rink", "id": 1301, "trainId": 233}, + {"name": "fruit", "id": 1033, "trainId": 234}, + {"name": "patio", "id": 1789, "trainId": 235}, + {"name": "vending machine", "id": 2939, "trainId": 236}, + {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237}, + {"name": "net", "id": 1652, "trainId": 238}, + { + "name": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "id": 90, + "trainId": 239, + }, + {"name": "jar", "id": 1349, "trainId": 240}, + {"name": "track", "id": 2830, "trainId": 241}, + {"name": "magazine", "id": 1485, "trainId": 242}, + {"name": "shutter", "id": 2370, "trainId": 243}, + {"name": "roof", "id": 2155, "trainId": 244}, + {"name": "banner, streamer", "id": 118, "trainId": 245}, + {"name": "landfill", "id": 1402, "trainId": 246}, + {"name": "post", "id": 1957, "trainId": 247}, + {"name": "altarpiece, reredos", "id": 3130, "trainId": 248}, + {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249}, + {"name": "arch, archway", "id": 52, "trainId": 250}, + {"name": "table game", "id": 2688, "trainId": 251}, + {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252}, + {"name": "document, written document, papers", "id": 762, "trainId": 253}, + {"name": "dome", "id": 772, "trainId": 254}, + {"name": "pier", "id": 1857, "trainId": 255}, + {"name": "shanties", "id": 2315, "trainId": 256}, + {"name": "forecourt", "id": 1016, "trainId": 257}, + {"name": "crane", "id": 643, "trainId": 258}, + {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259}, + {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260}, + {"name": "drawing", "id": 791, "trainId": 261}, + {"name": "cabin", "id": 349, "trainId": 262}, + { + "name": "ad, advertisement, advertizement, advertising, advertizing, advert", + "id": 6, + "trainId": 263, + }, + {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264}, + {"name": "monument", "id": 1587, "trainId": 265}, + {"name": "henhouse", "id": 1233, "trainId": 266}, + {"name": "cockpit", "id": 559, "trainId": 267}, + {"name": "heater, warmer", "id": 1223, "trainId": 268}, + {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269}, + {"name": "pool", "id": 1943, "trainId": 270}, + {"name": "elevator, lift", "id": 853, "trainId": 271}, + {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272}, + {"name": "labyrinth", "id": 1390, "trainId": 273}, + {"name": "text, textual matter", "id": 2748, "trainId": 274}, + {"name": "printer", "id": 2007, "trainId": 275}, + {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276}, + {"name": "mattress", "id": 1513, "trainId": 277}, + {"name": "straw", "id": 2600, "trainId": 278}, + {"name": "stalls", "id": 2538, "trainId": 279}, + {"name": "patio, terrace", "id": 1790, "trainId": 280}, + {"name": "billboard, hoarding", "id": 194, "trainId": 281}, + {"name": "bus stop", "id": 326, "trainId": 282}, + {"name": "trouser, pant", "id": 2877, "trainId": 283}, + {"name": "console table, console", "id": 594, "trainId": 284}, + {"name": "rack", "id": 2036, "trainId": 285}, + {"name": "notebook", "id": 1662, "trainId": 286}, + {"name": "shrine", "id": 2366, "trainId": 287}, + {"name": "pantry", "id": 1754, "trainId": 288}, + {"name": "cart", "id": 418, "trainId": 289}, + {"name": "steam shovel", "id": 2553, "trainId": 290}, + {"name": "porch", "id": 1951, "trainId": 291}, + {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292}, + {"name": "figurine, statuette", "id": 918, "trainId": 293}, + {"name": "recycling bin", "id": 2086, "trainId": 294}, + {"name": "folding screen", "id": 997, "trainId": 295}, + {"name": "telescope", "id": 2731, "trainId": 296}, + {"name": "deck chair, beach chair", "id": 704, "trainId": 297}, + {"name": "kennel", "id": 1365, "trainId": 298}, + {"name": "coffee maker", "id": 569, "trainId": 299}, + {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300}, + {"name": "fish", "id": 948, "trainId": 301}, + {"name": "easel", "id": 839, "trainId": 302}, + {"name": "artificial golf green", "id": 63, "trainId": 303}, + {"name": "iceberg", "id": 1305, "trainId": 304}, + {"name": "candlestick, candle holder", "id": 378, "trainId": 305}, + {"name": "shower stall, shower bath", "id": 2362, "trainId": 306}, + {"name": "television stand", "id": 2734, "trainId": 307}, + { + "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle", + "id": 2982, + "trainId": 308, + }, + {"name": "skeleton", "id": 2398, "trainId": 309}, + {"name": "grand piano, grand", "id": 1119, "trainId": 310}, + {"name": "candy, confect", "id": 382, "trainId": 311}, + {"name": "grille door", "id": 1141, "trainId": 312}, + {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313}, + {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314}, + {"name": "shoe", "id": 2341, "trainId": 315}, + {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316}, + {"name": "shanty", "id": 2316, "trainId": 317}, + {"name": "structure", "id": 2626, "trainId": 318}, + {"name": "rocking chair, rocker", "id": 3104, "trainId": 319}, + {"name": "bird", "id": 198, "trainId": 320}, + {"name": "place mat", "id": 1896, "trainId": 321}, + {"name": "tomb", "id": 2800, "trainId": 322}, + {"name": "big top", "id": 190, "trainId": 323}, + {"name": "gas pump, gasoline pump, petrol pump, island dispenser", "id": 3131, "trainId": 324}, + {"name": "lockers", "id": 1463, "trainId": 325}, + {"name": "cage", "id": 357, "trainId": 326}, + {"name": "finger", "id": 929, "trainId": 327}, + {"name": "bleachers", "id": 209, "trainId": 328}, + {"name": "ferris wheel", "id": 912, "trainId": 329}, + {"name": "hairdresser chair", "id": 1164, "trainId": 330}, + {"name": "mat", "id": 1509, "trainId": 331}, + {"name": "stands", "id": 2539, "trainId": 332}, + {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333}, + {"name": "streetcar, tram, tramcar, trolley, trolley car", "id": 2615, "trainId": 334}, + {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335}, + {"name": "dummy", "id": 818, "trainId": 336}, + {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337}, + {"name": "sand trap", "id": 2217, "trainId": 338}, + {"name": "shop, store", "id": 2347, "trainId": 339}, + {"name": "table cloth", "id": 2686, "trainId": 340}, + {"name": "service station", "id": 2300, "trainId": 341}, + {"name": "coffin", "id": 572, "trainId": 342}, + {"name": "drawer", "id": 789, "trainId": 343}, + {"name": "cages", "id": 358, "trainId": 344}, + {"name": "slot machine, coin machine", "id": 2443, "trainId": 345}, + {"name": "balcony", "id": 101, "trainId": 346}, + {"name": "volleyball court", "id": 2969, "trainId": 347}, + {"name": "table tennis", "id": 2692, "trainId": 348}, + {"name": "control table", "id": 606, "trainId": 349}, + {"name": "shirt", "id": 2339, "trainId": 350}, + {"name": "merchandise, ware, product", "id": 1533, "trainId": 351}, + {"name": "railway", "id": 2060, "trainId": 352}, + {"name": "parterre", "id": 1782, "trainId": 353}, + {"name": "chimney", "id": 495, "trainId": 354}, + {"name": "can, tin, tin can", "id": 371, "trainId": 355}, + {"name": "tanks", "id": 2707, "trainId": 356}, + {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357}, + {"name": "alga, algae", "id": 3156, "trainId": 358}, + {"name": "system", "id": 2683, "trainId": 359}, + {"name": "map", "id": 1499, "trainId": 360}, + {"name": "greenhouse", "id": 1135, "trainId": 361}, + {"name": "mug", "id": 1619, "trainId": 362}, + {"name": "barbecue", "id": 125, "trainId": 363}, + {"name": "trailer", "id": 2838, "trainId": 364}, + {"name": "toilet tissue, toilet paper, bathroom tissue", "id": 2792, "trainId": 365}, + {"name": "organ", "id": 1695, "trainId": 366}, + {"name": "dishrag, dishcloth", "id": 746, "trainId": 367}, + {"name": "island", "id": 1343, "trainId": 368}, + {"name": "keyboard", "id": 1370, "trainId": 369}, + {"name": "trench", "id": 2858, "trainId": 370}, + {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371}, + {"name": "steering wheel, wheel", "id": 2565, "trainId": 372}, + {"name": "pitcher, ewer", "id": 1892, "trainId": 373}, + {"name": "goal", "id": 1103, "trainId": 374}, + {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375}, + {"name": "beds", "id": 170, "trainId": 376}, + {"name": "wood", "id": 3073, "trainId": 377}, + {"name": "file cabinet", "id": 922, "trainId": 378}, + {"name": "newspaper, paper", "id": 1655, "trainId": 379}, + {"name": "motorboat", "id": 1602, "trainId": 380}, + {"name": "rope", "id": 2160, "trainId": 381}, + {"name": "guitar", "id": 1151, "trainId": 382}, + {"name": "rubble", "id": 2176, "trainId": 383}, + {"name": "scarf", "id": 2239, "trainId": 384}, + {"name": "barrels", "id": 132, "trainId": 385}, + {"name": "cap", "id": 394, "trainId": 386}, + {"name": "leaves", "id": 1424, "trainId": 387}, + {"name": "control tower", "id": 607, "trainId": 388}, + {"name": "dashboard", "id": 700, "trainId": 389}, + {"name": "bandstand", "id": 116, "trainId": 390}, + {"name": "lectern", "id": 1425, "trainId": 391}, + {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392}, + {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393}, + {"name": "shower room", "id": 2360, "trainId": 394}, + {"name": "smoke", "id": 2449, "trainId": 395}, + {"name": "faucet, spigot", "id": 897, "trainId": 396}, + {"name": "bulldozer", "id": 317, "trainId": 397}, + {"name": "saucepan", "id": 2228, "trainId": 398}, + {"name": "shops", "id": 2351, "trainId": 399}, + {"name": "meter", "id": 1543, "trainId": 400}, + {"name": "crevasse", "id": 656, "trainId": 401}, + {"name": "gear", "id": 1088, "trainId": 402}, + {"name": "candelabrum, candelabra", "id": 373, "trainId": 403}, + {"name": "sofa bed", "id": 2472, "trainId": 404}, + {"name": "tunnel", "id": 2892, "trainId": 405}, + {"name": "pallet", "id": 1740, "trainId": 406}, + {"name": "wire, conducting wire", "id": 3067, "trainId": 407}, + {"name": "kettle, boiler", "id": 1367, "trainId": 408}, + {"name": "bidet", "id": 188, "trainId": 409}, + { + "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher", + "id": 79, + "trainId": 410, + }, + {"name": "music stand", "id": 1633, "trainId": 411}, + {"name": "pipe, tube", "id": 1885, "trainId": 412}, + {"name": "cup", "id": 677, "trainId": 413}, + {"name": "parking meter", "id": 1779, "trainId": 414}, + {"name": "ice hockey rink", "id": 1297, "trainId": 415}, + {"name": "shelter", "id": 2334, "trainId": 416}, + {"name": "weeds", "id": 3027, "trainId": 417}, + {"name": "temple", "id": 2735, "trainId": 418}, + {"name": "patty, cake", "id": 1791, "trainId": 419}, + {"name": "ski slope", "id": 2405, "trainId": 420}, + {"name": "panel", "id": 1748, "trainId": 421}, + {"name": "wallet", "id": 2983, "trainId": 422}, + {"name": "wheel", "id": 3035, "trainId": 423}, + {"name": "towel rack, towel horse", "id": 2824, "trainId": 424}, + {"name": "roundabout", "id": 2168, "trainId": 425}, + {"name": "canister, cannister, tin", "id": 385, "trainId": 426}, + {"name": "rod", "id": 2148, "trainId": 427}, + {"name": "soap dispenser", "id": 2465, "trainId": 428}, + {"name": "bell", "id": 175, "trainId": 429}, + {"name": "canvas", "id": 390, "trainId": 430}, + {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431}, + {"name": "teacup", "id": 2722, "trainId": 432}, + {"name": "trellis", "id": 2857, "trainId": 433}, + {"name": "workbench", "id": 3088, "trainId": 434}, + {"name": "valley, vale", "id": 2926, "trainId": 435}, + {"name": "toaster", "id": 2782, "trainId": 436}, + {"name": "knife", "id": 1378, "trainId": 437}, + {"name": "podium", "id": 1934, "trainId": 438}, + {"name": "ramp", "id": 2072, "trainId": 439}, + {"name": "tumble dryer", "id": 2889, "trainId": 440}, + {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441}, + {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442}, + {"name": "lab bench", "id": 1383, "trainId": 443}, + {"name": "equipment", "id": 867, "trainId": 444}, + {"name": "rocky formation", "id": 2145, "trainId": 445}, + {"name": "plastic", "id": 1915, "trainId": 446}, + {"name": "calendar", "id": 361, "trainId": 447}, + {"name": "caravan", "id": 402, "trainId": 448}, + {"name": "check-in-desk", "id": 482, "trainId": 449}, + {"name": "ticket counter", "id": 2761, "trainId": 450}, + {"name": "brush", "id": 300, "trainId": 451}, + {"name": "mill", "id": 1554, "trainId": 452}, + {"name": "covered bridge", "id": 636, "trainId": 453}, + {"name": "bowling alley", "id": 260, "trainId": 454}, + {"name": "hanger", "id": 1186, "trainId": 455}, + {"name": "excavator", "id": 871, "trainId": 456}, + {"name": "trestle", "id": 2859, "trainId": 457}, + {"name": "revolving door", "id": 2103, "trainId": 458}, + {"name": "blast furnace", "id": 208, "trainId": 459}, + {"name": "scale, weighing machine", "id": 2236, "trainId": 460}, + {"name": "projector", "id": 2012, "trainId": 461}, + {"name": "soap", "id": 2462, "trainId": 462}, + {"name": "locker", "id": 1462, "trainId": 463}, + {"name": "tractor", "id": 2832, "trainId": 464}, + {"name": "stretcher", "id": 2617, "trainId": 465}, + {"name": "frame", "id": 1024, "trainId": 466}, + {"name": "grating", "id": 1129, "trainId": 467}, + {"name": "alembic", "id": 18, "trainId": 468}, + {"name": "candle, taper, wax light", "id": 376, "trainId": 469}, + {"name": "barrier", "id": 134, "trainId": 470}, + {"name": "cardboard", "id": 407, "trainId": 471}, + {"name": "cave", "id": 434, "trainId": 472}, + {"name": "puddle", "id": 2017, "trainId": 473}, + {"name": "tarp", "id": 2717, "trainId": 474}, + {"name": "price tag", "id": 2005, "trainId": 475}, + {"name": "watchtower", "id": 2993, "trainId": 476}, + {"name": "meters", "id": 1545, "trainId": 477}, + { + "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb", + "id": 1445, + "trainId": 478, + }, + {"name": "tracks", "id": 2831, "trainId": 479}, + {"name": "hair dryer", "id": 1161, "trainId": 480}, + {"name": "skirt", "id": 2411, "trainId": 481}, + {"name": "viaduct", "id": 2949, "trainId": 482}, + {"name": "paper towel", "id": 1769, "trainId": 483}, + {"name": "coat", "id": 552, "trainId": 484}, + {"name": "sheet", "id": 2327, "trainId": 485}, + {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486}, + {"name": "water wheel", "id": 3013, "trainId": 487}, + {"name": "pottery, clayware", "id": 1986, "trainId": 488}, + {"name": "magazine rack", "id": 1486, "trainId": 489}, + {"name": "teapot", "id": 2723, "trainId": 490}, + {"name": "microphone, mike", "id": 1549, "trainId": 491}, + {"name": "support", "id": 2649, "trainId": 492}, + {"name": "forklift", "id": 1020, "trainId": 493}, + {"name": "canyon", "id": 392, "trainId": 494}, + {"name": "cash register, register", "id": 422, "trainId": 495}, + {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496}, + {"name": "remote control, remote", "id": 2099, "trainId": 497}, + {"name": "soap dish", "id": 2464, "trainId": 498}, + {"name": "windshield, windscreen", "id": 3058, "trainId": 499}, + {"name": "cat", "id": 430, "trainId": 500}, + {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501}, + {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502}, + {"name": "videos", "id": 2955, "trainId": 503}, + {"name": "shovel", "id": 2355, "trainId": 504}, + {"name": "eaves", "id": 840, "trainId": 505}, + {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506}, + {"name": "shipyard", "id": 2338, "trainId": 507}, + {"name": "hen, biddy", "id": 1232, "trainId": 508}, + {"name": "traffic cone", "id": 2834, "trainId": 509}, + {"name": "washing machines", "id": 2991, "trainId": 510}, + {"name": "truck crane", "id": 2879, "trainId": 511}, + {"name": "cds", "id": 444, "trainId": 512}, + {"name": "niche", "id": 1657, "trainId": 513}, + {"name": "scoreboard", "id": 2246, "trainId": 514}, + {"name": "briefcase", "id": 296, "trainId": 515}, + {"name": "boot", "id": 245, "trainId": 516}, + {"name": "sweater, jumper", "id": 2661, "trainId": 517}, + {"name": "hay", "id": 1202, "trainId": 518}, + {"name": "pack", "id": 1714, "trainId": 519}, + {"name": "bottle rack", "id": 251, "trainId": 520}, + {"name": "glacier", "id": 1095, "trainId": 521}, + {"name": "pergola", "id": 1828, "trainId": 522}, + {"name": "building materials", "id": 311, "trainId": 523}, + {"name": "television camera", "id": 2732, "trainId": 524}, + {"name": "first floor", "id": 947, "trainId": 525}, + {"name": "rifle", "id": 2115, "trainId": 526}, + {"name": "tennis table", "id": 2738, "trainId": 527}, + {"name": "stadium", "id": 2525, "trainId": 528}, + {"name": "safety belt", "id": 2194, "trainId": 529}, + {"name": "cover", "id": 634, "trainId": 530}, + {"name": "dish rack", "id": 740, "trainId": 531}, + {"name": "synthesizer", "id": 2682, "trainId": 532}, + {"name": "pumpkin", "id": 2020, "trainId": 533}, + {"name": "gutter", "id": 1156, "trainId": 534}, + {"name": "fruit stand", "id": 1036, "trainId": 535}, + {"name": "ice floe, floe", "id": 1295, "trainId": 536}, + {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537}, + {"name": "wheelchair", "id": 3037, "trainId": 538}, + {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539}, + {"name": "diploma", "id": 736, "trainId": 540}, + {"name": "fairground ride", "id": 893, "trainId": 541}, + {"name": "radio", "id": 2047, "trainId": 542}, + {"name": "hotplate", "id": 1274, "trainId": 543}, + {"name": "junk", "id": 1361, "trainId": 544}, + {"name": "wheelbarrow", "id": 3036, "trainId": 545}, + {"name": "stream", "id": 2606, "trainId": 546}, + {"name": "toll plaza", "id": 2797, "trainId": 547}, + {"name": "punching bag", "id": 2022, "trainId": 548}, + {"name": "trough", "id": 2876, "trainId": 549}, + {"name": "throne", "id": 2758, "trainId": 550}, + {"name": "chair desk", "id": 472, "trainId": 551}, + {"name": "weighbridge", "id": 3028, "trainId": 552}, + {"name": "extractor fan", "id": 882, "trainId": 553}, + {"name": "hanging clothes", "id": 1189, "trainId": 554}, + {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555}, + {"name": "alarm clock, alarm", "id": 3122, "trainId": 556}, + {"name": "ski lift", "id": 2401, "trainId": 557}, + {"name": "chain", "id": 468, "trainId": 558}, + {"name": "garage", "id": 1061, "trainId": 559}, + {"name": "mechanical shovel", "id": 1523, "trainId": 560}, + {"name": "wine rack", "id": 3059, "trainId": 561}, + {"name": "tramway", "id": 2843, "trainId": 562}, + {"name": "treadmill", "id": 2853, "trainId": 563}, + {"name": "menu", "id": 1529, "trainId": 564}, + {"name": "block", "id": 214, "trainId": 565}, + {"name": "well", "id": 3032, "trainId": 566}, + {"name": "witness stand", "id": 3071, "trainId": 567}, + {"name": "branch", "id": 277, "trainId": 568}, + {"name": "duck", "id": 813, "trainId": 569}, + {"name": "casserole", "id": 426, "trainId": 570}, + {"name": "frying pan", "id": 1039, "trainId": 571}, + {"name": "desk organizer", "id": 727, "trainId": 572}, + {"name": "mast", "id": 1508, "trainId": 573}, + {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574}, + {"name": "service elevator", "id": 2299, "trainId": 575}, + {"name": "dollhouse", "id": 768, "trainId": 576}, + {"name": "hammock", "id": 1172, "trainId": 577}, + {"name": "clothes hanging", "id": 537, "trainId": 578}, + {"name": "photocopier", "id": 1847, "trainId": 579}, + {"name": "notepad", "id": 1664, "trainId": 580}, + {"name": "golf cart", "id": 1110, "trainId": 581}, + {"name": "footpath", "id": 1014, "trainId": 582}, + {"name": "cross", "id": 662, "trainId": 583}, + {"name": "baptismal font", "id": 121, "trainId": 584}, + {"name": "boiler", "id": 227, "trainId": 585}, + {"name": "skip", "id": 2410, "trainId": 586}, + {"name": "rotisserie", "id": 2165, "trainId": 587}, + {"name": "tables", "id": 2696, "trainId": 588}, + {"name": "water mill", "id": 3005, "trainId": 589}, + {"name": "helmet", "id": 1231, "trainId": 590}, + {"name": "cover curtain", "id": 635, "trainId": 591}, + {"name": "brick", "id": 292, "trainId": 592}, + {"name": "table runner", "id": 2690, "trainId": 593}, + {"name": "ashtray", "id": 65, "trainId": 594}, + {"name": "street box", "id": 2607, "trainId": 595}, + {"name": "stick", "id": 2574, "trainId": 596}, + {"name": "hangers", "id": 1188, "trainId": 597}, + {"name": "cells", "id": 456, "trainId": 598}, + {"name": "urinal", "id": 2913, "trainId": 599}, + {"name": "centerpiece", "id": 459, "trainId": 600}, + {"name": "portable fridge", "id": 1955, "trainId": 601}, + {"name": "dvds", "id": 827, "trainId": 602}, + {"name": "golf club", "id": 1111, "trainId": 603}, + {"name": "skirting board", "id": 2412, "trainId": 604}, + {"name": "water cooler", "id": 2997, "trainId": 605}, + {"name": "clipboard", "id": 528, "trainId": 606}, + {"name": "camera, photographic camera", "id": 366, "trainId": 607}, + {"name": "pigeonhole", "id": 1863, "trainId": 608}, + {"name": "chips", "id": 500, "trainId": 609}, + {"name": "food processor", "id": 1001, "trainId": 610}, + {"name": "post box", "id": 1958, "trainId": 611}, + {"name": "lid", "id": 1441, "trainId": 612}, + {"name": "drum", "id": 809, "trainId": 613}, + {"name": "blender", "id": 210, "trainId": 614}, + {"name": "cave entrance", "id": 435, "trainId": 615}, + {"name": "dental chair", "id": 718, "trainId": 616}, + {"name": "obelisk", "id": 1674, "trainId": 617}, + {"name": "canoe", "id": 388, "trainId": 618}, + {"name": "mobile", "id": 1572, "trainId": 619}, + {"name": "monitors", "id": 1584, "trainId": 620}, + {"name": "pool ball", "id": 1944, "trainId": 621}, + {"name": "cue rack", "id": 674, "trainId": 622}, + {"name": "baggage carts", "id": 99, "trainId": 623}, + {"name": "shore", "id": 2352, "trainId": 624}, + {"name": "fork", "id": 1019, "trainId": 625}, + {"name": "paper filer", "id": 1763, "trainId": 626}, + {"name": "bicycle rack", "id": 185, "trainId": 627}, + {"name": "coat rack", "id": 554, "trainId": 628}, + {"name": "garland", "id": 1066, "trainId": 629}, + {"name": "sports bag", "id": 2508, "trainId": 630}, + {"name": "fish tank", "id": 951, "trainId": 631}, + {"name": "towel dispenser", "id": 2822, "trainId": 632}, + {"name": "carriage", "id": 415, "trainId": 633}, + {"name": "brochure", "id": 297, "trainId": 634}, + {"name": "plaque", "id": 1914, "trainId": 635}, + {"name": "stringer", "id": 2619, "trainId": 636}, + {"name": "iron", "id": 1338, "trainId": 637}, + {"name": "spoon", "id": 2505, "trainId": 638}, + {"name": "flag pole", "id": 955, "trainId": 639}, + {"name": "toilet brush", "id": 2786, "trainId": 640}, + {"name": "book stand", "id": 238, "trainId": 641}, + {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642}, + {"name": "ticket office", "id": 2763, "trainId": 643}, + {"name": "broom", "id": 299, "trainId": 644}, + {"name": "dvd", "id": 822, "trainId": 645}, + {"name": "ice bucket", "id": 1288, "trainId": 646}, + {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647}, + {"name": "tureen", "id": 2894, "trainId": 648}, + {"name": "folders", "id": 992, "trainId": 649}, + {"name": "chess", "id": 489, "trainId": 650}, + {"name": "root", "id": 2157, "trainId": 651}, + {"name": "sewing machine", "id": 2309, "trainId": 652}, + {"name": "model", "id": 1576, "trainId": 653}, + {"name": "pen", "id": 1810, "trainId": 654}, + {"name": "violin", "id": 2964, "trainId": 655}, + {"name": "sweatshirt", "id": 2662, "trainId": 656}, + {"name": "recycling materials", "id": 2087, "trainId": 657}, + {"name": "mitten", "id": 1569, "trainId": 658}, + {"name": "chopping board, cutting board", "id": 503, "trainId": 659}, + {"name": "mask", "id": 1505, "trainId": 660}, + {"name": "log", "id": 1468, "trainId": 661}, + {"name": "mouse, computer mouse", "id": 1613, "trainId": 662}, + {"name": "grill", "id": 1138, "trainId": 663}, + {"name": "hole", "id": 1256, "trainId": 664}, + {"name": "target", "id": 2715, "trainId": 665}, + {"name": "trash bag", "id": 2846, "trainId": 666}, + {"name": "chalk", "id": 477, "trainId": 667}, + {"name": "sticks", "id": 2576, "trainId": 668}, + {"name": "balloon", "id": 108, "trainId": 669}, + {"name": "score", "id": 2245, "trainId": 670}, + {"name": "hair spray", "id": 1162, "trainId": 671}, + {"name": "roll", "id": 2149, "trainId": 672}, + {"name": "runner", "id": 2183, "trainId": 673}, + {"name": "engine", "id": 858, "trainId": 674}, + {"name": "inflatable glove", "id": 1324, "trainId": 675}, + {"name": "games", "id": 1055, "trainId": 676}, + {"name": "pallets", "id": 1741, "trainId": 677}, + {"name": "baskets", "id": 149, "trainId": 678}, + {"name": "coop", "id": 615, "trainId": 679}, + {"name": "dvd player", "id": 825, "trainId": 680}, + {"name": "rocking horse", "id": 2143, "trainId": 681}, + {"name": "buckets", "id": 304, "trainId": 682}, + {"name": "bread rolls", "id": 283, "trainId": 683}, + {"name": "shawl", "id": 2322, "trainId": 684}, + {"name": "watering can", "id": 3017, "trainId": 685}, + {"name": "spotlights", "id": 2510, "trainId": 686}, + {"name": "post-it", "id": 1960, "trainId": 687}, + {"name": "bowls", "id": 265, "trainId": 688}, + {"name": "security camera", "id": 2282, "trainId": 689}, + {"name": "runner cloth", "id": 2184, "trainId": 690}, + {"name": "lock", "id": 1461, "trainId": 691}, + {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692}, + {"name": "side", "id": 2372, "trainId": 693}, + {"name": "roulette", "id": 2166, "trainId": 694}, + {"name": "bone", "id": 232, "trainId": 695}, + {"name": "cutlery", "id": 693, "trainId": 696}, + {"name": "pool balls", "id": 1945, "trainId": 697}, + {"name": "wheels", "id": 3039, "trainId": 698}, + {"name": "spice rack", "id": 2494, "trainId": 699}, + {"name": "plant pots", "id": 1908, "trainId": 700}, + {"name": "towel ring", "id": 2827, "trainId": 701}, + {"name": "bread box", "id": 280, "trainId": 702}, + {"name": "video", "id": 2950, "trainId": 703}, + {"name": "funfair", "id": 1044, "trainId": 704}, + {"name": "breads", "id": 288, "trainId": 705}, + {"name": "tripod", "id": 2863, "trainId": 706}, + {"name": "ironing board", "id": 1342, "trainId": 707}, + {"name": "skimmer", "id": 2409, "trainId": 708}, + {"name": "hollow", "id": 1258, "trainId": 709}, + {"name": "scratching post", "id": 2249, "trainId": 710}, + {"name": "tricycle", "id": 2862, "trainId": 711}, + {"name": "file box", "id": 920, "trainId": 712}, + {"name": "mountain pass", "id": 1607, "trainId": 713}, + {"name": "tombstones", "id": 2802, "trainId": 714}, + {"name": "cooker", "id": 610, "trainId": 715}, + {"name": "card game, cards", "id": 3129, "trainId": 716}, + {"name": "golf bag", "id": 1108, "trainId": 717}, + {"name": "towel paper", "id": 2823, "trainId": 718}, + {"name": "chaise lounge", "id": 476, "trainId": 719}, + {"name": "sun", "id": 2641, "trainId": 720}, + {"name": "toilet paper holder", "id": 2788, "trainId": 721}, + {"name": "rake", "id": 2070, "trainId": 722}, + {"name": "key", "id": 1368, "trainId": 723}, + {"name": "umbrella stand", "id": 2903, "trainId": 724}, + {"name": "dartboard", "id": 699, "trainId": 725}, + {"name": "transformer", "id": 2844, "trainId": 726}, + {"name": "fireplace utensils", "id": 942, "trainId": 727}, + {"name": "sweatshirts", "id": 2663, "trainId": 728}, + { + "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "id": 457, + "trainId": 729, + }, + {"name": "tallboy", "id": 2701, "trainId": 730}, + {"name": "stapler", "id": 2540, "trainId": 731}, + {"name": "sauna", "id": 2231, "trainId": 732}, + {"name": "test tube", "id": 2746, "trainId": 733}, + {"name": "palette", "id": 1738, "trainId": 734}, + {"name": "shopping carts", "id": 2350, "trainId": 735}, + {"name": "tools", "id": 2808, "trainId": 736}, + {"name": "push button, push, button", "id": 2025, "trainId": 737}, + {"name": "star", "id": 2541, "trainId": 738}, + {"name": "roof rack", "id": 2156, "trainId": 739}, + {"name": "barbed wire", "id": 126, "trainId": 740}, + {"name": "spray", "id": 2512, "trainId": 741}, + {"name": "ear", "id": 831, "trainId": 742}, + {"name": "sponge", "id": 2503, "trainId": 743}, + {"name": "racket", "id": 2039, "trainId": 744}, + {"name": "tins", "id": 2774, "trainId": 745}, + {"name": "eyeglasses", "id": 886, "trainId": 746}, + {"name": "file", "id": 919, "trainId": 747}, + {"name": "scarfs", "id": 2240, "trainId": 748}, + {"name": "sugar bowl", "id": 2636, "trainId": 749}, + {"name": "flip flop", "id": 963, "trainId": 750}, + {"name": "headstones", "id": 1218, "trainId": 751}, + {"name": "laptop bag", "id": 1406, "trainId": 752}, + {"name": "leash", "id": 1420, "trainId": 753}, + {"name": "climbing frame", "id": 526, "trainId": 754}, + {"name": "suit hanger", "id": 2639, "trainId": 755}, + {"name": "floor spotlight", "id": 975, "trainId": 756}, + {"name": "plate rack", "id": 1921, "trainId": 757}, + {"name": "sewer", "id": 2305, "trainId": 758}, + {"name": "hard drive", "id": 1193, "trainId": 759}, + {"name": "sprinkler", "id": 2517, "trainId": 760}, + {"name": "tools box", "id": 2809, "trainId": 761}, + {"name": "necklace", "id": 1647, "trainId": 762}, + {"name": "bulbs", "id": 314, "trainId": 763}, + {"name": "steel industry", "id": 2560, "trainId": 764}, + {"name": "club", "id": 545, "trainId": 765}, + {"name": "jack", "id": 1345, "trainId": 766}, + {"name": "door bars", "id": 775, "trainId": 767}, + { + "name": "control panel, instrument panel, control board, board, panel", + "id": 603, + "trainId": 768, + }, + {"name": "hairbrush", "id": 1163, "trainId": 769}, + {"name": "napkin holder", "id": 1641, "trainId": 770}, + {"name": "office", "id": 1678, "trainId": 771}, + {"name": "smoke detector", "id": 2450, "trainId": 772}, + {"name": "utensils", "id": 2915, "trainId": 773}, + {"name": "apron", "id": 42, "trainId": 774}, + {"name": "scissors", "id": 2242, "trainId": 775}, + {"name": "terminal", "id": 2741, "trainId": 776}, + {"name": "grinder", "id": 1143, "trainId": 777}, + {"name": "entry phone", "id": 862, "trainId": 778}, + {"name": "newspaper stand", "id": 1654, "trainId": 779}, + {"name": "pepper shaker", "id": 1826, "trainId": 780}, + {"name": "onions", "id": 1689, "trainId": 781}, + { + "name": "central processing unit, cpu, c p u , central processor, processor, mainframe", + "id": 3124, + "trainId": 782, + }, + {"name": "tape", "id": 2710, "trainId": 783}, + {"name": "bat", "id": 152, "trainId": 784}, + {"name": "coaster", "id": 549, "trainId": 785}, + {"name": "calculator", "id": 360, "trainId": 786}, + {"name": "potatoes", "id": 1982, "trainId": 787}, + {"name": "luggage rack", "id": 1478, "trainId": 788}, + {"name": "salt", "id": 2203, "trainId": 789}, + {"name": "street number", "id": 2612, "trainId": 790}, + {"name": "viewpoint", "id": 2956, "trainId": 791}, + {"name": "sword", "id": 2681, "trainId": 792}, + {"name": "cd", "id": 437, "trainId": 793}, + {"name": "rowing machine", "id": 2171, "trainId": 794}, + {"name": "plug", "id": 1933, "trainId": 795}, + {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796}, + {"name": "pepper", "id": 1824, "trainId": 797}, + {"name": "tongs", "id": 2803, "trainId": 798}, + {"name": "bonfire", "id": 234, "trainId": 799}, + {"name": "dog dish", "id": 764, "trainId": 800}, + {"name": "belt", "id": 177, "trainId": 801}, + {"name": "dumbbells", "id": 817, "trainId": 802}, + {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803}, + {"name": "hook", "id": 1262, "trainId": 804}, + {"name": "envelopes", "id": 864, "trainId": 805}, + {"name": "shower faucet", "id": 2359, "trainId": 806}, + {"name": "watch", "id": 2992, "trainId": 807}, + {"name": "padlock", "id": 1725, "trainId": 808}, + {"name": "swimming pool ladder", "id": 2667, "trainId": 809}, + {"name": "spanners", "id": 2484, "trainId": 810}, + {"name": "gravy boat", "id": 1133, "trainId": 811}, + {"name": "notice board", "id": 1667, "trainId": 812}, + {"name": "trash bags", "id": 2847, "trainId": 813}, + {"name": "fire alarm", "id": 932, "trainId": 814}, + {"name": "ladle", "id": 1392, "trainId": 815}, + {"name": "stethoscope", "id": 2573, "trainId": 816}, + {"name": "rocket", "id": 2140, "trainId": 817}, + {"name": "funnel", "id": 1046, "trainId": 818}, + {"name": "bowling pins", "id": 264, "trainId": 819}, + {"name": "valve", "id": 2927, "trainId": 820}, + {"name": "thermometer", "id": 2752, "trainId": 821}, + {"name": "cups", "id": 679, "trainId": 822}, + {"name": "spice jar", "id": 2493, "trainId": 823}, + {"name": "night light", "id": 1658, "trainId": 824}, + {"name": "soaps", "id": 2466, "trainId": 825}, + {"name": "games table", "id": 1057, "trainId": 826}, + {"name": "slotted spoon", "id": 2444, "trainId": 827}, + {"name": "reel", "id": 2093, "trainId": 828}, + {"name": "scourer", "id": 2248, "trainId": 829}, + {"name": "sleeping robe", "id": 2432, "trainId": 830}, + {"name": "desk mat", "id": 726, "trainId": 831}, + {"name": "dumbbell", "id": 816, "trainId": 832}, + {"name": "hammer", "id": 1171, "trainId": 833}, + {"name": "tie", "id": 2766, "trainId": 834}, + {"name": "typewriter", "id": 2900, "trainId": 835}, + {"name": "shaker", "id": 2313, "trainId": 836}, + {"name": "cheese dish", "id": 488, "trainId": 837}, + {"name": "sea star", "id": 2265, "trainId": 838}, + {"name": "racquet", "id": 2043, "trainId": 839}, + {"name": "butane gas cylinder", "id": 332, "trainId": 840}, + {"name": "paper weight", "id": 1771, "trainId": 841}, + {"name": "shaving brush", "id": 2320, "trainId": 842}, + {"name": "sunglasses", "id": 2646, "trainId": 843}, + {"name": "gear shift", "id": 1089, "trainId": 844}, + {"name": "towel rail", "id": 2826, "trainId": 845}, + {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846}, +] + + +def _get_ade20k_full_meta(): + # Id 0 is reserved for ignore_label, we change ignore_label for 0 + # to 255 in our pre-processing, so all ids are shifted by 1. + stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES] + assert len(stuff_ids) == 847, len(stuff_ids) + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 91], used in models) to ids in the dataset (used for processing results) + stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} + stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES] + + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + } + return ret + + +def register_all_ade20k_full(root): + root = os.path.join(root, "ADE20K_2021_17_01") + meta = _get_ade20k_full_meta() + for name, dirname in [("train", "training"), ("val", "validation")]: + image_dir = os.path.join(root, "images_detectron2", dirname) + gt_dir = os.path.join(root, "annotations_detectron2", dirname) + name = f"ade20k_full_sem_seg_{name}" + DatasetCatalog.register( + name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="tif", image_ext="jpg") + ) + MetadataCatalog.get(name).set( + stuff_classes=meta["stuff_classes"][:], + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_ade20k_full(_root) diff --git a/mask_former/data/datasets/register_ade20k_panoptic.py b/mask_former/data/datasets/register_ade20k_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce9888663e69c28d959443d87fadf31a5be3547 --- /dev/null +++ b/mask_former/data/datasets/register_ade20k_panoptic.py @@ -0,0 +1,387 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.utils.file_io import PathManager + +ADE20K_150_CATEGORIES = [ + {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"}, + {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"}, + {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"}, + {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"}, + {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"}, + {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"}, + {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"}, + {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"}, + {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "}, + {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"}, + {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"}, + {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"}, + {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"}, + {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"}, + {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"}, + {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"}, + {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"}, + {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"}, + {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"}, + {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"}, + {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"}, + {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"}, + {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"}, + {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"}, + {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"}, + {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"}, + {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"}, + {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"}, + {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"}, + {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"}, + {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"}, + {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"}, + {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"}, + {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"}, + {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"}, + {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"}, + {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"}, + {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"}, + {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"}, + {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"}, + {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"}, + {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"}, + {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"}, + {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"}, + { + "color": [6, 51, 255], + "id": 44, + "isthing": 1, + "name": "chest of drawers, chest, bureau, dresser", + }, + {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"}, + {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"}, + {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"}, + {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"}, + {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"}, + {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"}, + {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"}, + {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"}, + {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"}, + {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"}, + {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"}, + { + "color": [255, 71, 0], + "id": 56, + "isthing": 1, + "name": "pool table, billiard table, snooker table", + }, + {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"}, + {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"}, + {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"}, + {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"}, + {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"}, + {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"}, + {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"}, + {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"}, + { + "color": [0, 255, 133], + "id": 65, + "isthing": 1, + "name": "toilet, can, commode, crapper, pot, potty, stool, throne", + }, + {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"}, + {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"}, + {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"}, + {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"}, + {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"}, + {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"}, + {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"}, + {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"}, + {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"}, + {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"}, + {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"}, + {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"}, + {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"}, + {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"}, + {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"}, + {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"}, + {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"}, + {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"}, + {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"}, + {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"}, + {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"}, + {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"}, + {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"}, + {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"}, + {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"}, + {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"}, + {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"}, + {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"}, + {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"}, + { + "color": [0, 122, 255], + "id": 95, + "isthing": 1, + "name": "bannister, banister, balustrade, balusters, handrail", + }, + { + "color": [0, 255, 163], + "id": 96, + "isthing": 0, + "name": "escalator, moving staircase, moving stairway", + }, + { + "color": [255, 153, 0], + "id": 97, + "isthing": 1, + "name": "ottoman, pouf, pouffe, puff, hassock", + }, + {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"}, + {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"}, + { + "color": [143, 255, 0], + "id": 100, + "isthing": 0, + "name": "poster, posting, placard, notice, bill, card", + }, + {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"}, + {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"}, + {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"}, + {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"}, + { + "color": [133, 0, 255], + "id": 105, + "isthing": 0, + "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", + }, + {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"}, + { + "color": [184, 0, 255], + "id": 107, + "isthing": 1, + "name": "washer, automatic washer, washing machine", + }, + {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"}, + {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"}, + {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"}, + {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"}, + {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"}, + {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"}, + {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"}, + {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"}, + {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"}, + {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"}, + {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"}, + {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"}, + {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"}, + {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"}, + {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"}, + {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"}, + {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"}, + {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"}, + {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"}, + {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"}, + {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"}, + {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"}, + {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"}, + {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"}, + {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"}, + {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"}, + {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"}, + {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"}, + {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"}, + {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"}, + {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"}, + {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"}, + {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"}, + {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"}, + {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"}, + {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"}, + {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"}, + {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"}, + {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"}, + {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"}, + {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"}, + {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"}, +] + +ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES] + +MetadataCatalog.get("ade20k_sem_seg_train").set( + stuff_colors=ADE20k_COLORS[:], +) + +MetadataCatalog.get("ade20k_sem_seg_val").set( + stuff_colors=ADE20k_COLORS[:], +) + + +def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". + gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". + json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + + def _convert_category_id(segment_info, meta): + if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: + segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = True + else: + segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = False + return segment_info + + with PathManager.open(json_file) as f: + json_info = json.load(f) + + ret = [] + for ann in json_info["annotations"]: + image_id = ann["image_id"] + # TODO: currently we assume image and label has the same filename but + # different extension, and images have extension ".jpg" for COCO. Need + # to make image extension a user-provided argument if we extend this + # function to support other COCO-like datasets. + image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg") + label_file = os.path.join(gt_dir, ann["file_name"]) + sem_label_file = os.path.join(semseg_dir, ann["file_name"]) + segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]] + ret.append( + { + "file_name": image_file, + "image_id": image_id, + "pan_seg_file_name": label_file, + "sem_seg_file_name": sem_label_file, + "segments_info": segments_info, + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] + assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"] + assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"] + return ret + + +def register_ade20k_panoptic( + name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None +): + """ + Register a "standard" version of ADE20k panoptic segmentation dataset named `name`. + The dictionaries in this registered dataset follows detectron2's standard format. + Hence it's called "standard". + Args: + name (str): the name that identifies a dataset, + e.g. "ade20k_panoptic_train" + metadata (dict): extra metadata associated with this dataset. + image_root (str): directory which contains all the images + panoptic_root (str): directory which contains panoptic annotation images in COCO format + panoptic_json (str): path to the json panoptic annotation file in COCO format + sem_seg_root (none): not used, to be consistent with + `register_coco_panoptic_separated`. + instances_json (str): path to the json instance annotation file + """ + panoptic_name = name + DatasetCatalog.register( + panoptic_name, + lambda: load_ade20k_panoptic_json( + panoptic_json, image_root, panoptic_root, semantic_root, metadata + ), + ) + MetadataCatalog.get(panoptic_name).set( + panoptic_root=panoptic_root, + image_root=image_root, + panoptic_json=panoptic_json, + json_file=instances_json, + evaluator_type="ade20k_panoptic_seg", + ignore_label=255, + label_divisor=1000, + **metadata, + ) + + +_PREDEFINED_SPLITS_ADE20K_PANOPTIC = { + "ade20k_panoptic_train": ( + "ADEChallengeData2016/images/training", + "ADEChallengeData2016/ade20k_panoptic_train", + "ADEChallengeData2016/ade20k_panoptic_train.json", + "ADEChallengeData2016/annotations_detectron2/training", + ), + "ade20k_panoptic_val": ( + "ADEChallengeData2016/images/validation", + "ADEChallengeData2016/ade20k_panoptic_val", + "ADEChallengeData2016/ade20k_panoptic_val.json", + "ADEChallengeData2016/annotations_detectron2/validation", + ), +} + + +def get_metadata(): + meta = {} + # The following metadata maps contiguous id from [0, #thing categories + + # #stuff categories) to their names and colors. We have to replica of the + # same name and color under "thing_*" and "stuff_*" because the current + # visualization function in D2 handles thing and class classes differently + # due to some heuristic used in Panoptic FPN. We keep the same naming to + # enable reusing existing visualization functions. + thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES] + thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES] + stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES] + stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES] + + meta["thing_classes"] = thing_classes + meta["thing_colors"] = thing_colors + meta["stuff_classes"] = stuff_classes + meta["stuff_colors"] = stuff_colors + + # Convert category id for training: + # category id: like semantic segmentation, it is the class id for each + # pixel. Since there are some classes not used in evaluation, the category + # id is not always contiguous and thus we have two set of category ids: + # - original category id: category id in the original dataset, mainly + # used for evaluation. + # - contiguous category id: [0, #classes), in order to train the linear + # softmax classifier. + thing_dataset_id_to_contiguous_id = {} + stuff_dataset_id_to_contiguous_id = {} + + for i, cat in enumerate(ADE20K_150_CATEGORIES): + if cat["isthing"]: + thing_dataset_id_to_contiguous_id[cat["id"]] = i + # else: + # stuff_dataset_id_to_contiguous_id[cat["id"]] = i + + # in order to use sem_seg evaluator + stuff_dataset_id_to_contiguous_id[cat["id"]] = i + + meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id + meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id + + return meta + + +def register_all_ade20k_panoptic(root): + metadata = get_metadata() + for ( + prefix, + (image_root, panoptic_root, panoptic_json, semantic_root), + ) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items(): + # The "standard" version of COCO panoptic segmentation dataset, + # e.g. used by Panoptic-DeepLab + register_ade20k_panoptic( + prefix, + metadata, + os.path.join(root, image_root), + os.path.join(root, panoptic_root), + os.path.join(root, semantic_root), + os.path.join(root, panoptic_json), + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_ade20k_panoptic(_root) diff --git a/mask_former/data/datasets/register_coco_stuff_10k.py b/mask_former/data/datasets/register_coco_stuff_10k.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ec0375858ada8e4270b534fcd58106254c7fa9 --- /dev/null +++ b/mask_former/data/datasets/register_coco_stuff_10k.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"id": 92, "name": "banner", "supercategory": "textile"}, + {"id": 93, "name": "blanket", "supercategory": "textile"}, + {"id": 94, "name": "branch", "supercategory": "plant"}, + {"id": 95, "name": "bridge", "supercategory": "building"}, + {"id": 96, "name": "building-other", "supercategory": "building"}, + {"id": 97, "name": "bush", "supercategory": "plant"}, + {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"}, + {"id": 99, "name": "cage", "supercategory": "structural"}, + {"id": 100, "name": "cardboard", "supercategory": "raw-material"}, + {"id": 101, "name": "carpet", "supercategory": "floor"}, + {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"}, + {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"}, + {"id": 104, "name": "cloth", "supercategory": "textile"}, + {"id": 105, "name": "clothes", "supercategory": "textile"}, + {"id": 106, "name": "clouds", "supercategory": "sky"}, + {"id": 107, "name": "counter", "supercategory": "furniture-stuff"}, + {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"}, + {"id": 109, "name": "curtain", "supercategory": "textile"}, + {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"}, + {"id": 111, "name": "dirt", "supercategory": "ground"}, + {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"}, + {"id": 113, "name": "fence", "supercategory": "structural"}, + {"id": 114, "name": "floor-marble", "supercategory": "floor"}, + {"id": 115, "name": "floor-other", "supercategory": "floor"}, + {"id": 116, "name": "floor-stone", "supercategory": "floor"}, + {"id": 117, "name": "floor-tile", "supercategory": "floor"}, + {"id": 118, "name": "floor-wood", "supercategory": "floor"}, + {"id": 119, "name": "flower", "supercategory": "plant"}, + {"id": 120, "name": "fog", "supercategory": "water"}, + {"id": 121, "name": "food-other", "supercategory": "food-stuff"}, + {"id": 122, "name": "fruit", "supercategory": "food-stuff"}, + {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"}, + {"id": 124, "name": "grass", "supercategory": "plant"}, + {"id": 125, "name": "gravel", "supercategory": "ground"}, + {"id": 126, "name": "ground-other", "supercategory": "ground"}, + {"id": 127, "name": "hill", "supercategory": "solid"}, + {"id": 128, "name": "house", "supercategory": "building"}, + {"id": 129, "name": "leaves", "supercategory": "plant"}, + {"id": 130, "name": "light", "supercategory": "furniture-stuff"}, + {"id": 131, "name": "mat", "supercategory": "textile"}, + {"id": 132, "name": "metal", "supercategory": "raw-material"}, + {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"}, + {"id": 134, "name": "moss", "supercategory": "plant"}, + {"id": 135, "name": "mountain", "supercategory": "solid"}, + {"id": 136, "name": "mud", "supercategory": "ground"}, + {"id": 137, "name": "napkin", "supercategory": "textile"}, + {"id": 138, "name": "net", "supercategory": "structural"}, + {"id": 139, "name": "paper", "supercategory": "raw-material"}, + {"id": 140, "name": "pavement", "supercategory": "ground"}, + {"id": 141, "name": "pillow", "supercategory": "textile"}, + {"id": 142, "name": "plant-other", "supercategory": "plant"}, + {"id": 143, "name": "plastic", "supercategory": "raw-material"}, + {"id": 144, "name": "platform", "supercategory": "ground"}, + {"id": 145, "name": "playingfield", "supercategory": "ground"}, + {"id": 146, "name": "railing", "supercategory": "structural"}, + {"id": 147, "name": "railroad", "supercategory": "ground"}, + {"id": 148, "name": "river", "supercategory": "water"}, + {"id": 149, "name": "road", "supercategory": "ground"}, + {"id": 150, "name": "rock", "supercategory": "solid"}, + {"id": 151, "name": "roof", "supercategory": "building"}, + {"id": 152, "name": "rug", "supercategory": "textile"}, + {"id": 153, "name": "salad", "supercategory": "food-stuff"}, + {"id": 154, "name": "sand", "supercategory": "ground"}, + {"id": 155, "name": "sea", "supercategory": "water"}, + {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"}, + {"id": 157, "name": "sky-other", "supercategory": "sky"}, + {"id": 158, "name": "skyscraper", "supercategory": "building"}, + {"id": 159, "name": "snow", "supercategory": "ground"}, + {"id": 160, "name": "solid-other", "supercategory": "solid"}, + {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"}, + {"id": 162, "name": "stone", "supercategory": "solid"}, + {"id": 163, "name": "straw", "supercategory": "plant"}, + {"id": 164, "name": "structural-other", "supercategory": "structural"}, + {"id": 165, "name": "table", "supercategory": "furniture-stuff"}, + {"id": 166, "name": "tent", "supercategory": "building"}, + {"id": 167, "name": "textile-other", "supercategory": "textile"}, + {"id": 168, "name": "towel", "supercategory": "textile"}, + {"id": 169, "name": "tree", "supercategory": "plant"}, + {"id": 170, "name": "vegetable", "supercategory": "food-stuff"}, + {"id": 171, "name": "wall-brick", "supercategory": "wall"}, + {"id": 172, "name": "wall-concrete", "supercategory": "wall"}, + {"id": 173, "name": "wall-other", "supercategory": "wall"}, + {"id": 174, "name": "wall-panel", "supercategory": "wall"}, + {"id": 175, "name": "wall-stone", "supercategory": "wall"}, + {"id": 176, "name": "wall-tile", "supercategory": "wall"}, + {"id": 177, "name": "wall-wood", "supercategory": "wall"}, + {"id": 178, "name": "water-other", "supercategory": "water"}, + {"id": 179, "name": "waterdrops", "supercategory": "water"}, + {"id": 180, "name": "window-blind", "supercategory": "window"}, + {"id": 181, "name": "window-other", "supercategory": "window"}, + {"id": 182, "name": "wood", "supercategory": "solid"}, +] + + +def _get_coco_stuff_meta(): + # Id 0 is reserved for ignore_label, we change ignore_label for 0 + # to 255 in our pre-processing. + stuff_ids = [k["id"] for k in COCO_CATEGORIES] + assert len(stuff_ids) == 171, len(stuff_ids) + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 91], used in models) to ids in the dataset (used for processing results) + stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} + stuff_classes = [k["name"] for k in COCO_CATEGORIES] + + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + } + return ret + + +def register_all_coco_stuff_10k(root): + root = os.path.join(root, "coco", "coco_stuff_10k") + meta = _get_coco_stuff_meta() + for name, image_dirname, sem_seg_dirname in [ + ("train", "images_detectron2/train", "annotations_detectron2/train"), + ("test", "images_detectron2/test", "annotations_detectron2/test"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + name = f"coco_2017_{name}_stuff_10k_sem_seg" + DatasetCatalog.register( + name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") + ) + MetadataCatalog.get(name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + **meta, + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_coco_stuff_10k(_root) diff --git a/mask_former/data/datasets/register_mapillary_vistas.py b/mask_former/data/datasets/register_mapillary_vistas.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3874b65d943c333d093abd6998500f8a3775f5 --- /dev/null +++ b/mask_former/data/datasets/register_mapillary_vistas.py @@ -0,0 +1,507 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +MAPILLARY_VISTAS_SEM_SEG_CATEGORIES = [ + { + "color": [165, 42, 42], + "instances": True, + "readable": "Bird", + "name": "animal--bird", + "evaluate": True, + }, + { + "color": [0, 192, 0], + "instances": True, + "readable": "Ground Animal", + "name": "animal--ground-animal", + "evaluate": True, + }, + { + "color": [196, 196, 196], + "instances": False, + "readable": "Curb", + "name": "construction--barrier--curb", + "evaluate": True, + }, + { + "color": [190, 153, 153], + "instances": False, + "readable": "Fence", + "name": "construction--barrier--fence", + "evaluate": True, + }, + { + "color": [180, 165, 180], + "instances": False, + "readable": "Guard Rail", + "name": "construction--barrier--guard-rail", + "evaluate": True, + }, + { + "color": [90, 120, 150], + "instances": False, + "readable": "Barrier", + "name": "construction--barrier--other-barrier", + "evaluate": True, + }, + { + "color": [102, 102, 156], + "instances": False, + "readable": "Wall", + "name": "construction--barrier--wall", + "evaluate": True, + }, + { + "color": [128, 64, 255], + "instances": False, + "readable": "Bike Lane", + "name": "construction--flat--bike-lane", + "evaluate": True, + }, + { + "color": [140, 140, 200], + "instances": True, + "readable": "Crosswalk - Plain", + "name": "construction--flat--crosswalk-plain", + "evaluate": True, + }, + { + "color": [170, 170, 170], + "instances": False, + "readable": "Curb Cut", + "name": "construction--flat--curb-cut", + "evaluate": True, + }, + { + "color": [250, 170, 160], + "instances": False, + "readable": "Parking", + "name": "construction--flat--parking", + "evaluate": True, + }, + { + "color": [96, 96, 96], + "instances": False, + "readable": "Pedestrian Area", + "name": "construction--flat--pedestrian-area", + "evaluate": True, + }, + { + "color": [230, 150, 140], + "instances": False, + "readable": "Rail Track", + "name": "construction--flat--rail-track", + "evaluate": True, + }, + { + "color": [128, 64, 128], + "instances": False, + "readable": "Road", + "name": "construction--flat--road", + "evaluate": True, + }, + { + "color": [110, 110, 110], + "instances": False, + "readable": "Service Lane", + "name": "construction--flat--service-lane", + "evaluate": True, + }, + { + "color": [244, 35, 232], + "instances": False, + "readable": "Sidewalk", + "name": "construction--flat--sidewalk", + "evaluate": True, + }, + { + "color": [150, 100, 100], + "instances": False, + "readable": "Bridge", + "name": "construction--structure--bridge", + "evaluate": True, + }, + { + "color": [70, 70, 70], + "instances": False, + "readable": "Building", + "name": "construction--structure--building", + "evaluate": True, + }, + { + "color": [150, 120, 90], + "instances": False, + "readable": "Tunnel", + "name": "construction--structure--tunnel", + "evaluate": True, + }, + { + "color": [220, 20, 60], + "instances": True, + "readable": "Person", + "name": "human--person", + "evaluate": True, + }, + { + "color": [255, 0, 0], + "instances": True, + "readable": "Bicyclist", + "name": "human--rider--bicyclist", + "evaluate": True, + }, + { + "color": [255, 0, 100], + "instances": True, + "readable": "Motorcyclist", + "name": "human--rider--motorcyclist", + "evaluate": True, + }, + { + "color": [255, 0, 200], + "instances": True, + "readable": "Other Rider", + "name": "human--rider--other-rider", + "evaluate": True, + }, + { + "color": [200, 128, 128], + "instances": True, + "readable": "Lane Marking - Crosswalk", + "name": "marking--crosswalk-zebra", + "evaluate": True, + }, + { + "color": [255, 255, 255], + "instances": False, + "readable": "Lane Marking - General", + "name": "marking--general", + "evaluate": True, + }, + { + "color": [64, 170, 64], + "instances": False, + "readable": "Mountain", + "name": "nature--mountain", + "evaluate": True, + }, + { + "color": [230, 160, 50], + "instances": False, + "readable": "Sand", + "name": "nature--sand", + "evaluate": True, + }, + { + "color": [70, 130, 180], + "instances": False, + "readable": "Sky", + "name": "nature--sky", + "evaluate": True, + }, + { + "color": [190, 255, 255], + "instances": False, + "readable": "Snow", + "name": "nature--snow", + "evaluate": True, + }, + { + "color": [152, 251, 152], + "instances": False, + "readable": "Terrain", + "name": "nature--terrain", + "evaluate": True, + }, + { + "color": [107, 142, 35], + "instances": False, + "readable": "Vegetation", + "name": "nature--vegetation", + "evaluate": True, + }, + { + "color": [0, 170, 30], + "instances": False, + "readable": "Water", + "name": "nature--water", + "evaluate": True, + }, + { + "color": [255, 255, 128], + "instances": True, + "readable": "Banner", + "name": "object--banner", + "evaluate": True, + }, + { + "color": [250, 0, 30], + "instances": True, + "readable": "Bench", + "name": "object--bench", + "evaluate": True, + }, + { + "color": [100, 140, 180], + "instances": True, + "readable": "Bike Rack", + "name": "object--bike-rack", + "evaluate": True, + }, + { + "color": [220, 220, 220], + "instances": True, + "readable": "Billboard", + "name": "object--billboard", + "evaluate": True, + }, + { + "color": [220, 128, 128], + "instances": True, + "readable": "Catch Basin", + "name": "object--catch-basin", + "evaluate": True, + }, + { + "color": [222, 40, 40], + "instances": True, + "readable": "CCTV Camera", + "name": "object--cctv-camera", + "evaluate": True, + }, + { + "color": [100, 170, 30], + "instances": True, + "readable": "Fire Hydrant", + "name": "object--fire-hydrant", + "evaluate": True, + }, + { + "color": [40, 40, 40], + "instances": True, + "readable": "Junction Box", + "name": "object--junction-box", + "evaluate": True, + }, + { + "color": [33, 33, 33], + "instances": True, + "readable": "Mailbox", + "name": "object--mailbox", + "evaluate": True, + }, + { + "color": [100, 128, 160], + "instances": True, + "readable": "Manhole", + "name": "object--manhole", + "evaluate": True, + }, + { + "color": [142, 0, 0], + "instances": True, + "readable": "Phone Booth", + "name": "object--phone-booth", + "evaluate": True, + }, + { + "color": [70, 100, 150], + "instances": False, + "readable": "Pothole", + "name": "object--pothole", + "evaluate": True, + }, + { + "color": [210, 170, 100], + "instances": True, + "readable": "Street Light", + "name": "object--street-light", + "evaluate": True, + }, + { + "color": [153, 153, 153], + "instances": True, + "readable": "Pole", + "name": "object--support--pole", + "evaluate": True, + }, + { + "color": [128, 128, 128], + "instances": True, + "readable": "Traffic Sign Frame", + "name": "object--support--traffic-sign-frame", + "evaluate": True, + }, + { + "color": [0, 0, 80], + "instances": True, + "readable": "Utility Pole", + "name": "object--support--utility-pole", + "evaluate": True, + }, + { + "color": [250, 170, 30], + "instances": True, + "readable": "Traffic Light", + "name": "object--traffic-light", + "evaluate": True, + }, + { + "color": [192, 192, 192], + "instances": True, + "readable": "Traffic Sign (Back)", + "name": "object--traffic-sign--back", + "evaluate": True, + }, + { + "color": [220, 220, 0], + "instances": True, + "readable": "Traffic Sign (Front)", + "name": "object--traffic-sign--front", + "evaluate": True, + }, + { + "color": [140, 140, 20], + "instances": True, + "readable": "Trash Can", + "name": "object--trash-can", + "evaluate": True, + }, + { + "color": [119, 11, 32], + "instances": True, + "readable": "Bicycle", + "name": "object--vehicle--bicycle", + "evaluate": True, + }, + { + "color": [150, 0, 255], + "instances": True, + "readable": "Boat", + "name": "object--vehicle--boat", + "evaluate": True, + }, + { + "color": [0, 60, 100], + "instances": True, + "readable": "Bus", + "name": "object--vehicle--bus", + "evaluate": True, + }, + { + "color": [0, 0, 142], + "instances": True, + "readable": "Car", + "name": "object--vehicle--car", + "evaluate": True, + }, + { + "color": [0, 0, 90], + "instances": True, + "readable": "Caravan", + "name": "object--vehicle--caravan", + "evaluate": True, + }, + { + "color": [0, 0, 230], + "instances": True, + "readable": "Motorcycle", + "name": "object--vehicle--motorcycle", + "evaluate": True, + }, + { + "color": [0, 80, 100], + "instances": False, + "readable": "On Rails", + "name": "object--vehicle--on-rails", + "evaluate": True, + }, + { + "color": [128, 64, 64], + "instances": True, + "readable": "Other Vehicle", + "name": "object--vehicle--other-vehicle", + "evaluate": True, + }, + { + "color": [0, 0, 110], + "instances": True, + "readable": "Trailer", + "name": "object--vehicle--trailer", + "evaluate": True, + }, + { + "color": [0, 0, 70], + "instances": True, + "readable": "Truck", + "name": "object--vehicle--truck", + "evaluate": True, + }, + { + "color": [0, 0, 192], + "instances": True, + "readable": "Wheeled Slow", + "name": "object--vehicle--wheeled-slow", + "evaluate": True, + }, + { + "color": [32, 32, 32], + "instances": False, + "readable": "Car Mount", + "name": "void--car-mount", + "evaluate": True, + }, + { + "color": [120, 10, 10], + "instances": False, + "readable": "Ego Vehicle", + "name": "void--ego-vehicle", + "evaluate": True, + }, + { + "color": [0, 0, 0], + "instances": False, + "readable": "Unlabeled", + "name": "void--unlabeled", + "evaluate": False, + }, +] + + +def _get_mapillary_vistas_meta(): + stuff_classes = [k["readable"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]] + assert len(stuff_classes) == 65 + + stuff_colors = [k["color"] for k in MAPILLARY_VISTAS_SEM_SEG_CATEGORIES if k["evaluate"]] + assert len(stuff_colors) == 65 + + ret = { + "stuff_classes": stuff_classes, + "stuff_colors": stuff_colors, + } + return ret + + +def register_all_mapillary_vistas(root): + root = os.path.join(root, "mapillary_vistas") + meta = _get_mapillary_vistas_meta() + for name, dirname in [("train", "training"), ("val", "validation")]: + image_dir = os.path.join(root, dirname, "images") + gt_dir = os.path.join(root, dirname, "labels") + name = f"mapillary_vistas_sem_seg_{name}" + DatasetCatalog.register( + name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") + ) + MetadataCatalog.get(name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=65, # different from other datasets, Mapillary Vistas sets ignore_label to 65 + **meta, + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_mapillary_vistas(_root) diff --git a/mask_former/mask_former_model.py b/mask_former/mask_former_model.py new file mode 100644 index 0000000000000000000000000000000000000000..965e1ff61121c5124b9f8725f2f2492d076c40ef --- /dev/null +++ b/mask_former/mask_former_model.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Tuple + +import torch +from detectron2.config import configurable +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head +from detectron2.modeling.backbone import Backbone +from detectron2.structures import ImageList +from torch import nn +from torch.nn import functional as F +from torchvision.transforms import functional as Ftv + +from utils.log import getLogger +from .modeling.criterion import SetCriterion +from .modeling.matcher import HungarianMatcher + +logger = getLogger(__name__) + + +def interpolate_or_crop(img, + size=(128, 128), + mode="bilinear", + align_corners=False, + tol=1.1): + h, w = img.shape[-2:] + H, W = size + if h == H and w == W: + return img + if H <= h < tol * H and W <= w < tol * W: + logger.info_once(f"Using center cropping instead of interpolation") + return Ftv.center_crop(img, output_size=size) + return F.interpolate(img, size=size, mode=mode, align_corners=align_corners) + + +@META_ARCH_REGISTRY.register() +class MaskFormer(nn.Module): + """ + Main class for mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + criterion: nn.Module, + num_queries: int, + panoptic_on: bool, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + crop_not_upsample: bool=True + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + num_queries: int, number of queries + panoptic_on: bool, whether to output panoptic segmentation prediction + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.crop_not_upsample = crop_not_upsample + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.criterion = criterion + self.num_queries = num_queries + self.overlap_threshold = overlap_threshold + self.panoptic_on = panoptic_on + self.object_mask_threshold = object_mask_threshold + self.metadata = metadata + if size_divisibility < 0: + # use backbone size_divisibility if not set + size_divisibility = self.backbone.size_divisibility + self.size_divisibility = size_divisibility + self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + out_shape = backbone.output_shape() + if len(cfg.GWM.SAMPLE_KEYS) > 1: + for k, v in out_shape.items(): + out_shape[k] = v._replace(channels=v.channels * len(cfg.GWM.SAMPLE_KEYS)) + sem_seg_head = build_sem_seg_head(cfg, out_shape) + + # Loss parameters: + deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT + dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT + mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT + + # building criterion + matcher = HungarianMatcher( + cost_class=1, + cost_mask=mask_weight, + cost_dice=dice_weight, + ) + + weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight} + if deep_supervision: + dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ["labels", "masks"] + + criterion = SetCriterion( + sem_seg_head.num_classes, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=no_object_weight, + losses=losses, + ) + + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "criterion": criterion, + "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, + "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, + "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, + "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, + "metadata": None, # MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), + "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, + "sem_seg_postprocess_before_inference": ( + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE + or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + 'crop_not_upsample': cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME != 'BasePixelDecoder' + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + return self.forward_base(batched_inputs, keys=["image"], get_train=not self.training, + get_eval=not self.training) + + def forward_base(self, batched_inputs, keys, get_train=False, get_eval=False, raw_sem_seg=False): + for i, key in enumerate(keys): + images = [x[key].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + logger.debug_once(f"Maskformer input {key} shape: {images.tensor.shape}") + out = self.backbone(images.tensor) + if i == 0: + features = out + else: + features = {k: torch.cat([features[k], v], 1) for k, v in out.items()} + outputs = self.sem_seg_head(features) + + if get_train: + # mask classification target + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances, images) + else: + targets = None + + # bipartite matching-based loss + losses = self.criterion(outputs, targets) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + if not get_eval: + return losses + + if get_eval: + # mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + mask_cls_results = mask_pred_results + logger.debug_once(f"Maskformer mask_pred_results shape: {mask_pred_results.shape}") + # upsample masks + # mask_pred_results = interpolate_or_crop( + # mask_pred_results, + # size=(images.tensor.shape[-2], images.tensor.shape[-1]), + # mode="bilinear", + # align_corners=False, + # ) + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + + if raw_sem_seg: + processed_results.append({"sem_seg": mask_pred_result}) + continue + + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + logger.debug_once(f"Maskformer mask_pred_results target HW: {height, width}") + r = interpolate_or_crop(mask_pred_result[None], size=(height, width), mode="bilinear", align_corners=False)[0] + + processed_results.append({"sem_seg": r}) + + # panoptic segmentation inference + # if self.panoptic_on: + # panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result) + # processed_results[-1]["panoptic_seg"] = panoptic_r + + # if 'features' in outputs: + # features = outputs['features'] + # features = interpolate_or_crop( + # features, + # size=(images.tensor.shape[-2], images.tensor.shape[-1]), + # mode="bilinear", + # align_corners=False, + # ) + # for res, f in zip(processed_results, features): + # res['features'] = f + del outputs + + if not get_train: + return processed_results + + return losses, processed_results + + + def prepare_targets(self, targets, images): + h, w = images.tensor.shape[-2:] + new_targets = [] + for targets_per_image in targets: + # pad gt + gt_masks = targets_per_image.gt_masks + padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device) + padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks + new_targets.append( + { + "labels": targets_per_image.gt_classes, + "masks": padded_masks, + } + ) + return new_targets + + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg + + + def panoptic_inference(self, mask_cls, mask_pred): + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + cur_mask_cls = mask_cls[keep] + cur_mask_cls = cur_mask_cls[:, :-1] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) + segments_info = [] + + current_segment_id = 0 + + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + return panoptic_seg, segments_info + else: + # take argmax + cur_mask_ids = cur_prob_masks.argmax(0) + stuff_memory_list = {} + for k in range(cur_classes.shape[0]): + pred_class = cur_classes[k].item() + isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < self.overlap_threshold: + continue + + # merge stuff regions + if not isthing: + if int(pred_class) in stuff_memory_list.keys(): + panoptic_seg[mask] = stuff_memory_list[int(pred_class)] + continue + else: + stuff_memory_list[int(pred_class)] = current_segment_id + 1 + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + + segments_info.append( + { + "id": current_segment_id, + "isthing": bool(isthing), + "category_id": int(pred_class), + } + ) + + return panoptic_seg, segments_info diff --git a/mask_former/modeling/__init__.py b/mask_former/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfafc0bea5858ddb8de852e0f938282ee970bd3 --- /dev/null +++ b/mask_former/modeling/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .backbone.swin import D2SwinTransformer +from .backbone.vit import D2ViTTransformer +from .heads.mask_former_head import MaskFormerHead +from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead +from .heads.pixel_decoder import BasePixelDecoder +from .heads.big_pixel_decoder import BigPixelDecoder +from .heads.mega_big_pixel_decoder import MegaBigPixelDecoder +from .heads.mask_former_head_baseline import MaskFormerBaselineHead diff --git a/mask_former/modeling/backbone/__init__.py b/mask_former/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/mask_former/modeling/backbone/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mask_former/modeling/backbone/swin.py b/mask_former/modeling/backbone/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..a616fec3273a49b96faeba007d71ae7776bec202 --- /dev/null +++ b/mask_former/modeling/backbone/swin.py @@ -0,0 +1,772 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec +logger = logging.getLogger('gwm') + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + logger.info(f"Freezing {self.frozen_stages} Layers") + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["res{}".format(i + 2)] = out + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +@BACKBONE_REGISTRY.register() +class D2SwinTransformer(SwinTransformer, Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE + patch_size = cfg.MODEL.SWIN.PATCH_SIZE + in_chans = 3 + embed_dim = cfg.MODEL.SWIN.EMBED_DIM + depths = cfg.MODEL.SWIN.DEPTHS + num_heads = cfg.MODEL.SWIN.NUM_HEADS + window_size = cfg.MODEL.SWIN.WINDOW_SIZE + mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO + qkv_bias = cfg.MODEL.SWIN.QKV_BIAS + qk_scale = cfg.MODEL.SWIN.QK_SCALE + drop_rate = cfg.MODEL.SWIN.DROP_RATE + attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE + drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE + norm_layer = nn.LayerNorm + ape = cfg.MODEL.SWIN.APE + patch_norm = cfg.MODEL.SWIN.PATCH_NORM + frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT + + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + frozen_stages=frozen_stages, + ) + + self._out_features = cfg.MODEL.SWIN.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 diff --git a/mask_former/modeling/backbone/vit.py b/mask_former/modeling/backbone/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f9919d03b507785c056aec92c4ea9304d837b071 --- /dev/null +++ b/mask_former/modeling/backbone/vit.py @@ -0,0 +1,441 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec +logger = logging.getLogger('gwm') +import argparse +import torch +import torchvision.transforms +from torch import nn +from torchvision import transforms +import torch.nn.modules.utils as nn_utils +import math +import timm +import types +from pathlib import Path +from typing import Union, List, Tuple +from PIL import Image +import einops + +class ViTExtractor: + """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. + + We use the following notation in the documentation of the module's methods: + B - batch size + h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW + p - patch size of the ViT. either 8 or 16. + t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width + of the input image. + d - the embedding dimension in the ViT. + """ + + def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): + """ + :param model_type: A string specifying the type of model to extract from. + [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | + vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] + :param stride: stride of first convolution layer. small stride -> higher resolution. + :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. + should be compatible with model_type. + """ + self.model_type = model_type + self.device = device + if model is not None: + self.model = model + else: + self.model = ViTExtractor.create_model(model_type) + + self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) + # self.model.eval() + self.model.to(self.device) + self.p = self.model.patch_embed.patch_size + self.stride = self.model.patch_embed.proj.stride + + self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + + self._feats = [] + self.hook_handlers = [] + self.load_size = None + self.num_patches = None + + @staticmethod + def create_model(model_type: str) -> nn.Module: + """ + :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | + dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | + vit_base_patch16_224] + :return: the model + """ + if 'dino' in model_type: + model = torch.hub.load('facebookresearch/dino:main', model_type) + else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). + temp_model = timm.create_model(model_type, pretrained=True) + model_type_dict = { + 'vit_small_patch16_224': 'dino_vits16', + 'vit_small_patch8_224': 'dino_vits8', + 'vit_base_patch16_224': 'dino_vitb16', + 'vit_base_patch8_224': 'dino_vitb8' + } + model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) + temp_state_dict = temp_model.state_dict() + del temp_state_dict['head.weight'] + del temp_state_dict['head.bias'] + model.load_state_dict(temp_state_dict) + return model + + @staticmethod + def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): + """ + Creates a method for position encoding interpolation. + :param patch_size: patch size of the model. + :param stride_hw: A tuple containing the new height and width stride respectively. + :return: the interpolation method + """ + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + # compute number of tokens taking stride into account + w0 = 1 + (w - patch_size) // stride_hw[1] + h0 = 1 + (h - patch_size) // stride_hw[0] + assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + align_corners=False, recompute_scale_factor=False + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return interpolate_pos_encoding + + @staticmethod + def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: + """ + change resolution of model output by changing the stride of the patch extraction. + :param model: the model to change resolution for. + :param stride: the new stride parameter. + :return: the adjusted model + """ + patch_size = model.patch_embed.patch_size + if stride == patch_size: # nothing to do + return model + + stride = nn_utils._pair(stride) + assert all([(patch_size // s_) * s_ == patch_size for s_ in + stride]), f'stride {stride} should divide patch_size {patch_size}' + + # fix the stride + model.patch_embed.proj.stride = stride + # fix the positional encoding code + model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) + return model + + def preprocess(self, image_path: Union[str, Path], + load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: + """ + Preprocesses an image before extraction. + :param image_path: path to image to be extracted. + :param load_size: optional. Size to resize image before the rest of preprocessing. + :return: a tuple containing: + (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. + (2) the pil image in relevant dimensions + """ + pil_image = Image.open(image_path).convert('RGB') + if load_size is not None: + pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) + prep = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std) + ]) + prep_img = prep(pil_image)[None, ...] + return prep_img, pil_image + + def _get_hook(self, facet: str): + """ + generate a hook method for a specific block and facet. + """ + if facet in ['attn', 'token']: + def _hook(model, input, output): + self._feats.append(output) + return _hook + + if facet == 'query': + facet_idx = 0 + elif facet == 'key': + facet_idx = 1 + elif facet == 'value': + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) + self._feats.append(qkv[facet_idx]) #Bxhxtxd + return _inner_hook + + def _register_hooks(self, layers: List[int], facet: str) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + for block_idx, block in enumerate(self.model.blocks): + if block_idx in layers: + if facet == 'token': + self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) + elif facet == 'attn': + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) + elif facet in ['key', 'query', 'value']: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: + """ + extract features from the model + :param batch: batch to extract features for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + :return : tensor of features. + if facet is 'key' | 'query' | 'value' has shape Bxhxtxd + if facet is 'attn' has shape Bxhxtxt + if facet is 'token' has shape Bxtxd + """ + B, C, H, W = batch.shape + self._feats = [] + self._register_hooks(layers, facet) + with torch.no_grad(): + _ = self.model(batch) + self._unregister_hooks() + self.load_size = (H, W) + self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) + return self._feats + + def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: + """ + create a log-binned descriptor. + :param x: tensor of features. Has shape Bxhxtxd. + :param hierarchy: how many bin hierarchies to use. + """ + B = x.shape[0] + num_bins = 1 + 8 * hierarchy + + bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) + bin_x = bin_x.permute(0, 2, 1) + bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) + # Bx(dxh)xnum_patches[0]xnum_patches[1] + sub_desc_dim = bin_x.shape[1] + + avg_pools = [] + # compute bins of all sizes for all spatial locations. + for k in range(0, hierarchy): + # avg pooling with kernel 3**kx3**k + win_size = 3 ** k + avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) + avg_pools.append(avg_pool(bin_x)) + + bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) + for y in range(self.num_patches[0]): + for x in range(self.num_patches[1]): + part_idx = 0 + # fill all bins for a spatial location (y, x) + for k in range(0, hierarchy): + kernel_size = 3 ** k + for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): + for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): + if i == y and j == x and k != 0: + continue + if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, i, j] + else: # handle padding in a more delicate way than zero padding + temp_i = max(0, min(i, self.num_patches[0] - 1)) + temp_j = max(0, min(j, self.num_patches[1] - 1)) + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, temp_i, + temp_j] + part_idx += 1 + bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) + # Bx1x(t-1)x(dxh) + return bin_x + + def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', + bin: bool = False, include_cls: bool = False) -> torch.Tensor: + """ + extract descriptors from the model + :param batch: batch to extract descriptors for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] + :param bin: apply log binning to the descriptor. default is False. + :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. + """ + assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. + choose from ['key' | 'query' | 'value' | 'token'] """ + self._extract_features(batch, [layer], facet) + x = self._feats[0] + if facet == 'token': + x.unsqueeze_(dim=1) #Bx1xtxd + if not include_cls: + x = x[:, :, 1:, :] # remove cls token + else: + assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." + if not bin: + desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + else: + desc = self._log_bin(x) + return desc + + def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: + """ + extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer + in of the CLS token. All values are then normalized to range between 0 and 1. + :param batch: batch to extract saliency maps for. Has shape BxCxHxW. + :return: a tensor of saliency maps. has shape Bxt-1 + """ + assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." + self._extract_features(batch, [11], 'attn') + head_idxs = [0, 2, 4, 5] + curr_feats = self._feats[0] #Bxhxtxt + cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) + temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] + cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] + return cls_attn_maps + +@BACKBONE_REGISTRY.register() +class D2ViTTransformer(Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE + patch_size = cfg.MODEL.SWIN.PATCH_SIZE + in_chans = 3 + embed_dim = cfg.MODEL.SWIN.EMBED_DIM + depths = cfg.MODEL.SWIN.DEPTHS + num_heads = cfg.MODEL.SWIN.NUM_HEADS + window_size = cfg.MODEL.SWIN.WINDOW_SIZE + mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO + qkv_bias = cfg.MODEL.SWIN.QKV_BIAS + qk_scale = cfg.MODEL.SWIN.QK_SCALE + drop_rate = cfg.MODEL.SWIN.DROP_RATE + attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE + drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE + norm_layer = nn.LayerNorm + ape = cfg.MODEL.SWIN.APE + patch_norm = cfg.MODEL.SWIN.PATCH_NORM + frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT + + super().__init__() + self.num_layers = 12 + num_features = [int(embed_dim) for i in range(self.num_layers)] + self.num_features = num_features + self.frozen_stages = frozen_stages + self.extractor = ViTExtractor( model_type='dino_vitb8', stride = 4, model = None, device = cfg.MODEL.DEVICE) + if self.frozen_stages >= 0: + for block_idx, block in enumerate(self.extractor.model.blocks): + if block_idx <= self.frozen_stages: + block.eval() + for p in block.parameters(): + p.requires_grad = False + + for block_idx, block in enumerate(self.extractor.model.blocks): + if all(p.requires_grad == False for p in block.parameters()): + print(f"DINO {block_idx} frozen") + + + self._out_features = cfg.MODEL.SWIN.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + facet = 'key' + self.extractor._extract_features(x, [5, 7, 9, 11], facet=facet) + res2 = self.extractor._feats[0].unsqueeze_(dim=1) # Bx1xtxd + res3 = self.extractor._feats[1].unsqueeze_(dim=1) # Bx1xtxd + res4 = self.extractor._feats[2].unsqueeze_(dim=1) # Bx1xtxd + res5 = self.extractor._feats[3].unsqueeze_(dim=1) # Bx1xtxd + if facet == 'key': + res2 = einops.rearrange(res2, 'b c h t d -> b c t (d h)') # Bx1xtxd + res3 = einops.rearrange(res3, 'b c h t d -> b c t (d h)') # Bx1xtxd + res4 = einops.rearrange(res4, 'b c h t d -> b c t (d h)') # Bx1xtxd + res5 = einops.rearrange(res5, 'b c h t d -> b c t (d h)') # Bx1xtxd + + res2 = res2.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + res3 = res3.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + res4 = res4.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + res5 = res5.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + + res2 = res2[:, :, 1:, :] # remove cls token + res3 = res3[:, :, 1:, :] # remove cls token + res4 = res4[:, :, 1:, :] # remove cls token + res5 = res5[:, :, 1:, :] # remove cls token + + res2 = res2.reshape(res2.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) + res3 = res3.reshape(res3.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) + res4 = res4.reshape(res4.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) + res5 = res5.reshape(res5.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) + + return { + "res2": res2, + "res3": res3, + "res4": res4, + "res5": res5, + } + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 diff --git a/mask_former/modeling/criterion.py b/mask_former/modeling/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..7631ee3766bc0920f3a0e8ca648782e6c5c2e0bf --- /dev/null +++ b/mask_former/modeling/criterion.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py +""" +MaskFormer criterion. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.utils.comm import get_world_size + +from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks + + +class SetCriterion(nn.Module): + """This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + def loss_labels(self, outputs, targets, indices, num_masks): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {"loss_ce": loss_ce} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = F.interpolate( + src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks), + "loss_dice": dice_loss(src_masks, target_masks, num_masks), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_masks): + loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_masks) + + def forward(self, outputs, targets): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_masks = sum(len(t["labels"]) for t in targets) + num_masks = torch.as_tensor( + [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses diff --git a/mask_former/modeling/heads/__init__.py b/mask_former/modeling/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/mask_former/modeling/heads/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mask_former/modeling/heads/big_pixel_decoder.py b/mask_former/modeling/heads/big_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3b673804e9ae15f699979229387b588d8eabd580 --- /dev/null +++ b/mask_former/modeling/heads/big_pixel_decoder.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, Optional, Union + +import fvcore.nn.weight_init as weight_init +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY +from torch import nn +from torch.nn import functional as F + +from ..transformer.position_encoding import PositionEmbeddingSine +from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer + +@SEM_SEG_HEADS_REGISTRY.register() +class BigPixelDecoder(nn.Module): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_dim = mask_dim + # self.mask_features = Conv2d( + # conv_dim, + # mask_dim, + # kernel_size=3, + # stride=1, + # padding=1, + # ) + + # weight_init.c2_xavier_fill(self.mask_features) + + self.mask_features = nn.Sequential( + Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + nn.UpsamplingNearest2d(scale_factor=2), + Conv2d( + conv_dim, + conv_dim, + kernel_size=1, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + nn.UpsamplingNearest2d(scale_factor=2), + Conv2d( + conv_dim, + conv_dim, + kernel_size=1, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + ) + + for name, module in self.mask_features.named_modules(): + if 'Conv2d' in name: + weight_init.c2_xavier_fill(module) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = {} + ret["input_shape"] = { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + } + ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), None + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) diff --git a/mask_former/modeling/heads/mask_former_head.py b/mask_former/modeling/heads/mask_former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..095b0e625975429e35276713af6974bf9846a12c --- /dev/null +++ b/mask_former/modeling/heads/mask_former_head.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class MaskFormerHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, + "transformer_predictor": TransformerPredictor( + cfg, + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" + else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, + mask_classification=True, + ), + } + + def forward(self, features): + return self.layers(features) + + def layers(self, features): + mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor(features[self.transformer_in_feature], mask_features) + # predictions['features'] = mask_features + return predictions diff --git a/mask_former/modeling/heads/mask_former_head_baseline.py b/mask_former/modeling/heads/mask_former_head_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..45ec19010070349ec0d70a022ac3446aaf51c86e --- /dev/null +++ b/mask_former/modeling/heads/mask_former_head_baseline.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +import torch +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class MaskFormerBaselineHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + inc = 256 + self.out_layers = nn.Sequential(nn.Conv2d(inc, inc, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(inc, 1)) + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, + "transformer_predictor": TransformerPredictor( + cfg, + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" + else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, + mask_classification=True, + ), + } + + def forward(self, features): + f = self.layers(features) + + return self.out_layers(f).squeeze(-1) + + def layers(self, features): + mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) + # if self.transformer_in_feature == "transformer_encoder": + # assert ( + # transformer_encoder_features is not None + # ), "Please use the TransformerEncoderPixelDecoder." + # predictions = self.predictor(transformer_encoder_features, mask_features) + # else: + # predictions = self.predictor(features[self.transformer_in_feature], mask_features) + return mask_features diff --git a/mask_former/modeling/heads/mega_big_pixel_decoder.py b/mask_former/modeling/heads/mega_big_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7951ff599e53b45ea52144cc7b2e19e3040002 --- /dev/null +++ b/mask_former/modeling/heads/mega_big_pixel_decoder.py @@ -0,0 +1,249 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, Optional, Union + +import fvcore.nn.weight_init as weight_init +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY +from torch import nn +from torch.nn import functional as F + +from ..transformer.position_encoding import PositionEmbeddingSine +from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer + +@SEM_SEG_HEADS_REGISTRY.register() +class MegaBigPixelDecoder(nn.Module): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_dim = mask_dim + # self.mask_features = Conv2d( + # conv_dim, + # mask_dim, + # kernel_size=3, + # stride=1, + # padding=1, + # ) + + # weight_init.c2_xavier_fill(self.mask_features) + + self.mask_features = nn.Sequential( + Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + nn.UpsamplingNearest2d(scale_factor=2), + Conv2d( + conv_dim, + conv_dim, + kernel_size=1, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + nn.UpsamplingNearest2d(scale_factor=2), + Conv2d( + conv_dim, + conv_dim, + kernel_size=1, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + nn.UpsamplingNearest2d(scale_factor=2), + Conv2d( + conv_dim, + conv_dim, + kernel_size=1, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ), + Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + ) + + for name, module in self.mask_features.named_modules(): + if 'Conv2d' in name: + weight_init.c2_xavier_fill(module) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = {} + ret["input_shape"] = { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + } + ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), None + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) diff --git a/mask_former/modeling/heads/per_pixel_baseline.py b/mask_former/modeling/heads/per_pixel_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..a99f508e7b4a87ada0af6f10209f10edefa7e412 --- /dev/null +++ b/mask_former/modeling/heads/per_pixel_baseline.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class PerPixelBaselineHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + logger = logging.getLogger(__name__) + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.warning(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = Conv2d( + self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0 + ) + weight_init.c2_msra_fill(self.predictor) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + } + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = self.layers(features) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + x, _ = self.pixel_decoder.forward_features(features) + x = self.predictor(x) + return x + + def losses(self, predictions, targets): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = F.cross_entropy( + predictions, targets, reduction="mean", ignore_index=self.ignore_value + ) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses + + +@SEM_SEG_HEADS_REGISTRY.register() +class PerPixelBaselinePlusHead(PerPixelBaselineHead): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + deep_supervision: bool, + # inherit parameters + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + deep_supervision: whether or not to add supervision to the output of + every transformer decoder layer + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + """ + super().__init__( + input_shape, + num_classes=num_classes, + pixel_decoder=pixel_decoder, + loss_weight=loss_weight, + ignore_value=ignore_value, + ) + + del self.predictor + + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + self.deep_supervision = deep_supervision + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": + in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + else: + in_channels = input_shape[ret["transformer_in_feature"]].channels + ret["transformer_predictor"] = TransformerPredictor( + cfg, in_channels, mask_classification=False + ) + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + return ret + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x, aux_outputs = self.layers(features) + if self.training: + if self.deep_supervision: + losses = self.losses(x, targets) + for i, aux_output in enumerate(aux_outputs): + losses["loss_sem_seg" + f"_{i}"] = self.losses( + aux_output["pred_masks"], targets + )["loss_sem_seg"] + return None, losses + else: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor(features[self.transformer_in_feature], mask_features) + if self.deep_supervision: + return predictions["pred_masks"], predictions["aux_outputs"] + else: + return predictions["pred_masks"], None diff --git a/mask_former/modeling/heads/pixel_decoder.py b/mask_former/modeling/heads/pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..20f3791731d35e111756dad3255a569bff3581db --- /dev/null +++ b/mask_former/modeling/heads/pixel_decoder.py @@ -0,0 +1,294 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.position_encoding import PositionEmbeddingSine +from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer + + +def build_pixel_decoder(cfg, input_shape): + """ + Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. + """ + name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME + model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) + forward_features = getattr(model, "forward_features", None) + if not callable(forward_features): + raise ValueError( + "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " + f"Please implement forward_features for {name} to only return mask features." + ) + return model + + +@SEM_SEG_HEADS_REGISTRY.register() +class BasePixelDecoder(nn.Module): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_dim = mask_dim + self.mask_features = Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + weight_init.c2_xavier_fill(self.mask_features) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = {} + ret["input_shape"] = { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + } + ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), None + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) + + +@SEM_SEG_HEADS_REGISTRY.register() +class TransformerEncoderPixelDecoder(BasePixelDecoder): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + transformer_dropout: float, + transformer_nheads: int, + transformer_dim_feedforward: int, + transformer_enc_layers: int, + transformer_pre_norm: bool, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_dropout: dropout probability in transformer + transformer_nheads: number of heads in transformer + transformer_dim_feedforward: dimension of feedforward network + transformer_enc_layers: number of transformer encoder layers + transformer_pre_norm: whether to use pre-layernorm or not + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm) + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + in_channels = feature_channels[len(self.in_features) - 1] + self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + self.transformer = TransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + normalize_before=transformer_pre_norm, + ) + N_steps = conv_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # update layer + use_bias = norm == "" + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + delattr(self, "layer_{}".format(len(self.in_features))) + self.add_module("layer_{}".format(len(self.in_features)), output_conv) + self.output_convs[0] = output_conv + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret[ + "transformer_enc_layers" + ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config + ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + transformer = self.input_proj(x) + pos = self.pe_layer(x) + transformer = self.transformer(transformer, None, pos) + y = output_conv(transformer) + # save intermediate feature as input to Transformer decoder + transformer_encoder_features = transformer + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), transformer_encoder_features + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) diff --git a/mask_former/modeling/matcher.py b/mask_former/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d11706577bf353bb0df13fe57032da3c66e8f5 --- /dev/null +++ b/mask_former/modeling/matcher.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import nn + + +def batch_dice_loss(inputs, targets): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + prob = inputs.sigmoid() + focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + if alpha >= 0: + focal_pos = focal_pos * alpha + focal_neg = focal_neg * (1 - alpha) + + loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum( + "nc,mc->nm", focal_neg, (1 - targets) + ) + + return loss / hw + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching""" + bs, num_queries = outputs["pred_logits"].shape[:2] + + # Work out the mask padding size + masks = [v["masks"] for v in targets] + h_max = max([m.shape[1] for m in masks]) + w_max = max([m.shape[2] for m in masks]) + + indices = [] + + # Iterate through batch size + for b in range(bs): + + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + # gt masks are already padded when preparing target + tgt_mask = targets[b]["masks"].to(out_mask) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Downsample gt masks to save memory + tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest") + + # Flatten spatial dimension + out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W] + tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W] + + # Compute the focal loss between masks + cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) + + # Compute the dice loss betwen masks + cost_dice = batch_dice_loss(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_mask * cost_mask + + self.cost_class * cost_class + + self.cost_dice * cost_dice + ) + C = C.reshape(num_queries, -1).cpu() + + indices.append(linear_sum_assignment(C)) + return [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/mask_former/modeling/transformer/__init__.py b/mask_former/modeling/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/mask_former/modeling/transformer/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mask_former/modeling/transformer/position_encoding.py b/mask_former/modeling/transformer/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f60587a41d5f3b26c247ef569523ec4a595bd4b8 --- /dev/null +++ b/mask_former/modeling/transformer/position_encoding.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/mask_former/modeling/transformer/transformer.py b/mask_former/modeling/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8caa0108f5e136a9739320ab69a3e1b6f40298 --- /dev/null +++ b/mask_former/modeling/transformer/transformer.py @@ -0,0 +1,369 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +""" +Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed + ) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/mask_former/modeling/transformer/transformer_predictor.py b/mask_former/modeling/transformer/transformer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..465d8bbcbc41245a6152aefc33f251a4c288146f --- /dev/null +++ b/mask_former/modeling/transformer/transformer_predictor.py @@ -0,0 +1,181 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d + +from .position_encoding import PositionEmbeddingSine +from .transformer import Transformer + + +class TransformerPredictor(nn.Module): + @configurable + def __init__( + self, + in_channels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dropout: float, + dim_feedforward: int, + enc_layers: int, + dec_layers: int, + pre_norm: bool, + deep_supervision: bool, + mask_dim: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dropout: dropout in Transformer + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + deep_supervision: whether to add supervision to every decoder layers + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv evens if input + channels and hidden dim is identical + """ + super().__init__() + + # self.mask_classification = mask_classification + self.mask_classification = False + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + transformer = Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=deep_supervision, + ) + + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + if in_channels != hidden_dim or enforce_input_project: + self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + else: + self.input_proj = nn.Sequential() + self.aux_loss = deep_supervision + + # output FFNs + if self.mask_classification: + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + if self.num_queries != 2: + self.mask_querie_comb = nn.Sequential( + Conv2d(self.num_queries, self.num_queries, kernel_size=1, activation=F.relu), + Conv2d(self.num_queries, 2, kernel_size=5, activation=F.relu) + ) + + @classmethod + def from_config(cls, cfg, in_channels, mask_classification): + ret = {} + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES + ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM + ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES + # Transformer parameters: + ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS + ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS + ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ + + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + + return ret + + def forward(self, x, mask_features): + pos = self.pe_layer(x) + + src = x + mask = None + hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) + + if self.mask_classification: + outputs_class = self.class_embed(hs) + out = {"pred_logits": outputs_class[-1]} + else: + out = {} + + if self.aux_loss: + # [l, bs, queries, embed] + mask_embed = self.mask_embed(hs) + outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) + pred_masks = outputs_seg_masks[-1] + if self.num_queries > 2: + pred_masks = self.mask_querie_comb(pred_masks) + out["pred_masks"] = pred_masks + out["aux_outputs"] = self._set_aux_loss( + outputs_class if self.mask_classification else None, outputs_seg_masks + ) + else: + # FIXME h_boxes takes the last one computed, keep this in mind + # [bs, queries, embed] + mask_embed = self.mask_embed(hs[-1]) + outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + out["pred_masks"] = outputs_seg_masks + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/mask_former/test_time_augmentation.py b/mask_former/test_time_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..8d250b6bb7792b54ddeaaab62cc6c170d74d3bb9 --- /dev/null +++ b/mask_former/test_time_augmentation.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +from itertools import count + +import numpy as np +import torch +from fvcore.transforms import HFlipTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +from detectron2.data.detection_utils import read_image +from detectron2.modeling import DatasetMapperTTA + +__all__ = [ + "SemanticSegmentorWithTTA", +] + + +class SemanticSegmentorWithTTA(nn.Module): + """ + A SemanticSegmentor with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=1): + """ + Args: + cfg (CfgNode): + model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + self.cfg = cfg.clone() + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + def _batch_inference(self, batched_inputs): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` + """ + outputs = [] + inputs = [] + for idx, input in zip(count(), batched_inputs): + inputs.append(input) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + with torch.no_grad(): + outputs.extend(self.model(inputs)) + inputs = [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`SemanticSegmentor.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + Returns: + dict: one output dict + """ + augmented_inputs, tfms = self._get_augmented_inputs(input) + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # Delete now useless variables to avoid being out of memory + del augmented_inputs + # 2: merge the results + # handle flip specially + new_outputs = [] + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + new_outputs.append(output.pop("sem_seg").flip(dims=[2])) + else: + new_outputs.append(output.pop("sem_seg")) + del outputs + # to avoid OOM with torch.stack + final_predictions = new_outputs[0] + for i in range(1, len(new_outputs)): + final_predictions += new_outputs[i] + final_predictions = final_predictions / len(new_outputs) + del new_outputs + return {"sem_seg": final_predictions} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms diff --git a/mask_former/utils/__init__.py b/mask_former/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/mask_former/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/mask_former/utils/misc.py b/mask_former/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46 --- /dev/null +++ b/mask_former/utils/misc.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +from typing import List, Optional + +import torch +import torch.distributed as dist +import torchvision +from torch import Tensor + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/mask_former_trainer.py b/mask_former_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2edbbeb4da01691204afbb417b74b56e7eef039e --- /dev/null +++ b/mask_former_trainer.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +MaskFormer Training Script. + +This script is a simplified version of the training script in detectron2/tools. +""" +import copy +import itertools +import json +import logging +import os +import sys +from typing import Any, Dict, List, Set + +import detectron2.utils.comm as comm +import torch +import wandb +from detectron2.config import get_cfg, CfgNode +from detectron2.engine import DefaultTrainer, default_setup +from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler +from detectron2.solver.build import maybe_add_gradient_clipping +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger + +import utils +# MaskFormer +from config import add_gwm_config + +logger = logging.getLogger('gwm') + + +class Trainer(DefaultTrainer): + """ + Extension of the Trainer class adapted to DETR. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name): + pass + + @classmethod + def build_lr_scheduler(cls, cfg, optimizer): + """ + It now calls :func:`detectron2.solver.build_lr_scheduler`. + Overwrite it if you'd like a different scheduler. + """ + return build_lr_scheduler(cfg, optimizer) + + @classmethod + def build_optimizer(cls, cfg, model): + weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM + weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED + + defaults = {} + defaults["lr"] = cfg.SOLVER.BASE_LR + defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY + + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + + hyperparams = copy.copy(defaults) + if "backbone" in module_name: + hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER + if ( + "relative_position_bias_table" in module_param_name + or "absolute_pos_embed" in module_param_name + ): + hyperparams["weight_decay"] = 0.0 + if isinstance(module, norm_module_types): + hyperparams["weight_decay"] = weight_decay_norm + if isinstance(module, torch.nn.Embedding): + hyperparams["weight_decay"] = weight_decay_embed + params.append({"params": [value], **hyperparams}) + + def maybe_add_full_model_gradient_clipping(optim): + # detectron2 doesn't have full model gradient clipping now + clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE + enable = ( + cfg.SOLVER.CLIP_GRADIENTS.ENABLED + and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" + and clip_norm_val > 0.0 + ) + + class FullModelGradientClippingOptimizer(optim): + def step(self, closure=None): + all_params = itertools.chain(*[x["params"] for x in self.param_groups]) + torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) + super().step(closure=closure) + + return FullModelGradientClippingOptimizer if enable else optim + + optimizer_type = cfg.SOLVER.OPTIMIZER + if optimizer_type == "SGD": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( + params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM + ) + elif optimizer_type == "ADAMW": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( + params, cfg.SOLVER.BASE_LR + ) + elif optimizer_type == "RMSProp": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.RMSprop)( + params, cfg.SOLVER.BASE_LR + ) + else: + raise NotImplementedError(f"no optimizer type {optimizer_type}") + if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": + optimizer = maybe_add_gradient_clipping(cfg, optimizer) + return optimizer + + +def setup(args): + """ + Create configs and perform basic setups. + """ + + wandb_inited = False + if 'CONFIG_FILE' in args.opts and not args.wandb_sweep_mode: + logger.warning( + f"Found CONFIG_FILE key in OPT args and using {args.opts[args.opts.index('CONFIG_FILE') + 1]} instead of {args.config_file}") + args.config_file = args.opts[args.opts.index('CONFIG_FILE') + 1] + else: + cfg = get_cfg() + add_gwm_config(cfg) + wandb_basedir = cfg.WANDB.BASEDIR + cfg_dict = CfgNode.load_yaml_with_base(args.config_file, allow_unsafe=True) + if 'WANDB' in cfg_dict and 'BASEDIR' in cfg_dict['WANDB']: + wandb_basedir = cfg_dict['WANDB']['BASEDIR'] + if 'CONFIG_FILE' in cfg_dict and cfg_dict['CONFIG_FILE'] is not None: + logger.warning( + f"Found CONFIG_FILE key in the config.yaml file and using {cfg_dict['CONFIG_FILE']} instead of {args.config_file}") + args.config_file = cfg_dict['CONFIG_FILE'] + + if args.wandb_sweep_mode: + if PathManager.isfile('wandb.yaml'): + wandb_cfg = CfgNode.load_yaml_with_base('wandb.yaml', allow_unsafe=False) + wandb.init(project=wandb_cfg['PROJECT'], entity=wandb_cfg['USER'], dir=wandb_basedir) + wandb_inited = True + if wandb.run.sweep_id: # sweep active + sweep_dict = dict(wandb.config) + if 'CONFIG_FILE' in sweep_dict: + args.config_file = sweep_dict['CONFIG_FILE'] + logger.warning(f"Loading CONFIG_FILE as set in sweep config: {args.config_file}") + elif 'CONFIG_FILE' in args.opts: + args.config_file = args.opts[args.opts.index('CONFIG_FILE') + 1] + logger.warning(f"Loading CONFIG_FILE as set in the optional arguments: {args.config_file}") + + if 'GWM.MODEL' in args.opts and not args.wandb_sweep_mode: + logger.warning( + "It is advised to not set GWM.MODEL in OPT args and instead set it in the config.yaml file") + model = args.opts[args.opts.index('GWM.MODEL') + 1] + else: + cfg = get_cfg() + add_gwm_config(cfg) + model = cfg.GWM.MODEL + wandb_basedir = cfg.WANDB.BASEDIR + cfg_dict = CfgNode.load_yaml_with_base(args.config_file, allow_unsafe=True) + if 'GWM' in cfg_dict and 'MODEL' in cfg_dict['GWM']: + model = cfg_dict['GWM']['MODEL'] + if 'WANDB' in cfg_dict and 'BASEDIR' in cfg_dict['WANDB']: + wandb_basedir = cfg_dict['WANDB']['BASEDIR'] + + if args.wandb_sweep_mode: + if PathManager.isfile('wandb.yaml'): + if not wandb_inited: + wandb_cfg = CfgNode.load_yaml_with_base('wandb.yaml', allow_unsafe=False) + wandb.init(project=wandb_cfg['PROJECT'], entity=wandb_cfg['USER'], dir=wandb_basedir) + wandb_inited = True + + if args.wandb_sweep_mode: + sweep_dict = dict(wandb.config) + if 'GWM.MODEL' in sweep_dict: + logger.warning( + "It is advised to not set GWM.MODEL in sweep config and instead set it in the config.yaml file") + model = sweep_dict['GWM.MODEL'] + elif 'GWM.MODEL' in args.opts: + logger.warning( + "It is advised to not set GWM.MODEL in optional arguments and instead set it in the config.yaml file") + model = args.opts[args.opts.index('GWM.MODEL') + 1] + + cfg = get_cfg() + # for poly lr schedule + add_deeplab_config(cfg) + if model == "MASKFORMER": + from mask_former import add_mask_former_config + add_mask_former_config(cfg) + else: + logger.error(f'Unknown Model: {model}. Exiting..') + sys.exit(0) + + add_gwm_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.WANDB.ENABLE = (cfg.WANDB.ENABLE or args.wandb_sweep_mode) and not args.eval_only + + datestring = utils.log.get_datestring_for_the_run() + if cfg.WANDB.ENABLE: + if PathManager.isfile('wandb.yaml'): + if not wandb_inited: + wandb_cfg = CfgNode.load_yaml_with_base('wandb.yaml', allow_unsafe=False) + wandb.init(project=wandb_cfg['PROJECT'], entity=wandb_cfg['USER'], dir=cfg.WANDB.BASEDIR) + + if args.wandb_sweep_mode: # sweep active + sweep_list = [(k, v) for k, v in dict(wandb.config).items()] + sweep_list = [item for items in sweep_list for item in items] + cfg.merge_from_list(sweep_list) + + if cfg.LOG_ID is not None: + api = wandb.Api() + run = api.run(path=f"{wandb_cfg['USER']}/{wandb_cfg['PROJECT']}/{wandb.run.id}") + run.name = f'{cfg.LOG_ID}/{datestring}-{wandb.run.id}' + run.save() + + else: + logger.error("W&B config file 'src/wandb.yaml' does not exist!") + cfg.WANDB.ENABLE = False + + if args.resume_path: + cfg.OUTPUT_DIR = "/".join(args.resume_path.split('/')[:-2]) # LOG_ID/datestring/checkpoints/checkpoints.pth + + if args.eval_only: + cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, 'eval', datestring) + + else: + if cfg.LOG_ID and not cfg.SLURM: + cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_BASEDIR, cfg.LOG_ID) + else: + cfg.OUTPUT_DIR = cfg.OUTPUT_BASEDIR + + if args.eval_only: + cfg.OUTPUT_DIR = None + else: + cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, datestring) + os.makedirs(f'{cfg.OUTPUT_DIR}/checkpoints', exist_ok=True) + + + if cfg.WANDB.ENABLE: + wandb.config.update(cfg, allow_val_change=True) + + if cfg.GWM.LOSS == 'OG': + cfg.FLAGS.EXTENDED_FLOW_RECON_VIS = False + cfg.FLAGS.COMP_NLL_FOR_GT = False + + cfg.freeze() + default_setup(cfg, args) + + # Setup logger for "gwm" module + setup_logger(output=f'{cfg.OUTPUT_DIR}/main.log', distributed_rank=comm.get_rank(), name="gwm") + with open(f'{cfg.OUTPUT_DIR}/args.json', 'w') as f: + json.dump(args.__dict__, f, indent=2) + return cfg diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cfb4b243b1d031560252bc912653ba4c63db028b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==1.12.0+cpu +torchvision==0.13.0+cpu +jupyter +tensorboard +timm +einops +scikit-learn +scikit-image +tqdm +cvbase +opencv-python +wandb +matplotlib diff --git a/samples/1920px-Woman_at_work,_Gujarat.jpg b/samples/1920px-Woman_at_work,_Gujarat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ff988239a7cd53f9536de272473eb26061491996 --- /dev/null +++ b/samples/1920px-Woman_at_work,_Gujarat.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4341db3b34a43b1b7589f674db3fc88a4160bd8e1eae65e304cfa6de3c7b4e3c +size 1051116 diff --git a/samples/2560px-2011_Toyota_Corolla_--_NHTSA.jpg b/samples/2560px-2011_Toyota_Corolla_--_NHTSA.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfb4eb234949c71a51a676972707d1af02aeebe5 Binary files /dev/null and b/samples/2560px-2011_Toyota_Corolla_--_NHTSA.jpg differ diff --git a/samples/Brooks_Chase_Ranger_of_Jolly_Dogs_Jack_Russell.jpg b/samples/Brooks_Chase_Ranger_of_Jolly_Dogs_Jack_Russell.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1487838b4f8d13ebe3f1625a46b05afb25d514d5 Binary files /dev/null and b/samples/Brooks_Chase_Ranger_of_Jolly_Dogs_Jack_Russell.jpg differ diff --git a/samples/Felis_catus-cat_on_snow.jpg b/samples/Felis_catus-cat_on_snow.jpg new file mode 100644 index 0000000000000000000000000000000000000000..92b42ae2372eb0d58f7b9d0b02f772daf3542e2e Binary files /dev/null and b/samples/Felis_catus-cat_on_snow.jpg differ diff --git a/samples/LICENSE b/samples/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..765d95634a652b5b7d383ae6442084d2bac93e67 --- /dev/null +++ b/samples/LICENSE @@ -0,0 +1,5 @@ +Licenses +https://en.wikipedia.org/wiki/Cat#/media/File:Felis_catus-cat_on_snow.jpg +https://en.wikipedia.org/wiki/Farmer#/media/File:Woman_at_work,_Gujarat.jpg +https://en.wikipedia.org/wiki/Car#/media/File:2011_Toyota_Corolla_--_NHTSA.jpg +https://en.wikipedia.org/wiki/Dog#/media/File:Brooks_Chase_Ranger_of_Jolly_Dogs_Jack_Russell.jpg \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77266a334e6241108ac43078811d06a88e336448 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,14 @@ +# Do not reorder this +from . import log +from . import data +from . import environment +from . import extras +from . import grid +from . import visualisation +from . import random_state + + +## have to keep it because it's here: +# https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/vision_transformer.py#L24 +## otherwise torch.hub.load(dino) will throw error +from .extras import trunc_normal_ diff --git a/utils/convert.py b/utils/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..569f5fddc5861f6bbe2568bfec47c0a3b00897af --- /dev/null +++ b/utils/convert.py @@ -0,0 +1,25 @@ +import torch +import itertools + +def cast_like(maybe_tensor, example_tensor): + if not torch.is_tensor(maybe_tensor): + maybe_tensor = torch.tensor(maybe_tensor) + maybe_tensor = maybe_tensor.to(example_tensor.device).to(example_tensor.dtype) + shape = [*maybe_tensor.shape, *[1] * len(example_tensor.shape)] + if not shape: + shape = [1] + return maybe_tensor.view(*shape) + + +def lofd_2_dofl(list_of_dicts, make_tensor=True): + keys = set(itertools.chain.from_iterable(list_of_dicts)) + out_dict = {} + for k in keys: + out_dict[k] = [d[k] for d in list_of_dicts if k in d] + if make_tensor: + example_tensor = next((v for v in out_dict[k] if torch.is_tensor(v)), None) + if example_tensor is None: + out_dict[k] = torch.tensor(out_dict[k]) + else: + out_dict[k] = torch.cat([cast_like(t, example_tensor) for t in out_dict[k]], 0) + return out_dict diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cd8149379e87c57912d9671a8b67ea4e22d771 --- /dev/null +++ b/utils/data.py @@ -0,0 +1,90 @@ +import logging +import os +import subprocess +from functools import lru_cache +from pathlib import Path + +import cv2 +import einops +import numpy as np +import torch +from cvbase.optflow.visualize import flow2rgb +from detectron2.data import detection_utils as d2_utils + +__LOGGER = logging.Logger(__name__) +__TAR_SP = [Path('/usr/bin/tar'), Path('/bin/tar')] + +TAG_FLOAT = 202021.25 + + +def read_flo(file): + assert type(file) is str, "file is not str %r" % str(file) + assert os.path.isfile(file) is True, "file does not exist %r" % str(file) + assert file[-4:] == '.flo', "file ending is not .flo %r" % file[-4:] + f = open(file, 'rb') + flo_number = np.fromfile(f, np.float32, count=1)[0] + assert flo_number == TAG_FLOAT, 'Flow number %r incorrect. Invalid .flo file' % flo_number + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + data = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) + # Reshape data into 3D array (columns, rows, bands) + flow = np.resize(data, (int(h), int(w), 2)) + f.close() + return flow + + +def read_flow(sample_dir, resolution=None, to_rgb=False): + flow = read_flo(sample_dir) + h, w, _ = np.shape(flow) + if resolution: + flow = cv2.resize(flow, (resolution[1], resolution[0]), interpolation=cv2.INTER_NEAREST) + flow[:, :, 0] = flow[:, :, 0] * resolution[1] / w + flow[:, :, 1] = flow[:, :, 1] * resolution[0] / h + if to_rgb: + flow = np.clip((flow2rgb(flow) - 0.5) * 2, -1., 1.) + return einops.rearrange(flow, 'h w c -> c h w') + + +def read_rgb(sample_dir, resolution=None): + rgb = d2_utils.read_image(sample_dir) + rgb = ((rgb / 255.0) - 0.5) * 2.0 + if resolution: + rgb = cv2.resize(rgb, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR) + rgb = np.clip(rgb, -1., 1.) + return einops.rearrange(rgb, 'h w c -> c h w') + + +### from: https://github.com/pytorch/pytorch/issues/15849#issuecomment-518126031 +class _RepeatSampler(object): + """ Sampler that repeats forever. + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + + +# https://github.com/pytorch/pytorch/issues/15849#issuecomment-573921048 +class FastDataLoader(torch.utils.data.dataloader.DataLoader): + '''for reusing cpu workers, to save time''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + # self.batch_sampler = _RepeatSampler(self.batch_sampler) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + +# Originally written by wkentaro +# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py diff --git a/utils/environment.py b/utils/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..04a4603c5eb4141650b3c00f78026105ccdb778a --- /dev/null +++ b/utils/environment.py @@ -0,0 +1,57 @@ +import os +import subprocess +from functools import lru_cache +from pathlib import Path + + +@lru_cache(None) +def __hostname(): + # env variables set up wrong on aimscdt2.dns.eng.ox.ac.uk + # if 'HOST' in os.environ: + # return str(os.environ['HOST']) + # if 'HOSTNAME' in os.environ: + # return str(os.environ['HOSTNAME']) + # else: + return str(subprocess.check_output('hostname', shell=True).decode().strip()) + + +@lru_cache(None) +def user(): + if 'USER' in os.environ: + return str(os.environ['USER']) + else: + return str(subprocess.check_output('whoami', shell=True).decode().strip()) + + +def is_slurm(): + return 'SLURM_JOB_ID' in os.environ and os.environ['SLURM_JOB_NAME'] not in ['zsh', 'bash'] + + +def get_slurm_id(): + return os.environ.get('SLURM_JOB_ID', None) + + +def is_aims_machine(): + hostname = __hostname() + return 'aims' in hostname + + +def is_vggdev_machine(): + hostname = __hostname() + return 'vggdev' in hostname or 'vggdebug' in hostname + + +def can_fit_in_tmp(path): + tmp_avail = int(str(subprocess.check_output(['/usr/bin/df', '-k', '--output=avail', str(os.environ['TMPDIR'])], + close_fds=True).decode().strip()).split()[-1].strip()) * 1024 + path_size = int(Path(path).stat().st_size) + print(f"{Path(path).name} size {path_size / 2 ** 30:.2f}GB vs {tmp_avail / 2 ** 30:.2f}GB") + return path_size < tmp_avail + + +def check_user(username, partial=True): + username = username.lower() + run_user = user().lower() + if partial: + return username in run_user + return username == run_user diff --git a/utils/extras.py b/utils/extras.py new file mode 100644 index 0000000000000000000000000000000000000000..5085975d9badbe68586b9d06f291ea3e0e70e6b6 --- /dev/null +++ b/utils/extras.py @@ -0,0 +1,84 @@ +import math +import warnings + +import torch + +from dist import cached_grid +from utils import log as log_utils + +LOGGER = log_utils.getLogger(__name__) + + +def mask_selector(masks_softmaxed, top=2, size_norm=False): + """Select centre most masks and sumthem """ + b, k, *other_dims, h, w = masks_softmaxed.shape + masks_softmaxed = masks_softmaxed.view(b, k, 1, h, w) + g = cached_grid(h, w, device=masks_softmaxed.device, dtype=masks_softmaxed.dtype) + x = g[0, 0] / (w - 1) - .5 + y = g[0, 1] / (h - 1) - .5 + + v = (x ** 2 + y ** 2) * 2 + assert len(v.shape) == 2 + v = v.view(*[1] * (len(masks_softmaxed) - 2), h, w) + scores = (masks_softmaxed * (1 - v)).sum([-1, -2]).view(b, k) + scores = scores / (masks_softmaxed.flatten(-3).sum(-1) + 1e-6) + + LOGGER.debug_once(f"Selector -- masks in {masks_softmaxed.shape}; scores {scores.shape}") + + best_idxs = scores.topk(top, dim=-1).indices[..., None, None, None].expand(-1, -1, -1, h, w) + wrst_idxs = (-scores).topk(k - top, dim=-1).indices[..., None, None, None].expand(-1, -1, -1, h, w) + + LOGGER.debug_once(f"Selector -- inds {best_idxs.shape} {wrst_idxs.shape}") + + masks_out = torch.empty(b, 2, 1, h, w, device=masks_softmaxed.device, dtype=masks_softmaxed.dtype) + + centre_most_masks = torch.gather(masks_softmaxed, 1, best_idxs).sum(1, keepdim=True) + other_masks = torch.gather(masks_softmaxed, 1, wrst_idxs).sum(1, keepdim=True) + + LOGGER.debug_once(f"Selector -- best {centre_most_masks.shape} others {other_masks.shape}") + + masks_out[:, 1:] = centre_most_masks + masks_out[:, :1] = other_masks + + return masks_out.view(b, 2, *other_dims, h, w) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/utils/grid.py b/utils/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..52d71b8fd2377cf36c210cc3b8641a4484ef45bb --- /dev/null +++ b/utils/grid.py @@ -0,0 +1,9 @@ +import torch + + +def get_meshgrid(resolution, device): + grid_x, grid_y = torch.meshgrid(torch.arange(resolution[0]).float() / resolution[0], + torch.arange(resolution[1]).float() / resolution[1], indexing='ij') + grid_x = grid_x.to(device) + grid_y = grid_y.to(device) + return grid_x, grid_y diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..ff73c94d4621218e2fbc29d617be04b5a60bf5db --- /dev/null +++ b/utils/log.py @@ -0,0 +1,101 @@ +import functools +import logging +from pathlib import Path +import shutil +import glob + +from datetime import datetime, timedelta + +_LOG_DICT = {} + +@functools.lru_cache(None) # always the same :) +def get_datestring_for_the_run(): + return datetime.now().strftime("%Y%m%d_%H%M%S") + +def _make_key(msg, args, kwargs): + args_str = ', '.join([str(arg) for arg in args]) + kwargs_str = ', '.join([f'{str(k)}={str(v)}' for k, v in kwargs.items()]) + r = [msg] + if args_str or kwargs_str: + r.append(' % (') + r.append(args_str) + if args_str: + r.append(', ') + r.append(kwargs_str) + if args_str or kwargs_str: + r.append(')') + # MyMessage % (arg1, arg2, kw1=v1m, kw2=v2m) + return ''.join(r) + + +def debug_once(msg, *args, logger=None, **kwargs): + key = _make_key(msg, args, kwargs) + + lvl = logging.DEBUG + t = datetime.now() + should_log = True + + if key in _LOG_DICT: + plvl, pt = _LOG_DICT[key] + # Do not overwrite + if plvl > lvl: + t = pt + lvl = plvl + should_log = False + + _LOG_DICT[key] = (lvl, t) + if should_log: + logger.debug(msg, *args, **kwargs) + + +def info_once(msg, *args, logger=None, **kwargs): + key = _make_key(msg, args, kwargs) + + lvl = logging.INFO + t = datetime.now() + should_log = True + + if key in _LOG_DICT: + plvl, pt = _LOG_DICT[key] + should_log = plvl <= lvl and t - pt > timedelta(minutes=5) + lvl = max(lvl, plvl) + + _LOG_DICT[key] = (lvl, t) + if should_log: + logger.info(msg, *args, **kwargs) + + +def getLogger(name): + if name != 'gwm' and not name.startswith('gwm.'): + name = 'gwm.' + name + logger = logging.getLogger(name) + logger.info_once = functools.partial(info_once, logger=logger) + logger.debug_once = functools.partial(info_once, logger=logger) + return logger + + +def checkpoint_code(log_path): + code_path = Path(log_path) / 'code' + if code_path.exists(): + code_path = code_path.with_name(f'code_{get_datestring_for_the_run()}') + code_path.mkdir(parents=True, exist_ok=True) + for file in glob.glob('*.py'): + shutil.copy(file, code_path) + shutil.copytree('datasets', code_path / 'datasets', ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + shutil.copytree('losses', code_path / 'losses', ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + shutil.copytree('utils', code_path / 'utils', ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + shutil.copytree('mask_former', code_path / 'mask_former', ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + + +class log_level: + def __init__(self, logger, lvl=logging.INFO): + self.logger = logging.getLogger(logger) + self.lvl = lvl + self.current_lvl = self.logger.level + + def __enter__(self): + self.current_lvl = self.logger.level + self.logger.setLevel(self.lvl) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.logger.setLevel(self.current_lvl) diff --git a/utils/random_state.py b/utils/random_state.py new file mode 100644 index 0000000000000000000000000000000000000000..91d0849c28023db8e2886354e0931dcbe6cdde75 --- /dev/null +++ b/utils/random_state.py @@ -0,0 +1,90 @@ +import os +import random + +import numpy as np +import torch + +from .log import getLogger + +# TODO: finish implementing this + +LOGGER = getLogger(__name__) + + +def worker_init_function(worker_id): + seed = torch.utils.data.get_worker_info().seed + np_seed = seed + if np_seed > 2**32 - 1: + np_seed = seed % (2**32 - 1) - 526 + int(worker_id) + np.random.seed(np_seed) + torch.manual_seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + +def get_randstate_magic_numbers(device=None): + """Use these to check that randstate advances the same accross runs""" + np_int = np.random.randint(0, int(1e6)) + random_int = random.randint(0, int(1e6)) + torch_cpu_int = torch.randint(int(1e6), (1,), device='cpu').item() + if device is not None: + torch_device_int = torch.randint(int(1e6), (1,), device=device).item() + else: + torch_device_int = None + return (random_int, np_int, torch_cpu_int, torch_device_int) + +class PytorchRNGState(torch.nn.Module): + """Class to save/restore PRNG states that masquarades as nn.Module for checkpointing""" + + __RANDOM_PRNG_STATE__ = '__random_prng_state__' + __NUMPY_PRNG_STATE__ = '__numpy_prng_state__' + __TORCH_PRNG_STATE__ = '__torch_prng_state__' + __CUDA_PRNG_STATE__ = '__cuda_prng_state__' + + def __init__(self, seed=42): + super(PytorchRNGState, self).__init__() + self.register_buffer('initial_seed', torch.tensor(seed, dtype=torch.long), persistent=True) + self.register_buffer('already_seeded', torch.tensor(False, dtype=torch.bool), persistent=True) + + @property + def device(self): + return self.initial_seed.device + + def seed_everything(self): + if torch.all(self.already_seeded): # sticky for checkpointing; do only once + return + else: + seed = int(self.initial_seed.item()) + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + self.already_seeded = torch.logical_not(self.already_seeded) # keep it as buffer i.e. tensor + LOGGER.info(f'Seed set to {seed}') + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super(PytorchRNGState, self).state_dict(destination, prefix, keep_vars) + state_dict[self.__RANDOM_PRNG_STATE__] = random.getstate() + state_dict[self.__NUMPY_PRNG_STATE__] = np.random.get_state() + state_dict[self.__TORCH_PRNG_STATE__] = torch.random.get_rng_state() + if torch.cuda.is_available() and 'cuda' in str(self.device): + cuda_state = torch.cuda.get_rng_state(self.device) + state_dict[self.__CUDA_PRNG_STATE__] = cuda_state + return state_dict + + def load_state_dict(self, state_dict, strict=True): + random.setstate(state_dict.pop(self.__RANDOM_PRNG_STATE__)) + np.random.set_state(state_dict.pop(self.__NUMPY_PRNG_STATE__)) + torch.set_rng_state(state_dict.pop(self.__TORCH_PRNG_STATE__)) + LOGGER.debug(f'Restored state to python process and ') + if strict: + if torch.cuda.is_available() and 'cuda' in str(self.device) and self.__CUDA_PRNG_STATE__ not in state_dict: + raise RuntimeError(f'Error in restoring CUDA PRNG state: state missing') + if self.__CUDA_PRNG_STATE__ in state_dict and (torch.cuda.is_available() or 'cuda' not in str(self.device)): + raise RuntimeError(f'Error in restoring CUDA PRNG state: CUDA not available') + if self.__CUDA_PRNG_STATE__ in state_dict and torch.cuda.is_available() and 'cuda' in str(self.device): + torch.cuda.set_rng_state(state_dict.pop(self.__CUDA_PRNG_STATE__), self.device) + return super(PytorchRNGState, self).load_state_dict(state_dict, strict) + + + diff --git a/utils/visualisation.py b/utils/visualisation.py new file mode 100644 index 0000000000000000000000000000000000000000..bf46f2e92d6a88b1a8f1e11809a9bcce95fe29fe --- /dev/null +++ b/utils/visualisation.py @@ -0,0 +1,38 @@ +import colorsys + +import torch +import numpy as np +from cvbase.optflow.visualize import flow2rgb + + +def flow2rgb_torch(x): + return torch.from_numpy(flow2rgb(x.permute(1, 2, 0).numpy())).permute(2, 0, 1) + + +def create_label_colormap(): + """Creates a label colormap used in CITYSCAPES segmentation benchmark. + Returns: + A colormap for visualizing segmentation results. + """ + colormap = np.zeros((256, 3), dtype=np.int64) + colormap[0] = [0, 0, 0] + colormap[1] = [166, 206, 227] + colormap[2] = [31, 120, 180] + colormap[3] = [178, 223, 138] + colormap[4] = [51, 160, 44] + colormap[5] = [251, 154, 153] + colormap[6] = [227, 26, 28] + colormap[7] = [253, 191, 111] + colormap[8] = [255, 127, 0] + colormap[9] = [202, 178, 214] + colormap[10] = [106, 61, 154] + colormap[11] = [255, 255, 153] + colormap[12] = [177, 89, 40] + colormap[13] = [0, 0, 142] + colormap[14] = [0, 0, 70] + colormap[15] = [0, 60, 100] + colormap[16] = [0, 80, 100] + colormap[17] = [0, 0, 230] + colormap[18] = [119, 11, 32] + + return torch.from_numpy(colormap).long() diff --git a/utils/vit_extractor.py b/utils/vit_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..3fa42d9a10d5580af96d9ae3b379018a332dd945 --- /dev/null +++ b/utils/vit_extractor.py @@ -0,0 +1,364 @@ +import argparse +import math +import types +from pathlib import Path +from typing import Union, List, Tuple + +import timm +import torch +import torch.nn.modules.utils as nn_utils +from PIL import Image +from torch import nn +from torchvision import transforms + + +class ViTExtractor: + """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. + + We use the following notation in the documentation of the module's methods: + B - batch size + h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW + p - patch size of the ViT. either 8 or 16. + t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width + of the input image. + d - the embedding dimension in the ViT. + """ + + def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): + """ + :param model_type: A string specifying the type of model to extract from. + [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | + vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] + :param stride: stride of first convolution layer. small stride -> higher resolution. + :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. + should be compatible with model_type. + """ + self.model_type = model_type + self.device = device + if model is not None: + self.model = model + else: + self.model = ViTExtractor.create_model(model_type) + + self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) + self.model.eval() + self.model.to(self.device) + self.p = self.model.patch_embed.patch_size + self.stride = self.model.patch_embed.proj.stride + + self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + + self._feats = [] + self.hook_handlers = [] + self.load_size = None + self.num_patches = None + + @staticmethod + def create_model(model_type: str) -> nn.Module: + """ + :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | + dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | + vit_base_patch16_224] + :return: the model + """ + if 'dino' in model_type: + model = torch.hub.load('facebookresearch/dino:main', model_type) + else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). + temp_model = timm.create_model(model_type, pretrained=True) + model_type_dict = { + 'vit_small_patch16_224': 'dino_vits16', + 'vit_small_patch8_224': 'dino_vits8', + 'vit_base_patch16_224': 'dino_vitb16', + 'vit_base_patch8_224': 'dino_vitb8' + } + model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) + temp_state_dict = temp_model.state_dict() + del temp_state_dict['head.weight'] + del temp_state_dict['head.bias'] + model.load_state_dict(temp_state_dict) + return model + + @staticmethod + def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): + """ + Creates a method for position encoding interpolation. + :param patch_size: patch size of the model. + :param stride_hw: A tuple containing the new height and width stride respectively. + :return: the interpolation method + """ + + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + # compute number of tokens taking stride into account + w0 = 1 + (w - patch_size) // stride_hw[1] + h0 = 1 + (h - patch_size) // stride_hw[0] + assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + align_corners=False, recompute_scale_factor=False + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return interpolate_pos_encoding + + @staticmethod + def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: + """ + change resolution of model output by changing the stride of the patch extraction. + :param model: the model to change resolution for. + :param stride: the new stride parameter. + :return: the adjusted model + """ + patch_size = model.patch_embed.patch_size + if stride == patch_size: # nothing to do + return model + + stride = nn_utils._pair(stride) + assert all([(patch_size // s_) * s_ == patch_size for s_ in + stride]), f'stride {stride} should divide patch_size {patch_size}' + + # fix the stride + model.patch_embed.proj.stride = stride + # fix the positional encoding code + model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) + return model + + def preprocess(self, image_path: Union[str, Path], + load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: + """ + Preprocesses an image before extraction. + :param image_path: path to image to be extracted. + :param load_size: optional. Size to resize image before the rest of preprocessing. + :return: a tuple containing: + (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. + (2) the pil image in relevant dimensions + """ + pil_image = Image.open(image_path).convert('RGB') + if load_size is not None: + pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) + prep = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=self.mean, std=self.std) + ]) + prep_img = prep(pil_image)[None, ...] + return prep_img, pil_image + + def _get_hook(self, facet: str): + """ + generate a hook method for a specific block and facet. + """ + if facet in ['attn', 'token']: + def _hook(model, input, output): + self._feats.append(output) + + return _hook + + if facet == 'query': + facet_idx = 0 + elif facet == 'key': + facet_idx = 1 + elif facet == 'value': + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) + self._feats.append(qkv[facet_idx]) # Bxhxtxd + + return _inner_hook + + def _register_hooks(self, layers: List[int], facet: str) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + for block_idx, block in enumerate(self.model.blocks): + if block_idx in layers: + if facet == 'token': + self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) + elif facet == 'attn': + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) + elif facet in ['key', 'query', 'value']: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: + """ + extract features from the model + :param batch: batch to extract features for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + :return : tensor of features. + if facet is 'key' | 'query' | 'value' has shape Bxhxtxd + if facet is 'attn' has shape Bxhxtxt + if facet is 'token' has shape Bxtxd + """ + B, C, H, W = batch.shape + self._feats = [] + self._register_hooks(layers, facet) + _ = self.model(batch) + self._unregister_hooks() + self.load_size = (H, W) + self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) + return self._feats + + def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: + """ + create a log-binned descriptor. + :param x: tensor of features. Has shape Bxhxtxd. + :param hierarchy: how many bin hierarchies to use. + """ + B = x.shape[0] + num_bins = 1 + 8 * hierarchy + + bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) + bin_x = bin_x.permute(0, 2, 1) + bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) + # Bx(dxh)xnum_patches[0]xnum_patches[1] + sub_desc_dim = bin_x.shape[1] + + avg_pools = [] + # compute bins of all sizes for all spatial locations. + for k in range(0, hierarchy): + # avg pooling with kernel 3**kx3**k + win_size = 3 ** k + avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) + avg_pools.append(avg_pool(bin_x)) + + bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) + for y in range(self.num_patches[0]): + for x in range(self.num_patches[1]): + part_idx = 0 + # fill all bins for a spatial location (y, x) + for k in range(0, hierarchy): + kernel_size = 3 ** k + for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): + for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): + if i == y and j == x and k != 0: + continue + if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, i, j] + else: # handle padding in a more delicate way than zero padding + temp_i = max(0, min(i, self.num_patches[0] - 1)) + temp_j = max(0, min(j, self.num_patches[1] - 1)) + bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ + :, :, temp_i, + temp_j] + part_idx += 1 + bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) + # Bx1x(t-1)x(dxh) + return bin_x + + def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', + bin: bool = False, include_cls: bool = False) -> torch.Tensor: + """ + extract descriptors from the model + :param batch: batch to extract descriptors for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] + :param bin: apply log binning to the descriptor. default is False. + :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. + """ + assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. + choose from ['key' | 'query' | 'value' | 'token'] """ + self._extract_features(batch, [layer], facet) + x = self._feats[0] + if facet == 'token': + x.unsqueeze_(dim=1) # Bx1xtxd + if not include_cls: + x = x[:, :, 1:, :] # remove cls token + else: + assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." + if not bin: + desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + else: + desc = self._log_bin(x) + return desc + + def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: + """ + extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer + in of the CLS token. All values are then normalized to range between 0 and 1. + :param batch: batch to extract saliency maps for. Has shape BxCxHxW. + :return: a tensor of saliency maps. has shape Bxt-1 + """ + assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." + self._extract_features(batch, [11], 'attn') + head_idxs = [0, 2, 4, 5] + curr_feats = self._feats[0] # Bxhxtxt + cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) # Bx(t-1) + temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] + cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] + return cls_attn_maps + + +""" taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Facilitate ViT Descriptor extraction.') + parser.add_argument('--image_path', type=str, required=True, help='path of the extracted image.') + parser.add_argument('--output_path', type=str, required=True, help='path to file containing extracted descriptors.') + parser.add_argument('--load_size', default=224, type=int, help='load size of the input image.') + parser.add_argument('--stride', default=4, type=int, help="""stride of first convolution layer. + small stride -> higher resolution.""") + parser.add_argument('--model_type', default='dino_vits8', type=str, + help="""type of model to extract. + Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | + vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""") + parser.add_argument('--facet', default='key', type=str, help="""facet to create descriptors from. + options: ['key' | 'query' | 'value' | 'token']""") + parser.add_argument('--layer', default=11, type=int, help="layer to create descriptors from.") + parser.add_argument('--bin', default='False', type=str2bool, help="create a binned descriptor if True.") + + args = parser.parse_args() + + with torch.no_grad(): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + extractor = ViTExtractor(args.model_type, args.stride, device=device) + image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size) + print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.") + descriptors = extractor.extract_descriptors(image_batch.to(device), args.layer, args.facet, args.bin) + print(f"Descriptors are of size: {descriptors.shape}") + torch.save(descriptors, args.output_path) + print(f"Descriptors saved to: {args.output_path}")