diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bd764a800a9d80da448fe912b9e8263364fdc229 --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/\ + +flagged/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b076d86084a9743afbd07dac765b7fdabb8e064f --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, Aastha Singh +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 91e7a1244e05851f5bc1073302d94ebb97f2321d..7fd235c0f0d7f48e345b3fb1c9eae7903a67cdd3 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,94 @@ ---- -title: GLIP BLIP Object Detection VQA -emoji: 📊 -colorFrom: indigo -colorTo: pink -sdk: gradio -sdk_version: 3.4.1 -app_file: app.py -pinned: false -license: bsd-3-clause ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Vision-Language Object Detection and Visual Question Answering +This repository includes Microsoft's GLIP and Salesforce's BLIP ensembled demo for detecting objects and Visual Question Answering based on text prompts. + +
+ +## About GLIP: Grounded Language-Image Pre-training - +> GLIP demonstrate strong zero-shot and few-shot transferability to various object-level recognition tasks. + +> The model used in this repo is GLIP-T, it is originally pre-trained on Conceptual Captions 3M and SBU captions. + +
+ +## About BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation - + +> A new model architecture that enables a wider range of downstream tasks than existing methods, and a new dataset bootstrapping method for learning from noisy web data. + +
+ +## Installation and Setup + +***Enviornment*** - Due to limitations with `maskrcnn_benchmark`, this repo requires Pytorch=1.10 and torchvision. + +Use `requirements.txt` to install dependencies + +```sh +pip3 install -r requirements.txt +``` +Build `maskrcnn_benchmark` +``` +python setup.py build develop --user +``` + +To verify a successful build, check the terminal for message +"Finished processing dependencies for maskrcnn-benchmark==0.1" + +## Checkpoints + +> Download the pre-trained models into the `checkpoints` folder. + +
+ +```sh +mkdir checkpoints +cd checkpoints +``` + +Model | Weight +-- | -- +**GLIP-T** | [weight](https://drive.google.com/file/d/1nlPL6PHkslarP6RiWJJu6QGKjqHG4tkc/view?usp=sharing) +**BLIP** | [weight](https://drive.google.com/file/d/1QliNGiAcyCCJLd22eNOxWvMUDzb7GzrO/view?usp=sharing) + +
files.maxMemoryForLargeFilesMB + +## If you have an NVIDIA GPU with 8GB VRAM, run local demo using Gradio interface + +```sh +python3 app.py +``` +## Future Work + +- [x] Frame based Visual Question Answering +- [ ] Each object based Visual Question Answering + + +## Citations + +```txt +@inproceedings{li2022blip, + title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation}, + author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi}, + year={2022}, + booktitle={ICML}, +} +@inproceedings{li2021grounded, + title={Grounded Language-Image Pre-training}, + author={Liunian Harold Li* and Pengchuan Zhang* and Haotian Zhang* and Jianwei Yang and Chunyuan Li and Yiwu Zhong and Lijuan Wang and Lu Yuan and Lei Zhang and Jenq-Neng Hwang and Kai-Wei Chang and Jianfeng Gao}, + year={2022}, + booktitle={CVPR}, +} +@article{zhang2022glipv2, + title={GLIPv2: Unifying Localization and Vision-Language Understanding}, + author={Zhang, Haotian* and Zhang, Pengchuan* and Hu, Xiaowei and Chen, Yen-Chun and Li, Liunian Harold and Dai, Xiyang and Wang, Lijuan and Yuan, Lu and Hwang, Jenq-Neng and Gao, Jianfeng}, + journal={arXiv preprint arXiv:2206.05836}, + year={2022} +} +@article{li2022elevater, + title={ELEVATER: A Benchmark and Toolkit for Evaluating Language-Augmented Visual Models}, + author={Li*, Chunyuan and Liu*, Haotian and Li, Liunian Harold and Zhang, Pengchuan and Aneja, Jyoti and Yang, Jianwei and Jin, Ping and Lee, Yong Jae and Hu, Houdong and Liu, Zicheng and others}, + journal={arXiv preprint arXiv:2204.08790}, + year={2022} +} +``` +## Acknowledgement +The implementation of this work relies on resources from BLIP, GLIP, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6a03c5bcd0e21ebcacac2453c128052f9deac0 --- /dev/null +++ b/app.py @@ -0,0 +1,57 @@ +import os +import gradio as gr +import warnings + +warnings.filterwarnings("ignore") + +os.system("python setup.py build develop --user") + +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo +import vqa +import vqa + +# Use this command for evaluate the GLIP-T model +config_file = "configs/glip_Swin_T_O365_GoldG.yaml" +weight_file = "checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth" + +# manual override some options +cfg.local_rank = 0 +cfg.num_gpus = 1 +cfg.merge_from_file(config_file) +cfg.merge_from_list(["MODEL.WEIGHT", weight_file]) +cfg.merge_from_list(["MODEL.DEVICE", "cuda"]) + +glip_demo = GLIPDemo( + cfg, + min_image_size=800, + confidence_threshold=0.7, + show_mask_heatmaps=False +) +blip_demo = vqa.VQA( + model_path = 'checkpoints/model_base_vqa_capfilt_large.pth' +) + +def predict(image, object, question): + result, _ = glip_demo.run_on_web_image(image[:, :, [2, 1, 0]], object, 0.5) + answer = blip_demo.vqa_demo(image, question) + return result[:, :, [2, 1, 0]], answer + +image = gr.inputs.Image() + +gr.Interface( + description="GLIP + BLIP VQA Demo.", + fn=predict, + inputs=[ + "image", + gr.Textbox(label='Objects', lines=1, placeholder="Objects here.."), + gr.Textbox(label='Question', lines=1, placeholder="Question here..")], + + outputs=[ + gr.outputs.Image( + type="pil", + label="grounding results" + ), + gr.Textbox(label="Answer") + ], +).launch() \ No newline at end of file diff --git a/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth b/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth new file mode 100644 index 0000000000000000000000000000000000000000..d05b8d5d3318107871c13ca068ee094644600779 --- /dev/null +++ b/checkpoints/glip_tiny_model_o365_goldg_cc_sbu.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bec0a3dea804fcb278d7106c5438de5116ee888e49dfae46270e7ad7bc4ccbf +size 3710104213 diff --git a/checkpoints/model_base_vqa_capfilt_large.pth b/checkpoints/model_base_vqa_capfilt_large.pth new file mode 100644 index 0000000000000000000000000000000000000000..df8c62ad684ab84409a19a947cd33b920b78b5ad --- /dev/null +++ b/checkpoints/model_base_vqa_capfilt_large.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a7d546209f1ccfa8b3cd3a0138c53e0d1e95e4a4bc280bef8f67e20fe4925ae +size 1446244375 diff --git a/configs/glip_Swin_T_O365_GoldG.yaml b/configs/glip_Swin_T_O365_GoldG.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80b9edba1b47a83f5da99254dd081dac3f80354a --- /dev/null +++ b/configs/glip_Swin_T_O365_GoldG.yaml @@ -0,0 +1,100 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedVLRCNN" + WEIGHT: "swin_tiny_patch4_window7_224.pth" + RPN_ONLY: True + RPN_ARCHITECTURE: "VLDYHEAD" + + BACKBONE: + CONV_BODY: "SWINT-FPN-RETINANET" + OUT_CHANNELS: 256 + FREEZE_CONV_BODY_AT: -1 + + LANGUAGE_BACKBONE: + FREEZE: False + MODEL_TYPE: "bert-base-uncased" # "roberta-base", "clip" + MASK_SPECIAL: False + + RPN: + USE_FPN: True + ANCHOR_SIZES: (64, 128, 256, 512, 1024) + ANCHOR_STRIDE: (8, 16, 32, 64, 128) + ASPECT_RATIOS: (1.0,) + SCALES_PER_OCTAVE: 1 + + DYHEAD: + CHANNELS: 256 + NUM_CONVS: 6 + USE_GN: True + USE_DYRELU: True + USE_DFCONV: True + USE_DYFUSE: True + TOPK: 9 # topk for selecting candidate positive samples from each level + SCORE_AGG: "MEAN" + LOG_SCALE: 0.0 + + FUSE_CONFIG: + EARLY_FUSE_ON: True + TYPE: "MHA-B" + USE_CLASSIFICATION_LOSS: False + USE_TOKEN_LOSS: False + USE_CONTRASTIVE_ALIGN_LOSS: False + CONTRASTIVE_HIDDEN_DIM: 64 + USE_DOT_PRODUCT_TOKEN_LOSS: True + USE_FUSED_FEATURES_DOT_PRODUCT: True + USE_LAYER_SCALE: True + CLAMP_MIN_FOR_UNDERFLOW: True + CLAMP_MAX_FOR_OVERFLOW: True + CLAMP_BERTATTN_MIN_FOR_UNDERFLOW: True + CLAMP_BERTATTN_MAX_FOR_OVERFLOW: True + CLAMP_DOT_PRODUCT: True + + USE_CHECKPOINT: True + +TEST: + DURING_TRAINING: False + IMS_PER_BATCH: 64 + +# use for grounding model +DATASETS: + TRAIN: ("object365_dt_train", "mixed_train_no_coco", "flickr30k_train", ) + TEST: ("coco_2017_val", ) + DISABLE_SHUFFLE: False + ADD_DET_PROMPT: False + RANDOM_SAMPLE_NEG: 85 + CONTROL_PROB: (0.0, 0.0, 0.5, 0.0) + + SEPARATION_TOKENS: ". " + +INPUT: + PIXEL_MEAN: [ 103.530, 116.280, 123.675 ] + PIXEL_STD: [ 57.375, 57.120, 58.395 ] + MIN_SIZE_TRAIN: 800 + MAX_SIZE_TRAIN: 1333 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1333 + +AUGMENT: + MULT_MIN_SIZE_TRAIN: (480,560,640,720,800) + +DATALOADER: + SIZE_DIVISIBILITY: 32 + +SOLVER: + OPTIMIZER: ADAMW + BASE_LR: 0.0001 + LANG_LR: 0.00001 + WEIGHT_DECAY: 0.0001 + STEPS: (0.67, 0.89) + MAX_EPOCH: 30 + IMS_PER_BATCH: 64 + WARMUP_ITERS: 2000 + WARMUP_FACTOR: 0.001 + USE_AMP: True + MODEL_EMA: 0.999 + FIND_UNUSED_PARAMETERS: False + + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 1.0 + NORM_TYPE: 2.0 \ No newline at end of file diff --git a/configs/med_config.json b/configs/med_config.json new file mode 100644 index 0000000000000000000000000000000000000000..0ffad0a6f3c2f9f11b8faa84529d9860bb70327a --- /dev/null +++ b/configs/med_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true +} diff --git a/configs/vqa.yaml b/configs/vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..74327e6d0a34672023b44569558fe8beeb052548 --- /dev/null +++ b/configs/vqa.yaml @@ -0,0 +1,25 @@ +vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ +vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ +train_files: ['vqa_train','vqa_val','vg_qa'] +ann_root: 'annotation' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' + +# size of vit model; base or large +vit: 'base' +batch_size_train: 16 +batch_size_test: 32 +vit_grad_ckpt: False +vit_ckpt_layer: 0 +init_lr: 2e-5 + +image_size: 480 + +k_test: 128 +inference: 'rank' + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 10 \ No newline at end of file diff --git a/itm.py b/itm.py new file mode 100644 index 0000000000000000000000000000000000000000..6da8af6dfe782beff41de4efb952f481fa97a6c6 --- /dev/null +++ b/itm.py @@ -0,0 +1,77 @@ +import sys +from PIL import Image +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from models.blip_vqa import blip_vqa +from models.blip_itm import blip_itm + + +class VQA: + def __init__(self, model_path, image_size=480): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base') + self.model.eval() + self.model = self.model.to(self.device) + + def load_demo_image(self, image_size, img_path, device): + raw_image = Image.open(img_path).convert('RGB') + w,h = raw_image.size + transform = transforms.Compose([ + transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + image = transform(raw_image).unsqueeze(0).to(device) + return raw_image, image + + def vqa(self, img_path, question): + raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device) + with torch.no_grad(): + answer = self.model(image, question, train=False, inference='generate') + return answer[0] +class ITM: + def __init__(self, model_path, image_size=384): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = blip_itm(pretrained=model_path, image_size=image_size, vit='base') + self.model.eval() + self.model = self.model.to(device='cpu') + + def load_demo_image(self, image_size, img_path, device): + raw_image = Image.open(img_path).convert('RGB') + w,h = raw_image.size + transform = transforms.Compose([ + transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + image = transform(raw_image).unsqueeze(0).to(device) + return raw_image, image + + def itm(self, img_path, caption): + raw_image, image = self.load_demo_image(image_size=384,img_path=img_path, device=self.device) + itm_output = self.model(image,caption,match_head='itm') + itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1] + itc_score = self.model(image,caption,match_head='itc') + # print('The image and text is matched with a probability of %.4f'%itm_score) + # print('The image feature and text feature has a cosine similarity of %.4f'%itc_score) + return itm_score, itc_score + +if __name__=="__main__": + if not len(sys.argv) == 3: + print('Format: python3 vqa.py ') + print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"') + + else: + model_path = 'checkpoints/model_base_vqa_capfilt_large.pth' + model2_path = 'model_base_retrieval_coco.pth' + # vqa_object = VQA(model_path=model_path) + itm_object = ITM(model_path=model2_path) + img_path = sys.argv[1] + # question = sys.argv[2] + caption = sys.argv[2] + # answer = vqa_object.vqa(img_path, caption) + itm_score, itc_score = itm_object.itm(img_path, caption) + # print('Question: {} | Answer: {}'.format(caption, answer)) + print('Caption: {} | The image and text is matched with a probability of %.4f: {} | The image feature and text feature has a cosine similarity of %.4f: {}'.format (caption,itm_score,itc_score)) + diff --git a/maskrcnn_benchmark/__init__.py b/maskrcnn_benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc96c7a6bf8379e1adfb3e4adf536107b385fa9 --- /dev/null +++ b/maskrcnn_benchmark/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. diff --git a/maskrcnn_benchmark/config/__init__.py b/maskrcnn_benchmark/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2015d6bd830bc3e0ec8b1ca7fcb63b4781a41ad --- /dev/null +++ b/maskrcnn_benchmark/config/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .defaults import _C as cfg +from .paths_catalog import try_to_find \ No newline at end of file diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..bd62a9ea307b727e0db06985264707046e8c7234 --- /dev/null +++ b/maskrcnn_benchmark/config/defaults.py @@ -0,0 +1,861 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os + +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +_C.MODEL = CN() +_C.MODEL.RPN_ONLY = False +_C.MODEL.BOX_ON = True +_C.MODEL.MASK_ON = False +_C.MODEL.KEYPOINT_ON = False +_C.MODEL.DEVICE = "cuda" + +_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" + +_C.MODEL.RPN_ARCHITECTURE = "RPN" +_C.MODEL.DEBUG = False # add debug flag +_C.MODEL.ONNX = False # add onnx flag + +# If the WEIGHT starts with a catalog://, like :R-50, the code will look for +# the path in paths_catalog. Else, it will use it as the specified absolute +# path +_C.MODEL.WEIGHT = "" +_C.MODEL.PRETRAIN_NAME = "" + +# If LINEAR_PROB = True, only the last linear layers in rpn and roi_head are trainable +_C.MODEL.LINEAR_PROB = False + +# ----------------------------------------------------------------------------- +# Multitask Training / Test specific parameters +# ----------------------------------------------------------------------------- +_C.MODEL.MULTITASK = CN(new_allowed=True) + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Size of the smallest side of the image during training +_C.INPUT.MIN_SIZE_TRAIN = 800 # (800,) +# Maximum size of the side of the image during training +_C.INPUT.MAX_SIZE_TRAIN = 1333 +# Size of the smallest side of the image during testing +_C.INPUT.MIN_SIZE_TEST = 800 +# Maximum size of the side of the image during testing +_C.INPUT.MAX_SIZE_TEST = 1333 +# Values to be used for image normalization +_C.INPUT.PIXEL_MEAN = [102.9801, 115.9465, 122.7717] +# Values to be used for image normalization +_C.INPUT.PIXEL_STD = [1., 1., 1.] +# Convert image to BGR format (for Caffe2 models), in range 0-255 +_C.INPUT.TO_BGR255 = True +_C.INPUT.FORMAT = '' +_C.INPUT.FIX_RES = False + +# ----------------------------------------------------------------------------- +# Augmentation +# ----------------------------------------------------------------------------- +_C.AUGMENT = CN() +_C.AUGMENT.USE_RA = 0 +_C.AUGMENT.FLIP_PROB_TRAIN = 0.5 +_C.AUGMENT.VERTICAL_FLIP_PROB_TRAIN = 0.0 +_C.AUGMENT.MULT_MIN_SIZE_TRAIN = () + +_C.AUGMENT.BRIGHTNESS = 0.0 +_C.AUGMENT.CONTRAST = 0.0 +_C.AUGMENT.SATURATION = 0.0 +_C.AUGMENT.HUE = 0.0 + +_C.AUGMENT.CROP_PROB = 0.5 +_C.AUGMENT.CROP_MIN_IOUS = (0.1, 0.3, 0.5, 0.7, 0.9) +_C.AUGMENT.CROP_MIN_SIZE = 0.3 + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.TRAIN = () +# List of the dataset names for testing, as present in paths_catalog.py +_C.DATASETS.TEST = () +# Use is_crowd label +_C.DATASETS.USE_CROWD = False +_C.DATASETS.CLASS_AGNOSTIC = False +_C.DATASETS.CLASS_CONCAT = False +_C.DATASETS.MAX_BOX = -1 +_C.DATASETS.SAMPLE_RATIO = 0.0 +_C.DATASETS.FEW_SHOT = 0 +# SHUFFLE_SEED != 0 means shuffle the dataset in the few shot setting +_C.DATASETS.SHUFFLE_SEED = 0 +_C.DATASETS.PREDEFINED_TEXT = '' +_C.DATASETS.ALTERNATIVE_TRAINING = False +_C.DATASETS.MULTISTAGE_TRAINING = False +_C.DATASETS.REGISTER = CN(new_allowed=True) +_C.DATASETS.BOX_THRESHOLD = 0.1 +# Duplicate Dataset +_C.DATASETS.COCO_COPY = 1 +_C.DATASETS.LVIS_COPY = 1 +_C.DATASETS.FLICKR_COPY = 1 +_C.DATASETS.MIXED_COPY = 1 +_C.DATASETS.OBJECT365_COPY = 1 +_C.DATASETS.VG_COPY = 1 +_C.DATASETS.OI_COPY = 1 +_C.DATASETS.IN_COPY = 1 + +# Duplicate Dataset +_C.DATASETS.COCO_COPY = 1 +_C.DATASETS.FLICKR_COPY = 1 +_C.DATASETS.MIXED_COPY = 1 +_C.DATASETS.OBJECT365_COPY = 1 +_C.DATASETS.VG_COPY = 1 +_C.DATASETS.OI_COPY = 1 +_C.DATASETS.IN_COPY = 1 +_C.DATASETS.GENERAL_COPY = -1 +_C.DATASETS.GENERAL_COPY_TEST = -1 + +# OD to Grounding +_C.DATASETS.RANDOM_SAMPLE_NEG = -1 +_C.DATASETS.ADD_DET_PROMPT = False +_C.DATASETS.ADD_DET_PROMPT_ADVANCED = False +_C.DATASETS.USE_OD_AUG = False +_C.DATASETS.USE_COCO_FORMAT = False +_C.DATASETS.CONTROL_PROB = () +_C.DATASETS.DISABLE_SHUFFLE = False +_C.DATASETS.PROMPT_VERSION = "" +_C.DATASETS.PROMPT_LIMIT_NEG = -1 +_C.DATASETS.POS_QUESTION_PROB = 0.6 +_C.DATASETS.NEG_QUESTION_PROB = 0.8 +_C.DATASETS.FULL_QUESTION_PROB = 0.5 +_C.DATASETS.ONE_HOT = False +_C.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT = False + +_C.DATASETS.DISABLE_CLIP_TO_IMAGE = False +_C.DATASETS.SEPARATION_TOKENS = " " + +# LVIS +_C.DATASETS.LVIS_USE_NORMAL_AP = False +_C.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING = False + +# Caption +_C.DATASETS.BING_INDEX_LIST = [] +_C.DATASETS.CAPTION_MIN_BOX = 1 +_C.DATASETS.REPLACE_CLEAN_LABEL = False +_C.DATASETS.FURTHER_SCREEN = False +_C.DATASETS.CAPTION_CONF = 0.9 +_C.DATASETS.CAPTION_NMS = 0.9 +_C.DATASETS.PACK_RANDOM_CAPTION_NUMBER = 0 +_C.DATASETS.INFERENCE_CAPTION = False +_C.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA = -1.0 +_C.DATASETS.RANDOM_PACK_PROB = -1.0 +_C.DATASETS.NO_RANDOM_PACK_PROBABILITY = 0.0 +_C.DATASETS.SAFEGUARD_POSITIVE_CAPTION = True +_C.DATASETS.CAPTION_FORMAT_VERSION = "v1" +_C.DATASETS.LOCAL_DEBUG = False + + +# Od in the wild +_C.DATASETS.PREDEFINED_TEXT = None +_C.DATASETS.TRAIN_DATASETNAME_SUFFIX = "" +_C.DATASETS.TEST_DATASETNAME_SUFFIX = "" +_C.DATASETS.OVERRIDE_CATEGORY = None +_C.DATASETS.USE_OVERRIDE_CATEGORY = False +_C.DATASETS.SUPRESS_QUERY = None +_C.DATASETS.USE_SUPRESS_QUERY = False +_C.DATASETS.USE_CAPTION_PROMPT = False +_C.DATASETS.CAPTION_PROMPT = None + +_C.DATASETS.FLICKR_GT_TYPE = "separate" + +# VQA +_C.DATASETS.DIVER_BOX_FOR_VQA = False +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +# If > 0, this enforces that each collated batch should have a size divisible +# by SIZE_DIVISIBILITY +_C.DATALOADER.SIZE_DIVISIBILITY = 0 +# If True, each batch should contain only images for which the aspect ratio +# is compatible. This groups portrait images together, and landscape images +# are not batched with portrait images. +_C.DATALOADER.ASPECT_RATIO_GROUPING = True +# Define min number of keypoints required from GT, for example 10 out of 17 +_C.DATALOADER.MIN_KPS_PER_IMS = 0 +# Use random sampler during training +_C.DATALOADER.USE_RANDOM_SEED = False + +_C.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE = False +# ---------------------------------------------------------------------------- # +# Backbone options +# ---------------------------------------------------------------------------- # +_C.MODEL.BACKBONE = CN() + +# The backbone conv body to use +# The string must match a function that is imported in modeling.model_builder +# (e.g., 'FPN.add_fpn_ResNet101_conv5_body' to specify a ResNet-101-FPN +# backbone) +_C.MODEL.BACKBONE.CONV_BODY = "R-50-C4" + +# Add StopGrad at a specified stage so the bottom layers are frozen +_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2 +_C.MODEL.BACKBONE.FREEZE = False +_C.MODEL.BACKBONE.GROUP = 1 +_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4 +# Option to reset bn running statics +_C.MODEL.BACKBONE.RESET_BN = False +# Backbone Normalization Level +_C.MODEL.BACKBONE.NORM_LEVEL = 3 +# BN for backbone +_C.MODEL.BACKBONE.USE_BN = False +# Sync BN for backbone +_C.MODEL.BACKBONE.USE_SYNCBN = False +_C.MODEL.BACKBONE.USE_NSYNCBN = False +# GN for backbone +_C.MODEL.BACKBONE.USE_GN = False +# Evo Norm for backbone +_C.MODEL.BACKBONE.USE_EN = False +# Layers for backbone +_C.MODEL.BACKBONE.USE_DFCONV = False +_C.MODEL.BACKBONE.USE_DYRELU = False +_C.MODEL.BACKBONE.USE_SE = False +_C.MODEL.BACKBONE.LAYER_SETUP = (3, 4, 6, 3) +_C.MODEL.BACKBONE.LAYER_SEARCH = CN(new_allowed=True) +_C.MODEL.BACKBONE.OUT_FEATURES = ("stage2", "stage3", "stage4", "stage5") +_C.MODEL.BACKBONE.FPN_LAYER = () +_C.MODEL.BACKBONE.USE_CHECKPOINT = False +# Add JF efficient det cfgs +_C.MODEL.BACKBONE.EFFICIENT_DET_START_FROM = 3 +_C.MODEL.BACKBONE.EFFICIENT_DET_COMPOUND = 0 +_C.MODEL.BACKBONE.EFFICIENT_DET_BIFPN_VERSION = 0 + +_C.MODEL.LANGUAGE_BACKBONE = CN() +_C.MODEL.LANGUAGE_BACKBONE.WEIGHT = "" +_C.MODEL.LANGUAGE_BACKBONE.FREEZE = False +_C.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT = False +_C.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE = "bert-base-uncased" +_C.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE = "bert-base-uncased" +_C.MODEL.LANGUAGE_BACKBONE.LANG_DIM = 768 +_C.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN = 256 +_C.MODEL.LANGUAGE_BACKBONE.N_LAYERS = 1 +_C.MODEL.LANGUAGE_BACKBONE.UNUSED_TOKEN = 106 +_C.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL = False + +_C.MODEL.LANGUAGE_BACKBONE.RNN_TYPE = "lstm" +_C.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH = True +_C.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE = 512 +_C.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE = 512 +_C.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE = 512 +_C.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL = True +_C.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P = 0.5 +_C.MODEL.LANGUAGE_BACKBONE.DROPOUT_P = 0.2 +_C.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH = "" +_C.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE = 0 + +_C.MODEL.LANGUAGE_BACKBONE.PAD_MAX = True +# ---------------------------------------------------------------------------- # +# FPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.FPN = CN() +_C.MODEL.FPN.FREEZE = False +_C.MODEL.FPN.USE_GN = False +_C.MODEL.FPN.USE_RELU = False +_C.MODEL.FPN.USE_DYRELU = False +_C.MODEL.FPN.DROP_BLOCK = True +_C.MODEL.FPN.DROP_PROB = 0.3 +_C.MODEL.FPN.DROP_SIZE = 3 +_C.MODEL.FPN.USE_SPP = False +_C.MODEL.FPN.USE_PAN = False +_C.MODEL.FPN.USE_DYHEAD = False +_C.MODEL.FPN.RETURN_SWINT_FEATURE_BEFORE_FUSION = False +# ---------------------------------------------------------------------------- # +# BIFPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.BIFPN = CN() +_C.MODEL.BIFPN.NUM_REPEATS = 1 +_C.MODEL.BIFPN.USE_ATTENTION = True + +# ---------------------------------------------------------------------------- # +# Group Norm options +# ---------------------------------------------------------------------------- # +_C.MODEL.GROUP_NORM = CN() +# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS) +_C.MODEL.GROUP_NORM.DIM_PER_GP = -1 +# Number of groups in GroupNorm (-1 if using DIM_PER_GP) +_C.MODEL.GROUP_NORM.NUM_GROUPS = 16 +# GroupNorm's small constant in the denominator +_C.MODEL.GROUP_NORM.EPSILON = 1e-5 + +# ---------------------------------------------------------------------------- # +# Evo Norm options +# ---------------------------------------------------------------------------- # +_C.MODEL.EVO_NORM = CN() +# Number of groups in EvoNorm (-1 if using DIM_PER_GP) +_C.MODEL.EVO_NORM.NUM_GROUPS = 8 +# EvoNorm's small constant in the denominator +_C.MODEL.EVO_NORM.EPSILON = 1e-5 + +# ---------------------------------------------------------------------------- # +# RetinaNet Options (Follow the Detectron version) +# ---------------------------------------------------------------------------- # +_C.MODEL.RETINANET = CN() +# This is the number of foreground classes and background. +_C.MODEL.RETINANET.NUM_CLASSES = 81 +# Convolutions to use in the cls and bbox tower +# NOTE: this doesn't include the last conv for logits +_C.MODEL.RETINANET.NUM_CONVS = 4 +# During inference, #locs to select based on cls score before NMS is performed +# per FPN level +_C.MODEL.RETINANET.PRE_NMS_TOP_N = 1000 +# Prior prob for the positives at the beginning of training. This is used to set +# the bias init for the logits layer +_C.MODEL.RETINANET.PRIOR_PROB = 0.01 +# Inference cls score threshold, anchors with score > INFERENCE_TH are +# considered for inference +_C.MODEL.RETINANET.INFERENCE_TH = 0.05 +# NMS threshold used in RetinaNet +_C.MODEL.RETINANET.NMS_TH = 0.4 +_C.MODEL.RETINANET.DETECTIONS_PER_IMG = 100 + +# ---------------------------------------------------------------------------- # +# Focal Loss Options (Follow the Detectron version) +# ---------------------------------------------------------------------------- # +_C.MODEL.FOCAL = CN() +# Weight for bbox_regression loss +_C.MODEL.FOCAL.BBOX_REG_WEIGHT = 4.0 +# Smooth L1 loss beta for bbox regression +_C.MODEL.FOCAL.BBOX_REG_BETA = 0.11 +# IoU overlap ratio for labeling an anchor as positive +# Anchors with >= iou overlap are labeled positive +_C.MODEL.FOCAL.FG_IOU_THRESHOLD = 0.5 +# IoU overlap ratio for labeling an anchor as negative +# Anchors with < iou overlap are labeled negative +_C.MODEL.FOCAL.BG_IOU_THRESHOLD = 0.4 +# Focal loss parameter: alpha +_C.MODEL.FOCAL.LOSS_ALPHA = 0.25 +# Focal loss parameter: gamma +_C.MODEL.FOCAL.LOSS_GAMMA = 2.0 + +# ---------------------------------------------------------------------------- # +# FCOS Options +# ---------------------------------------------------------------------------- # +_C.MODEL.FCOS = CN() +_C.MODEL.FCOS.NUM_CLASSES = 81 # the number of classes including background +_C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128] +_C.MODEL.FCOS.PRIOR_PROB = 0.01 +_C.MODEL.FCOS.INFERENCE_TH = 0.05 +_C.MODEL.FCOS.NMS_TH = 0.6 +_C.MODEL.FCOS.PRE_NMS_TOP_N = 1000 + +# the number of convolutions used in the cls and bbox tower +_C.MODEL.FCOS.NUM_CONVS = 4 +# if use deformable conv to align features +_C.MODEL.FCOS.USE_DFCONV = False + +# if CENTER_SAMPLING_RADIUS <= 0, it will disable center sampling +_C.MODEL.FCOS.CENTER_SAMPLING_RADIUS = 0.0 +# IOU_LOSS_TYPE can be "iou", "linear_iou" or "giou" +_C.MODEL.FCOS.IOU_LOSS_TYPE = "iou" + +_C.MODEL.FCOS.NORM_REG_TARGETS = False +_C.MODEL.FCOS.CENTERNESS_ON_REG = False +_C.MODEL.FCOS.USE_GT_CENTER = False + +_C.MODEL.FCOS.DETECTIONS_PER_IMG = 100 +_C.MODEL.FCOS.USE_GN = False +_C.MODEL.FCOS.USE_BN = False + +_C.MODEL.FCOS.INFERENCE_TH_TRAIN = 0.0 +_C.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN = 3000 +_C.MODEL.FCOS.POST_NMS_TOP_N_TRAIN = 1000 + +# ---------------------------------------------------------------------------- # +# ATSS Options +# ---------------------------------------------------------------------------- # +_C.MODEL.ATSS = CN() +_C.MODEL.ATSS.NUM_CLASSES = 81 # the number of classes including background +_C.MODEL.ATSS.PRIOR_PROB = 0.01 +_C.MODEL.ATSS.INFERENCE_TH = 0.05 +_C.MODEL.ATSS.NMS_TH = 0.6 +_C.MODEL.ATSS.PRE_NMS_TOP_N = 1000 + +# the number of convolutions used in the cls and bbox tower +_C.MODEL.ATSS.NUM_CONVS = 4 +# the channels of convolutions used in the cls and bbox tower +_C.MODEL.ATSS.CHANNELS = 128 +# if use deformable conv to align features +_C.MODEL.ATSS.USE_DFCONV = False + +# topk for selecting candidate positive samples from each level +_C.MODEL.ATSS.TOPK = 9 + +# Weight for bbox_regression loss +_C.MODEL.ATSS.REG_LOSS_WEIGHT = 2.0 + +_C.MODEL.ATSS.DETECTIONS_PER_IMG = 100 +_C.MODEL.ATSS.USE_GN = False +_C.MODEL.ATSS.USE_BN = False + +_C.MODEL.ATSS.USE_DYRELU = False +_C.MODEL.ATSS.USE_SE = False + +_C.MODEL.ATSS.INFERENCE_TH_TRAIN = 0.0 +_C.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN = 3000 +_C.MODEL.ATSS.POST_NMS_TOP_N_TRAIN = 1000 +# ---------------------------------------------------------------------------- # +# DYHEAD Options +# ---------------------------------------------------------------------------- # +_C.MODEL.DYHEAD = CN() +_C.MODEL.DYHEAD.NUM_CLASSES = 81 # the number of classes including background +_C.MODEL.DYHEAD.PRIOR_PROB = 0.01 + +# the number of convolutions used in the cls and bbox tower +_C.MODEL.DYHEAD.NUM_CONVS = 4 +# the channels of convolutions used in the cls and bbox tower +_C.MODEL.DYHEAD.CHANNELS = 128 +_C.MODEL.DYHEAD.GROUPS = 1 +# if use deformable conv to align features +_C.MODEL.DYHEAD.USE_DFCONV = False + +# topk for selecting candidate positive samples from each level +_C.MODEL.DYHEAD.TOPK = 9 + +_C.MODEL.DYHEAD.SCORE_AGG = "MEAN" # MEAN or MAX, for binary focal loss score aggregation + +_C.MODEL.DYHEAD.LOG_SCALE = 0.0 # temperature (dot product) +_C.MODEL.DYHEAD.SHALLOW_LOG_SCALE = 0.0 # # temperature (shallow contrastive) + +_C.MODEL.DYHEAD.USE_GN = False +_C.MODEL.DYHEAD.USE_NSYNCBN = False +_C.MODEL.DYHEAD.USE_SYNCBN = False + +_C.MODEL.DYHEAD.USE_DYFUSE = False +_C.MODEL.DYHEAD.USE_DYRELU = False + +_C.MODEL.DYHEAD.CONV_FUNC = '' + +# CosineSimOutputLayers: https://github.com/ucbdrive/few-shot-object-detection/blob/master/fsdet/modeling/roi_heads/fast_rcnn.py#L448-L464 +_C.MODEL.DYHEAD.COSINE_SCALE = -1.0 + +_C.MODEL.DYHEAD.FUSE_CONFIG = CN() +_C.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON = False +_C.MODEL.DYHEAD.FUSE_CONFIG.TYPE = "" +_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE = 256 +_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE = 256 +_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT = 0.1 +_C.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS = 2 + +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS = False + +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT = 1.0 +_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA = 2.0 +_C.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA = 0.25 + +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM = 64 +_C.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT = 1.0 +_C.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT = 1.0 +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE = True +_C.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL = False +_C.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D = False + +_C.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT = False + +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT = False + +# Controls for +_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW = False +_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW = False +_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW = False +_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW = False +_C.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT = False + +# MLM Loss +_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES = True +_C.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD = False +_C.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD = False +_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF = 1.0 +_C.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE = False + +# Shallow Contrastive Loss (FPN) +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS = 100 +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS = False +_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM = 64 +_C.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT = 1.0 + +# Shallow Contrastive Loss (BACKBONE) +_C.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS = False + +_C.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER = False + +# use checkpoint to save memory +_C.MODEL.DYHEAD.USE_CHECKPOINT = False + +# ---------------------------------------------------------------------------- # +# RPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.RPN = CN() +_C.MODEL.RPN.USE_FPN = False +# Base RPN anchor sizes given in absolute pixels w.r.t. the scaled network input +_C.MODEL.RPN.ANCHOR_SIZES = (32, 64, 128, 256, 512) +# Stride of the feature map that RPN is attached. +# For FPN, number of strides should match number of scales +_C.MODEL.RPN.ANCHOR_STRIDE = (16,) +# RPN anchor aspect ratios +_C.MODEL.RPN.ASPECT_RATIOS = (0.5, 1.0, 2.0) +# Anchor shift away ration from the center for r,t,l,d +_C.MODEL.RPN.ANCHOR_SHIFT = (0.0, 0.0, 0.0, 0.0) +# Use center to decide anchor size +_C.MODEL.RPN.USE_RELATIVE_SIZE = False +# Remove RPN anchors that go outside the image by RPN_STRADDLE_THRESH pixels +# Set to -1 or a large value, e.g. 100000, to disable pruning anchors +_C.MODEL.RPN.STRADDLE_THRESH = 0 +# Anchor scales per octave for complex anchors +_C.MODEL.RPN.OCTAVE = 2.0 +_C.MODEL.RPN.SCALES_PER_OCTAVE = 3 +# Minimum overlap required between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD +# ==> positive RPN example) +_C.MODEL.RPN.FG_IOU_THRESHOLD = 0.7 +# Maximum overlap allowed between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD +# ==> negative RPN example) +_C.MODEL.RPN.BG_IOU_THRESHOLD = 0.3 +# Total number of RPN examples per image +_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per RPN minibatch +_C.MODEL.RPN.POSITIVE_FRACTION = 0.5 +# Number of top scoring RPN proposals to keep before applying NMS +# When FPN is used, this is *per FPN level* (not total) +_C.MODEL.RPN.PRE_NMS_TOP_N_TRAIN = 12000 +_C.MODEL.RPN.PRE_NMS_TOP_N_TEST = 6000 +# Number of top scoring RPN proposals to keep after applying NMS +_C.MODEL.RPN.POST_NMS_TOP_N_TRAIN = 2000 +_C.MODEL.RPN.POST_NMS_TOP_N_TEST = 1000 +# NMS threshold used on RPN proposals +_C.MODEL.RPN.NMS_THRESH = 0.7 +# Proposal height and width both need to be greater than RPN_MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.RPN.MIN_SIZE = 0 +# Number of top scoring RPN proposals to keep after combining proposals from +# all FPN levels +_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000 +_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000 +# Custom rpn head, empty to use default conv or separable conv +_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead" +_C.MODEL.RPN.FREEZE = False +_C.MODEL.RPN.FORCE_BOXES = False +_C.MODEL.RPN.RETURN_FUSED_FEATURES = False + +# ---------------------------------------------------------------------------- # +# ROI HEADS options +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_HEADS = CN() +_C.MODEL.ROI_HEADS.USE_FPN = False +# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) +_C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5 +# Overlap threshold for an RoI to be considered background +# (class = 0 if overlap in [0, BG_IOU_THRESHOLD)) +_C.MODEL.ROI_HEADS.BG_IOU_THRESHOLD = 0.5 +# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets +# These are empirically chosen to approximately lead to unit variance targets +_C.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS = (10., 10., 5., 5.) +# RoI minibatch size *per image* (number of regions of interest [ROIs]) +# Total number of RoIs per training minibatch = +# TRAIN.BATCH_SIZE_PER_IM * TRAIN.IMS_PER_BATCH * NUM_GPUS +# E.g., a common configuration is: 512 * 2 * 8 = 8192 +_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 +# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) +_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 + +# Only used on test mode + +# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to +# balance obtaining high recall with not having too many low precision +# detections that will slow down inference post processing steps (like NMS) +_C.MODEL.ROI_HEADS.SCORE_THRESH = 0.05 +# Overlap threshold used for non-maximum suppression (suppress boxes with +# IoU >= this threshold) +_C.MODEL.ROI_HEADS.NMS = 0.5 +# Maximum number of detections to return per image (100 is based on the limit +# established for the COCO dataset) +_C.MODEL.ROI_HEADS.DETECTIONS_PER_IMG = 100 + +_C.MODEL.ROI_BOX_HEAD = CN() +_C.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor" +_C.MODEL.ROI_BOX_HEAD.PREDICTOR = "FastRCNNPredictor" +_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,) +_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81 +# Hidden layer dimension when using an MLP for the RoI box head +_C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024 +# GN +_C.MODEL.ROI_BOX_HEAD.USE_GN = False +# Dilation +_C.MODEL.ROI_BOX_HEAD.DILATION = 1 +_C.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM = 256 +_C.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS = 4 +# Use D2 style ROIAlignV2 +_C.MODEL.ROI_BOX_HEAD.POOLER_ALIGNED = False + +_C.MODEL.ROI_MASK_HEAD = CN() +_C.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR = "ResNet50Conv5ROIFeatureExtractor" +_C.MODEL.ROI_MASK_HEAD.PREDICTOR = "MaskRCNNC4Predictor" +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_MASK_HEAD.POOLER_SCALES = (1.0 / 16,) +_C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024 +_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256) +_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True +# Whether or not resize and translate masks to the input image. +_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False +_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5 +# Dilation +_C.MODEL.ROI_MASK_HEAD.DILATION = 1 +# GN +_C.MODEL.ROI_MASK_HEAD.USE_GN = False +# HG +_C.MODEL.ROI_MASK_HEAD.HG_SCALE = 1 + +_C.MODEL.ROI_KEYPOINT_HEAD = CN() +_C.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR = "KeypointRCNNFeatureExtractor" +_C.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR = "KeypointRCNNPredictor" +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES = (1.0 / 16,) +_C.MODEL.ROI_KEYPOINT_HEAD.MLP_HEAD_DIM = 1024 +_C.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS = tuple(512 for _ in range(8)) +_C.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION = 14 +_C.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES = 17 +_C.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME = () # If left empty, use default names +_C.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True + +# ---------------------------------------------------------------------------- # +# ResNe[X]t options (ResNets = {ResNet, ResNeXt} +# Note that parts of a resnet may be used for both the backbone and the head +# These options apply to both +# ---------------------------------------------------------------------------- # +_C.MODEL.RESNETS = CN() + +_C.MODEL.RESNETS.USE_STEM3X3 = False +_C.MODEL.RESNETS.WITH_SE = False +_C.MODEL.RESNETS.USE_AVG_DOWN = False + +# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt +_C.MODEL.RESNETS.NUM_GROUPS = 1 + +# Baseline width of each group +_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64 + +# Place the stride 2 conv on the 1x1 filter +# Use True only for the original MSRA ResNet; use False for C2 and Torch models +_C.MODEL.RESNETS.STRIDE_IN_1X1 = True + +# Residual transformation function +_C.MODEL.RESNETS.TRANS_FUNC = "BottleneckWithFixedBatchNorm" +# ResNet's stem function (conv1 and pool1) +_C.MODEL.RESNETS.STEM_FUNC = "StemWithFixedBatchNorm" + +# Apply dilation in stage "res5" +_C.MODEL.RESNETS.RES5_DILATION = 1 + +_C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4 +_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 +_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 + +_C.MODEL.RESNETS.REVISION = "resnet_light" +# Deformable convolutions +_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False) +_C.MODEL.RESNETS.WITH_MODULATED_DCN = False +_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1 + +# ---------------------------------------------------------------------------- # +# Swin Transformer +# ---------------------------------------------------------------------------- # +_C.MODEL.SWINT = CN() +_C.MODEL.SWINT.EMBED_DIM = 96 +_C.MODEL.SWINT.OUT_CHANNELS = (96, 192, 384, 768) +_C.MODEL.SWINT.DEPTHS = (2, 2, 6, 2) +_C.MODEL.SWINT.NUM_HEADS = (3, 6, 12, 24) +_C.MODEL.SWINT.WINDOW_SIZE = 7 +_C.MODEL.SWINT.MLP_RATIO = 4 +_C.MODEL.SWINT.DROP_PATH_RATE = 0.2 +_C.MODEL.SWINT.APE = False +_C.MODEL.SWINT.VERSION = "v1" +_C.MODEL.SWINT.OUT_NORM = True +_C.MODEL.SWINT.LAYER_SCALE = 0 + +# ---------------------------------------------------------------------------- # +# CVT SPEC +# ---------------------------------------------------------------------------- # +_C.MODEL.SPEC = CN(new_allowed=True) + +# ---------------------------------------------------------------------------- # +# CLIP SPEC +# ---------------------------------------------------------------------------- # +_C.MODEL.CLIP = CN() +_C.MODEL.CLIP.CONTEXT_LENGTH = 256 # default 77 +_C.MODEL.CLIP.WIDTH = 512 +_C.MODEL.CLIP.LAYERS = 12 +_C.MODEL.CLIP.HEADS = 8 +_C.MODEL.CLIP.DROP_PATH = 0.0 +_C.MODEL.CLIP.TOKENIZER = "clip" +_C.MODEL.CLIP.VOCAB_SIZE = 49408 + +# ---------------------------------------------------------------------------- # +# SEARCH +# ---------------------------------------------------------------------------- # + +_C.SEARCH = CN() +_C.SEARCH.MAX_EPOCH = 20 +_C.SEARCH.SELECT_NUM = 20 +_C.SEARCH.POPULATION_NUM = 64 +_C.SEARCH.MUTATION_NUM = 24 +_C.SEARCH.CROSSOVER_NUM = 24 +_C.SEARCH.MUTATION_PROB = 0.1 + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() +_C.SOLVER.USE_AMP = False + +_C.SOLVER.MAX_ITER = 40000 +_C.SOLVER.MULTI_MAX_ITER = () # set different max epoch for different stage +_C.SOLVER.MAX_EPOCH = 0 # any epoch number>0 will overwrite max_iter +_C.SOLVER.MULTI_MAX_EPOCH = () # set different max epoch for different stage + +_C.SOLVER.OPTIMIZER = "SGD" # "ADAMW" + +_C.SOLVER.BASE_LR = 0.001 + +_C.SOLVER.LANG_LR = 0.00001 +_C.SOLVER.BACKBONE_BODY_LR_FACTOR = 1.0 + +_C.SOLVER.BIAS_LR_FACTOR = 2 +_C.SOLVER.GRAD_CLIP = 0.0 +# D2 gradient clip +_C.SOLVER.CLIP_GRADIENTS = CN() +_C.SOLVER.CLIP_GRADIENTS.ENABLED = False +_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.0 +_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model" +_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0 +_C.SOLVER.MODEL_EMA = 0.0 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.WEIGHT_DECAY = 0.0005 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0 +_C.SOLVER.WEIGHT_DECAY_NORM_FACTOR = 1.0 + +# use cosine lr to replace default multistage +_C.SOLVER.USE_COSINE = False +_C.SOLVER.MIN_LR = 0.000001 + +_C.SOLVER.GAMMA = 0.1 +_C.SOLVER.STEPS = (30000,) + +_C.SOLVER.USE_AUTOSTEP = False +_C.SOLVER.STEP_PATIENCE = 5 + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 3 +_C.SOLVER.WARMUP_ITERS = 500 +_C.SOLVER.WARMUP_METHOD = "linear" + +_C.SOLVER.CHECKPOINT_PERIOD = 2500 +_C.SOLVER.CHECKPOINT_PER_EPOCH = -1.0 +_C.SOLVER.TEST_WITH_INFERENCE = False +_C.SOLVER.AUTO_TERMINATE_PATIENCE = -1 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.SOLVER.IMS_PER_BATCH = 16 +# This is the max negative ratio allowed per batch +_C.SOLVER.MAX_NEG_PER_BATCH = 0.1 + +_C.SOLVER.SEED = 0 +_C.SOLVER.DISABLE_OUTPUT_DISTRIBUTED = False + + +_C.SOLVER.PROMPT_PROBING_LEVEL = -1.0 +# -1 means tuning the whole model; +# 1 means tuning the whole language model; 1.5 means tuning the box head as well + +_C.SOLVER.FIND_UNUSED_PARAMETERS = True +_C.SOLVER.DATASET_LENGTH = -1 # Just for logging purpose +_C.SOLVER.TUNING_HIGHLEVEL_OVERRIDE = None +_C.SOLVER.USE_EMA_FOR_MONITOR = False + +_C.SOLVER.WEIGHT_DECAY_SCHEDULE = False +_C.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO = 0.667 + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +_C.TEST.EXPECTED_RESULTS = [] +_C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4 +_C.TEST.DURING_TRAINING = False +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.TEST.IMS_PER_BATCH = 16 +# Special Test Configuration +_C.TEST.USE_MULTISCALE = False +# _C.TEST.SCALES = (400, 600, 800, 1000, 1200, 1400) +# _C.TEST.RANGES = ((96, 10000), (64, 10000), (0, 10000), (0, 10000), (0, 256), (0, 192)) +_C.TEST.SCALES = (400, 500, 600, 640, 700, 900, 1000, 1100, 1200, 1300, 1400, 1800) +_C.TEST.RANGES = ((96, 10000), (96, 10000), (64, 10000), (64, 10000), (64, 10000), (0, 10000), (0, 10000), (0, 256), (0, 256), (0, 192), (0, 192), (0, 96)) +_C.TEST.MAX_SIZE = 2500 +_C.TEST.FLIP = True +_C.TEST.SPECIAL_NMS = 'none' # ('none', 'soft-nms', 'vote', 'soft-vote') +_C.TEST.TH = 0.6 # threshold for nms or vote +_C.TEST.PRE_NMS_TOP_N = 1000 +_C.TEST.NUM_CLASSES = 81 +_C.TEST.SELECT_CLASSES = () + +_C.TEST.EVAL_TASK = "" +_C.TEST.SUBSET = -1 +_C.TEST.CHUNKED_EVALUATION = -1 +_C.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM = -1 +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +_C.OUTPUT_DIR = "OUTPUT" + +_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py") + +# TensorBoard experiment location +_C.TENSORBOARD_EXP = "OUTPUT" + + +_C.GLIPKNOW = CN() +_C.GLIPKNOW.KNOWLEDGE_FILE = "" +_C.GLIPKNOW.KNOWLEDGE_TYPE = "" +_C.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN = -1 +_C.GLIPKNOW.PARALLEL_LANGUAGE_INPUT = False +_C.GLIPKNOW.LAN_FEATURE_AGG_TYPE = "first" +_C.GLIPKNOW.GPT3_NUM = 5 +_C.GLIPKNOW.WIKI_AND_GPT3 = False \ No newline at end of file diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..be63e5715434d696cb1480c8a5b436b642808afb --- /dev/null +++ b/maskrcnn_benchmark/config/paths_catalog.py @@ -0,0 +1,447 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +"""Centralized catalog of paths.""" + +import os + + +def try_to_find(file, return_dir=False, search_path=['./DATASET', './OUTPUT', './data', './MODEL']): + if not file: + return file + + if file.startswith('catalog://'): + return file + + DATASET_PATH = ['./'] + if 'DATASET' in os.environ: + DATASET_PATH.append(os.environ['DATASET']) + DATASET_PATH += search_path + + for path in DATASET_PATH: + if os.path.exists(os.path.join(path, file)): + if return_dir: + return path + else: + return os.path.join(path, file) + + print('Cannot find {} in {}'.format(file, DATASET_PATH)) + exit(1) + + +class DatasetCatalog(object): + DATASETS = { + # pretrained grounding dataset + # mixed vg and coco + "mixed_train": { + "coco_img_dir": "coco/train2014", + "vg_img_dir": "gqa/images", + "ann_file": "mdetr_annotations/final_mixed_train.json", + }, + "mixed_train_no_coco": { + "coco_img_dir": "coco/train2014", + "vg_img_dir": "gqa/images", + "ann_file": "mdetr_annotations/final_mixed_train_no_coco.json", + }, + + # flickr30k + "flickr30k_train": { + "img_folder": "flickr30k/flickr30k_images/train", + "ann_file": "mdetr_annotations/final_flickr_separateGT_train.json", + "is_train": True + }, + "flickr30k_val": { + "img_folder": "flickr30k/flickr30k_images/val", + "ann_file": "mdetr_annotations/final_flickr_separateGT_val.json", + "is_train": False + }, + "flickr30k_test": { + "img_folder": "flickr30k/flickr30k_images/test", + "ann_file": "mdetr_annotations/final_flickr_separateGT_test.json", + "is_train": False + }, + + # refcoco + "refexp_all_val": { + "img_dir": "refcoco/train2014", + "ann_file": "mdetr_annotations/final_refexp_val.json", + "is_train": False + }, + + # gqa + "gqa_val": { + "img_dir": "gqa/images", + "ann_file": "mdetr_annotations/final_gqa_val.json", + "is_train": False + }, + + # phrasecut + "phrasecut_train": { + "img_dir": "gqa/images", + "ann_file": "mdetr_annotations/finetune_phrasecut_train.json", + "is_train": True + }, + + + # od to grounding + # coco tsv + "coco_dt_train": { + "dataset_file": "coco_dt", + "yaml_path": "coco_tsv/coco_obj.yaml", + "is_train": True, + }, + "COCO_odinw_train_8copy_dt_train": { + "dataset_file": "coco_odinw_dt", + "yaml_path": "coco_tsv/COCO_odinw_train_8copy.yaml", + "is_train": True, + }, + "COCO_odinw_val_dt_train": { + "dataset_file": "coco_odinw_dt", + "yaml_path": "coco_tsv/COCO_odinw_val.yaml", + "is_train": False, + }, + # lvis tsv + "lvisv1_dt_train": { + "dataset_file": "lvisv1_dt", + "yaml_path": "coco_tsv/LVIS_v1_train.yaml", + "is_train": True, + }, + "LVIS_odinw_train_8copy_dt_train": { + "dataset_file": "coco_odinw_dt", + "yaml_path": "coco_tsv/LVIS_odinw_train_8copy.yaml", + "is_train": True, + }, + # object365 tsv + "object365_dt_train": { + "dataset_file": "object365_dt", + "yaml_path": "Objects365/objects365_train_vgoiv6.cas2000.yaml", + "is_train": True, + }, + "object365_odinw_2copy_dt_train": { + "dataset_file": "object365_odinw_dt", + "yaml_path": "Objects365/objects365_train_odinw.cas2000_2copy.yaml", + "is_train": True, + }, + "objects365_odtsv_train": { + "dataset_file": "objects365_odtsv", + "yaml_path": "Objects365/train.cas2000.yaml", + "is_train": True, + }, + "objects365_odtsv_val": { + "dataset_file": "objects365_odtsv", + "yaml_path": "Objects365/val.yaml", + "is_train": False, + }, + + # ImagetNet OD + "imagenetod_train_odinw_2copy_dt": { + "dataset_file": "imagenetod_odinw_dt", + "yaml_path": "imagenet_od/imagenetod_train_odinw_2copy.yaml", + "is_train": True, + }, + + # OpenImage OD + "oi_train_odinw_dt": { + "dataset_file": "oi_odinw_dt", + "yaml_path": "openimages_v5c/oi_train_odinw.cas.2000.yaml", + "is_train": True, + }, + + # vg tsv + "vg_dt_train": { + "dataset_file": "vg_dt", + "yaml_path": "visualgenome/train_vgoi6_clipped.yaml", + "is_train": True, + }, + + "vg_odinw_clipped_8copy_dt_train": { + "dataset_file": "vg_odinw_clipped_8copy_dt", + "yaml_path": "visualgenome/train_odinw_clipped_8copy.yaml", + "is_train": True, + }, + "vg_vgoi6_clipped_8copy_dt_train": { + "dataset_file": "vg_vgoi6_clipped_8copy_dt", + "yaml_path": "visualgenome/train_vgoi6_clipped_8copy.yaml", + "is_train": True, + }, + + # coco json + "coco_grounding_train": { + "img_dir": "coco/train2017", + "ann_file": "coco/annotations/instances_train2017.json", + "is_train": True, + }, + + "lvis_grounding_train": { + "img_dir": "coco", + "ann_file": "coco/annotations/lvis_od_train.json" + }, + + + "lvis_val": { + "img_dir": "coco", + "ann_file": "coco/annotations/lvis_od_val.json" + }, + "coco_2017_train": { + "img_dir": "coco/train2017", + "ann_file": "coco/annotations/instances_train2017.json" + }, + "coco_2017_val": { + "img_dir": "coco/val2017", + "ann_file": "coco/annotations/instances_val2017.json" + }, + "coco_2017_test": { + "img_dir": "coco/test2017", + "ann_file": "coco/annotations/image_info_test-dev2017.json" + }, + "coco_2014_train": { + "img_dir": "coco/train2014", + "ann_file": "coco/annotations/instances_train2014.json" + }, + "coco_2014_val": { + "img_dir": "coco/val2014", + "ann_file": "coco/annotations/instances_val2014.json" + }, + "coco_2014_minival": { + "img_dir": "coco/val2014", + "ann_file": "coco/annotations/instances_minival2014.json" + }, + } + + @staticmethod + def set(name, info): + DatasetCatalog.DATASETS.update({name: info}) + + @staticmethod + def get(name): + + if name.endswith('_bg'): + attrs = DatasetCatalog.DATASETS[name] + data_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + root=os.path.join(data_dir, attrs["img_dir"]), + ann_file=os.path.join(data_dir, attrs["ann_file"]), + ) + return dict( + factory="Background", + args=args, + ) + else: + if "bing" in name.split("_"): + attrs = DatasetCatalog.DATASETS["bing_caption_train"] + else: + attrs = DatasetCatalog.DATASETS[name] + + if "voc" in name and 'split' in attrs: + data_dir = try_to_find(attrs["data_dir"], return_dir=True) + args = dict( + data_dir=os.path.join(data_dir, attrs["data_dir"]), + split=attrs["split"], + ) + return dict( + factory="PascalVOCDataset", + args=args, + ) + elif "mixed" in name: + vg_img_dir = try_to_find(attrs["vg_img_dir"], return_dir=True) + coco_img_dir = try_to_find(attrs["coco_img_dir"], return_dir=True) + ann_file = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder_coco=os.path.join(coco_img_dir, attrs["coco_img_dir"]), + img_folder_vg=os.path.join(vg_img_dir, attrs["vg_img_dir"]), + ann_file=os.path.join(ann_file, attrs["ann_file"]) + ) + return dict( + factory="MixedDataset", + args=args, + ) + elif "flickr" in name: + img_dir = try_to_find(attrs["img_folder"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_folder"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + is_train=attrs["is_train"] + ) + return dict( + factory="FlickrDataset", + args=args, + ) + elif "refexp" in name: + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + return dict( + factory="RefExpDataset", + args=args, + ) + elif "gqa" in name: + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + return dict( + factory="GQADataset", + args=args, + ) + elif "phrasecut" in name: + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + return dict( + factory="PhrasecutDetection", + args=args, + ) + elif "_caption" in name: + yaml_path = try_to_find(attrs["yaml_path"], return_dir=True) + if "no_coco" in name: + yaml_name = attrs["yaml_name_no_coco"] + else: + yaml_name = attrs["yaml_name"] + yaml_file_name = "{}.{}.yaml".format(yaml_name, name.split("_")[2]) + args = dict( + yaml_file=os.path.join(yaml_path, attrs["yaml_path"], yaml_file_name) + ) + return dict( + factory="CaptionTSV", + args=args, + ) + elif "inferencecap" in name: + yaml_file_name = try_to_find(attrs["yaml_path"]) + args = dict( + yaml_file=yaml_file_name) + return dict( + factory="CaptionTSV", + args=args, + ) + elif "pseudo_data" in name: + args = dict( + yaml_file=try_to_find(attrs["yaml_path"]) + ) + return dict( + factory="PseudoData", + args=args, + ) + elif "_dt" in name: + dataset_file = attrs["dataset_file"] + yaml_path = try_to_find(attrs["yaml_path"], return_dir=True) + args = dict( + name=dataset_file, + yaml_file=os.path.join(yaml_path, attrs["yaml_path"]), + ) + return dict( + factory="CocoDetectionTSV", + args=args, + ) + elif "_odtsv" in name: + dataset_file = attrs["dataset_file"] + yaml_path = try_to_find(attrs["yaml_path"], return_dir=True) + args = dict( + name=dataset_file, + yaml_file=os.path.join(yaml_path, attrs["yaml_path"]), + ) + return dict( + factory="ODTSVDataset", + args=args, + ) + elif "_grounding" in name: + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + return dict( + factory="CocoGrounding", + args=args, + ) + elif "lvis_evaluation" in name: + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + args = dict( + img_folder=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + return dict( + factory="LvisDetection", + args=args, + ) + else: + ann_dir = try_to_find(attrs["ann_file"], return_dir=True) + img_dir = try_to_find(attrs["img_dir"], return_dir=True) + args = dict( + root=os.path.join(img_dir, attrs["img_dir"]), + ann_file=os.path.join(ann_dir, attrs["ann_file"]), + ) + for k, v in attrs.items(): + args.update({k: os.path.join(ann_dir, v)}) + return dict( + factory="COCODataset", + args=args, + ) + + raise RuntimeError("Dataset not available: {}".format(name)) + + +class ModelCatalog(object): + S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron" + C2_IMAGENET_MODELS = { + "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl", + "MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl", + "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl", + "MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl", + "FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl", + "FAIR/20171220/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl", + } + + C2_DETECTRON_SUFFIX = "output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl" + C2_DETECTRON_MODELS = { + "35857197/e2e_faster_rcnn_R-50-C4_1x": "01_33_49.iAX0mXvW", + "35857345/e2e_faster_rcnn_R-50-FPN_1x": "01_36_30.cUF7QR7I", + "35857890/e2e_faster_rcnn_R-101-FPN_1x": "01_38_50.sNxI7sX7", + "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "06_31_39.5MIHi1fZ", + "35858791/e2e_mask_rcnn_R-50-C4_1x": "01_45_57.ZgkA7hPB", + "35858933/e2e_mask_rcnn_R-50-FPN_1x": "01_48_14.DzEQe4wC", + "35861795/e2e_mask_rcnn_R-101-FPN_1x": "02_31_37.KqyEK4tT", + "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "06_35_59.RZotkLKI", + } + + @staticmethod + def get(name): + if name.startswith("Caffe2Detectron/COCO"): + return ModelCatalog.get_c2_detectron_12_2017_baselines(name) + if name.startswith("ImageNetPretrained"): + return ModelCatalog.get_c2_imagenet_pretrained(name) + raise RuntimeError("model not present in the catalog {}".format(name)) + + @staticmethod + def get_c2_imagenet_pretrained(name): + prefix = ModelCatalog.S3_C2_DETECTRON_URL + name = name[len("ImageNetPretrained/"):] + name = ModelCatalog.C2_IMAGENET_MODELS[name] + url = "/".join([prefix, name]) + return url + + @staticmethod + def get_c2_detectron_12_2017_baselines(name): + # Detectron C2 models are stored following the structure + # prefix//2012_2017_baselines/.yaml./suffix + # we use as identifiers in the catalog Caffe2Detectron/COCO// + prefix = ModelCatalog.S3_C2_DETECTRON_URL + suffix = ModelCatalog.C2_DETECTRON_SUFFIX + # remove identification prefix + name = name[len("Caffe2Detectron/COCO/"):] + # split in and + model_id, model_name = name.split("/") + # parsing to make it match the url address from the Caffe2 models + model_name = "{}.yaml".format(model_name) + signature = ModelCatalog.C2_DETECTRON_MODELS[name] + unique_name = ".".join([model_name, signature]) + url = "/".join([prefix, model_id, "12_2017_baselines", unique_name, suffix]) + return url diff --git a/maskrcnn_benchmark/csrc/ROIAlign.h b/maskrcnn_benchmark/csrc/ROIAlign.h new file mode 100644 index 0000000000000000000000000000000000000000..2683dbf52e120eebb7b60bb2257cd3527c5a86c3 --- /dev/null +++ b/maskrcnn_benchmark/csrc/ROIAlign.h @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +at::Tensor ROIAlign_forward(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor ROIAlign_backward(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + if (grad.device().is_cuda()) { +#ifdef WITH_CUDA + return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/maskrcnn_benchmark/csrc/ROIPool.h b/maskrcnn_benchmark/csrc/ROIPool.h new file mode 100644 index 0000000000000000000000000000000000000000..9b62b2dcb8f69ac65bc1fdf0eeb5fa556539bc13 --- /dev/null +++ b/maskrcnn_benchmark/csrc/ROIPool.h @@ -0,0 +1,48 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +std::tuple ROIPool_forward(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return ROIPool_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +at::Tensor ROIPool_backward(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + if (grad.device().is_cuda()) { +#ifdef WITH_CUDA + return ROIPool_backward_cuda(grad, input, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + + diff --git a/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h new file mode 100644 index 0000000000000000000000000000000000000000..e220c12ae558a176f6b4b0a6640e724358f2ecb0 --- /dev/null +++ b/maskrcnn_benchmark/csrc/SigmoidFocalLoss.h @@ -0,0 +1,41 @@ +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +at::Tensor SigmoidFocalLoss_forward( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha) { + if (logits.device().is_cuda()) { +#ifdef WITH_CUDA + return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +at::Tensor SigmoidFocalLoss_backward( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha) { + if (logits.device().is_cuda()) { +#ifdef WITH_CUDA + return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0c061351588df7752293ed84bba1c900768e3ab8 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/ROIAlign_cpu.cpp @@ -0,0 +1,257 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "cpu/vision.h" + +// implementation taken from Caffe2 +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indeces + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void ROIAlignForward_cpu_kernel( + const int nthreads, + const T* bottom_data, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, + //int roi_cols, + T* top_data) { + //AT_ASSERT(roi_cols == 4 || roi_cols == 5); + int roi_cols = 5; + + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + // roi could have 4 or 5 columns + const T* offset_bottom_rois = bottom_rois + n * roi_cols; + int roi_batch_ind = 0; + if (roi_cols == 5) { + roi_batch_ind = offset_bottom_rois[0]; + offset_bottom_rois++; + } + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[0] * spatial_scale; + T roi_start_h = offset_bottom_rois[1] * spatial_scale; + T roi_end_w = offset_bottom_rois[2] * spatial_scale; + T roi_end_h = offset_bottom_rois[3] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); + T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + // we want to precalculate indeces and weights shared by all chanels, + // this is the key point of optimiation + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_bottom_data[pc.pos1] + + pc.w2 * offset_bottom_data[pc.pos2] + + pc.w3 * offset_bottom_data[pc.pos3] + + pc.w4 * offset_bottom_data[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + top_data[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(!input.device().is_cuda(), "input must be a CPU tensor"); + AT_ASSERTM(!rois.device().is_cuda(), "rois must be a CPU tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] { + ROIAlignForward_cpu_kernel( + output_size, + input.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.data_ptr(), + output.data_ptr()); + }); + return output; +} diff --git a/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11b7aa60fdca907352b334f142faadb46d662f99 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "cpu/vision.h" + + +template +at::Tensor nms_cpu_kernel(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor"); + AT_ASSERTM(!scores.device().is_cuda(), "scores must be a CPU tensor"); + AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); + } + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); + + auto suppressed = suppressed_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1 + 1); + auto h = std::max(static_cast(0), yy2 - yy1 + 1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr >= threshold) + suppressed[j] = 1; + } + } + return at::nonzero(suppressed_t == 0).squeeze(1); +} + +at::Tensor nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + at::Tensor result; + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] { + result = nms_cpu_kernel(dets, scores, threshold); + }); + return result; +} diff --git a/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp b/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp new file mode 100644 index 0000000000000000000000000000000000000000..423941d71e29f5b9823006d57cdf0088646586ed --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/soft_nms.cpp @@ -0,0 +1,117 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "cpu/vision.h" + + +template +std::pair soft_nms_cpu_kernel(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold, + const float sigma) { + AT_ASSERTM(!dets.device().is_cuda(), "dets must be a CPU tensor"); + AT_ASSERTM(!scores.device().is_cuda(), "scores must be a CPU tensor"); + AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return std::make_pair(at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)), + at::empty({0}, scores.options().dtype(at::kFloat).device(at::kCPU))); + } + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + auto scores_t = scores.clone(); + + at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); + auto ndets = dets.size(0); + auto inds_t = at::arange(ndets, dets.options().dtype(at::kLong).device(at::kCPU)); + + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto s = scores_t.data_ptr(); + auto inds = inds_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + for (int64_t i = 0; i < ndets; i++) { + + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto is = s[i]; + auto ii = inds[i]; + auto iarea = areas[i]; + + auto maxpos = scores_t.slice(0, i, ndets).argmax().item() + i; + + // add max box as a detection + x1[i] = x1[maxpos]; + y1[i] = y1[maxpos]; + x2[i] = x2[maxpos]; + y2[i] = y2[maxpos]; + s[i] = s[maxpos]; + inds[i] = inds[maxpos]; + areas[i] = areas[maxpos]; + + // swap ith box with position of max box + x1[maxpos] = ix1; + y1[maxpos] = iy1; + x2[maxpos] = ix2; + y2[maxpos] = iy2; + s[maxpos] = is; + inds[maxpos] = ii; + areas[maxpos] = iarea; + + ix1 = x1[i]; + iy1 = y1[i]; + ix2 = x2[i]; + iy2 = y2[i]; + iarea = areas[i]; + + // NMS iterations, note that ndets changes if detection boxes + // fall below threshold + for (int64_t j = i + 1; j < ndets; j++) { + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1 + 1); + auto h = std::max(static_cast(0), yy2 - yy1 + 1); + + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + + s[j] = s[j] * std::exp(- std::pow(ovr, 2.0) / sigma); + + // if box score falls below threshold, discard the box by + // swapping with last box update ndets + if (s[j] < threshold) { + x1[j] = x1[ndets - 1]; + y1[j] = y1[ndets - 1]; + x2[j] = x2[ndets - 1]; + y2[j] = y2[ndets - 1]; + s[j] = s[ndets - 1]; + inds[j] = inds[ndets - 1]; + areas[j] = areas[ndets - 1]; + j--; + ndets--; + } + } + } + return std::make_pair(inds_t.slice(0, 0, ndets), scores_t.slice(0, 0, ndets)); +} + +std::pair soft_nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold, + const float sigma) { + std::pair result; + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "soft_nms", [&] { + result = soft_nms_cpu_kernel(dets, scores, threshold, sigma); + }); + return result; +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/cpu/vision.h b/maskrcnn_benchmark/csrc/cpu/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..e00ef683150eb9d46d0e4f6a30f55a7230a52e93 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cpu/vision.h @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + + +at::Tensor nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold); + + +std::pair soft_nms_cpu(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold, + const float sigma); \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..9ed1a0adfd841a17d3574dee6ac703820fcfe144 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu @@ -0,0 +1,346 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__device__ T bilinear_interpolate(const T* bottom_data, + const int height, const int width, + T y, T x, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + return 0; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int) y; + int x_low = (int) x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = bottom_data[y_low * width + x_low]; + T v2 = bottom_data[y_low * width + x_high]; + T v3 = bottom_data[y_high * width + x_low]; + T v4 = bottom_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + + +template +__device__ void bilinear_interpolate_gradient( + const int height, const int width, + T y, T x, + T & w1, T & w2, T & w3, T & w4, + int & x_low, int & x_high, int & y_low, int & y_high, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int) y; + x_low = (int) x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = bottom_data[y_low * width + x_low]; + // T v2 = bottom_data[y_low * width + x_high]; + // T v3 = bottom_data[y_high * width + x_low]; + // T v4 = bottom_data[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, + const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, + w1, w2, w3, w4, + x_low, x_high, y_low, y_high, + index); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) + { + atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignBackward + + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_forward", [&] { + RoIAlignForward<<>>( + output_size, + input.contiguous().data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data_ptr(), + output.data_ptr()); + }); + THCudaCheck(cudaGetLastError()); + return output; +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIAlign_backward", [&] { + RoIAlignBackwardFeature<<>>( + grad.numel(), + grad.contiguous().data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois.contiguous().data_ptr()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..60fc9fbc55956304c7ff6b48cbf3c086029b8354 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu @@ -0,0 +1,202 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const T* bottom_rois, T* top_data, int* argmax_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (offset_bottom_data[bottom_index] > maxval) { + maxval = offset_bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff, + const int* argmax_data, const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + T* offset_bottom_diff = bottom_diff + bottom_offset; + const int* offset_argmax_data = argmax_data + top_offset; + + int argmax = offset_argmax_data[ph * pooled_width + pw]; + if (argmax != -1) { + atomicAdd( + offset_bottom_diff + argmax, + static_cast(offset_top_diff[ph * pooled_width + pw])); + + } + } +} + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIPool_forward", [&] { + RoIPoolFForward<<>>( + output_size, + input.contiguous().data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.contiguous().data_ptr(), + output.data_ptr(), + argmax.data_ptr()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + // TODO add more checks + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "ROIPool_backward", [&] { + RoIPoolFBackward<<>>( + grad.numel(), + grad.contiguous().data_ptr(), + argmax.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data_ptr(), + rois.contiguous().data_ptr()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..8aeceae0f825598cd36ea99add8da613c5e2482a --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -0,0 +1,188 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu +// Cheng-Yang Fu +// cyfu@cs.unc.edu +#include +#include + +#include +#include +#include + +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void SigmoidFocalLossForward(const int nthreads, + const T* logits, + const int* targets, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* losses) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80]; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**gamma * log(p) where + T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN)); + + // p**gamma * log(1-p) + T term2 = powf(p, gamma) * + (-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))); + + losses[i] = 0.0; + losses[i] += -c1 * term1 * zp; + losses[i] += -c2 * term2 * zn; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossForward + + +template +__global__ void SigmoidFocalLossBackward(const int nthreads, + const T* logits, + const int* targets, + const T* d_losses, + const int num_classes, + const float gamma, + const float alpha, + const int num, + T* d_logits) { + CUDA_1D_KERNEL_LOOP(i, nthreads) { + + int n = i / num_classes; + int d = i % num_classes; // current class[0~79]; + int t = targets[n]; // target class [1~80], 0 is background; + + // Decide it is positive or negative case. + T c1 = (t == (d+1)); + T c2 = (t>=0 & t != (d+1)); + + T zn = (1.0 - alpha); + T zp = (alpha); + // p = 1. / 1. + expf(-x); p = sigmoid(x) + T p = 1. / (1. + expf(-logits[i])); + + // (1-p)**g * (1 - p - g*p*log(p) + T term1 = powf((1. - p), gamma) * + (1. - p - (p * gamma * logf(max(p, FLT_MIN)))); + + // (p**g) * (g*(1-p)*log(1-p) - p) + T term2 = powf(p, gamma) * + ((-1. * logits[i] * (logits[i] >= 0) - + logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) * + (1. - p) * gamma - p); + d_logits[i] = 0.0; + d_logits[i] += -c1 * term1 * zp; + d_logits[i] += -c2 * term2 * zn; + d_logits[i] = d_logits[i] * d_losses[i]; + + } // CUDA_1D_KERNEL_LOOP +} // SigmoidFocalLossBackward + + +at::Tensor SigmoidFocalLoss_forward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.device().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.device().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + + auto losses = at::empty({num_samples, logits.size(1)}, logits.options()); + auto losses_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L)); + dim3 block(512); + + if (losses.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return losses; + } + + AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_forward", [&] { + SigmoidFocalLossForward<<>>( + losses_size, + logits.contiguous().data_ptr(), + targets.contiguous().data_ptr(), + num_classes, + gamma, + alpha, + num_samples, + losses.data_ptr()); + }); + THCudaCheck(cudaGetLastError()); + return losses; +} + + +at::Tensor SigmoidFocalLoss_backward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha) { + AT_ASSERTM(logits.device().is_cuda(), "logits must be a CUDA tensor"); + AT_ASSERTM(targets.device().is_cuda(), "targets must be a CUDA tensor"); + AT_ASSERTM(d_losses.device().is_cuda(), "d_losses must be a CUDA tensor"); + + AT_ASSERTM(logits.dim() == 2, "logits should be NxClass"); + + const int num_samples = logits.size(0); + AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes"); + + auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); + auto d_logits_size = num_samples * logits.size(1); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L)); + dim3 block(512); + + if (d_logits.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return d_logits; + } + + AT_DISPATCH_FLOATING_TYPES(logits.scalar_type(), "SigmoidFocalLoss_backward", [&] { + SigmoidFocalLossBackward<<>>( + d_logits_size, + logits.contiguous().data_ptr(), + targets.contiguous().data_ptr(), + d_losses.contiguous().data_ptr(), + num_classes, + gamma, + alpha, + num_samples, + d_logits.data_ptr()); + }); + + THCudaCheck(cudaGetLastError()); + return d_logits; +} + diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..2cdf8d61957e50d452dd230c97b5754dacd2fa0e --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu @@ -0,0 +1,691 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#include +#include +#include + + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) +{ + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) +{ + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) +{ + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) +{ + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..ee15810103a4edaf213abdb222a70249d622c0f9 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_kernel_cuda.cu @@ -0,0 +1,874 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +/* +const int CUDA_NUM_THREADS = 1024; + +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +}*/ + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..bbe22d77b49be70f174ae3f17647b09968358255 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu @@ -0,0 +1,87 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c + +// based on +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +#include +#include + +#include +#include + +#include +#include +#include + + +void DeformablePSROIPoolForward( + const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, + at::Tensor out, at::Tensor top_count, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void DeformablePSROIPoolBackwardAcc( + const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, + const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, + at::Tensor trans_grad, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out.size(0), num_bbox); + + DeformablePSROIPoolForward( + input, bbox, trans, out, top_count, batch, channels, height, width, + num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, + pooled_size, part_size, sample_per_part, trans_std); +} + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std) +{ + TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out_grad.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out_grad.size(0), num_bbox); + + DeformablePSROIPoolBackwardAcc( + out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, + channels, height, width, num_bbox, channels_trans, no_trans, + spatial_scale, output_dim, group_size, pooled_size, part_size, + sample_per_part, trans_std); +} diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..3f6c4cb22f6ecbae242e21c9530f474e709c6e90 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_kernel_cuda.cu @@ -0,0 +1,365 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ +// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu + + +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__device__ scalar_t bilinear_interp( + const scalar_t *data, + const scalar_t x, + const scalar_t y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + scalar_t dist_x = (scalar_t)(x - x1); + scalar_t dist_y = (scalar_t)(y - y1); + scalar_t value11 = data[y1 * width + x1]; + scalar_t value12 = data[y2 * width + x1]; + scalar_t value21 = data[y1 * width + x2]; + scalar_t value22 = data[y2 * width + x2]; + scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +template +__global__ void DeformablePSROIPoolForwardKernel( + const int count, + const scalar_t *bottom_data, + const scalar_t spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const scalar_t *bottom_rois, const scalar_t *bottom_trans, + const int no_trans, + const scalar_t trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + scalar_t *top_data, + scalar_t *top_count) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const scalar_t *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); + scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); + + scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); + scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); + + int part_h = floor((scalar_t)(ph) / pooled_height * part_size); + int part_w = floor((scalar_t)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + + scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + scalar_t sum = 0; + int count = 0; + int gw = floor((scalar_t)(pw)*group_size / pooled_width); + int gh = floor((scalar_t)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + scalar_t w = wstart + iw * sub_bin_size_w; + scalar_t h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? (scalar_t)(0) : sum / count; + top_count[index] = count; + } +} + +template +__global__ void DeformablePSROIPoolBackwardAccKernel( + const int count, + const scalar_t *top_diff, + const scalar_t *top_count, + const int num_rois, + const scalar_t spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff, + const scalar_t *bottom_data, + const scalar_t *bottom_rois, + const scalar_t *bottom_trans, + const int no_trans, + const scalar_t trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const scalar_t *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height); + scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width); + + scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part); + scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part); + + int part_h = floor((scalar_t)(ph) / pooled_height * part_size); + int part_w = floor((scalar_t)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std; + + scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + scalar_t diff_val = top_diff[index] / top_count[index]; + const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor((scalar_t)(pw)*group_size / pooled_width); + int gh = floor((scalar_t)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + scalar_t w = wstart + iw * sub_bin_size_w; + scalar_t h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + scalar_t dist_x = w - x0, dist_y = h - y0; + scalar_t q00 = (1 - dist_x) * (1 - dist_y); + scalar_t q01 = (1 - dist_x) * dist_y; + scalar_t q10 = dist_x * (1 - dist_y); + scalar_t q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) + { + continue; + } + scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); + } + } + } +} + +void DeformablePSROIPoolForward(const at::Tensor data, + const at::Tensor bbox, + const at::Tensor trans, + at::Tensor out, + at::Tensor top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data.scalar_type(), "deformable_psroi_pool_forward", ([&] { + const scalar_t *bottom_data = data.data_ptr(); + const scalar_t *bottom_rois = bbox.data_ptr(); + const scalar_t *bottom_trans = no_trans ? NULL : trans.data_ptr(); + scalar_t *top_data = out.data_ptr(); + scalar_t *top_count_data = top_count.data_ptr(); + + DeformablePSROIPoolForwardKernel<<>>( + count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width, + bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim, + group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} + +void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, + const at::Tensor data, + const at::Tensor bbox, + const at::Tensor trans, + const at::Tensor top_count, + at::Tensor in_grad, + at::Tensor trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + // LOG(INFO) << "DeformablePSROIPoolBackward"; + const int num_rois = num_bbox; + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] { + const scalar_t *top_diff = out_grad.data_ptr(); + const scalar_t *bottom_data = data.data_ptr(); + const scalar_t *bottom_rois = bbox.data_ptr(); + const scalar_t *bottom_trans = no_trans ? NULL : trans.data_ptr(); + scalar_t *bottom_data_diff = in_grad.data_ptr(); + scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data_ptr(); + const scalar_t *top_count_data = top_count.data_ptr(); + + DeformablePSROIPoolBackwardAccKernel<<>>( + count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width, + pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, + group_size, part_size, num_classes, channels_each_class); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/cuda/ml_nms.cu b/maskrcnn_benchmark/csrc/cuda/ml_nms.cu new file mode 100644 index 0000000000000000000000000000000000000000..cd958a0899a9e3adc69ca053170beb2b34fbd8ef --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/ml_nms.cu @@ -0,0 +1,136 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include + +#include +#include + +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +__device__ inline float devIoU(float const * const a, float const * const b) { + if (a[5] != b[5]) { + return 0.0; + } + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 6]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 6 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0]; + block_boxes[threadIdx.x * 6 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1]; + block_boxes[threadIdx.x * 6 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2]; + block_boxes[threadIdx.x * 6 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3]; + block_boxes[threadIdx.x * 6 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4]; + block_boxes[threadIdx.x * 6 + 5] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 6; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 6) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +// boxes is a N x 6 tensor +at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { + using scalar_t = float; + AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor"); + auto scores = boxes.select(1, 4); + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto boxes_sorted = boxes.index_select(0, order_t); + + int boxes_num = boxes.size(0); + + const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + + scalar_t* boxes_dev = boxes_sorted.data_ptr(); + + THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + + unsigned long long* mask_dev = NULL; + //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, + // boxes_num * col_blocks * sizeof(unsigned long long))); + + mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + + dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), + THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + ml_nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); + + std::vector mask_host(boxes_num * col_blocks); + THCudaCheck(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long *p = &mask_host[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + THCudaFree(state, mask_dev); + // TODO improve this part + return std::get<0>(order_t.index({ + keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( + order_t.device(), keep.scalar_type()) + }).sort(0, false)); +} diff --git a/maskrcnn_benchmark/csrc/cuda/nms.cu b/maskrcnn_benchmark/csrc/cuda/nms.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6221b85fa8f6b40cf498b76d6dbfc3c8438e25e --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/nms.cu @@ -0,0 +1,131 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include + +#include +#include + +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +// boxes is a N x 5 tensor +at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { + using scalar_t = float; + AT_ASSERTM(boxes.device().is_cuda(), "boxes must be a CUDA tensor"); + auto scores = boxes.select(1, 4); + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto boxes_sorted = boxes.index_select(0, order_t); + + int boxes_num = boxes.size(0); + + const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + + scalar_t* boxes_dev = boxes_sorted.data_ptr(); + + THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + + unsigned long long* mask_dev = NULL; + //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, + // boxes_num * col_blocks * sizeof(unsigned long long))); + + mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + + dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), + THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); + + std::vector mask_host(boxes_num * col_blocks); + THCudaCheck(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long *p = &mask_host[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + THCudaFree(state, mask_dev); + // TODO improve this part + return std::get<0>(order_t.index({ + keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( + order_t.device(), keep.scalar_type()) + }).sort(0, false)); +} diff --git a/maskrcnn_benchmark/csrc/cuda/vision.h b/maskrcnn_benchmark/csrc/cuda/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..16a7f644ed5798d1917d32cda0590161b6da8c64 --- /dev/null +++ b/maskrcnn_benchmark/csrc/cuda/vision.h @@ -0,0 +1,116 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + +at::Tensor SigmoidFocalLoss_forward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const int num_classes, + const float gamma, + const float alpha); + +at::Tensor SigmoidFocalLoss_backward_cuda( + const at::Tensor& logits, + const at::Tensor& targets, + const at::Tensor& d_losses, + const int num_classes, + const float gamma, + const float alpha); + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); + +at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); +at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); + +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std); + + +at::Tensor compute_flow_cuda(const at::Tensor& boxes, + const int height, + const int width); diff --git a/maskrcnn_benchmark/csrc/deform_conv.h b/maskrcnn_benchmark/csrc/deform_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..56452c18cb8677ed964ca08c9e6e68b368da39a6 --- /dev/null +++ b/maskrcnn_benchmark/csrc/deform_conv.h @@ -0,0 +1,191 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +// Interface for Python +int deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda( + input, weight, offset, output, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +int deform_conv_backward_input( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda( + input, offset, gradOutput, gradInput, gradOffset, weight, columns, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +int deform_conv_backward_parameters( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda( + input, offset, gradOutput, gradWeight, columns, ones, + kW, kH, dW, dH, padW, padH, dilationW, dilationH, + group, deformable_group, scale, im2col_step + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void modulated_deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward( + input, weight, bias, ones, offset, mask, output, columns, + kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void modulated_deform_conv_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward( + input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, + group, deformable_group, with_bias + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/deform_pool.h b/maskrcnn_benchmark/csrc/deform_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..b3379e205caa43d854447ba896ce5848ccd65c89 --- /dev/null +++ b/maskrcnn_benchmark/csrc/deform_pool.h @@ -0,0 +1,70 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +// Interface for Python +void deform_psroi_pooling_forward( + at::Tensor input, + at::Tensor bbox, + at::Tensor trans, + at::Tensor out, + at::Tensor top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_forward( + input, bbox, trans, out, top_count, + no_trans, spatial_scale, output_dim, group_size, + pooled_size, part_size, sample_per_part, trans_std + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + + +void deform_psroi_pooling_backward( + at::Tensor out_grad, + at::Tensor input, + at::Tensor bbox, + at::Tensor trans, + at::Tensor top_count, + at::Tensor input_grad, + at::Tensor trans_grad, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_psroi_pooling_cuda_backward( + out_grad, input, bbox, trans, top_count, input_grad, trans_grad, + no_trans, spatial_scale, output_dim, group_size, pooled_size, + part_size, sample_per_part, trans_std + ); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/maskrcnn_benchmark/csrc/ml_nms.h b/maskrcnn_benchmark/csrc/ml_nms.h new file mode 100644 index 0000000000000000000000000000000000000000..bb4370d0576a3280b324ae69257f41789dd2416d --- /dev/null +++ b/maskrcnn_benchmark/csrc/ml_nms.h @@ -0,0 +1,27 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +at::Tensor ml_nms(const at::Tensor& dets, + const at::Tensor& scores, + const at::Tensor& labels, + const float threshold) { + + if (dets.device().is_cuda()) { +#ifdef WITH_CUDA + // TODO raise error if not compiled with CUDA + if (dets.numel() == 0) + return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); + auto b = at::cat({dets, scores.unsqueeze(1), labels.unsqueeze(1)}, 1); + return ml_nms_cuda(b, threshold); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("CPU version not implemented"); +} diff --git a/maskrcnn_benchmark/csrc/nms.h b/maskrcnn_benchmark/csrc/nms.h new file mode 100644 index 0000000000000000000000000000000000000000..cb86028949747e215a8f5c74d768ece8937f4f81 --- /dev/null +++ b/maskrcnn_benchmark/csrc/nms.h @@ -0,0 +1,45 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + + +at::Tensor nms(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold) { + + if (dets.device().is_cuda()) { +#ifdef WITH_CUDA + // TODO raise error if not compiled with CUDA + if (dets.numel() == 0) + return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); + auto b = at::cat({dets, scores.unsqueeze(1)}, 1); + return nms_cuda(b, threshold); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + + at::Tensor result = nms_cpu(dets, scores, threshold); + return result; +} + + +std::pair soft_nms(const at::Tensor& dets, + const at::Tensor& scores, + const float threshold, + const float sigma) { + + if (dets.device().is_cuda()) { +#ifdef WITH_CUDA + AT_ERROR("Soft NMS Does Not have GPU support"); +#endif + } + + std::pair result = soft_nms_cpu(dets, scores, threshold, sigma); + + return result; +} \ No newline at end of file diff --git a/maskrcnn_benchmark/csrc/vision.cpp b/maskrcnn_benchmark/csrc/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5bd4751b67aa35f7649dd3f5b733982e38088d1 --- /dev/null +++ b/maskrcnn_benchmark/csrc/vision.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "nms.h" +#include "ml_nms.h" +#include "ROIAlign.h" +#include "ROIPool.h" +#include "SigmoidFocalLoss.h" +#include "deform_conv.h" +#include "deform_pool.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nms", &nms, "non-maximum suppression"); + m.def("ml_nms", &ml_nms, "multi-label non-maximum suppression"); + m.def("soft_nms", &soft_nms, "soft non-maximum suppression"); + m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); + m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); + m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); + m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); + m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward"); + m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward"); + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward"); + m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward"); + m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward"); + m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward"); +} diff --git a/maskrcnn_benchmark/data/__init__.py b/maskrcnn_benchmark/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0210bc1653fd56b4fcea06e22f185ffaa57e06 --- /dev/null +++ b/maskrcnn_benchmark/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_data_loader diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..14b5973b5642d9d1d99093887a49bda869d0246a --- /dev/null +++ b/maskrcnn_benchmark/data/build.py @@ -0,0 +1,489 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import bisect +import copy +import logging +import os + +import torch.utils.data +import torch.distributed as dist +from maskrcnn_benchmark.utils.comm import get_world_size +from maskrcnn_benchmark.utils.imports import import_file + +from . import datasets as D +from . import samplers + +from .collate_batch import BatchCollator, BBoxAugCollator +from .transforms import build_transforms + +from transformers import AutoTokenizer +from .datasets.duplicate_dataset import create_duplicate_dataset + +def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True, class_concat=False, extra_args={}): + """ + Arguments: + dataset_list (list[str]): Contains the names of the datasets, i.e., + coco_2014_trian, coco_2014_val, etc + transforms (callable): transforms to apply to each (image, target) sample + dataset_catalog (DatasetCatalog): contains the information on how to + construct a dataset. + is_train (bool): whether to setup the dataset for training or testing + """ + if not isinstance(dataset_list, (list, tuple)): + raise RuntimeError( + "dataset_list should be a list of strings, got {}".format(dataset_list) + ) + datasets = [] + num_category = 1 + for dataset_id, dataset_name in enumerate(dataset_list, 1): + if is_train: + dataset_name = dataset_name + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX + else: + dataset_name = dataset_name + cfg.DATASETS.TEST_DATASETNAME_SUFFIX + data = dataset_catalog.get(dataset_name) + factory = getattr(D, data["factory"]) + args = data["args"] + # for COCODataset, we want to remove images without annotations + # during training + if data["factory"] == "COCODataset": + args["remove_images_without_annotations"] = is_train + + if data["factory"] == "PascalVOCDataset": + args["use_difficult"] = not is_train + if data["factory"] in ["VGTSVDataset", "CocoDetectionTSV", "ODTSVDataset"]: + args["extra_fields"] = ["class"] + if cfg.MODEL.MASK_ON: + args["extra_fields"].append("mask") + + if data["factory"] in ["CocoGrounding", "CocoDetectionTSV", "CaptionTSV", "MixedDataset", "FlickrDataset", "RefExpDataset", "GQADataset", "PseudoData", "PhrasecutDetection"]: + # args["return_masks"] = False + args["return_masks"] = cfg.MODEL.MASK_ON + args["return_tokens"] = True + args["max_num_labels"] = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM + args["max_query_len"] = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN + + args["transforms"] = transforms + args.update(extra_args) + + if dataset_name == "flickr30k_train": + copy = cfg.DATASETS.FLICKR_COPY + elif dataset_name in ["mixed_train", "mixed_train_no_coco"]: + copy = cfg.DATASETS.MIXED_COPY + elif dataset_name == "COCO_odinw_train_8copy_dt_train": + copy = cfg.DATASETS.COCO_COPY + elif dataset_name == "LVIS_odinw_train_8copy_dt_train": + copy = cfg.DATASETS.LVIS_COPY + elif dataset_name == "object365_odinw_2copy_dt_train": + copy = cfg.DATASETS.OBJECT365_COPY + elif dataset_name == "vg_odinw_clipped_8copy_dt_train": + copy = cfg.DATASETS.VG_COPY + elif dataset_name == "vg_vgoi6_clipped_8copy_dt_train": + copy = cfg.DATASETS.VG_COPY + elif dataset_name == "imagenetod_train_odinw_2copy_dt": + copy = cfg.DATASETS.IN_COPY + elif dataset_name == "oi_train_odinw_dt": + copy = cfg.DATASETS.OI_COPY + elif is_train: + copy = cfg.DATASETS.GENERAL_COPY + elif not is_train: + copy = cfg.DATASETS.GENERAL_COPY_TEST + else: + copy = -1 # do not ever copy test + + if copy != -1: + new_factory = create_duplicate_dataset(factory) + dataset = new_factory(copy=copy, **args) + else: + # make dataset from factory + dataset = factory(**args) + + print(dataset_name, 'has the {} data points'.format(len(dataset)), data["factory"]) + + if class_concat: + category = list(dataset.contiguous_category_id_to_json_id.values()) + dataset.contiguous_category_id_to_json_id = {} + dataset.json_category_id_to_contiguous_id = {} + for id, cat in enumerate(category, start=num_category): + dataset.json_category_id_to_contiguous_id[cat] = id + dataset.contiguous_category_id_to_json_id[id] = cat + num_category += len(category) + print("Found {} #category after group {}, concating ...".format(num_category, dataset_id)) + datasets.append(dataset) + + # for testing, return a list of datasets + if not is_train: + return datasets + + # for training, concatenate all datasets into a single one + dataset = datasets[0] + if len(datasets) > 1: + dataset = D.ConcatDataset(datasets) + + return [dataset] + + +def build_dataset_by_group(dataset_list, transforms, dataset_catalog, is_train=True, class_by_group=True, + class_concat=False, extra_args={}): + """ + Arguments: + dataset_list (list[str]): Contains the names of the datasets, i.e., + coco_2014_trian, coco_2014_val, etc + transforms (callable): transforms to apply to each (image, target) sample + dataset_catalog (DatasetCatalog): contains the information on how to + construct a dataset. + is_train (bool): whether to setup the dataset for training or testing + """ + if not isinstance(dataset_list, (list, tuple)): + raise RuntimeError( + "dataset_list should be a list of strings, got {}".format(dataset_list) + ) + + num_category = 1 + grouped_datasets = [] + for group_id, group in enumerate(dataset_list, 1): + datasets = [] + for dataset_name in group: + data = dataset_catalog.get(dataset_name) + factory = getattr(D, data["factory"]) + args = data["args"] + # for COCODataset, we want to remove images without annotations + # during training + if data["factory"] == "COCODataset": + args["remove_images_without_annotations"] = is_train + if data["factory"] == "PascalVOCDataset": + args["use_difficult"] = not is_train + args["transforms"] = transforms + args.update(extra_args) + # make dataset from factory + dataset = factory(**args) + + # check if dataset is grouped by task, assume one class per task + if class_by_group and data["factory"] != "Background": + category = dataset.contiguous_category_id_to_json_id[1] + del dataset.contiguous_category_id_to_json_id[1] + dataset.json_category_id_to_contiguous_id[category] = group_id + dataset.contiguous_category_id_to_json_id[group_id] = category + + datasets.append(dataset) + + if class_concat: + for dataset in datasets: + category = list(dataset.contiguous_category_id_to_json_id.values()) + dataset.contiguous_category_id_to_json_id = {} + dataset.json_category_id_to_contiguous_id = {} + for id, cat in enumerate(category, start=num_category): + dataset.json_category_id_to_contiguous_id[cat] = id + dataset.contiguous_category_id_to_json_id[id] = cat + num_category += len(category) + print("Found {} #category after group {}, concating ...".format(num_category, group_id)) + + if is_train: + datasets = D.ConcatDataset(datasets) + + grouped_datasets.append(datasets) + + # for testing, return a list of datasets + if not is_train: + datasets = [dataset for group in grouped_datasets for dataset in group] + return datasets + if class_concat: + grouped_datasets = D.ConcatDataset(grouped_datasets) + return [grouped_datasets] + + # for training, concatenate all datasets into a single one + return grouped_datasets + + +def make_data_sampler(dataset, shuffle, distributed, num_replicas=None, rank=None, use_random_seed=True): + if distributed: + return samplers.DistributedSampler(dataset, shuffle=shuffle, num_replicas=num_replicas, rank=rank, + use_random=use_random_seed) + if shuffle: + sampler = torch.utils.data.sampler.RandomSampler(dataset) + else: + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + return sampler + + +def _quantize(x, bins): + bins = copy.copy(bins) + bins = sorted(bins) + quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) + return quantized + + +def _compute_aspect_ratios(dataset): + aspect_ratios = [] + for i in range(len(dataset)): + img_info = dataset.get_img_info(i) + aspect_ratio = float(img_info["height"]) / float(img_info["width"]) + aspect_ratios.append(aspect_ratio) + return aspect_ratios + + +def make_batch_data_sampler( + dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, drop_last=False +): + if aspect_grouping: + if not isinstance(aspect_grouping, (list, tuple)): + aspect_grouping = [aspect_grouping] + aspect_ratios = _compute_aspect_ratios(dataset) + group_ids = _quantize(aspect_ratios, aspect_grouping) + batch_sampler = samplers.GroupedBatchSampler( + sampler, group_ids, images_per_batch, drop_uneven=drop_last + ) + else: + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, images_per_batch, drop_last=drop_last + ) + if num_iters is not None: + batch_sampler = samplers.IterationBasedBatchSampler( + batch_sampler, num_iters, start_iter + ) + return batch_sampler + +def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0): + num_gpus = num_replicas or get_world_size() + + if is_train: + images_per_batch = cfg.SOLVER.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number " + "of GPUs ({}) used.".format(images_per_batch, num_gpus) + images_per_gpu = images_per_batch // num_gpus + shuffle = True + num_iters = cfg.SOLVER.MAX_ITER + else: + images_per_batch = cfg.TEST.IMS_PER_BATCH + assert ( + images_per_batch % num_gpus == 0 + ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number " + "of GPUs ({}) used.".format(images_per_batch, num_gpus) + images_per_gpu = images_per_batch // num_gpus + shuffle = False if not is_distributed else True + num_iters = None + start_iter = 0 + + if images_per_gpu > 1: + logger = logging.getLogger(__name__) + logger.warning( + "When using more than one image per GPU you may encounter " + "an out-of-memory (OOM) error if your GPU does not have " + "sufficient memory. If this happens, you can reduce " + "SOLVER.IMS_PER_BATCH (for training) or " + "TEST.IMS_PER_BATCH (for inference). For training, you must " + "also adjust the learning rate and schedule length according " + "to the linear scaling rule. See for example: " + "https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14" + ) + + # group images which have similar aspect ratio. In this case, we only + # group in two cases: those with width / height > 1, and the other way around, + # but the code supports more general grouping strategy + aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] + + paths_catalog = import_file( + "maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True + ) + + DatasetCatalog = paths_catalog.DatasetCatalog + if len(cfg.DATASETS.REGISTER) > 0: + for new_dataset in cfg.DATASETS.REGISTER: + # img_dir = cfg.DATASETS.REGISTER[new_dataset]["img_dir"] + # if "ann_file" in cfg.DATASETS.REGISTER[new_dataset]: + # ann_file = cfg.DATASETS.REGISTER[new_dataset]["ann_file"] + # else: + # ann_file = None + attrs = dict(cfg.DATASETS.REGISTER[new_dataset]) + if is_train: + new_dataset = new_dataset + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX + else: + new_dataset = new_dataset + cfg.DATASETS.TEST_DATASETNAME_SUFFIX + DatasetCatalog.set(new_dataset, attrs) + + + dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST + + # Haotian: expand bing dataset + if "bing_caption_train" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0: + dataset_list = list(dataset_list) + dataset_list.remove("bing_caption_train") + for bing_index in cfg.DATASETS.BING_INDEX_LIST: + dataset_list.insert(len(dataset_list), "bing_caption_{}_train".format(bing_index)) + dataset_list = tuple(dataset_list) + + if "bing_caption_train_no_coco" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0: + dataset_list = list(dataset_list) + dataset_list.remove("bing_caption_train_no_coco") + for bing_index in cfg.DATASETS.BING_INDEX_LIST: + dataset_list.insert(len(dataset_list), "bing_caption_{}_train_no_coco".format(bing_index)) + dataset_list = tuple(dataset_list) + + print("The combined datasets are: {}.".format(dataset_list)) + + transforms = None if not is_train and cfg.TEST.USE_MULTISCALE else build_transforms(cfg, is_train) + + extra_args = {} + if is_train and cfg.DATASETS.USE_CROWD: + extra_args['ignore_crowd'] = False + if is_train and cfg.DATASETS.MAX_BOX > 0: + extra_args['max_box'] = cfg.DATASETS.MAX_BOX + if is_train and cfg.DATASETS.FEW_SHOT>0: + extra_args['few_shot'] = cfg.DATASETS.FEW_SHOT + if is_train and cfg.DATASETS.SHUFFLE_SEED != 0: + extra_args['shuffle_seed'] = cfg.DATASETS.SHUFFLE_SEED + + # od to grounding + if is_train and cfg.DATASETS.RANDOM_SAMPLE_NEG > 0: + extra_args['random_sample_negative'] = cfg.DATASETS.RANDOM_SAMPLE_NEG + if is_train and cfg.DATASETS.ADD_DET_PROMPT: + extra_args["add_detection_prompt"] = True + if is_train and cfg.DATASETS.USE_OD_AUG: + extra_args["use_od_data_aug"] = True + if is_train and cfg.DATASETS.DISABLE_SHUFFLE: + extra_args["disable_shuffle"] = True + if cfg.DATASETS.ONE_HOT: + extra_args["one_hot"] = True + if is_train and len(cfg.DATASETS.PROMPT_VERSION) > 0: + extra_args["prompt_engineer_version"] = cfg.DATASETS.PROMPT_VERSION + if is_train and len(cfg.DATASETS.CONTROL_PROB) == 4: + extra_args["control_probabilities"] = cfg.DATASETS.CONTROL_PROB + if is_train and cfg.DATASETS.DISABLE_CLIP_TO_IMAGE: + extra_args["disable_clip_to_image"] = cfg.DATASETS.DISABLE_CLIP_TO_IMAGE + if is_train and cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT: + extra_args["no_minus_one_for_one_hot"] = cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT + if is_train: + extra_args["separation_tokens"] = cfg.DATASETS.SEPARATION_TOKENS + # caption + if is_train and cfg.DATASETS.CAPTION_MIN_BOX > 0: + extra_args["caption_min_box"] = cfg.DATASETS.CAPTION_MIN_BOX + if is_train and cfg.DATASETS.REPLACE_CLEAN_LABEL: + extra_args["replace_clean_label"] = True + if is_train and cfg.DATASETS.FURTHER_SCREEN: + extra_args["further_screen"] = True + if is_train and cfg.DATASETS.CAPTION_CONF > 0.0: + extra_args["caption_conf"] = cfg.DATASETS.CAPTION_CONF + if is_train: + extra_args["caption_nms"] = cfg.DATASETS.CAPTION_NMS + if is_train and cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER > 0: + extra_args["pack_random_caption_number"] = cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER + if is_train and cfg.DATASETS.INFERENCE_CAPTION: + extra_args["inference_caption"] = True + if is_train and cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA > 0: + extra_args["sample_negative_for_grounding_data"] = cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA + if is_train and cfg.DATASETS.RANDOM_PACK_PROB > 0: + extra_args["random_pack_prob"] = cfg.DATASETS.RANDOM_PACK_PROB + if is_train and cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY > 0: + extra_args["no_random_pack_probability"] = cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY + if is_train: + extra_args["safeguard_positive_caption"] = cfg.DATASETS.SAFEGUARD_POSITIVE_CAPTION + if is_train: + extra_args["local_debug"] = cfg.DATASETS.LOCAL_DEBUG + if is_train: + extra_args["no_mask_for_od"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD + if is_train: + extra_args["no_mask_for_gold"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD + if is_train: + extra_args["mlm_obj_for_only_positive"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE + if cfg.DATASETS.OVERRIDE_CATEGORY and cfg.DATASETS.USE_OVERRIDE_CATEGORY: + extra_args["override_category"] = cfg.DATASETS.OVERRIDE_CATEGORY + if is_train: + extra_args["caption_format_version"] = cfg.DATASETS.CAPTION_FORMAT_VERSION + if is_train: + extra_args["special_safeguard_for_coco_grounding"] = cfg.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING + if is_train: + extra_args["diver_box_for_vqa"] = cfg.DATASETS.DIVER_BOX_FOR_VQA + extra_args["caption_prompt"] = cfg.DATASETS.CAPTION_PROMPT + extra_args["use_caption_prompt"] = cfg.DATASETS.USE_CAPTION_PROMPT + + # extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) + if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": + # extra_args['tokenizer'] = build_tokenizer("clip") + from transformers import CLIPTokenizerFast + if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True, mask_token='ðŁĴij') + else: + extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) + else: + extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) + + if isinstance(dataset_list[0], (tuple, list)): + datasets = build_dataset_by_group(dataset_list, transforms, DatasetCatalog, is_train, + class_by_group=cfg.DATASETS.ALTERNATIVE_TRAINING, + class_concat=cfg.DATASETS.CLASS_CONCAT, + extra_args=extra_args) + else: + datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, is_train, + class_concat=cfg.DATASETS.CLASS_CONCAT, + extra_args=extra_args) + + data_loaders = [] + for di, dataset in enumerate(datasets): + if is_train and cfg.SOLVER.MAX_EPOCH > 0: + num_iters = cfg.SOLVER.MAX_EPOCH * len(dataset) // cfg.SOLVER.IMS_PER_BATCH + print("Number of iterations are {}".format(num_iters)) + cfg.defrost() + cfg.SOLVER.MAX_ITER = num_iters + cfg.SOLVER.DATASET_LENGTH = len(dataset) + cfg.freeze() + if is_train and cfg.SOLVER.MULTI_MAX_EPOCH: + num_iters = None + cfg.defrost() + cfg.SOLVER.MULTI_MAX_ITER += (cfg.SOLVER.MULTI_MAX_EPOCH[di] * len(dataset) // cfg.SOLVER.IMS_PER_BATCH,) + cfg.freeze() + + if is_train and cfg.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE: + from .datasets.custom_distributed_sampler import DistributedSamplerChunkByNode + chunk_or_not = [] + for i in dataset_list: + if "bing_caption" in i: + chunk_or_not.append(True) + else: + chunk_or_not.append(False) + assert(len(chunk_or_not) == len(dataset.datasets)) + ''' + If we are training on 4 nodes, each with 8 GPUs + ''' + num_nodes = int(os.getenv('NODE_COUNT', os.getenv('OMPI_COMM_WORLD_SIZE', 1))) + local_size = cfg.num_gpus//num_nodes + node_rank = int(os.getenv('NODE_RANK', os.getenv('OMPI_COMM_WORLD_RANK', 0))) + local_rank = cfg.local_rank + sampler = DistributedSamplerChunkByNode( + dataset = dataset, + all_datasets = dataset.datasets, # Assumming dataset is a ConcateDataset instance, + chunk_or_not = chunk_or_not, + num_replicas = cfg.num_gpus, # total GPU number, e.g., 32 + rank = dist.get_rank(), # Global Rank, e.g., 0~31 + node_rank = node_rank, # Node Rank, e.g., 0~3 + node_number = num_nodes, # how many node e.g., 4 + process_num_per_node = local_size, # e.g., 8 + rank_within_local_node = local_rank, # e.g., 0~7 + ) + else: + sampler = make_data_sampler(dataset, shuffle, is_distributed, num_replicas=num_replicas, rank=rank, + use_random_seed=cfg.DATALOADER.USE_RANDOM_SEED) + batch_sampler = make_batch_data_sampler( + dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, drop_last=is_train + ) + collator = BBoxAugCollator() if not is_train and cfg.TEST.USE_MULTISCALE else BatchCollator( + cfg.DATALOADER.SIZE_DIVISIBILITY) + num_workers = cfg.DATALOADER.NUM_WORKERS + data_loader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=collator, + ) + data_loaders.append(data_loader) + if is_train and cfg.SOLVER.MULTI_MAX_EPOCH: + cfg.defrost() + cfg.SOLVER.MULTI_MAX_ITER += ( + cfg.SOLVER.MULTI_MAX_EPOCH[-1] * min([len(dataset) // cfg.SOLVER.IMS_PER_BATCH for dataset in datasets]),) + cfg.freeze() + + if is_train and not cfg.DATASETS.ALTERNATIVE_TRAINING and not cfg.DATASETS.MULTISTAGE_TRAINING: + # during training, a single (possibly concatenated) data_loader is returned + assert len(data_loaders) == 1 + return data_loaders[0] + + return data_loaders diff --git a/maskrcnn_benchmark/data/collate_batch.py b/maskrcnn_benchmark/data/collate_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..bf08fd9b5fd67ef41e659bd6df8ae20933359435 --- /dev/null +++ b/maskrcnn_benchmark/data/collate_batch.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from maskrcnn_benchmark.structures.image_list import to_image_list + +import pdb +class BatchCollator(object): + """ + From a list of samples from the dataset, + returns the batched images and targets. + This should be passed to the DataLoader + """ + + def __init__(self, size_divisible=0): + self.size_divisible = size_divisible + + def __call__(self, batch): + transposed_batch = list(zip(*batch)) + + images = to_image_list(transposed_batch[0], self.size_divisible) + targets = transposed_batch[1] + img_ids = transposed_batch[2] + positive_map = None + positive_map_eval = None + greenlight_map = None + + if isinstance(targets[0], dict): + return images, targets, img_ids, positive_map, positive_map_eval + + if "greenlight_map" in transposed_batch[1][0].fields(): + greenlight_map = torch.stack([i.get_field("greenlight_map") for i in transposed_batch[1]], dim = 0) + + if "positive_map" in transposed_batch[1][0].fields(): + # we batch the positive maps here + # Since in general each batch element will have a different number of boxes, + # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. + max_len = max([v.get_field("positive_map").shape[1] for v in transposed_batch[1]]) + nb_boxes = sum([v.get_field("positive_map").shape[0] for v in transposed_batch[1]]) + batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) + cur_count = 0 + for v in transposed_batch[1]: + cur_pos = v.get_field("positive_map") + batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos + cur_count += len(cur_pos) + + assert cur_count == len(batched_pos_map) + positive_map = batched_pos_map.float() + + + if "positive_map_eval" in transposed_batch[1][0].fields(): + # we batch the positive maps here + # Since in general each batch element will have a different number of boxes, + # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. + max_len = max([v.get_field("positive_map_eval").shape[1] for v in transposed_batch[1]]) + nb_boxes = sum([v.get_field("positive_map_eval").shape[0] for v in transposed_batch[1]]) + batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) + cur_count = 0 + for v in transposed_batch[1]: + cur_pos = v.get_field("positive_map_eval") + batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos + cur_count += len(cur_pos) + + assert cur_count == len(batched_pos_map) + # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) + positive_map_eval = batched_pos_map.float() + + + return images, targets, img_ids, positive_map, positive_map_eval, greenlight_map + + +class BBoxAugCollator(object): + """ + From a list of samples from the dataset, + returns the images and targets. + Images should be converted to batched images in `im_detect_bbox_aug` + """ + + def __call__(self, batch): + # return list(zip(*batch)) + transposed_batch = list(zip(*batch)) + + images = transposed_batch[0] + targets = transposed_batch[1] + img_ids = transposed_batch[2] + positive_map = None + positive_map_eval = None + + if isinstance(targets[0], dict): + return images, targets, img_ids, positive_map, positive_map_eval + + return images, targets, img_ids, positive_map, positive_map_eval + + + diff --git a/maskrcnn_benchmark/data/datasets/__init__.py b/maskrcnn_benchmark/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1136cd63ccdf0cf6207226ab7fad98181e3aa0dc --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .coco import COCODataset +from .voc import PascalVOCDataset +from .concat_dataset import ConcatDataset +from .background import Background +from .tsv import TSVDataset, ODTSVDataset + +from .modulated_coco import ModulatedDataset, CocoDetection, CocoGrounding +from .flickr import FlickrDataset +from .refexp import RefExpDataset +from .mixed import MixedDataset +from .gqa import GQADataset + +from .coco_dt import CocoDetectionTSV +from .caption import CaptionTSV +from .lvis import LvisDetection +from .pseudo_data import PseudoData +from .phrasecut import PhrasecutDetection + +__all__ = ["COCODataset", "TSVDataset", "ODTSVDataset", "ConcatDataset", "PascalVOCDataset", "Background", + "ModulatedDataset", "MixedDataset", "CocoDetection", "FlickrDataset", "RefExpDataset", "GQADataset", + "CocoDetectionTSV", "CocoGrounding", "CaptionTSV", "LvisDetection", "PseudoData", "PhrasecutDetection" + ] diff --git a/maskrcnn_benchmark/data/datasets/background.py b/maskrcnn_benchmark/data/datasets/background.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2051da45b046fc3481e6116d75769fbe42be0d --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/background.py @@ -0,0 +1,53 @@ +import os +import os.path +import json +from PIL import Image + +import torch +import torchvision +import torch.utils.data as data +from maskrcnn_benchmark.structures.bounding_box import BoxList + +class Background(data.Dataset): + """ Background + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + """ + + def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None): + self.root = root + + with open(ann_file, 'r') as f: + self.ids = json.load(f)['images'] + self.transform = transforms + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + im_info = self.ids[index] + path = im_info['file_name'] + fp = os.path.join(self.root, path) + + img = Image.open(fp).convert('RGB') + if self.transform is not None: + img, _ = self.transform(img, None) + null_target = BoxList(torch.zeros((0,4)), (img.shape[-1], img.shape[-2])) + null_target.add_field('labels', torch.zeros(0)) + + return img, null_target, index + + def __len__(self): + return len(self.ids) + + def get_img_info(self, index): + im_info = self.ids[index] + return im_info \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/box_label_loader.py b/maskrcnn_benchmark/data/datasets/box_label_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..1d40c758aacd1915f97f163f4736ecd86311fcc0 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/box_label_loader.py @@ -0,0 +1,251 @@ +import torch +import numpy as np +import math +import base64 +import collections +import pycocotools.mask as mask_utils + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask + + +class LabelLoader(object): + def __init__(self, labelmap, extra_fields=(), filter_duplicate_relations=False, ignore_attr=None, ignore_rel=None, + mask_mode="poly"): + self.labelmap = labelmap + self.extra_fields = extra_fields + self.supported_fields = ["class", "conf", "attributes", 'scores_all', 'boxes_all', 'feature', "mask"] + self.filter_duplicate_relations = filter_duplicate_relations + self.ignore_attr = set(ignore_attr) if ignore_attr != None else set() + self.ignore_rel = set(ignore_rel) if ignore_rel != None else set() + assert mask_mode == "poly" or mask_mode == "mask" + self.mask_mode = mask_mode + + def __call__(self, annotations, img_size, remove_empty=False, load_fields=None): + boxes = [obj["rect"] for obj in annotations] + boxes = torch.as_tensor(boxes).reshape(-1, 4) + target = BoxList(boxes, img_size, mode="xyxy") + + if load_fields is None: + load_fields = self.extra_fields + + for field in load_fields: + assert field in self.supported_fields, "Unsupported field {}".format(field) + if field == "class": + classes = self.add_classes(annotations) + target.add_field("labels", classes) + elif field == "conf": + confidences = self.add_confidences(annotations) + target.add_field("scores", confidences) + elif field == "attributes": + attributes = self.add_attributes(annotations) + target.add_field("attributes", attributes) + elif field == "scores_all": + scores_all = self.add_scores_all(annotations) + target.add_field("scores_all", scores_all) + elif field == "boxes_all": + boxes_all = self.add_boxes_all(annotations) + target.add_field("boxes_all", boxes_all) + elif field == "feature": + features = self.add_features(annotations) + target.add_field("box_features", features) + elif field == "mask": + masks, is_box_mask = self.add_masks(annotations, img_size) + target.add_field("masks", masks) + target.add_field("is_box_mask", is_box_mask) + + target = target.clip_to_image(remove_empty=remove_empty) + return target + + def get_box_mask(self, rect, img_size): + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + if self.mask_mode == "poly": + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + elif self.mask_mode == "mask": + # note the order of height/width order in mask is opposite to image + mask = np.zeros([img_size[1], img_size[0]], dtype=np.uint8) + mask[math.floor(y1):math.ceil(y2), math.floor(x1):math.ceil(x2)] = 255 + encoded_mask = mask_utils.encode(np.asfortranarray(mask)) + encoded_mask["counts"] = encoded_mask["counts"].decode("utf-8") + return encoded_mask + + def add_masks(self, annotations, img_size): + masks = [] + is_box_mask = [] + for obj in annotations: + if "mask" in obj: + masks.append(obj["mask"]) + is_box_mask.append(0) + else: + masks.append(self.get_box_mask(obj["rect"], img_size)) + is_box_mask.append(1) + masks = SegmentationMask(masks, img_size, mode=self.mask_mode) + is_box_mask = torch.tensor(is_box_mask) + return masks, is_box_mask + + def add_classes(self, annotations): + class_names = [obj["class"] for obj in annotations] + classes = [None] * len(class_names) + for i in range(len(class_names)): + classes[i] = self.labelmap['class_to_ind'][class_names[i]] + return torch.tensor(classes) + + def add_confidences(self, annotations): + confidences = [] + for obj in annotations: + if "conf" in obj: + confidences.append(obj["conf"]) + else: + confidences.append(1.0) + return torch.tensor(confidences) + + def add_attributes(self, annotations): + # the maximal number of attributes per object is 16 + attributes = [[0] * 16 for _ in range(len(annotations))] + for i, obj in enumerate(annotations): + for j, attr in enumerate(obj["attributes"]): + attributes[i][j] = self.labelmap['attribute_to_ind'][attr] + return torch.tensor(attributes) + + def add_features(self, annotations): + features = [] + for obj in annotations: + features.append(np.frombuffer(base64.b64decode(obj['feature']), np.float32)) + return torch.tensor(features) + + def add_scores_all(self, annotations): + scores_all = [] + for obj in annotations: + scores_all.append(np.frombuffer(base64.b64decode(obj['scores_all']), np.float32)) + return torch.tensor(scores_all) + + def add_boxes_all(self, annotations): + boxes_all = [] + for obj in annotations: + boxes_all.append(np.frombuffer(base64.b64decode(obj['boxes_all']), np.float32).reshape(-1, 4)) + return torch.tensor(boxes_all) + + def relation_loader(self, relation_annos, target): + if self.filter_duplicate_relations: + # Filter out dupes! + all_rel_sets = collections.defaultdict(list) + for triplet in relation_annos: + all_rel_sets[(triplet['subj_id'], triplet['obj_id'])].append(triplet) + relation_annos = [np.random.choice(v) for v in all_rel_sets.values()] + + # get M*M pred_labels + relation_triplets = [] + relations = torch.zeros([len(target), len(target)], dtype=torch.int64) + for i in range(len(relation_annos)): + if len(self.ignore_rel) != 0 and relation_annos[i]['class'] in self.ignore_rel: + continue + subj_id = relation_annos[i]['subj_id'] + obj_id = relation_annos[i]['obj_id'] + predicate = self.labelmap['relation_to_ind'][relation_annos[i]['class']] + relations[subj_id, obj_id] = predicate + relation_triplets.append([subj_id, obj_id, predicate]) + + relation_triplets = torch.tensor(relation_triplets) + target.add_field("relation_labels", relation_triplets) + target.add_field("pred_labels", relations) + return target + + +class BoxLabelLoader(object): + def __init__(self, labelmap, extra_fields=(), ignore_attrs=(), + mask_mode="poly"): + self.labelmap = labelmap + self.extra_fields = extra_fields + self.ignore_attrs = ignore_attrs + assert mask_mode == "poly" or mask_mode == "mask" + self.mask_mode = mask_mode + self.all_fields = ["class", "mask", "confidence", + "attributes_encode", "IsGroupOf", "IsProposal"] + + def __call__(self, annotations, img_size, remove_empty=True): + boxes = [obj["rect"] for obj in annotations] + boxes = torch.as_tensor(boxes).reshape(-1, 4) + target = BoxList(boxes, img_size, mode="xyxy") + + for field in self.extra_fields: + assert field in self.all_fields, "Unsupported field {}".format(field) + if field == "class": + classes = self.add_classes_with_ignore(annotations) + target.add_field("labels", classes) + elif field == "mask": + masks, is_box_mask = self.add_masks(annotations, img_size) + target.add_field("masks", masks) + target.add_field("is_box_mask", is_box_mask) + elif field == "confidence": + confidences = self.add_confidences(annotations) + target.add_field("confidences", confidences) + elif field == "attributes_encode": + attributes = self.add_attributes(annotations) + target.add_field("attributes", attributes) + elif field == "IsGroupOf": + is_group = [1 if 'IsGroupOf' in obj and obj['IsGroupOf'] == 1 else 0 + for obj in annotations] + target.add_field("IsGroupOf", torch.tensor(is_group)) + elif field == "IsProposal": + is_proposal = [1 if "IsProposal" in obj and obj['IsProposal'] == 1 else 0 + for obj in annotations] + target.add_field("IsProposal", torch.tensor(is_proposal)) + + target = target.clip_to_image(remove_empty=remove_empty) + return target + + def add_classes_with_ignore(self, annotations): + class_names = [obj["class"] for obj in annotations] + classes = [None] * len(class_names) + if self.ignore_attrs: + for i, obj in enumerate(annotations): + if any([obj[attr] for attr in self.ignore_attrs if attr in obj]): + classes[i] = -1 + for i, cls in enumerate(classes): + if cls != -1: + classes[i] = self.labelmap[class_names[i]] + 1 # 0 is saved for background + return torch.tensor(classes) + + def add_masks(self, annotations, img_size): + masks = [] + is_box_mask = [] + for obj in annotations: + if "mask" in obj: + masks.append(obj["mask"]) + is_box_mask.append(0) + else: + masks.append(self.get_box_mask(obj["rect"], img_size)) + is_box_mask.append(1) + masks = SegmentationMask(masks, img_size, mode=self.mask_mode) + is_box_mask = torch.tensor(is_box_mask) + return masks, is_box_mask + + def get_box_mask(self, rect, img_size): + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + if self.mask_mode == "poly": + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + elif self.mask_mode == "mask": + # note the order of height/width order in mask is opposite to image + mask = np.zeros([img_size[1], img_size[0]], dtype=np.uint8) + mask[math.floor(y1):math.ceil(y2), math.floor(x1):math.ceil(x2)] = 255 + encoded_mask = mask_utils.encode(np.asfortranarray(mask)) + encoded_mask["counts"] = encoded_mask["counts"].decode("utf-8") + return encoded_mask + + def add_confidences(self, annotations): + confidences = [] + for obj in annotations: + if "confidence" in obj: + confidences.append(obj["confidence"]) + elif "conf" in obj: + confidences.append(obj["conf"]) + else: + confidences.append(1.0) + return torch.tensor(confidences) + + def add_attributes(self, annotations): + # we know that the maximal number of attributes per object is 16 + attributes = [[0] * 16 for _ in range(len(annotations))] + for i, obj in enumerate(annotations): + attributes[i][:len(obj["attributes_encode"])] = obj["attributes_encode"] + return torch.tensor(attributes) diff --git a/maskrcnn_benchmark/data/datasets/caption.py b/maskrcnn_benchmark/data/datasets/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a5ec88ab02e589a0333a5d65f907b545d710ed --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/caption.py @@ -0,0 +1,279 @@ +import torch +import torch.distributed as dist +import time +from torchvision.ops import nms +import random +import numpy as np +from PIL import Image, ImageDraw +import pdb +from maskrcnn_benchmark.structures.bounding_box import BoxList +from .modulated_coco import ConvertCocoPolysToMask +from .tsv import ODTSVDataset, TSVYamlDataset +from .od_to_grounding import sanity_check_target_after_processing + +class CaptionTSV(TSVYamlDataset): + def __init__(self, + yaml_file, + transforms, + return_tokens, + return_masks, + tokenizer, + caption_min_box=1, + replace_clean_label=False, + further_screen=False, + caption_conf=0.5, + caption_nms=-1, + pack_random_caption_number=0, + inference_caption=False, + sample_negative_for_grounding_data=-1, + random_pack_prob=-1.0, + no_random_pack_probability=0.0, + safeguard_positive_caption=True, + mlm_obj_for_only_positive=False, + caption_format_version="v1", + local_debug=False, + max_query_len=256, + **kwargs + ): + super(CaptionTSV, self).__init__(yaml_file, None, replace_clean_label) + self.yaml_file = yaml_file + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(return_masks=return_masks, + return_tokens=return_tokens, + tokenizer=tokenizer, + max_query_len=max_query_len) + self.tokenizer = tokenizer + self.caption_min_box = caption_min_box + self.replace_clean_label = replace_clean_label + self.further_screen = further_screen + self.pack_random_caption_number = pack_random_caption_number + self.caption_format_version = caption_format_version + + self.caption_conf = caption_conf + self.caption_nms = caption_nms + self.inference_caption = inference_caption + self.sample_negative_for_grounding_data = sample_negative_for_grounding_data + self.random_pack_prob = random_pack_prob + self.no_random_pack_probability = no_random_pack_probability + self.safeguard_positive_caption = safeguard_positive_caption + self.mlm_obj_for_only_positive = mlm_obj_for_only_positive + try: + self.rank = dist.get_rank() + except: + self.rank = 0 + + def __len__(self): + return super(CaptionTSV, self).__len__() + + def pack_caption(self, positive_caption, negative_captions, original_tokens_positive): + if len(negative_captions) == 0: + return positive_caption, original_tokens_positive, [(0, len(positive_caption))] + if self.safeguard_positive_caption: + length_of_each_caption = [] + for caption in negative_captions + [positive_caption]: + tokenized = self.tokenizer(caption, return_tensors="pt") + length_of_each_caption.append(tokenized.input_ids.size(-1)) + max_length = self.max_query_len - length_of_each_caption[-1] + indexes = list(range(len(negative_captions))) + random.shuffle(indexes) + new_caption_list = [positive_caption] + for i in indexes: + if length_of_each_caption[i] < max_length: + new_caption_list.append(negative_captions[i]) + max_length -= length_of_each_caption[i] + else: + new_caption_list = [positive_caption] + negative_captions + random.shuffle(new_caption_list) + + new_caption = '' + + for i in new_caption_list: + if i == positive_caption: + start_position = len(new_caption) + new_caption += i + if not i.endswith("."): + new_caption += "." + new_caption += " " + + # shift the token positions the boxes are aligned to + for index, i in enumerate(original_tokens_positive): + original_tokens_positive[index] = [tuple(j) for j in i] + for i in original_tokens_positive: + for index, j in enumerate(i): + i[index] = (j[0] + start_position, j[1] + start_position) + + return new_caption, original_tokens_positive, [(start_position, start_position + len(positive_caption))] + + def __get_negative_captions__(self, idx, negative_size=7): + negative_captions = [] + for i in range(negative_size): + img, anno, _, scale = super(CaptionTSV, self).__getitem__(np.random.choice(len(self))) + caption = anno["caption"] + negative_captions.append(caption) + + return negative_captions + + def __getitem__(self, idx): + try: + img, anno, _, scale = super(CaptionTSV, self).__getitem__(idx) + if self.inference_caption: + caption = None + if isinstance(anno, list): + caption = anno[0]["caption"] # inference mode for bing + anno = [] + elif len(anno) == 1: + caption = anno["caption"] # inference mode for googlecc + anno = [] + else: + caption = " ".join(anno["captions"]) + anno = [] + else: + ''' + An example + {'img_h': 1154, 'img_w': 1600, 'caption': 'xxx', 'tokens_positive': [[[47, 50], [51, 53], [54, 59]], [[32, 35], [36, 41]], [[32, 35], [36, 41]], [[0, 3], [3, 6], [6, 10], [11, 16], [17, 19], [20, 23]], [[32, 35], [36, 41]], [[32, 35], [36, 41]]], 'bboxes': [[7.344961166381836, 10.479412078857422, 1592.2679443359375, 1090.0028076171875], [950.32861328125, 346.572021484375, 1333.2373046875, 679.3215942382812], [927.44140625, 342.7712707519531, 1389.833984375, 719.5758666992188], [90.48786163330078, 363.67572021484375, 1381.8631591796875, 1078.687744140625], [122.84217071533203, 422.6786193847656, 507.845703125, 667.2651977539062], [80.62384033203125, 416.500244140625, 563.1666259765625, 734.603271484375]], 'scores': [0.7966700196266174, 0.8952182531356812, 0.8186006546020508, 0.9995516538619995, 0.8021856546401978, 0.8923134803771973]} + ''' + if len(anno["bboxes"]) < self.caption_min_box: # Retry triggered! + return self[np.random.choice(len(self))] + + if self.caption_format_version == "v2": + anno = self.convert_anno_from_v2_to_v1(anno) + + try: + if self.further_screen: + conf = self.caption_conf + nms_thre = self.caption_nms + + bboxes = torch.as_tensor(anno["bboxes"]).float() + scores = torch.as_tensor(anno["scores"]) + tokens_positive = anno["tokens_positive"] + + # print("\n\n\n\n tokens_positive in original data", tokens_positive) + + keep = scores > conf + scores = scores[keep] + bboxes = bboxes[keep] + tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]] + + assert (len(tokens_positive) == len(bboxes) == len(scores)) + + if len(bboxes) < self.caption_min_box: # Retry triggered! + return self[np.random.choice(len(self))] + + if nms_thre > 0: + keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre) + scores = scores[keep] + bboxes = bboxes[keep] + tokens_positive = [tokens_positive[i] for i in keep] + assert (len(tokens_positive) == len(bboxes) == len(scores)) + + # Write back + anno["bboxes"] = bboxes.tolist() + anno["scores"] = scores.tolist() + anno["tokens_positive"] = tokens_positive + + boxes = torch.as_tensor(anno["bboxes"]) + + if len(boxes) < self.caption_min_box: # Retry triggered! + return self[np.random.choice(len(self))] + + target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy") + target = target.clip_to_image(remove_empty=True) + + caption = anno["caption"] + # print("original caption", caption) + empty_everything = False + if self.sample_negative_for_grounding_data != -1: + if random.random() < self.sample_negative_for_grounding_data: + empty_everything = True + + if empty_everything: + caption = self.__get_negative_captions__(idx, negative_size=1)[0] + + if self.pack_random_caption_number != 0: + if self.random_pack_prob != -1.0: + if random.random() < self.no_random_pack_probability: + negative_pack_number = 0 + elif random.random() < self.random_pack_prob: + negative_pack_number = self.pack_random_caption_number + else: + negative_pack_number = np.random.choice(self.pack_random_caption_number) + else: + negative_pack_number = self.pack_random_caption_number + + negative_captions = self.__get_negative_captions__(idx, negative_size=negative_pack_number) + + caption, anno["tokens_positive"], greenlight_span_for_masked_lm_objective = self.pack_caption( + caption, negative_captions, anno["tokens_positive"]) + else: + greenlight_span_for_masked_lm_objective = [(0, len(caption))] + + if not self.mlm_obj_for_only_positive: + greenlight_span_for_masked_lm_objective = [(0, len(caption))] + + new_anno = [] + areas = target.area() + for i in range(len(target)): + new_anno_i = {} + new_anno_i["area"] = areas[i] + new_anno_i["iscrowd"] = 0 + new_anno_i["image_id"] = idx + new_anno_i["category_id"] = 1 # following vg and others + new_anno_i["id"] = None + new_anno_i['bbox'] = target.bbox[i].numpy().tolist() + new_anno_i["tokens_positive"] = anno["tokens_positive"][i] + new_anno.append(new_anno_i) + + except: + return self[np.random.choice(len(self))] + + anno = new_anno + if empty_everything: + anno = [] + + annotations = {"image_id": idx, "annotations": anno, "caption": caption} + annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective + img, annotations = self.prepare(img, annotations, box_format="xyxy") + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in annotations: + target.add_field(ann, annotations[ann]) + except: + print("Outter Retry triggered!!") + return self[np.random.choice(len(self))] + + sanity_check_target_after_processing(target) + + return img, target, idx + + def convert_anno_from_v2_to_v1(self, anno): + flatterned_bboxes = [] + flatterned_tokens_positive = [] + flatterned_bboxes_scores = [] + for i in range(len(anno["bboxes"])): + # i is the index for entity + for j in range(len(anno["bboxes"][i])): + # j is the index for each box + flatterned_bboxes.append(anno["bboxes"][i][j]) + flatterned_tokens_positive.append( + anno["tokens_positive"][i]) # Assume this box corresponds to all the token_spans for this entity + flatterned_bboxes_scores.append(anno["scores"][i][j]) + anno["bboxes"] = flatterned_bboxes + anno["tokens_positive"] = flatterned_tokens_positive + anno["scores"] = flatterned_bboxes_scores + return anno + + + def get_raw_image(self, idx): + image, *_ = super(CaptionTSV, self).__getitem__(idx) + return image + + def get_img_id(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv.seek(line_no) + img_id = row[0] + return img_id diff --git a/maskrcnn_benchmark/data/datasets/coco.py b/maskrcnn_benchmark/data/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..095af9ea67f08acb93fe4d6b175708cca60809ea --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/coco.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os +import os.path +import math +from PIL import Image, ImageDraw + +import random +import numpy as np + +import torch +import torchvision +import torch.utils.data as data + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask +from maskrcnn_benchmark.structures.keypoint import PersonKeypoints +from maskrcnn_benchmark.config import cfg +import pdb + +def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + +def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + +def has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different critera for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= cfg.DATALOADER.MIN_KPS_PER_IMS: + return True + return False + + +def pil_loader(path, retry=5): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + ri = 0 + while ri < retry: + try: + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + except: + ri += 1 + + +def rgb2id(color): + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +class CocoDetection(data.Dataset): + """`MS Coco Detection `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, annFile, transform=None, target_transform=None): + from pycocotools.coco import COCO + self.root = root + self.coco = COCO(annFile) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index, return_meta=False): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + if isinstance(img_id, str): + img_id = [img_id] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + meta = coco.loadImgs(img_id)[0] + path = meta['file_name'] + img = pil_loader(os.path.join(self.root, path)) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if return_meta: + return img, target, meta + else: + return img, target + + def __len__(self): + return len(self.ids) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +class COCODataset(CocoDetection): + def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None, ignore_crowd=True, + max_box=-1, + few_shot=0, one_hot=False, override_category=None, **kwargs + ): + super(COCODataset, self).__init__(root, ann_file) + # sort indices for reproducible results + self.ids = sorted(self.ids) + + # filter images without detection annotations + if remove_images_without_annotations: + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + ids.append(img_id) + self.ids = ids + + if few_shot: + ids = [] + cats_freq = [few_shot]*len(self.coco.cats.keys()) + if 'shuffle_seed' in kwargs and kwargs['shuffle_seed'] != 0: + import random + random.Random(kwargs['shuffle_seed']).shuffle(self.ids) + print("Shuffle the dataset with random seed: ", kwargs['shuffle_seed']) + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level + is_needed = sum([cats_freq[c-1]>0 for c in cat]) + if is_needed: + ids.append(img_id) + for c in cat: + cats_freq[c-1] -= 1 + # print(cat, cats_freq) + self.ids = ids + + if override_category is not None: + self.coco.dataset["categories"] = override_category + print("Override category: ", override_category) + + self.json_category_id_to_contiguous_id = { + v: i + 1 for i, v in enumerate(self.coco.getCatIds()) + } + self.contiguous_category_id_to_json_id = { + v: k for k, v in self.json_category_id_to_contiguous_id.items() + } + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self.transforms = transforms + self.ignore_crowd = ignore_crowd + self.max_box = max_box + self.one_hot = one_hot + + def categories(self, no_background=True): + categories = self.coco.dataset["categories"] + label_list = {} + for index, i in enumerate(categories): + if not no_background or (i["name"] != "__background__" and i['id'] != 0): + label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] + return label_list + + def __getitem__(self, idx): + + + img, anno = super(COCODataset, self).__getitem__(idx) + + # filter crowd annotations + if self.ignore_crowd: + anno = [obj for obj in anno if obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes + if self.max_box > 0 and len(boxes) > self.max_box: + rand_idx = torch.randperm(self.max_box) + boxes = boxes[rand_idx, :] + else: + rand_idx = None + target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") + + classes = [obj["category_id"] for obj in anno] + classes = [self.json_category_id_to_contiguous_id[c] for c in classes] + classes = torch.tensor(classes) + + if rand_idx is not None: + classes = classes[rand_idx] + if cfg.DATASETS.CLASS_AGNOSTIC: + classes = torch.ones_like(classes) + target.add_field("labels", classes) + + if anno and "segmentation" in anno[0]: + masks = [obj["segmentation"] for obj in anno] + masks = SegmentationMask(masks, img.size, mode='poly') + target.add_field("masks", masks) + + if anno and "cbox" in anno[0]: + cboxes = [obj["cbox"] for obj in anno] + cboxes = torch.as_tensor(cboxes).reshape(-1, 4) # guard against no boxes + cboxes = BoxList(cboxes, img.size, mode="xywh").convert("xyxy") + target.add_field("cbox", cboxes) + + if anno and "keypoints" in anno[0]: + keypoints = [] + gt_keypoint = self.coco.cats[1]['keypoints'] # a better way to get keypoint description + use_keypoint = cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME + for obj in anno: + if len(use_keypoint) > 0: + kps = [] + for name in use_keypoint: + kp_idx = slice(3 * gt_keypoint.index(name), 3 * gt_keypoint.index(name) + 3) + kps += obj["keypoints"][kp_idx] + keypoints.append(kps) + else: + keypoints.append(obj["keypoints"]) + keypoints = PersonKeypoints(keypoints, img.size) + target.add_field("keypoints", keypoints) + + target = target.clip_to_image(remove_empty=True) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + if cfg.DATASETS.SAMPLE_RATIO != 0.0: + ratio = cfg.DATASETS.SAMPLE_RATIO + num_sample_target = math.ceil(len(target) * ratio) if ratio > 0 else math.ceil(-ratio) + sample_idx = torch.randperm(len(target))[:num_sample_target] + target = target[sample_idx] + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data diff --git a/maskrcnn_benchmark/data/datasets/coco_dt.py b/maskrcnn_benchmark/data/datasets/coco_dt.py new file mode 100644 index 0000000000000000000000000000000000000000..b050b3ed0dd2fa5b4974ef17fc8bdb26ed08fee7 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/coco_dt.py @@ -0,0 +1,154 @@ +""" +COCO dataset which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" + +import torch +import json +from PIL import Image, ImageDraw + +from .modulated_coco import ConvertCocoPolysToMask +from .tsv import ODTSVDataset +from pycocotools.coco import COCO +from maskrcnn_benchmark.structures.bounding_box import BoxList +import random +from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing + + +class CocoDetectionTSV(ODTSVDataset): + def __init__(self, + name, + yaml_file, + transforms, + return_tokens, + tokenizer, + extra_fields, + random_sample_negative=-1, + add_detection_prompt=False, + add_detection_prompt_advanced=False, + use_od_data_aug=False, + control_probabilities={}, + disable_shuffle=False, + prompt_engineer_version="v2", + prompt_limit_negative=-1, + positive_question_probability=0.6, + negative_question_probability=0.8, + full_question_probability=0.5, + disable_clip_to_image=False, + separation_tokens=" ", + no_mask_for_od=False, + max_num_labels=-1, + max_query_len=256, + **kwargs + ): + super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs) + + self._transforms = transforms + self.name = name + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask( + return_masks=False, + return_tokens=return_tokens, + tokenizer=tokenizer, + max_query_len=max_query_len + ) + self.tokenizer = tokenizer + + self.control_probabilities = control_probabilities + self.random_sample_negative = random_sample_negative + self.add_detection_prompt = add_detection_prompt + self.add_detection_prompt_advanced = add_detection_prompt_advanced + self.use_od_data_aug = use_od_data_aug + + self.prompt_engineer_version = prompt_engineer_version + self.prompt_limit_negative = prompt_limit_negative + self.positive_question_probability = positive_question_probability + self.negative_question_probability = negative_question_probability + self.full_question_probability = full_question_probability + self.separation_tokens = separation_tokens + self.disable_clip_to_image = disable_clip_to_image + self.disable_shuffle = disable_shuffle + self.no_mask_for_od = no_mask_for_od + self.max_num_labels = max_num_labels + + def __len__(self): + return super(CocoDetectionTSV, self).__len__() + + def categories(self, no_background=True): + categories = self.coco.dataset["categories"] + label_list = {} + for index, i in enumerate(categories): + # assert(index + 1 == i["id"]) + if not no_background or (i["name"] != "__background__" and i['id'] != 0): + label_list[i["id"]] = i["name"] + return label_list + + def __getitem__(self, idx): + # tgt is a BoxList + img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx) + image_id = self.get_img_id(idx) + restricted_negative_list = None + + if not self.disable_clip_to_image: + target = target.clip_to_image(remove_empty=True) + + original_box_num = len(target) + + target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens + + if len(target) < original_box_num: + print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) + + annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( + target=target, + image_id=image_id, + ind_to_class=self.ind_to_class, + disable_shuffle=self.disable_shuffle, + add_detection_prompt=self.add_detection_prompt, + add_detection_prompt_advanced=self.add_detection_prompt_advanced, + random_sample_negative=self.random_sample_negative, + control_probabilities=self.control_probabilities, + restricted_negative_list=restricted_negative_list, + separation_tokens=self.separation_tokens, + max_num_labels=self.max_num_labels, + positive_caption_length=positive_caption_length, + tokenizer=self.tokenizer, + max_seq_length=self.max_query_len-2 + ) + + # assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2) + + # print(caption) + anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions": label_to_positions} + anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective + + if self.no_mask_for_od: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + + img, anno = self.prepare(img, anno, box_format="xyxy") + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + sanity_check_target_after_processing(target) + + return img, target, idx + + def get_raw_image(self, idx): + image, *_ = super(CocoDetectionTSV, self).__getitem__(idx) + return image + + def get_img_id(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv.seek(line_no) + img_id = row[0] + try: + return int(img_id) + except: + return idx diff --git a/maskrcnn_benchmark/data/datasets/concat_dataset.py b/maskrcnn_benchmark/data/datasets/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb6c0d96f906056c0b6d0d001db00c6eac2a5ae --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/concat_dataset.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import bisect + +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + + +class ConcatDataset(_ConcatDataset): + """ + Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra + method for querying the sizes of the image + """ + + def get_idxs(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return dataset_idx, sample_idx + + def get_img_info(self, idx): + dataset_idx, sample_idx = self.get_idxs(idx) + return self.datasets[dataset_idx].get_img_info(sample_idx) diff --git a/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py b/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf8d8c4ea2b5f603d4a3a94cb114c154b56566d --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py @@ -0,0 +1,185 @@ +import math +from typing import TypeVar, Optional, Iterator + +import torch +from torch.utils.data import Sampler, Dataset +import torch.distributed as dist +import random +import numpy as np +import torch + + +class DistributedSamplerChunkByNode(torch.utils.data.Sampler): + + def __init__(self, + dataset, + all_datasets, + chunk_or_not, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + node_rank=0, + node_number=1, process_num_per_node=1, + rank_within_local_node=0) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.node_number = node_number + self.node_rank = node_rank + self.chunk_or_not = chunk_or_not + self.process_num_per_node = process_num_per_node + self.rank_within_local_node = rank_within_local_node + + assert (self.process_num_per_node * self.node_number == self.num_replicas) + + # 1. divide the datasets into two parts + normal_datasets = [] + chunked_datasets = [] + for dataset_i, chunk_i in zip(all_datasets, chunk_or_not): + if chunk_i: + chunked_datasets.append(dataset_i) + else: + normal_datasets.append(dataset_i) + + # 2. calculate dataset sizes: + self.normal_dataset_size = sum( + [len(i) for i in normal_datasets]) # this part we follow the conventional distributed sampler + + # 3. Divide + self.current_node_start_range = -1 + self.current_node_end_range = -1 + assert (len(chunked_datasets) >= self.node_number) + chunk_size = len(chunked_datasets) // self.node_number + current_example_num = self.normal_dataset_size + + for index in range(len(chunked_datasets)): + if index == self.node_rank * chunk_size: + self.current_node_start_range = current_example_num + current_example_num += len(chunked_datasets[index]) + if index == (self.node_rank + 1) * chunk_size - 1: + self.current_node_end_range = current_example_num + + if self.current_node_end_range == -1: # boundary + self.current_node_end_range = current_example_num + + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + indices = self.generate_indices_within_range_with_rank( + seed=self.seed, + epoch=self.epoch, + + # NOTE: Distribute among all processes + process_num=self.num_replicas, + rank=self.rank, + generate_length=-1, + valid_indices=list(range(self.normal_dataset_size)), + prefix="Normal " + ) + + addition_indices = self.generate_indices_within_range_with_rank( + seed=self.seed, + epoch=self.epoch, + + # NOTE : very important arguments, distribute among local nodes + process_num=self.process_num_per_node, + rank=self.rank_within_local_node, + + generate_length=self.num_samples - len(indices), + valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)), + prefix="Distribute " + ) + + indices.extend(addition_indices) + random.seed(self.seed + self.epoch + 10 * self.rank) # Set the seed to maximize randomness + random.shuffle(indices) # Reshuffle + assert len(indices) == self.num_samples + return iter(indices) + + def generate_indices_within_range_with_rank(self, seed, epoch, process_num, generate_length, valid_indices, rank=-1, + shuffle=True, prefix=""): + ''' + Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process. + Modified from DistributedSampler + ''' + dataset_size = len(valid_indices) + if shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(seed + epoch) + indices = torch.randperm(dataset_size, generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(dataset_size)) # type: ignore[arg-type] + + indices = [valid_indices[i] for i in indices] + + num_samples_normal = math.ceil( + (dataset_size - process_num) / process_num # type: ignore[arg-type] + ) + # remove tail of data to make it evenly divisible. + indices = indices[:num_samples_normal * process_num] + + print("\n") + print(prefix, + "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_before_subsample {} {}".format( + self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10])) + + # subsample + indices = indices[rank:num_samples_normal * process_num: process_num] + + print(prefix, + "Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_after_subsample {} {}".format( + self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10])) + print("\n") + + if generate_length != -1: + if len(indices) > generate_length: + indices = indices[:generate_length] + else: + indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist()) + return indices + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/maskrcnn_benchmark/data/datasets/duplicate_dataset.py b/maskrcnn_benchmark/data/datasets/duplicate_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3506c968d496bdc90954418c8f955f6012beb6 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/duplicate_dataset.py @@ -0,0 +1,31 @@ +import math +from typing import TypeVar, Optional, Iterator + +import torch +from torch.utils.data import Sampler, Dataset +import torch.distributed as dist +import random +import numpy as np + + +def create_duplicate_dataset(DatasetBaseClass): + class DupDataset(DatasetBaseClass): + + def __init__(self, copy, **kwargs): + super(DupDataset, self).__init__(**kwargs) + + self.copy = copy + self.length = super(DupDataset, self).__len__() + + def __len__(self): + return self.copy * self.length + + def __getitem__(self, index): + true_index = index % self.length + return super(DupDataset, self).__getitem__(true_index) + + def get_img_info(self, index): + true_index = index % self.length + return super(DupDataset, self).get_img_info(true_index) + + return DupDataset diff --git a/maskrcnn_benchmark/data/datasets/evaluation/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d84ead6b86ddcc1e0ae4088a5e36546ebef0efd --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/__init__.py @@ -0,0 +1,56 @@ +from maskrcnn_benchmark.data import datasets + +from .coco import coco_evaluation +from .voc import voc_evaluation +from .vg import vg_evaluation +from .box_aug import im_detect_bbox_aug +from .od_to_grounding import od_to_grounding_evaluation + + +def evaluate(dataset, predictions, output_folder, **kwargs): + """evaluate dataset using different methods based on dataset type. + Args: + dataset: Dataset object + predictions(list[BoxList]): each item in the list represents the + prediction results for one image. + output_folder: output folder, to save evaluation files or results. + **kwargs: other args. + Returns: + evaluation result + """ + args = dict( + dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs + ) + if isinstance(dataset, datasets.COCODataset) or isinstance(dataset, datasets.TSVDataset): + return coco_evaluation(**args) + # elif isinstance(dataset, datasets.VGTSVDataset): + # return vg_evaluation(**args) + elif isinstance(dataset, datasets.PascalVOCDataset): + return voc_evaluation(**args) + elif isinstance(dataset, datasets.CocoDetectionTSV): + return od_to_grounding_evaluation(**args) + elif isinstance(dataset, datasets.LvisDetection): + pass + else: + dataset_name = dataset.__class__.__name__ + raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name)) + + +def evaluate_mdetr(dataset, predictions, output_folder, cfg): + + args = dict( + dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs + ) + if isinstance(dataset, datasets.COCODataset) or isinstance(dataset, datasets.TSVDataset): + return coco_evaluation(**args) + # elif isinstance(dataset, datasets.VGTSVDataset): + # return vg_evaluation(**args) + elif isinstance(dataset, datasets.PascalVOCDataset): + return voc_evaluation(**args) + elif isinstance(dataset, datasets.CocoDetectionTSV): + return od_to_grounding_evaluation(**args) + elif isinstance(dataset, datasets.LvisDetection): + pass + else: + dataset_name = dataset.__class__.__name__ + raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name)) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py b/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..da7e5d1907bc3fa69ce85a78723479988e532b2e --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/box_aug.py @@ -0,0 +1,349 @@ +import torch +import numpy as np + +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.data import transforms as T +from maskrcnn_benchmark.structures.image_list import to_image_list +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.layers import nms, soft_nms + + +def im_detect_bbox_aug(model, images, device, captions=None, positive_map_label_to_token=None): + # Collect detections computed under different transformations + boxlists_ts = [] + for _ in range(len(images)): + boxlists_ts.append([]) + + def add_preds_t(boxlists_t): + for i, boxlist_t in enumerate(boxlists_t): + # Resize the boxlist as the first one + boxlists_ts[i].append(boxlist_t.resize(images[i].size)) + + # Compute detections at different scales + if len(cfg.TEST.RANGES)==len(cfg.TEST.SCALES): + keep_ranges = cfg.TEST.RANGES + else: + keep_ranges = [None for _ in cfg.TEST.SCALES] + + for scale, keep_range in zip(cfg.TEST.SCALES, keep_ranges): + max_size = cfg.TEST.MAX_SIZE + boxlists_scl = im_detect_bbox_scale( + model, images, scale, max_size, device, + captions=captions, + positive_map_label_to_token=positive_map_label_to_token, + ) + if keep_range is not None: + boxlists_scl = remove_boxes(boxlists_scl, *keep_range) + add_preds_t(boxlists_scl) + + if cfg.TEST.FLIP: + boxlists_scl_hf = im_detect_bbox_scale( + model, images, scale, max_size, device, + captions=captions, + positive_map_label_to_token=positive_map_label_to_token, + hflip=True + ) + if keep_range is not None: + boxlists_scl_hf = remove_boxes(boxlists_scl_hf, *keep_range) + add_preds_t(boxlists_scl_hf) + + # Merge boxlists detected by different bbox aug params + boxlists = [] + for i, boxlist_ts in enumerate(boxlists_ts): + bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts]) + scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts]) + labels = torch.cat([boxlist_t.get_field('labels') for boxlist_t in boxlist_ts]) + boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode) + boxlist.add_field('scores', scores) + boxlist.add_field('labels', labels) + boxlists.append(boxlist) + results = merge_result_from_multi_scales(boxlists) + return results + + +def im_detect_bbox(model, images, target_scale, target_max_size, device, + captions=None, + positive_map_label_to_token=None + ): + """ + Performs bbox detection on the original image. + """ + if cfg.INPUT.FORMAT is not '': + input_format = cfg.INPUT.FORMAT + elif cfg.INPUT.TO_BGR255: + input_format = 'bgr255' + transform = T.Compose([ + T.Resize(target_scale, target_max_size), + T.ToTensor(), + T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format + ) + ]) + images = [transform(image) for image in images] + images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY) + if captions is None: + return model(images.to(device)) + else: + return model(images.to(device), + captions=captions, + positive_map=positive_map_label_to_token + ) + + +def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device, + captions=None, + positive_map_label_to_token=None + ): + """ + Performs bbox detection on the horizontally flipped image. + Function signature is the same as for im_detect_bbox. + """ + if cfg.INPUT.FORMAT is not '': + input_format = cfg.INPUT.FORMAT + elif cfg.INPUT.TO_BGR255: + input_format = 'bgr255' + transform = T.Compose([ + T.Resize(target_scale, target_max_size), + T.RandomHorizontalFlip(1.0), + T.ToTensor(), + T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format + ) + ]) + images = [transform(image) for image in images] + images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY) + if captions is None: + boxlists = model(images.to(device)) + else: + boxlists = model(images.to(device), + captions=captions, + positive_map=positive_map_label_to_token + ) + + # Invert the detections computed on the flipped image + boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists] + return boxlists_inv + + +def im_detect_bbox_scale(model, images, target_scale, target_max_size, device, + captions=None, + positive_map_label_to_token=None, + hflip=False): + """ + Computes bbox detections at the given scale. + Returns predictions in the scaled image space. + """ + if hflip: + boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device, + captions=captions, + positive_map_label_to_token=positive_map_label_to_token + ) + else: + boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device, + captions=captions, + positive_map_label_to_token=positive_map_label_to_token + ) + return boxlists_scl + + +def remove_boxes(boxlist_ts, min_scale, max_scale): + new_boxlist_ts = [] + for _, boxlist_t in enumerate(boxlist_ts): + mode = boxlist_t.mode + boxlist_t = boxlist_t.convert("xyxy") + boxes = boxlist_t.bbox + keep = [] + for j, box in enumerate(boxes): + w = box[2] - box[0] + 1 + h = box[3] - box[1] + 1 + if (w * h > min_scale * min_scale) and (w * h < max_scale * max_scale): + keep.append(j) + new_boxlist_ts.append(boxlist_t[keep].convert(mode)) + return new_boxlist_ts + + +def merge_result_from_multi_scales(boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + scores = boxlists[i].get_field("scores") + labels = boxlists[i].get_field("labels") + boxes = boxlists[i].bbox + boxlist = boxlists[i] + result = [] + # test on classes + if len(cfg.TEST.SELECT_CLASSES): + class_list = cfg.TEST.SELECT_CLASSES + else: + class_list = range(1, cfg.TEST.NUM_CLASSES) + for j in class_list: + inds = (labels == j).nonzero().view(-1) + + scores_j = scores[inds] + boxes_j = boxes[inds, :].view(-1, 4) + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + boxlist_for_class = boxlist_nms(boxlist_for_class, cfg.TEST.TH, score_field="scores", nms_type=cfg.TEST.SPECIAL_NMS) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field("labels", torch.full((num_labels,), j, dtype=torch.int64, device=scores.device)) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > cfg.TEST.PRE_NMS_TOP_N > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), + number_of_detections - cfg.TEST.PRE_NMS_TOP_N + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + +def boxlist_nms(boxlist, thresh, max_proposals=-1, score_field="scores", nms_type='nms'): + if thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + score = boxlist.get_field(score_field) + + if nms_type == 'vote': + boxes_vote, scores_vote = bbox_vote(boxes, score, thresh) + if len(boxes_vote) > 0: + boxlist.bbox = boxes_vote + boxlist.extra_fields['scores'] = scores_vote + elif nms_type == 'soft-vote': + boxes_vote, scores_vote = soft_bbox_vote(boxes, score, thresh) + if len(boxes_vote) > 0: + boxlist.bbox = boxes_vote + boxlist.extra_fields['scores'] = scores_vote + elif nms_type == 'soft-nms': + keep, new_score = soft_nms(boxes.cpu(), score.cpu(), thresh, 0.95) + if max_proposals > 0: + keep = keep[: max_proposals] + boxlist = boxlist[keep] + boxlist.extra_fields['scores'] = new_score + else: + keep = nms(boxes, score, thresh) + if max_proposals > 0: + keep = keep[: max_proposals] + boxlist = boxlist[keep] + return boxlist.convert(mode) + + +def bbox_vote(boxes, scores, vote_thresh): + boxes = boxes.cpu().numpy() + scores = scores.cpu().numpy().reshape(-1, 1) + det = np.concatenate((boxes, scores), axis=1) + if det.shape[0] <= 1: + return np.zeros((0, 5)), np.zeros((0, 1)) + order = det[:, 4].ravel().argsort()[::-1] + det = det[order, :] + dets = [] + while det.shape[0] > 0: + # IOU + area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1) + xx1 = np.maximum(det[0, 0], det[:, 0]) + yy1 = np.maximum(det[0, 1], det[:, 1]) + xx2 = np.minimum(det[0, 2], det[:, 2]) + yy2 = np.minimum(det[0, 3], det[:, 3]) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + o = inter / (area[0] + area[:] - inter) + + # get needed merge det and delete these det + merge_index = np.where(o >= vote_thresh)[0] + det_accu = det[merge_index, :] + det = np.delete(det, merge_index, 0) + + if merge_index.shape[0] <= 1: + try: + dets = np.row_stack((dets, det_accu)) + except: + dets = det_accu + continue + else: + det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4)) + max_score = np.max(det_accu[:, 4]) + det_accu_sum = np.zeros((1, 5)) + det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:]) + det_accu_sum[:, 4] = max_score + try: + dets = np.row_stack((dets, det_accu_sum)) + except: + dets = det_accu_sum + + boxes = torch.from_numpy(dets[:, :4]).float().cuda() + scores = torch.from_numpy(dets[:, 4]).float().cuda() + + return boxes, scores + + +def soft_bbox_vote(boxes, scores, vote_thresh): + boxes = boxes.cpu().numpy() + scores = scores.cpu().numpy().reshape(-1, 1) + det = np.concatenate((boxes, scores), axis=1) + if det.shape[0] <= 1: + return np.zeros((0, 5)), np.zeros((0, 1)) + order = det[:, 4].ravel().argsort()[::-1] + det = det[order, :] + dets = [] + while det.shape[0] > 0: + # IOU + area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1) + xx1 = np.maximum(det[0, 0], det[:, 0]) + yy1 = np.maximum(det[0, 1], det[:, 1]) + xx2 = np.minimum(det[0, 2], det[:, 2]) + yy2 = np.minimum(det[0, 3], det[:, 3]) + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + o = inter / (area[0] + area[:] - inter) + + # get needed merge det and delete these det + merge_index = np.where(o >= vote_thresh)[0] + det_accu = det[merge_index, :] + det_accu_iou = o[merge_index] + det = np.delete(det, merge_index, 0) + + if merge_index.shape[0] <= 1: + try: + dets = np.row_stack((dets, det_accu)) + except: + dets = det_accu + continue + else: + soft_det_accu = det_accu.copy() + soft_det_accu[:, 4] = soft_det_accu[:, 4] * (1 - det_accu_iou) + soft_index = np.where(soft_det_accu[:, 4] >= cfg.MODEL.RETINANET.INFERENCE_TH)[0] + soft_det_accu = soft_det_accu[soft_index, :] + + det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4)) + max_score = np.max(det_accu[:, 4]) + det_accu_sum = np.zeros((1, 5)) + det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:]) + det_accu_sum[:, 4] = max_score + + if soft_det_accu.shape[0] > 0: + det_accu_sum = np.row_stack((det_accu_sum, soft_det_accu)) + + try: + dets = np.row_stack((dets, det_accu_sum)) + except: + dets = det_accu_sum + + order = dets[:, 4].ravel().argsort()[::-1] + dets = dets[order, :] + + boxes = torch.from_numpy(dets[:, :4]).float().cuda() + scores = torch.from_numpy(dets[:, 4]).float().cuda() + + return boxes, scores \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a25c9b536e131b4d8bfd8e7ceb24c783d8d97cd --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py @@ -0,0 +1,21 @@ +from .coco_eval import do_coco_evaluation + + +def coco_evaluation( + dataset, + predictions, + output_folder, + box_only=False, + iou_types=("bbox",), + expected_results=(), + expected_results_sigma_tol=4, +): + return do_coco_evaluation( + dataset=dataset, + predictions=predictions, + box_only=box_only, + output_folder=output_folder, + iou_types=iou_types, + expected_results=expected_results, + expected_results_sigma_tol=expected_results_sigma_tol, + ) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..be79d7429b14c848ca161ccfe434512749c06af8 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/coco/coco_eval.py @@ -0,0 +1,531 @@ +import logging +import tempfile +import os +import torch +import numpy as np +import json + +from collections import OrderedDict +from tqdm import tqdm + +from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou + + +def do_coco_evaluation( + dataset, + predictions, + box_only, + output_folder, + iou_types, + expected_results, + expected_results_sigma_tol, +): + logger = logging.getLogger("maskrcnn_benchmark.inference") + + if box_only: + logger.info("Evaluating bbox proposals") + if dataset.coco is None and output_folder: + json_results = prepare_for_tsv_detection(predictions, dataset) + with open(os.path.join(output_folder, "box_proposals.json"), "w") as f: + json.dump(json_results, f) + return None + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + res = COCOResults("box_proposal") + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = evaluate_box_proposals( + predictions, dataset, area=area, limit=limit + ) + key = "AR{}@{:d}".format(suffix, limit) + res.results["box_proposal"][key] = stats["ar"].item() + logger.info(res) + check_expected_results(res, expected_results, expected_results_sigma_tol) + if output_folder: + torch.save(res, os.path.join(output_folder, "box_proposals.pth")) + return res, predictions + logger.info("Preparing results for COCO format") + coco_results = {} + if "bbox" in iou_types: + logger.info("Preparing bbox results") + if dataset.coco is None: + coco_results["bbox"] = prepare_for_tsv_detection(predictions, dataset) + else: + coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset) + if "segm" in iou_types: + logger.info("Preparing segm results") + coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset) + if 'keypoints' in iou_types: + logger.info('Preparing keypoints results') + coco_results['keypoints'] = prepare_for_coco_keypoint(predictions, dataset) + + results = COCOResults(*iou_types) + logger.info("Evaluating predictions") + for iou_type in iou_types: + with tempfile.NamedTemporaryFile() as f: + file_path = f.name + if output_folder: + file_path = os.path.join(output_folder, iou_type + ".json") + if dataset.coco: + res = evaluate_predictions_on_coco( + dataset.coco, coco_results[iou_type], file_path, iou_type + ) + results.update(res) + elif output_folder: + with open(file_path, "w") as f: + json.dump(coco_results[iou_type], f) + + logger.info(results) + check_expected_results(results, expected_results, expected_results_sigma_tol) + if output_folder: + torch.save(results, os.path.join(output_folder, "coco_results.pth")) + return results, coco_results + + +def prepare_for_tsv_detection(predictions, dataset): + # assert isinstance(dataset, COCODataset) + proposal_results = [] + image_list = [] + for im_id, prediction in enumerate(predictions): + image_info = dataset.get_img_info(im_id) + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_id = image_info["id"] + image_width = image_info["width"] + image_height = image_info["height"] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") + + boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + if prediction.has_field("centers"): + centers = prediction.get_field("centers") + else: + centers = None + + for k, box in enumerate(boxes): + proposal = { + "image_id": image_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + "area": image_width * image_height, + "iscrowd": 0, + } + if centers is not None: + proposal.update(center=centers[k].tolist()) + proposal_results.append(proposal) + + image_list.append(image_info) + + # categories = [{'supercategory': 'proposal', 'id': 0, 'name': 'proposal'}] + return dict(images=image_list, annotations=proposal_results) + + +def prepare_for_coco_detection(predictions, dataset): + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") + + boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + + for k, box in enumerate(boxes): + if labels[k] in dataset.contiguous_category_id_to_json_id: + coco_results.append( + { + "image_id": original_id, + "category_id": dataset.contiguous_category_id_to_json_id[labels[k]], + "bbox": box, + "score": scores[k], + }) + + return coco_results + + +def prepare_for_coco_segmentation(predictions, dataset): + import pycocotools.mask as mask_util + import numpy as np + + masker = Masker(threshold=0.5, padding=1) + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in tqdm(enumerate(predictions)): + original_id = dataset.id_to_img_map[image_id] + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + masks = prediction.get_field("mask") + # t = time.time() + # Masker is necessary only if masks haven't been already resized. + if list(masks.shape[-2:]) != [image_height, image_width]: + masks = masker(masks.expand(1, -1, -1, -1, -1), prediction) + masks = masks[0] + # logger.info('Time mask: {}'.format(time.time() - t)) + # prediction = prediction.convert('xywh') + + # boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + + # rles = prediction.get_field('mask') + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels] + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + +def prepare_for_coco_keypoint(predictions, dataset): + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + if len(prediction.bbox) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]['width'] + image_height = dataset.coco.imgs[original_id]['height'] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert('xywh') + + boxes = prediction.bbox.tolist() + scores = prediction.get_field('scores').tolist() + labels = prediction.get_field('labels').tolist() + keypoints = prediction.get_field('keypoints') + keypoints = keypoints.resize((image_width, image_height)) + keypoints = keypoints.to_coco_format() + + mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels] + + coco_results.extend([{ + 'image_id': original_id, + 'category_id': mapped_labels[k], + 'keypoints': keypoint, + 'score': scores[k]} for k, keypoint in enumerate(keypoints)]) + return coco_results + + +# inspired from Detectron +def evaluate_box_proposals( + predictions, dataset, thresholds=None, area="all", limit=None +): + """Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official COCO API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0 ** 2, 1e5 ** 2], # all + [0 ** 2, 32 ** 2], # small + [32 ** 2, 96 ** 2], # medium + [96 ** 2, 1e5 ** 2], # large + [96 ** 2, 128 ** 2], # 96-128 + [128 ** 2, 256 ** 2], # 128-256 + [256 ** 2, 512 ** 2], # 256-512 + [512 ** 2, 1e5 ** 2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + if prediction.has_field("objectness"): + inds = prediction.get_field("objectness").sort(descending=True)[1] + else: + inds = prediction.get_field("scores").sort(descending=True)[1] + prediction = prediction[inds] + + ann_ids = dataset.coco.getAnnIds(imgIds=original_id) + anno = dataset.coco.loadAnns(ann_ids) + gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert( + "xyxy" + ) + gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0]) + + if len(gt_boxes) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if len(prediction) == 0: + continue + + if limit is not None and len(prediction) > limit: + prediction = prediction[:limit] + + overlaps = boxlist_iou(prediction, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(prediction), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + + if len(gt_overlaps) == 0: + return { + "ar": torch.zeros(1), + "recalls": torch.zeros(1), + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + gt_overlaps = torch.cat(gt_overlaps, dim=0) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def evaluate_predictions_on_coco( + coco_gt, coco_results, json_result_file, iou_type="bbox" +): + import json + + with open(json_result_file, "w") as f: + json.dump(coco_results, f) + + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + coco_dt = coco_gt.loadRes(str(json_result_file)) if coco_results else COCO() + + # coco_dt = coco_gt.loadRes(coco_results) + if iou_type == 'keypoints': + coco_gt = filter_valid_keypoints(coco_gt, coco_dt) + coco_eval = COCOeval(coco_gt, coco_dt, iou_type) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if iou_type == 'bbox': + summarize_per_category(coco_eval, json_result_file.replace('.json', '.csv')) + return coco_eval + + +def summarize_per_category(coco_eval, csv_output=None): + ''' + Compute and display summary metrics for evaluation results. + Note this functin can *only* be applied on the default parameter setting + ''' + + def _summarize(iouThr=None, areaRng='all', maxDets=100): + p = coco_eval.params + titleStr = 'Average Precision' + typeStr = '(AP)' + iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ + if iouThr is None else '{:0.2f}'.format(iouThr) + result_str = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ], '. \ + format(titleStr, typeStr, iouStr, areaRng, maxDets) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + + # dimension of precision: [TxRxKxAxM] + s = coco_eval.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + # cacluate AP(average precision) for each category + num_classes = len(p.catIds) + avg_ap = 0.0 + for i in range(0, num_classes): + result_str += '{}, '.format(np.mean(s[:, :, i, :])) + avg_ap += np.mean(s[:, :, i, :]) + result_str += ('{} \n'.format(avg_ap / num_classes)) + return result_str + + id2name = {} + for _, cat in coco_eval.cocoGt.cats.items(): + id2name[cat['id']] = cat['name'] + title_str = 'metric, ' + for cid in coco_eval.params.catIds: + title_str += '{}, '.format(id2name[cid]) + title_str += 'avg \n' + + results = [title_str] + results.append(_summarize()) + results.append(_summarize(iouThr=.5, maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='small', maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='medium', maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='large', maxDets=coco_eval.params.maxDets[2])) + + with open(csv_output, 'w') as f: + for result in results: + f.writelines(result) + + +def filter_valid_keypoints(coco_gt, coco_dt): + kps = coco_dt.anns[1]['keypoints'] + for id, ann in coco_gt.anns.items(): + ann['keypoints'][2::3] = [a * b for a, b in zip(ann['keypoints'][2::3], kps[2::3])] + ann['num_keypoints'] = sum(ann['keypoints'][2::3]) + return coco_gt + + +class COCOResults(object): + METRICS = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "box_proposal": [ + "AR@100", + "ARs@100", + "ARm@100", + "ARl@100", + "AR@1000", + "ARs@1000", + "ARm@1000", + "ARl@1000", + ], + "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], + } + + def __init__(self, *iou_types): + allowed_types = ("box_proposal", "bbox", "segm", "keypoints") + assert all(iou_type in allowed_types for iou_type in iou_types) + results = OrderedDict() + for iou_type in iou_types: + results[iou_type] = OrderedDict( + [(metric, -1) for metric in COCOResults.METRICS[iou_type]] + ) + self.results = results + + def update(self, coco_eval): + if coco_eval is None: + return + from pycocotools.cocoeval import COCOeval + + assert isinstance(coco_eval, COCOeval) + s = coco_eval.stats + iou_type = coco_eval.params.iouType + res = self.results[iou_type] + metrics = COCOResults.METRICS[iou_type] + for idx, metric in enumerate(metrics): + res[metric] = s[idx] + + def __repr__(self): + # TODO make it pretty + return repr(self.results) + + +def check_expected_results(results, expected_results, sigma_tol): + if not expected_results: + return + + logger = logging.getLogger("maskrcnn_benchmark.inference") + for task, metric, (mean, std) in expected_results: + actual_val = results.results[task][metric] + lo = mean - sigma_tol * std + hi = mean + sigma_tol * std + ok = (lo < actual_val) and (actual_val < hi) + msg = ( + "{} > {} sanity check (actual vs. expected): " + "{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})" + ).format(task, metric, actual_val, mean, std, lo, hi) + if not ok: + msg = "FAIL: " + msg + logger.error(msg) + else: + msg = "PASS: " + msg + logger.info(msg) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd063073c837183ac09aee7c6bbc4d8ad9dd47ef --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/flickr/__init__.py @@ -0,0 +1 @@ +from .flickr_eval import FlickrEvaluator diff --git a/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..394bd59aa88b4b9ca67ca7b5ad18f390befe9b99 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/flickr/flickr_eval.py @@ -0,0 +1,440 @@ +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.structures.bounding_box import BoxList +import json +import numpy as np +import os.path as osp +import os +from prettytable import PrettyTable + +import xml.etree.ElementTree as ET +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import maskrcnn_benchmark.utils.mdetr_dist as dist +#### The following loading utilities are imported from +#### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py +# Changelog: +# - Added typing information +# - Completed docstrings + +def get_sentence_data(filename) -> List[Dict[str, Any]]: + """ + Parses a sentence file from the Flickr30K Entities dataset + + input: + filename - full file path to the sentence file to parse + + output: + a list of dictionaries for each sentence with the following fields: + sentence - the original sentence + phrases - a list of dictionaries for each phrase with the + following fields: + phrase - the text of the annotated phrase + first_word_index - the position of the first word of + the phrase in the sentence + phrase_id - an identifier for this phrase + phrase_type - a list of the coarse categories this + phrase belongs to + + """ + with open(filename, "r") as f: + sentences = f.read().split("\n") + + annotations = [] + for sentence in sentences: + if not sentence: + continue + + first_word = [] + phrases = [] + phrase_id = [] + phrase_type = [] + words = [] + current_phrase = [] + add_to_phrase = False + for token in sentence.split(): + if add_to_phrase: + if token[-1] == "]": + add_to_phrase = False + token = token[:-1] + current_phrase.append(token) + phrases.append(" ".join(current_phrase)) + current_phrase = [] + else: + current_phrase.append(token) + + words.append(token) + else: + if token[0] == "[": + add_to_phrase = True + first_word.append(len(words)) + parts = token.split("/") + phrase_id.append(parts[1][3:]) + phrase_type.append(parts[2:]) + else: + words.append(token) + + sentence_data = {"sentence": " ".join(words), "phrases": []} + for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): + sentence_data["phrases"].append( + {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type} + ) + + annotations.append(sentence_data) + + return annotations + + +def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]: + """ + Parses the xml files in the Flickr30K Entities dataset + + input: + filename - full file path to the annotations file to parse + + output: + dictionary with the following fields: + scene - list of identifiers which were annotated as + pertaining to the whole scene + nobox - list of identifiers which were annotated as + not being visible in the image + boxes - a dictionary where the fields are identifiers + and the values are its list of boxes in the + [xmin ymin xmax ymax] format + height - int representing the height of the image + width - int representing the width of the image + depth - int representing the depth of the image + """ + tree = ET.parse(filename) + root = tree.getroot() + size_container = root.findall("size")[0] + anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {} + all_boxes: Dict[str, List[List[int]]] = {} + all_noboxes: List[str] = [] + all_scenes: List[str] = [] + for size_element in size_container: + assert size_element.text + anno_info[size_element.tag] = int(size_element.text) + + for object_container in root.findall("object"): + for names in object_container.findall("name"): + box_id = names.text + assert box_id + box_container = object_container.findall("bndbox") + if len(box_container) > 0: + if box_id not in all_boxes: + all_boxes[box_id] = [] + xmin = int(box_container[0].findall("xmin")[0].text) + ymin = int(box_container[0].findall("ymin")[0].text) + xmax = int(box_container[0].findall("xmax")[0].text) + ymax = int(box_container[0].findall("ymax")[0].text) + all_boxes[box_id].append([xmin, ymin, xmax, ymax]) + else: + nobndbox = int(object_container.findall("nobndbox")[0].text) + if nobndbox > 0: + all_noboxes.append(box_id) + + scene = int(object_container.findall("scene")[0].text) + if scene > 0: + all_scenes.append(box_id) + anno_info["boxes"] = all_boxes + anno_info["nobox"] = all_noboxes + anno_info["scene"] = all_scenes + + return anno_info + + +#### END of import from flickr30k_entities + + +#### Bounding box utilities imported from torchvision and converted to numpy +def box_area(boxes: np.array) -> np.array: + """ + Computes the area of a set of bounding boxes, which are specified by its + (x1, y1, x2, y2) coordinates. + + Args: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (x1, y1, x2, y2) format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Returns: + area (Tensor[N]): area for each box + """ + assert boxes.ndim == 2 and boxes.shape[-1] == 4 + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clip(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + return inter, union + + +def box_iou(boxes1: np.array, boxes2: np.array) -> np.array: + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]) + boxes2 (Tensor[M, 4]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + inter, union = _box_inter_union(boxes1, boxes2) + iou = inter / union + return iou + + +#### End of import of box utilities + +def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]: + """ + Return the boxes corresponding to the smallest enclosing box containing all the provided boxes + The boxes are expected in [x1, y1, x2, y2] format + """ + if len(boxes) == 1: + return boxes + + np_boxes = np.asarray(boxes) + + return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]] + + +class RecallTracker: + """ Utility class to track recall@k for various k, split by categories""" + + def __init__(self, topk: Sequence[int]): + """ + Parameters: + - topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...) + """ + + self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} + self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} + + def add_positive(self, k: int, category: str): + """Log a positive hit @k for given category""" + if k not in self.total_byk_bycat: + raise RuntimeError(f"{k} is not a valid recall threshold") + self.total_byk_bycat[k][category] += 1 + self.positives_byk_bycat[k][category] += 1 + + def add_negative(self, k: int, category: str): + """Log a negative hit @k for given category""" + if k not in self.total_byk_bycat: + raise RuntimeError(f"{k} is not a valid recall threshold") + self.total_byk_bycat[k][category] += 1 + + def report(self) -> Dict[int, Dict[str, float]]: + """Return a condensed report of the results as a dict of dict. + report[k][cat] is the recall@k for the given category + """ + report: Dict[int, Dict[str, float]] = {} + for k in self.total_byk_bycat: + assert k in self.positives_byk_bycat + report[k] = { + cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k] + } + return report + + +class Flickr30kEntitiesRecallEvaluator: + def __init__( + self, + flickr_path: str, + subset: str = "test", + topk: Sequence[int] = (1, 5, 10, -1), + iou_thresh: float = 0.5, + merge_boxes: bool = False, + verbose: bool = True, + ): + assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}" + + self.topk = topk + self.iou_thresh = iou_thresh + + flickr_path = Path(flickr_path) + + # We load the image ids corresponding to the current subset + with open(flickr_path / f"{subset}.txt") as file_d: + self.img_ids = [line.strip() for line in file_d] + + if verbose: + print(f"Flickr subset contains {len(self.img_ids)} images") + + # Read the box annotations for all the images + self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {} + + if verbose: + print("Loading annotations...") + + for img_id in self.img_ids: + anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"] + if merge_boxes: + merged = {} + for phrase_id, boxes in anno_info.items(): + merged[phrase_id] = _merge_boxes(boxes) + anno_info = merged + self.imgid2boxes[img_id] = anno_info + + # Read the sentences annotations + self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {} + + if verbose: + print("Loading annotations...") + + self.all_ids: List[str] = [] + tot_phrases = 0 + for img_id in self.img_ids: + sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt") + self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))] + + # Some phrases don't have boxes, we filter them. + for sent_id, sentence in enumerate(sentence_info): + phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]] + if len(phrases) > 0: + self.imgid2sentences[img_id][sent_id] = phrases + tot_phrases += len(phrases) + + self.all_ids += [ + f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None + ] + + if verbose: + print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate") + + def evaluate(self, predictions: List[Dict]): + evaluated_ids = set() + + recall_tracker = RecallTracker(self.topk) + + for pred in predictions: + cur_id = f"{pred['image_id']}_{pred['sentence_id']}" + if cur_id in evaluated_ids: + print( + "Warning, multiple predictions found for sentence" + f"{pred['sentence_id']} in image {pred['image_id']}" + ) + continue + + # Skip the sentences with no valid phrase + if cur_id not in self.all_ids: + if len(pred["boxes"]) != 0: + print( + f"Warning, in image {pred['image_id']} we were not expecting predictions " + f"for sentence {pred['sentence_id']}. Ignoring them." + ) + continue + + evaluated_ids.add(cur_id) + + pred_boxes = pred["boxes"] + if str(pred["image_id"]) not in self.imgid2sentences: + raise RuntimeError(f"Unknown image id {pred['image_id']}") + if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]): + raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}") + target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] + + phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] + if len(pred_boxes) != len(phrases): + raise RuntimeError( + f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} " + f"for sentence {pred['sentence_id']} in image {pred['image_id']}" + ) + + for cur_boxes, phrase in zip(pred_boxes, phrases): + target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]] + + ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes)) + for k in self.topk: + maxi = 0 + if k == -1: + maxi = ious.max() + else: + assert k > 0 + maxi = ious[:k].max() + if maxi >= self.iou_thresh: + recall_tracker.add_positive(k, "all") + for phrase_type in phrase["phrase_type"]: + recall_tracker.add_positive(k, phrase_type) + else: + recall_tracker.add_negative(k, "all") + for phrase_type in phrase["phrase_type"]: + recall_tracker.add_negative(k, phrase_type) + + if len(evaluated_ids) != len(self.all_ids): + print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:") + un_processed = set(self.all_ids) - evaluated_ids + for missing in un_processed: + img_id, sent_id = missing.split("_") + print(f"\t sentence {sent_id} in image {img_id}") + raise RuntimeError("Missing predictions") + + return recall_tracker.report() + + +class FlickrEvaluator(object): + def __init__( + self, + flickr_path, + subset, + top_k=(1, 5, 10, -1), + iou_thresh=0.5, + merge_boxes=False, + ): + assert isinstance(top_k, (list, tuple)) + + self.evaluator = Flickr30kEntitiesRecallEvaluator( + flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False + ) + self.predictions = [] + self.results = None + + def accumulate(self): + pass + + def update(self, predictions): + self.predictions += predictions + + def synchronize_between_processes(self): + all_predictions = dist.all_gather(self.predictions) + self.predictions = sum(all_predictions, []) + + def summarize(self): + if dist.is_main_process(): + self.results = self.evaluator.evaluate(self.predictions) + table = PrettyTable() + all_cat = sorted(list(self.results.values())[0].keys()) + table.field_names = ["Recall@k"] + all_cat + + score = {} + for k, v in self.results.items(): + cur_results = [v[cat] for cat in all_cat] + header = "Upper_bound" if k == -1 else f"Recall@{k}" + + for cat in all_cat: + score[f"{header}_{cat}"] = v[cat] + table.add_row([header] + cur_results) + + print(table) + + return score + + return None, None diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..11332d93c1b85fc3df3b6a2480cb1be0e610bac4 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/_change_lvis_annotation.py @@ -0,0 +1,10 @@ +path = "DATASET/coco/annotations/lvis_v1_minival.json" +import json +with open(path) as f: + all = json.load(f) + +for i in all["images"]: + i["file_name"] = "/".join(i["coco_url"].split("/")[-2:]) + +with open("DATASET/coco/annotations/lvis_v1_minival_inserted_image_name.json", "w") as f: + json.dump(all, f) \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9288c85562667527e6f41d97f3201c6b71a305 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis.py @@ -0,0 +1,207 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json +import os +import time +from collections import defaultdict + +import pycocotools.mask as mask_utils +import torchvision +from PIL import Image + + + +def _isArrayLike(obj): + return hasattr(obj, "__iter__") and hasattr(obj, "__len__") + + +class LVIS: + def __init__(self, annotation_path=None): + """Class for reading and visualizing annotations. + Args: + annotation_path (str): location of annotation file + """ + self.anns = {} + self.cats = {} + self.imgs = {} + self.img_ann_map = defaultdict(list) + self.cat_img_map = defaultdict(list) + self.dataset = {} + + if annotation_path is not None: + print("Loading annotations.") + + tic = time.time() + self.dataset = self._load_json(annotation_path) + print("Done (t={:0.2f}s)".format(time.time() - tic)) + + assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset)) + self._create_index() + + def _load_json(self, path): + with open(path, "r") as f: + return json.load(f) + + def _create_index(self): + print("Creating index.") + + self.img_ann_map = defaultdict(list) + self.cat_img_map = defaultdict(list) + + self.anns = {} + self.cats = {} + self.imgs = {} + + for ann in self.dataset["annotations"]: + self.img_ann_map[ann["image_id"]].append(ann) + self.anns[ann["id"]] = ann + + for img in self.dataset["images"]: + self.imgs[img["id"]] = img + + for cat in self.dataset["categories"]: + self.cats[cat["id"]] = cat + + for ann in self.dataset["annotations"]: + self.cat_img_map[ann["category_id"]].append(ann["image_id"]) + + print("Index created.") + + def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None): + """Get ann ids that satisfy given filter conditions. + Args: + img_ids (int array): get anns for given imgs + cat_ids (int array): get anns for given cats + area_rng (float array): get anns for a given area range. e.g [0, inf] + Returns: + ids (int array): integer array of ann ids + """ + if img_ids is not None: + img_ids = img_ids if _isArrayLike(img_ids) else [img_ids] + if cat_ids is not None: + cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids] + anns = [] + if img_ids is not None: + for img_id in img_ids: + anns.extend(self.img_ann_map[img_id]) + else: + anns = self.dataset["annotations"] + + # return early if no more filtering required + if cat_ids is None and area_rng is None: + return [_ann["id"] for _ann in anns] + + cat_ids = set(cat_ids) + + if area_rng is None: + area_rng = [0, float("inf")] + + ann_ids = [ + _ann["id"] + for _ann in anns + if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1] + ] + return ann_ids + + def get_cat_ids(self): + """Get all category ids. + Returns: + ids (int array): integer array of category ids + """ + return list(self.cats.keys()) + + def get_img_ids(self): + """Get all img ids. + Returns: + ids (int array): integer array of image ids + """ + return list(self.imgs.keys()) + + def _load_helper(self, _dict, ids): + if ids is None: + return list(_dict.values()) + elif _isArrayLike(ids): + return [_dict[id] for id in ids] + else: + return [_dict[ids]] + + def load_anns(self, ids=None): + """Load anns with the specified ids. If ids=None load all anns. + Args: + ids (int array): integer array of annotation ids + Returns: + anns (dict array) : loaded annotation objects + """ + return self._load_helper(self.anns, ids) + + def load_cats(self, ids): + """Load categories with the specified ids. If ids=None load all + categories. + Args: + ids (int array): integer array of category ids + Returns: + cats (dict array) : loaded category dicts + """ + return self._load_helper(self.cats, ids) + + def load_imgs(self, ids): + """Load categories with the specified ids. If ids=None load all images. + Args: + ids (int array): integer array of image ids + Returns: + imgs (dict array) : loaded image dicts + """ + return self._load_helper(self.imgs, ids) + + def download(self, save_dir, img_ids=None): + """Download images from mscoco.org server. + Args: + save_dir (str): dir to save downloaded images + img_ids (int array): img ids of images to download + """ + imgs = self.load_imgs(img_ids) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + for img in imgs: + file_name = os.path.join(save_dir, img["file_name"]) + if not os.path.exists(file_name): + from urllib.request import urlretrieve + + urlretrieve(img["coco_url"], file_name) + + def ann_to_rle(self, ann): + """Convert annotation which can be polygons, uncompressed RLE to RLE. + Args: + ann (dict) : annotation object + Returns: + ann (rle) + """ + img_data = self.imgs[ann["image_id"]] + h, w = img_data["height"], img_data["width"] + segm = ann["segmentation"] + if isinstance(segm, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_utils.frPyObjects(segm, h, w) + rle = mask_utils.merge(rles) + elif isinstance(segm["counts"], list): + # uncompressed RLE + rle = mask_utils.frPyObjects(segm, h, w) + else: + # rle + rle = ann["segmentation"] + return rle + + def ann_to_mask(self, ann): + """Convert annotation which can be polygons, uncompressed RLE, or RLE + to binary mask. + Args: + ann (dict) : annotation object + Returns: + binary mask (numpy 2D array) + """ + rle = self.ann_to_rle(ann) + return mask_utils.decode(rle) + diff --git a/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..6eeca5d2b4cb68bcda1dbd96ae25715ae2deb120 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/lvis/lvis_eval.py @@ -0,0 +1,998 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy +import datetime +import json +import os +from collections import OrderedDict, defaultdict + +import numpy as np +import pycocotools.mask as mask_util +import torch +import torch._six + +import maskrcnn_benchmark.utils.mdetr_dist as dist + +from maskrcnn_benchmark.utils.mdetr_dist import all_gather + + +from .lvis import LVIS + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +################################################################# +# From LVIS, with following changes: +# * fixed LVISEval constructor to accept empty dt +# * Removed logger +# * LVIS results supports numpy inputs +################################################################# + + +class Params: + def __init__(self, iou_type): + """Params for LVIS evaluation API.""" + self.img_ids = [] + self.cat_ids = [] + # np.arange causes trouble. the data point on arange is slightly + # larger than the true value + self.iou_thrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True) + self.rec_thrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True) + self.max_dets = 300 + self.area_rng = [ + [0 ** 2, 1e5 ** 2], + [0 ** 2, 32 ** 2], + [32 ** 2, 96 ** 2], + [96 ** 2, 1e5 ** 2], + ] + self.area_rng_lbl = ["all", "small", "medium", "large"] + self.use_cats = 1 + # We bin categories in three bins based how many images of the training + # set the category is present in. + # r: Rare : < 10 + # c: Common : >= 10 and < 100 + # f: Frequent: >= 100 + self.img_count_lbl = ["r", "c", "f"] + self.iou_type = iou_type + + +class LVISResults(LVIS): + def __init__(self, lvis_gt, results, max_dets=300): + """Constructor for LVIS results. + Args: + lvis_gt (LVIS class instance, or str containing path of + annotation file) + results (str containing path of result file or a list of dicts) + max_dets (int): max number of detections per image. The official + value of max_dets for LVIS is 300. + """ + super(LVISResults, self).__init__() + assert isinstance(lvis_gt, LVIS) + self.dataset["images"] = [img for img in lvis_gt.dataset["images"]] + + if isinstance(results, str): + result_anns = self._load_json(results) + elif type(results) == np.ndarray: + result_anns = self.loadNumpyAnnotations(results) + else: + result_anns = results + + if max_dets >= 0: + result_anns = self.limit_dets_per_image(result_anns, max_dets) + + if len(result_anns) > 0 and "bbox" in result_anns[0]: + self.dataset["categories"] = copy.deepcopy(lvis_gt.dataset["categories"]) + for id, ann in enumerate(result_anns): + x1, y1, w, h = ann["bbox"] + x2 = x1 + w + y2 = y1 + h + + if "segmentation" not in ann: + ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + + ann["area"] = w * h + ann["id"] = id + 1 + + elif len(result_anns) > 0 and "segmentation" in result_anns[0]: + self.dataset["categories"] = copy.deepcopy(lvis_gt.dataset["categories"]) + for id, ann in enumerate(result_anns): + # Only support compressed RLE format as segmentation results + ann["area"] = mask_util.area(ann["segmentation"]) + + if "bbox" not in ann: + ann["bbox"] = mask_util.toBbox(ann["segmentation"]) + + ann["id"] = id + 1 + + self.dataset["annotations"] = result_anns + self._create_index() + + # #FIXME: disabling this check for now + # img_ids_in_result = [ann["image_id"] for ann in result_anns] + + # assert set(img_ids_in_result) == ( + # set(img_ids_in_result) & set(self.get_img_ids()) + # ), "Results do not correspond to current LVIS set." + + def limit_dets_per_image(self, anns, max_dets): + img_ann = defaultdict(list) + for ann in anns: + img_ann[ann["image_id"]].append(ann) + + for img_id, _anns in img_ann.items(): + if len(_anns) <= max_dets: + continue + _anns = sorted(_anns, key=lambda ann: ann["score"], reverse=True) + img_ann[img_id] = _anns[:max_dets] + + return [ann for anns in img_ann.values() for ann in anns] + + def get_top_results(self, img_id, score_thrs): + ann_ids = self.get_ann_ids(img_ids=[img_id]) + anns = self.load_anns(ann_ids) + return list(filter(lambda ann: ann["score"] > score_thrs, anns)) + + +class LVISEval: + def __init__(self, lvis_gt, lvis_dt=None, iou_type="segm"): + """Constructor for LVISEval. + Args: + lvis_gt (LVIS class instance, or str containing path of annotation file) + lvis_dt (LVISResult class instance, or str containing path of result file, + or list of dict) + iou_type (str): segm or bbox evaluation + """ + + if iou_type not in ["bbox", "segm"]: + raise ValueError("iou_type: {} is not supported.".format(iou_type)) + + if isinstance(lvis_gt, LVIS): + self.lvis_gt = lvis_gt + elif isinstance(lvis_gt, str): + self.lvis_gt = LVIS(lvis_gt) + else: + raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt)) + + if isinstance(lvis_dt, LVISResults): + self.lvis_dt = lvis_dt + elif isinstance(lvis_dt, (str, list)): + self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt) + elif lvis_dt is not None: + raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt)) + + # per-image per-category evaluation results + self.eval_imgs = defaultdict(list) + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Params(iou_type=iou_type) # parameters + self.results = OrderedDict() + self.stats = [] + self.ious = {} # ious between all gts and dts + + self.params.img_ids = sorted(self.lvis_gt.get_img_ids()) + self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids()) + + def _to_mask(self, anns, lvis): + for ann in anns: + rle = lvis.ann_to_rle(ann) + ann["segmentation"] = rle + + def _prepare(self): + """Prepare self._gts and self._dts for evaluation based on params.""" + + cat_ids = self.params.cat_ids if self.params.cat_ids else None + + gts = self.lvis_gt.load_anns(self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)) + dts = self.lvis_dt.load_anns(self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids)) + # convert ground truth to mask if iou_type == 'segm' + if self.params.iou_type == "segm": + self._to_mask(gts, self.lvis_gt) + self._to_mask(dts, self.lvis_dt) + + # set ignore flag + for gt in gts: + if "ignore" not in gt: + gt["ignore"] = 0 + + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + + # For federated dataset evaluation we will filter out all dt for an + # image which belong to categories not present in gt and not present in + # the negative list for an image. In other words detector is not penalized + # for categories about which we don't have gt information about their + # presence or absence in an image. + img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids) + # per image map of categories not present in image + img_nl = {d["id"]: d["neg_category_ids"] for d in img_data} + # per image list of categories present in image + img_pl = defaultdict(set) + for ann in gts: + img_pl[ann["image_id"]].add(ann["category_id"]) + # per image map of categoires which have missing gt. For these + # categories we don't penalize the detector for flase positives. + self.img_nel = {d["id"]: d["not_exhaustive_category_ids"] for d in img_data} + + for dt in dts: + img_id, cat_id = dt["image_id"], dt["category_id"] + if cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]: + continue + self._dts[img_id, cat_id].append(dt) + + self.freq_groups = self._prepare_freq_group() + + def _prepare_freq_group(self): + freq_groups = [[] for _ in self.params.img_count_lbl] + cat_data = self.lvis_gt.load_cats(self.params.cat_ids) + for idx, _cat_data in enumerate(cat_data): + frequency = _cat_data["frequency"] + freq_groups[self.params.img_count_lbl.index(frequency)].append(idx) + return freq_groups + + def evaluate(self): + """ + Run per image evaluation on given images and store results + (a list of dict) in self.eval_imgs. + """ + + self.params.img_ids = list(np.unique(self.params.img_ids)) + + if self.params.use_cats: + cat_ids = self.params.cat_ids + else: + cat_ids = [-1] + + self._prepare() + + self.ious = { + (img_id, cat_id): self.compute_iou(img_id, cat_id) for img_id in self.params.img_ids for cat_id in cat_ids + } + + # loop through images, area range, max detection number + self.eval_imgs = [ + self.evaluate_img(img_id, cat_id, area_rng) + for cat_id in cat_ids + for area_rng in self.params.area_rng + for img_id in self.params.img_ids + ] + + def _get_gt_dt(self, img_id, cat_id): + """Create gt, dt which are list of anns/dets. If use_cats is true + only anns/dets corresponding to tuple (img_id, cat_id) will be + used. Else, all anns/dets in image are used and cat_id is not used. + """ + if self.params.use_cats: + gt = self._gts[img_id, cat_id] + dt = self._dts[img_id, cat_id] + else: + gt = [_ann for _cat_id in self.params.cat_ids for _ann in self._gts[img_id, cat_id]] + dt = [_ann for _cat_id in self.params.cat_ids for _ann in self._dts[img_id, cat_id]] + return gt, dt + + def compute_iou(self, img_id, cat_id): + gt, dt = self._get_gt_dt(img_id, cat_id) + + if len(gt) == 0 and len(dt) == 0: + return [] + + # Sort detections in decreasing order of score. + idx = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in idx] + + iscrowd = [int(False)] * len(gt) + + if self.params.iou_type == "segm": + ann_type = "segmentation" + elif self.params.iou_type == "bbox": + ann_type = "bbox" + else: + raise ValueError("Unknown iou_type for iou computation.") + gt = [g[ann_type] for g in gt] + dt = [d[ann_type] for d in dt] + + # compute iou between each dt and gt region + # will return array of shape len(dt), len(gt) + ious = mask_util.iou(dt, gt, iscrowd) + return ious + + def evaluate_img(self, img_id, cat_id, area_rng): + """Perform evaluation for single category and image.""" + gt, dt = self._get_gt_dt(img_id, cat_id) + + if len(gt) == 0 and len(dt) == 0: + return None + + # Add another filed _ignore to only consider anns based on area range. + for g in gt: + if g["ignore"] or (g["area"] < area_rng[0] or g["area"] > area_rng[1]): + g["_ignore"] = 1 + else: + g["_ignore"] = 0 + + # Sort gt ignore last + gt_idx = np.argsort([g["_ignore"] for g in gt], kind="mergesort") + gt = [gt[i] for i in gt_idx] + + # Sort dt highest score first + dt_idx = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in dt_idx] + + # load computed ious + ious = self.ious[img_id, cat_id][:, gt_idx] if len(self.ious[img_id, cat_id]) > 0 else self.ious[img_id, cat_id] + + num_thrs = len(self.params.iou_thrs) + num_gt = len(gt) + num_dt = len(dt) + + # Array to store the "id" of the matched dt/gt + gt_m = np.zeros((num_thrs, num_gt)) + dt_m = np.zeros((num_thrs, num_dt)) + + gt_ig = np.array([g["_ignore"] for g in gt]) + dt_ig = np.zeros((num_thrs, num_dt)) + + for iou_thr_idx, iou_thr in enumerate(self.params.iou_thrs): + if len(ious) == 0: + break + + for dt_idx, _dt in enumerate(dt): + iou = min([iou_thr, 1 - 1e-10]) + # information about best match so far (m=-1 -> unmatched) + # store the gt_idx which matched for _dt + m = -1 + for gt_idx, _ in enumerate(gt): + # if this gt already matched continue + if gt_m[iou_thr_idx, gt_idx] > 0: + continue + # if _dt matched to reg gt, and on ignore gt, stop + if m > -1 and gt_ig[m] == 0 and gt_ig[gt_idx] == 1: + break + # continue to next gt unless better match made + if ious[dt_idx, gt_idx] < iou: + continue + # if match successful and best so far, store appropriately + iou = ious[dt_idx, gt_idx] + m = gt_idx + + # No match found for _dt, go to next _dt + if m == -1: + continue + + # if gt to ignore for some reason update dt_ig. + # Should not be used in evaluation. + dt_ig[iou_thr_idx, dt_idx] = gt_ig[m] + # _dt match found, update gt_m, and dt_m with "id" + dt_m[iou_thr_idx, dt_idx] = gt[m]["id"] + gt_m[iou_thr_idx, m] = _dt["id"] + + # For LVIS we will ignore any unmatched detection if that category was + # not exhaustively annotated in gt. + dt_ig_mask = [ + d["area"] < area_rng[0] or d["area"] > area_rng[1] or d["category_id"] in self.img_nel[d["image_id"]] + for d in dt + ] + dt_ig_mask = np.array(dt_ig_mask).reshape((1, num_dt)) # 1 X num_dt + dt_ig_mask = np.repeat(dt_ig_mask, num_thrs, 0) # num_thrs X num_dt + # Based on dt_ig_mask ignore any unmatched detection by updating dt_ig + dt_ig = np.logical_or(dt_ig, np.logical_and(dt_m == 0, dt_ig_mask)) + # store results for given image and category + return { + "image_id": img_id, + "category_id": cat_id, + "area_rng": area_rng, + "dt_ids": [d["id"] for d in dt], + "gt_ids": [g["id"] for g in gt], + "dt_matches": dt_m, + "gt_matches": gt_m, + "dt_scores": [d["score"] for d in dt], + "gt_ignore": gt_ig, + "dt_ignore": dt_ig, + } + + def accumulate(self): + """Accumulate per image evaluation results and store the result in + self.eval. + """ + + if not self.eval_imgs: + print("Warning: Please run evaluate first.") + + if self.params.use_cats: + cat_ids = self.params.cat_ids + else: + cat_ids = [-1] + + num_thrs = len(self.params.iou_thrs) + num_recalls = len(self.params.rec_thrs) + num_cats = len(cat_ids) + num_area_rngs = len(self.params.area_rng) + num_imgs = len(self.params.img_ids) + + # -1 for absent categories + precision = -np.ones((num_thrs, num_recalls, num_cats, num_area_rngs)) + recall = -np.ones((num_thrs, num_cats, num_area_rngs)) + + # Initialize dt_pointers + dt_pointers = {} + for cat_idx in range(num_cats): + dt_pointers[cat_idx] = {} + for area_idx in range(num_area_rngs): + dt_pointers[cat_idx][area_idx] = {} + + # Per category evaluation + for cat_idx in range(num_cats): + Nk = cat_idx * num_area_rngs * num_imgs + for area_idx in range(num_area_rngs): + Na = area_idx * num_imgs + E = [self.eval_imgs[Nk + Na + img_idx] for img_idx in range(num_imgs)] + # Remove elements which are None + E = [e for e in E if e is not None] + if len(E) == 0: + continue + + # Append all scores: shape (N,) + dt_scores = np.concatenate([e["dt_scores"] for e in E], axis=0) + dt_ids = np.concatenate([e["dt_ids"] for e in E], axis=0) + + dt_idx = np.argsort(-dt_scores, kind="mergesort") + dt_scores = dt_scores[dt_idx] + dt_ids = dt_ids[dt_idx] + + dt_m = np.concatenate([e["dt_matches"] for e in E], axis=1)[:, dt_idx] + dt_ig = np.concatenate([e["dt_ignore"] for e in E], axis=1)[:, dt_idx] + + gt_ig = np.concatenate([e["gt_ignore"] for e in E]) + # num gt anns to consider + num_gt = np.count_nonzero(gt_ig == 0) + + if num_gt == 0: + continue + + tps = np.logical_and(dt_m, np.logical_not(dt_ig)) + fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig)) + + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + + dt_pointers[cat_idx][area_idx] = { + "dt_ids": dt_ids, + "tps": tps, + "fps": fps, + } + + for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + tp = np.array(tp) + fp = np.array(fp) + num_tp = len(tp) + rc = tp / num_gt + if num_tp: + recall[iou_thr_idx, cat_idx, area_idx] = rc[-1] + else: + recall[iou_thr_idx, cat_idx, area_idx] = 0 + + # np.spacing(1) ~= eps + pr = tp / (fp + tp + np.spacing(1)) + pr = pr.tolist() + + # Replace each precision value with the maximum precision + # value to the right of that recall level. This ensures + # that the calculated AP value will be less suspectable + # to small variations in the ranking. + for i in range(num_tp - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + rec_thrs_insert_idx = np.searchsorted(rc, self.params.rec_thrs, side="left") + + pr_at_recall = [0.0] * num_recalls + + try: + for _idx, pr_idx in enumerate(rec_thrs_insert_idx): + pr_at_recall[_idx] = pr[pr_idx] + except Exception: + pass + precision[iou_thr_idx, :, cat_idx, area_idx] = np.array(pr_at_recall) + + self.eval = { + "params": self.params, + "counts": [num_thrs, num_recalls, num_cats, num_area_rngs], + "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "precision": precision, + "recall": recall, + "dt_pointers": dt_pointers, + } + + def _summarize(self, summary_type, iou_thr=None, area_rng="all", freq_group_idx=None): + aidx = [idx for idx, _area_rng in enumerate(self.params.area_rng_lbl) if _area_rng == area_rng] + + if summary_type == "ap": + s = self.eval["precision"] + if iou_thr is not None: + tidx = np.where(iou_thr == self.params.iou_thrs)[0] + s = s[tidx] + if freq_group_idx is not None: + s = s[:, :, self.freq_groups[freq_group_idx], aidx] + else: + s = s[:, :, :, aidx] + else: + s = self.eval["recall"] + if iou_thr is not None: + tidx = np.where(iou_thr == self.params.iou_thrs)[0] + s = s[tidx] + s = s[:, :, aidx] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + return mean_s + + def summarize(self): + """Compute and display summary metrics for evaluation results.""" + if not self.eval: + raise RuntimeError("Please run accumulate() first.") + + max_dets = self.params.max_dets + + self.results["AP"] = self._summarize("ap") + self.results["AP50"] = self._summarize("ap", iou_thr=0.50) + self.results["AP75"] = self._summarize("ap", iou_thr=0.75) + self.results["APs"] = self._summarize("ap", area_rng="small") + self.results["APm"] = self._summarize("ap", area_rng="medium") + self.results["APl"] = self._summarize("ap", area_rng="large") + self.results["APr"] = self._summarize("ap", freq_group_idx=0) + self.results["APc"] = self._summarize("ap", freq_group_idx=1) + self.results["APf"] = self._summarize("ap", freq_group_idx=2) + + self.stats = np.zeros((9,)) + self.stats[0] = self._summarize("ap") + self.stats[1] = self._summarize("ap", iou_thr=0.50) + self.stats[2] = self._summarize("ap", iou_thr=0.75) + self.stats[3] = self._summarize("ap", area_rng="small") + self.stats[4] = self._summarize("ap", area_rng="medium") + self.stats[5] = self._summarize("ap", area_rng="large") + self.stats[6] = self._summarize("ap", freq_group_idx=0) + self.stats[7] = self._summarize("ap", freq_group_idx=1) + self.stats[8] = self._summarize("ap", freq_group_idx=2) + + key = "AR@{}".format(max_dets) + self.results[key] = self._summarize("ar") + + for area_rng in ["small", "medium", "large"]: + key = "AR{}@{}".format(area_rng[0], max_dets) + self.results[key] = self._summarize("ar", area_rng=area_rng) + _returned = self.print_results() + return _returned + + def run(self): + """Wrapper function which calculates the results.""" + self.evaluate() + self.accumulate() + self.summarize() + + def print_results(self): + template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}" + out_strings = [] + for key, value in self.results.items(): + max_dets = self.params.max_dets + if "AP" in key: + title = "Average Precision" + _type = "(AP)" + else: + title = "Average Recall" + _type = "(AR)" + + if len(key) > 2 and key[2].isdigit(): + iou_thr = float(key[2:]) / 100 + iou = "{:0.2f}".format(iou_thr) + else: + iou = "{:0.2f}:{:0.2f}".format(self.params.iou_thrs[0], self.params.iou_thrs[-1]) + + if len(key) > 2 and key[2] in ["r", "c", "f"]: + cat_group_name = key[2] + else: + cat_group_name = "all" + + if len(key) > 2 and key[2] in ["s", "m", "l"]: + area_rng = key[2] + else: + area_rng = "all" + + print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value)) + out_strings.append(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value)) + return out_strings + + def get_results(self): + if not self.results: + print("Warning: results is empty. Call run().") + return self.results + + +################################################################# +# end of straight copy from lvis, just fixing constructor +################################################################# + + +class LvisEvaluator(object): + def __init__(self, lvis_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + # lvis_gt = copy.deepcopy(lvis_gt) + self.lvis_gt = lvis_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = LVISEval(lvis_gt, iou_type=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + lvis_dt = LVISResults(self.lvis_gt, results) + lvis_eval = self.coco_eval[iou_type] + + lvis_eval.lvis_dt = lvis_dt + lvis_eval.params.img_ids = list(img_ids) + lvis_eval.evaluate() + eval_imgs = lvis_eval.eval_imgs + eval_imgs = np.asarray(eval_imgs).reshape( + len(lvis_eval.params.cat_ids), len(lvis_eval.params.area_rng), len(lvis_eval.params.img_ids) + ) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_lvis_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for lvis_eval in self.coco_eval.values(): + lvis_eval.accumulate() + + def summarize(self): + for iou_type, lvis_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + lvis_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_lvis_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_lvis_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_lvis_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_lvis_detection(self, predictions): + lvis_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + lvis_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return lvis_results + + def prepare_for_lvis_segmentation(self, predictions): + lvis_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + lvis_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return lvis_results + + +def _merge_lists(listA, listB, maxN, key): + result = [] + indA, indB = 0, 0 + while (indA < len(listA) or indB < len(listB)) and len(result) < maxN: + if (indB < len(listB)) and (indA >= len(listA) or key(listA[indA]) < key(listB[indB])): + result.append(listB[indB]) + indB += 1 + else: + result.append(listA[indA]) + indA += 1 + return result + + +# Adapted from https://github.com/achalddave/large-vocab-devil/blob/9aaddc15b00e6e0d370b16743233e40d973cd53f/scripts/evaluate_ap_fixed.py +class LvisEvaluatorFixedAP(object): + def __init__(self, gt: LVIS, topk=10000, fixed_ap=True): + + self.results = [] + self.by_cat = {} + self.gt = gt + self.topk = topk + self.fixed_ap = fixed_ap + + def update(self, predictions): + cur_results = self.prepare(predictions) + if self.fixed_ap: + by_cat = defaultdict(list) + for ann in cur_results: + by_cat[ann["category_id"]].append(ann) + + for cat, cat_anns in by_cat.items(): + if cat not in self.by_cat: + self.by_cat[cat] = [] + + cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk] + self.by_cat[cat] = _merge_lists(self.by_cat[cat], cur, self.topk, key=lambda x: x["score"]) + else: + by_id = defaultdict(list) + for ann in cur_results: + by_id[ann["image_id"]].append(ann) + + for id_anns in by_id.values(): + self.results.extend(sorted(id_anns, key=lambda x: x["score"], reverse=True)[:300]) + + def synchronize_between_processes(self): + if self.fixed_ap: + all_cats = dist.all_gather(self.by_cat) + self.by_cat = defaultdict(list) + for cats in all_cats: + for cat, cat_anns in cats.items(): + self.by_cat[cat].extend(cat_anns) + else: + self.results = sum(dist.all_gather(self.results), []) + + def prepare(self, predictions): + lvis_results = [] + for original_id, prediction in predictions: + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + lvis_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return lvis_results + + def summarize(self): + if not dist.is_main_process(): + return + + if self.fixed_ap: + return self._summarize_fixed() + else: + return self._summarize_standard() + + def _summarize_standard(self): + results = LVISResults(self.gt, self.results) + lvis_eval = LVISEval(self.gt, results, iou_type="bbox") + lvis_eval.run() + lvis_eval.print_results() + + def _summarize_fixed(self): + results = [] + + missing_dets_cats = set() + for cat, cat_anns in self.by_cat.items(): + if len(cat_anns) < self.topk: + missing_dets_cats.add(cat) + results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]) + if missing_dets_cats: + print( + f"\n===\n" + f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n" + f"Outputting {self.topk} detections for each class will improve AP further.\n" + f"If using detectron2, please use the lvdevil/infer_topk.py script to " + f"output a results file with {self.topk} detections for each class.\n" + f"===" + ) + + results = LVISResults(self.gt, results, max_dets=-1) + lvis_eval = LVISEval(self.gt, results, iou_type="bbox") + params = lvis_eval.params + params.max_dets = -1 # No limit on detections per image. + lvis_eval.run() + scores = lvis_eval.print_results() + metrics = {k: v for k, v in lvis_eval.results.items() if k.startswith("AP")} + print("copypaste: %s,%s", ",".join(map(str, metrics.keys())), "path") + return scores + + +class LvisDumper(object): + def __init__(self, topk=10000, fixed_ap=True, out_path="lvis_eval"): + + self.results = [] + self.by_cat = {} + self.topk = topk + self.fixed_ap = fixed_ap + self.out_path = out_path + if dist.is_main_process(): + if not os.path.exists(self.out_path): + os.mkdir(self.out_path) + + def update(self, predictions): + cur_results = self.prepare(predictions) + if self.fixed_ap: + by_cat = defaultdict(list) + for ann in cur_results: + by_cat[ann["category_id"]].append(ann) + + for cat, cat_anns in by_cat.items(): + if cat not in self.by_cat: + self.by_cat[cat] = [] + + cur = sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk] + self.by_cat[cat] = _merge_lists(self.by_cat[cat], cur, self.topk, key=lambda x: x["score"]) + else: + by_id = defaultdict(list) + for ann in cur_results: + by_id[ann["image_id"]].append(ann) + + for id_anns in by_id.values(): + self.results.extend(sorted(id_anns, key=lambda x: x["score"], reverse=True)[:300]) + + def synchronize_between_processes(self): + if self.fixed_ap: + all_cats = dist.all_gather(self.by_cat) + self.by_cat = defaultdict(list) + for cats in all_cats: + for cat, cat_anns in cats.items(): + self.by_cat[cat].extend(cat_anns) + else: + self.results = sum(dist.all_gather(self.results), []) + + def prepare(self, predictions): + lvis_results = [] + for original_id, prediction in predictions: + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + lvis_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return lvis_results + + def summarize(self): + if not dist.is_main_process(): + return + + if self.fixed_ap: + self._summarize_fixed() + else: + self._summarize_standard() + + def _summarize_standard(self): + json_path = os.path.join(self.out_path, "results.json") + print("dumping to ", json_path) + with open(json_path, "w") as f: + json.dump(self.results, f) + + print("dumped") + + def _summarize_fixed(self): + results = [] + + missing_dets_cats = set() + for cat, cat_anns in self.by_cat.items(): + if len(cat_anns) < self.topk: + missing_dets_cats.add(cat) + results.extend(sorted(cat_anns, key=lambda x: x["score"], reverse=True)[: self.topk]) + if missing_dets_cats: + print( + f"\n===\n" + f"{len(missing_dets_cats)} classes had less than {self.topk} detections!\n" + f"Outputting {self.topk} detections for each class will improve AP further.\n" + f"If using detectron2, please use the lvdevil/infer_topk.py script to " + f"output a results file with {self.topk} detections for each class.\n" + f"===" + ) + + json_path = os.path.join(self.out_path, "results.json") + print("dumping to ", json_path) + with open(json_path, "w") as f: + json.dump(results, f) + + print("dumped") + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def create_common_lvis_eval(lvis_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + lvis_eval.eval_imgs = eval_imgs + lvis_eval.params.img_ids = img_ids + +def lvis_evaluation(): + pass \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/od_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ea88105b4480c4398ad6ab0864bd291fdf47ff --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/__init__.py @@ -0,0 +1,20 @@ +from .od_eval import do_od_evaluation + + +def od_to_grounding_evaluation( + dataset, + predictions, + output_folder, + box_only=False, + iou_types=("bbox",), + expected_results=(), + expected_results_sigma_tol=4, ): + return do_od_evaluation( + dataset=dataset, + predictions=predictions, + box_only=box_only, + output_folder=output_folder, + iou_types=iou_types, + expected_results=expected_results, + expected_results_sigma_tol=expected_results_sigma_tol, + ) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..443e0f7a2c70a48fedb54c9902a93a3ded15fcd0 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/od_to_grounding/od_eval.py @@ -0,0 +1,532 @@ +import logging +import tempfile +import os +import torch +import numpy as np +import json + +from collections import OrderedDict +from tqdm import tqdm + +from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou + + +def do_od_evaluation( + dataset, + predictions, + box_only, + output_folder, + iou_types, + expected_results, + expected_results_sigma_tol, +): + logger = logging.getLogger("maskrcnn_benchmark.inference") + + if box_only: + logger.info("Evaluating bbox proposals") + if dataset.coco is None and output_folder: + json_results = prepare_for_tsv_detection(predictions, dataset) + with open(os.path.join(output_folder, "box_proposals.json"), "w") as f: + json.dump(json_results, f) + return None + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + res = COCOResults("box_proposal") + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = evaluate_box_proposals( + predictions, dataset, area=area, limit=limit + ) + key = "AR{}@{:d}".format(suffix, limit) + res.results["box_proposal"][key] = stats["ar"].item() + logger.info(res) + check_expected_results(res, expected_results, expected_results_sigma_tol) + if output_folder: + torch.save(res, os.path.join(output_folder, "box_proposals.pth")) + return res, predictions + logger.info("Preparing results for COCO format") + coco_results = {} + if "bbox" in iou_types: + logger.info("Preparing bbox results") + if dataset.coco is None: + coco_results["bbox"] = prepare_for_tsv_detection(predictions, dataset) + else: + coco_results["bbox"] = prepare_for_coco_detection(predictions, dataset) + if "segm" in iou_types: + logger.info("Preparing segm results") + coco_results["segm"] = prepare_for_coco_segmentation(predictions, dataset) + if 'keypoints' in iou_types: + logger.info('Preparing keypoints results') + coco_results['keypoints'] = prepare_for_coco_keypoint(predictions, dataset) + + results = COCOResults(*iou_types) + logger.info("Evaluating predictions") + for iou_type in iou_types: + with tempfile.NamedTemporaryFile() as f: + file_path = f.name + if output_folder: + file_path = os.path.join(output_folder, iou_type + ".json") + if dataset.coco: + res = evaluate_predictions_on_coco( + dataset.coco, coco_results[iou_type], file_path, iou_type + ) + results.update(res) + elif output_folder: + with open(file_path, "w") as f: + json.dump(coco_results[iou_type], f) + + logger.info(results) + check_expected_results(results, expected_results, expected_results_sigma_tol) + if output_folder: + torch.save(results, os.path.join(output_folder, "coco_results.pth")) + return results, coco_results + + +def prepare_for_tsv_detection(predictions, dataset): + # assert isinstance(dataset, COCODataset) + proposal_results = [] + image_list = [] + for im_id, prediction in enumerate(predictions): + image_info = dataset.get_img_info(im_id) + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_id = image_info["id"] + image_width = image_info["width"] + image_height = image_info["height"] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") + + boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + if prediction.has_field("centers"): + centers = prediction.get_field("centers") + else: + centers = None + + for k, box in enumerate(boxes): + proposal = { + "image_id": image_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + "area": image_width * image_height, + "iscrowd": 0, + } + if centers is not None: + proposal.update(center=centers[k].tolist()) + proposal_results.append(proposal) + + image_list.append(image_info) + + # categories = [{'supercategory': 'proposal', 'id': 0, 'name': 'proposal'}] + return dict(images=image_list, annotations=proposal_results) + + +def prepare_for_coco_detection(predictions, dataset): + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") + + boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + + for k, box in enumerate(boxes): + if labels[k] in dataset.contiguous_category_id_to_json_id: + coco_results.append( + { + "image_id": original_id, + "category_id": dataset.contiguous_category_id_to_json_id[labels[k]], + "bbox": box, + "score": scores[k], + }) + + return coco_results + + +def prepare_for_coco_segmentation(predictions, dataset): + import pycocotools.mask as mask_util + import numpy as np + + masker = Masker(threshold=0.5, padding=1) + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in tqdm(enumerate(predictions)): + original_id = dataset.id_to_img_map[image_id] + if len(prediction) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + masks = prediction.get_field("mask") + # t = time.time() + # Masker is necessary only if masks haven't been already resized. + if list(masks.shape[-2:]) != [image_height, image_width]: + masks = masker(masks.expand(1, -1, -1, -1, -1), prediction) + masks = masks[0] + # logger.info('Time mask: {}'.format(time.time() - t)) + # prediction = prediction.convert('xywh') + + # boxes = prediction.bbox.tolist() + scores = prediction.get_field("scores").tolist() + labels = prediction.get_field("labels").tolist() + + # rles = prediction.get_field('mask') + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels] + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + +def prepare_for_coco_keypoint(predictions, dataset): + # assert isinstance(dataset, COCODataset) + coco_results = [] + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + if len(prediction.bbox) == 0: + continue + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]['width'] + image_height = dataset.coco.imgs[original_id]['height'] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert('xywh') + + boxes = prediction.bbox.tolist() + scores = prediction.get_field('scores').tolist() + labels = prediction.get_field('labels').tolist() + keypoints = prediction.get_field('keypoints') + keypoints = keypoints.resize((image_width, image_height)) + keypoints = keypoints.to_coco_format() + + mapped_labels = [dataset.contiguous_category_id_to_json_id[i] for i in labels] + + coco_results.extend([{ + 'image_id': original_id, + 'category_id': mapped_labels[k], + 'keypoints': keypoint, + 'score': scores[k]} for k, keypoint in enumerate(keypoints)]) + return coco_results + + +# inspired from Detectron +def evaluate_box_proposals( + predictions, dataset, thresholds=None, area="all", limit=None +): + """Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official COCO API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0 ** 2, 1e5 ** 2], # all + [0 ** 2, 32 ** 2], # small + [32 ** 2, 96 ** 2], # medium + [96 ** 2, 1e5 ** 2], # large + [96 ** 2, 128 ** 2], # 96-128 + [128 ** 2, 256 ** 2], # 128-256 + [256 ** 2, 512 ** 2], # 256-512 + [512 ** 2, 1e5 ** 2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for image_id, prediction in enumerate(predictions): + original_id = dataset.id_to_img_map[image_id] + + # TODO replace with get_img_info? + image_width = dataset.coco.imgs[original_id]["width"] + image_height = dataset.coco.imgs[original_id]["height"] + prediction = prediction.resize((image_width, image_height)) + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + if prediction.has_field("objectness"): + inds = prediction.get_field("objectness").sort(descending=True)[1] + else: + inds = prediction.get_field("scores").sort(descending=True)[1] + prediction = prediction[inds] + + ann_ids = dataset.coco.getAnnIds(imgIds=original_id) + anno = dataset.coco.loadAnns(ann_ids) + gt_boxes = [obj["bbox"] for obj in anno if obj["iscrowd"] == 0] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = BoxList(gt_boxes, (image_width, image_height), mode="xywh").convert( + "xyxy" + ) + gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0]) + + if len(gt_boxes) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if len(prediction) == 0: + continue + + if limit is not None and len(prediction) > limit: + prediction = prediction[:limit] + + overlaps = boxlist_iou(prediction, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(prediction), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + + if len(gt_overlaps) == 0: + return { + "ar": torch.zeros(1), + "recalls": torch.zeros(1), + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + gt_overlaps = torch.cat(gt_overlaps, dim=0) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def evaluate_predictions_on_coco( + coco_gt, coco_results, json_result_file, iou_type="bbox" +): + import json + + with open(json_result_file, "w") as f: + json.dump(coco_results, f) + + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + coco_dt = coco_gt.loadRes(str(json_result_file)) if coco_results else COCO() + + # coco_dt = coco_gt.loadRes(coco_results) + if iou_type == 'keypoints': + coco_gt = filter_valid_keypoints(coco_gt, coco_dt) + coco_eval = COCOeval(coco_gt, coco_dt, iou_type) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if iou_type == 'bbox': + summarize_per_category(coco_eval, json_result_file.replace('.json', '.csv')) + return coco_eval + + +def summarize_per_category(coco_eval, csv_output=None): + ''' + Compute and display summary metrics for evaluation results. + Note this functin can *only* be applied on the default parameter setting + ''' + + def _summarize(iouThr=None, areaRng='all', maxDets=100): + p = coco_eval.params + titleStr = 'Average Precision' + typeStr = '(AP)' + iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ + if iouThr is None else '{:0.2f}'.format(iouThr) + result_str = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ], '. \ + format(titleStr, typeStr, iouStr, areaRng, maxDets) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + + # dimension of precision: [TxRxKxAxM] + s = coco_eval.eval['precision'] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + # cacluate AP(average precision) for each category + num_classes = len(p.catIds) + avg_ap = 0.0 + for i in range(0, num_classes): + result_str += '{}, '.format(np.mean(s[:, :, i, :])) + avg_ap += np.mean(s[:, :, i, :]) + result_str += ('{} \n'.format(avg_ap / num_classes)) + return result_str + + id2name = {} + for _, cat in coco_eval.cocoGt.cats.items(): + id2name[cat['id']] = cat['name'] + title_str = 'metric, ' + for cid in coco_eval.params.catIds: + title_str += '{}, '.format(id2name[cid]) + title_str += 'avg \n' + + results = [title_str] + results.append(_summarize()) + results.append(_summarize(iouThr=.5, maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='small', maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='medium', maxDets=coco_eval.params.maxDets[2])) + results.append(_summarize(areaRng='large', maxDets=coco_eval.params.maxDets[2])) + + with open(csv_output, 'w') as f: + for result in results: + f.writelines(result) + + +def filter_valid_keypoints(coco_gt, coco_dt): + kps = coco_dt.anns[1]['keypoints'] + for id, ann in coco_gt.anns.items(): + ann['keypoints'][2::3] = [a * b for a, b in zip(ann['keypoints'][2::3], kps[2::3])] + ann['num_keypoints'] = sum(ann['keypoints'][2::3]) + return coco_gt + + +class COCOResults(object): + METRICS = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "box_proposal": [ + "AR@100", + "ARs@100", + "ARm@100", + "ARl@100", + "AR@1000", + "ARs@1000", + "ARm@1000", + "ARl@1000", + ], + "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], + } + + def __init__(self, *iou_types): + allowed_types = ("box_proposal", "bbox", "segm", "keypoints") + assert all(iou_type in allowed_types for iou_type in iou_types) + results = OrderedDict() + for iou_type in iou_types: + results[iou_type] = OrderedDict( + [(metric, -1) for metric in COCOResults.METRICS[iou_type]] + ) + self.results = results + + def update(self, coco_eval): + if coco_eval is None: + return + from pycocotools.cocoeval import COCOeval + + assert isinstance(coco_eval, COCOeval) + s = coco_eval.stats + iou_type = coco_eval.params.iouType + res = self.results[iou_type] + metrics = COCOResults.METRICS[iou_type] + for idx, metric in enumerate(metrics): + res[metric] = s[idx] + + def __repr__(self): + # TODO make it pretty + return repr(self.results) + + +def check_expected_results(results, expected_results, sigma_tol): + if not expected_results: + return + + logger = logging.getLogger("maskrcnn_benchmark.inference") + for task, metric, (mean, std) in expected_results: + actual_val = results.results[task][metric] + lo = mean - sigma_tol * std + hi = mean + sigma_tol * std + ok = (lo < actual_val) and (actual_val < hi) + msg = ( + "{} > {} sanity check (actual vs. expected): " + "{:.3f} vs. mean={:.4f}, std={:.4}, range=({:.4f}, {:.4f})" + ).format(task, metric, actual_val, mean, std, lo, hi) + if not ok: + msg = "FAIL: " + msg + logger.error(msg) + else: + msg = "PASS: " + msg + logger.info(msg) + diff --git a/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef18b3e5e9b007018fd7c839c7d053c48c2984d3 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/vg/__init__.py @@ -0,0 +1,16 @@ +import logging + +from .vg_eval import do_vg_evaluation + + +def vg_evaluation(dataset, predictions, output_folder, box_only, eval_attributes=False, **_): + logger = logging.getLogger("maskrcnn_benchmark.inference") + logger.info("performing vg evaluation, ignored iou_types.") + return do_vg_evaluation( + dataset=dataset, + predictions=predictions, + output_folder=output_folder, + box_only=box_only, + eval_attributes=eval_attributes, + logger=logger, + ) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..cb20fc9f69f1d70efa65eb9e88bab95d438f2b51 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/vg/vg_eval.py @@ -0,0 +1,672 @@ +# A modification version from chainercv repository. +# (See https://github.com/chainer/chainercv/blob/master/chainercv/evaluations/eval_detection_voc.py) +from __future__ import division + +import os +from collections import OrderedDict +import numpy as np +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou, getUnionBBox + + +# inspired from Detectron +def evaluate_box_proposals( + predictions, dataset, thresholds=None, area="all", limit=None +): + """Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official COCO API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0 ** 2, 1e5 ** 2], # all + [0 ** 2, 32 ** 2], # small + [32 ** 2, 96 ** 2], # medium + [96 ** 2, 1e5 ** 2], # large + [96 ** 2, 128 ** 2], # 96-128 + [128 ** 2, 256 ** 2], # 128-256 + [256 ** 2, 512 ** 2], # 256-512 + [512 ** 2, 1e5 ** 2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for image_id, prediction in enumerate(predictions): + img_info = dataset.get_img_info(image_id) + image_width = img_info["width"] + image_height = img_info["height"] + prediction = prediction.resize((image_width, image_height)) + + # deal with ground truth + gt_boxes = dataset.get_groundtruth(image_id) + # filter out the field "relations" + gt_boxes = gt_boxes.copy_with_fields(['attributes', 'labels']) + gt_areas = gt_boxes.area() + + if len(gt_boxes) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + _gt_overlaps = torch.zeros(len(gt_boxes)) + if len(prediction) == 0: + gt_overlaps.append(_gt_overlaps) + continue + if "objectness" in prediction.extra_fields: + inds = prediction.get_field("objectness").sort(descending=True)[1] + elif "scores" in prediction.extra_fields: + inds = prediction.get_field("scores").sort(descending=True)[1] + else: + raise ValueError("Neither objectness nor scores is in the extra_fields!") + prediction = prediction[inds] + + if limit is not None and len(prediction) > limit: + prediction = prediction[:limit] + + overlaps = boxlist_iou(prediction, gt_boxes) + + for j in range(min(len(prediction), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = torch.cat(gt_overlaps, dim=0) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +class VGResults(object): + METRICS = { + "bbox": ["AP",], + "segm": ["AP",], + "box_proposal": ["AR@100",], + } + + def __init__(self, iou_type, value): + allowed_types = ("box_proposal", "bbox", "segm", "keypoints") + assert iou_type in allowed_types + results = OrderedDict() + results[iou_type] = OrderedDict([(metric, value) for metric in VGResults.METRICS[iou_type]]) + self.results = results + + +def do_vg_evaluation(dataset, predictions, output_folder, box_only, eval_attributes, logger, save_predictions=True): + # TODO need to make the use_07_metric format available + # for the user to choose + # we use int for box_only. 0: False, 1: box for RPN, 2: box for object detection, + if box_only: + if box_only == 1: + limits = [100, 1000] + elif box_only == 2: + limits = [36, 99] + else: + raise ValueError("box_only can be either 0/1/2, but get {0}".format(box_only)) + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + result = {} + for area, suffix in areas.items(): + for limit in limits: + logger.info("Evaluating bbox proposals@{:d}".format(limit)) + stats = evaluate_box_proposals( + predictions, dataset, area=area, limit=limit + ) + key_ar = "AR{}@{:d}".format(suffix, limit) + key_num_pos = "num_pos{}@{:d}".format(suffix, limit) + result[key_num_pos] = stats["num_pos"] + result[key_ar] = stats["ar"].item() + key_recalls = "Recalls{}@{:d}".format(suffix, limit) + # result[key_recalls] = stats["recalls"] + print(key_recalls, stats["recalls"]) + print(key_ar, "ar={:.4f}".format(result[key_ar])) + print(key_num_pos, "num_pos={:d}".format(result[key_num_pos])) + if limit != 1000 and dataset.relation_on: + # if True: + # relation @ 1000 (all and large) takes about 2 hs to compute + # relation pair evaluation + logger.info("Evaluating relation proposals@{:d}".format(limit)) + stats = evaluate_box_proposals_for_relation( + predictions, dataset, area=area, limit=limit + ) + key_ar = "AR{}@{:d}_for_relation".format(suffix, limit) + key_num_pos = "num_pos{}@{:d}_for_relation".format(suffix, limit) + result[key_num_pos] = stats["num_pos"] + result[key_ar] = stats["ar"].item() + # key_recalls = "Recalls{}@{:d}_for_relation".format(suffix, limit) + # result[key_recalls] = stats["recalls"] + print(key_ar, "ar={:.4f}".format(result[key_ar])) + print(key_num_pos, "num_pos={:d}".format(result[key_num_pos])) + logger.info(result) + # check_expected_results(result, expected_results, expected_results_sigma_tol) + if output_folder and save_predictions: + if box_only == 1: + torch.save(result, os.path.join(output_folder, "rpn_proposals.pth")) + elif box_only == 2: + torch.save(result, os.path.join(output_folder, "box_proposals.pth")) + else: + raise ValueError("box_only can be either 0/1/2, but get {0}".format(box_only)) + return VGResults('box_proposal', result["AR@100"]), {"box_proposal": result} + + pred_boxlists = [] + gt_boxlists = [] + for image_id, prediction in enumerate(predictions): + img_info = dataset.get_img_info(image_id) + if len(prediction) == 0: + continue + image_width = img_info["width"] + image_height = img_info["height"] + prediction = prediction.resize((image_width, image_height)) + pred_boxlists.append(prediction) + + gt_boxlist = dataset.get_groundtruth(image_id) + gt_boxlists.append(gt_boxlist) + if eval_attributes: + classes = dataset.attributes + else: + classes = dataset.classes + result = eval_detection_voc( + pred_boxlists=pred_boxlists, + gt_boxlists=gt_boxlists, + classes=classes, + iou_thresh=0.5, + eval_attributes=eval_attributes, + use_07_metric=False, + ) + result_str = "mAP: {:.4f}\n".format(result["map"]) + logger.info(result_str) + for i, ap in enumerate(result["ap"]): + # if i == 0: # skip background + # continue + # we skipped background in result['ap'], so we need to use i+1 + if eval_attributes: + result_str += "{:<16}: {:.4f}\n".format( + dataset.map_attribute_id_to_attribute_name(i+1), ap + ) + else: + result_str += "{:<16}: {:.4f}\n".format( + dataset.map_class_id_to_class_name(i+1), ap + ) + # return mAP and weighted mAP + vg_result = VGResults('bbox', result["map"]) + if eval_attributes: + if output_folder and save_predictions: + with open(os.path.join(output_folder, "result_attr.txt"), "w") as fid: + fid.write(result_str) + return vg_result, {"attr": {"map": result["map"], "weighted map": result["weighted map"]}} + else: + if output_folder and save_predictions: + with open(os.path.join(output_folder, "result_obj.txt"), "w") as fid: + fid.write(result_str) + return vg_result, {"obj": {"map": result["map"], "weighted map": result["weighted map"]}}, + + +def eval_detection_voc(pred_boxlists, gt_boxlists, classes, iou_thresh=0.5, eval_attributes=False, use_07_metric=False): + """Evaluate on voc dataset. + Args: + pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields. + gt_boxlists(list[BoxList]): ground truth boxlist, has labels field. + iou_thresh: iou thresh + use_07_metric: boolean + Returns: + dict represents the results + """ + assert len(gt_boxlists) == len( + pred_boxlists + ), "Length of gt and pred lists need to be same." + + aps = [] + nposs = [] + thresh = [] + + for i, classname in enumerate(classes): + if classname == "__background__" or classname == "__no_attribute__": + continue + rec, prec, ap, scores, npos = calc_detection_voc_prec_rec(pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, \ + classindex=i, iou_thresh=iou_thresh, + eval_attributes=eval_attributes, + use_07_metric=use_07_metric) + # Determine per class detection thresholds that maximise f score + # if npos > 1: + if npos > 1 and type(scores) != np.int: + f = np.nan_to_num((prec * rec) / (prec + rec)) + thresh += [scores[np.argmax(f)]] + else: + thresh += [0] + aps += [ap] + nposs += [float(npos)] + # print('AP for {} = {:.4f} (npos={:,})'.format(classname, ap, npos)) + # if pickle: + # with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f: + # cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap, + # 'scores': scores, 'npos':npos}, f) + + # Set thresh to mean for classes with poor results + thresh = np.array(thresh) + avg_thresh = np.mean(thresh[thresh != 0]) + thresh[thresh == 0] = avg_thresh + # if eval_attributes: + # filename = 'attribute_thresholds_' + self._image_set + '.txt' + # else: + # filename = 'object_thresholds_' + self._image_set + '.txt' + # path = os.path.join(output_dir, filename) + # with open(path, 'wt') as f: + # for i, cls in enumerate(classes[1:]): + # f.write('{:s} {:.3f}\n'.format(cls, thresh[i])) + + weights = np.array(nposs) + weights /= weights.sum() + # print('Mean AP = {:.4f}'.format(np.mean(aps))) + # print('Weighted Mean AP = {:.4f}'.format(np.average(aps, weights=weights))) + # print('Mean Detection Threshold = {:.3f}'.format(avg_thresh)) + # print('~~~~~~~~') + # print('Results:') + # for ap, npos in zip(aps, nposs): + # print('{:.3f}\t{:.3f}'.format(ap, npos)) + # print('{:.3f}'.format(np.mean(aps))) + # print('~~~~~~~~') + # print('') + # print('--------------------------------------------------------------') + # print('Results computed with the **unofficial** PASCAL VOC Python eval code.') + # print('--------------------------------------------------------------') + + # pdb.set_trace() + return {"ap": aps, "map": np.mean(aps), "weighted map": np.average(aps, weights=weights)} + + +def calc_detection_voc_prec_rec(pred_boxlists, gt_boxlists, classindex, iou_thresh=0.5, eval_attributes=False, + use_07_metric=False): + """Calculate precision and recall based on evaluation code of PASCAL VOC. + This function calculates precision and recall of + predicted bounding boxes obtained from a dataset which has :math:`N` + images. + The code is based on the evaluation code used in PASCAL VOC Challenge. + """ + class_recs = {} + npos = 0 + image_ids = [] + confidence = [] + BB = [] + for image_index, (gt_boxlist, pred_boxlist) in enumerate(zip(gt_boxlists, pred_boxlists)): + pred_bbox = pred_boxlist.bbox.numpy() + gt_bbox = gt_boxlist.bbox.numpy() + if eval_attributes: + gt_label = gt_boxlist.get_field("attributes").numpy() + pred_label = pred_boxlist.get_field("attr_labels").numpy() + pred_score = pred_boxlist.get_field("attr_scores").numpy() + else: + gt_label = gt_boxlist.get_field("labels").numpy() + pred_label = pred_boxlist.get_field("labels").numpy() + pred_score = pred_boxlist.get_field("scores").numpy() + + # get the ground truth information for this class + if eval_attributes: + gt_mask_l = np.array([classindex in i for i in gt_label]) + else: + gt_mask_l = gt_label == classindex + gt_bbox_l = gt_bbox[gt_mask_l] + gt_difficult_l = np.zeros(gt_bbox_l.shape[0], dtype=bool) + det = [False] * gt_bbox_l.shape[0] + npos = npos + sum(~gt_difficult_l) + class_recs[image_index] = {'bbox': gt_bbox_l, + 'difficult': gt_difficult_l, + 'det': det} + + # prediction output for each class + # pdb.set_trace() + if eval_attributes: + pred_mask_l = np.logical_and(pred_label == classindex, np.not_equal(pred_score, 0.0)).nonzero() + pred_bbox_l = pred_bbox[pred_mask_l[0]] + pred_score_l = pred_score[pred_mask_l] + else: + pred_mask_l = pred_label == classindex + pred_bbox_l = pred_bbox[pred_mask_l] + pred_score_l = pred_score[pred_mask_l] + + for bbox_tmp, score_tmp in zip(pred_bbox_l, pred_score_l): + image_ids.append(image_index) + confidence.append(float(score_tmp)) + BB.append([float(z) for z in bbox_tmp]) + + if npos == 0: + # No ground truth examples + return 0, 0, 0, 0, npos + + if len(confidence) == 0: + # No detection examples + return 0, 0, 0, 0, npos + + confidence = np.array(confidence) + BB = np.array(BB) + + # sort by confidence + sorted_ind = np.argsort(-confidence) + sorted_scores = -np.sort(-confidence) + BB = BB[sorted_ind, :] + image_ids = [image_ids[x] for x in sorted_ind] + + # go down dets and mark TPs and FPs + nd = len(image_ids) + tp = np.zeros(nd) + fp = np.zeros(nd) + + for d in range(nd): + R = class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R['bbox'].astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1., 0.) + ih = np.maximum(iymax - iymin + 1., 0.) + inters = iw * ih + + # union + uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + + (BBGT[:, 2] - BBGT[:, 0] + 1.) * + (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > iou_thresh: + if not R['difficult'][jmax]: + if not R['det'][jmax]: + tp[d] = 1. + R['det'][jmax] = 1 + else: + fp[d] = 1. + else: + fp[d] = 1. + + # compute precision recall + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, use_07_metric) + + return rec, prec, ap, sorted_scores, npos + + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def calc_detection_voc_ap(prec, rec, use_07_metric=False): + """Calculate average precisions based on evaluation code of PASCAL VOC. + This function calculates average precisions + from given precisions and recalls. + The code is based on the evaluation code used in PASCAL VOC Challenge. + Args: + prec (list of numpy.array): A list of arrays. + :obj:`prec[l]` indicates precision for class :math:`l`. + If :obj:`prec[l]` is :obj:`None`, this function returns + :obj:`numpy.nan` for class :math:`l`. + rec (list of numpy.array): A list of arrays. + :obj:`rec[l]` indicates recall for class :math:`l`. + If :obj:`rec[l]` is :obj:`None`, this function returns + :obj:`numpy.nan` for class :math:`l`. + use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric + for calculating average precision. The default value is + :obj:`False`. + Returns: + ~numpy.ndarray: + This function returns an array of average precisions. + The :math:`l`-th value corresponds to the average precision + for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is + :obj:`None`, the corresponding value is set to :obj:`numpy.nan`. + """ + + n_fg_class = len(prec) + ap = np.empty(n_fg_class) + for l in range(n_fg_class): + if prec[l] is None or rec[l] is None: + ap[l] = np.nan + continue + + if use_07_metric: + # 11 point metric + ap[l] = 0 + for t in np.arange(0.0, 1.1, 0.1): + if np.sum(rec[l] >= t) == 0: + p = 0 + else: + p = np.max(np.nan_to_num(prec[l])[rec[l] >= t]) + ap[l] += p / 11 + else: + # correct AP calculation + # first append sentinel values at the end + mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0])) + mrec = np.concatenate(([0], rec[l], [1])) + + mpre = np.maximum.accumulate(mpre[::-1])[::-1] + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + + return ap + + +# inspired from Detectron +def evaluate_box_proposals_for_relation( + predictions, dataset, thresholds=None, area="all", limit=None +): + """Evaluate how many relation pairs can be captured by the proposed boxes. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0 ** 2, 1e5 ** 2], # all + [0 ** 2, 32 ** 2], # small + [32 ** 2, 96 ** 2], # medium + [96 ** 2, 1e5 ** 2], # large + [96 ** 2, 128 ** 2], # 96-128 + [128 ** 2, 256 ** 2], # 128-256 + [256 ** 2, 512 ** 2], # 256-512 + [512 ** 2, 1e5 ** 2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for image_id, prediction in enumerate(predictions): + img_info = dataset.get_img_info(image_id) + image_width = img_info["width"] + image_height = img_info["height"] + prediction = prediction.resize((image_width, image_height)) + + # deal with ground truth + gt_boxes = dataset.get_groundtruth(image_id) + # filter out the field "relation_labels" + gt_triplets = gt_boxes.get_field("relation_labels") + if len(gt_triplets) == 0: + continue + gt_boxes = gt_boxes.copy_with_fields(['attributes', 'labels']) + # get union bounding boxes (the box that cover both) + gt_relations = getUnionBBox(gt_boxes[gt_triplets[:, 0]], gt_boxes[gt_triplets[:, 1]], margin=0) + gt_relations.add_field('rel_classes', gt_triplets[:, 2]) + # focus on the range interested + gt_relation_areas = gt_relations.area() + valid_gt_inds = (gt_relation_areas >= area_range[0]) & (gt_relation_areas <= area_range[1]) + gt_relations = gt_relations[valid_gt_inds] + + num_pos += len(gt_relations) + + if len(gt_relations) == 0: + continue + + # sort predictions in descending order and limit to the number we specify + # TODO maybe remove this and make it explicit in the documentation + _gt_overlaps = torch.zeros(len(gt_relations)) + if len(prediction) == 0: + gt_overlaps.append(_gt_overlaps) + continue + if "objectness" in prediction.extra_fields: + inds = prediction.get_field("objectness").sort(descending=True)[1] + elif "scores" in prediction.extra_fields: + inds = prediction.get_field("scores").sort(descending=True)[1] + else: + raise ValueError("Neither objectness nor scores is in the extra_fields!") + prediction = prediction[inds] + if limit is not None and len(prediction) > limit: + prediction = prediction[:limit] + # get the predicted relation pairs + N = len(prediction) + map_x = np.arange(N) + map_y = np.arange(N) + map_x_g, map_y_g = np.meshgrid(map_x, map_y) + anchor_pairs = torch.from_numpy(np.vstack((map_y_g.ravel(), map_x_g.ravel())).transpose()) + # remove diagonal pairs + keep = anchor_pairs[:, 0] != anchor_pairs[:, 1] + anchor_pairs = anchor_pairs[keep] + # get anchor_relations + # anchor_relations = getUnionBBox(prediction[anchor_pairs[:,0]], prediction[anchor_pairs[:,1]], margin=0) + if len(anchor_pairs) == 0: + continue + + overlaps_sub = boxlist_iou(prediction[anchor_pairs[:, 0]], gt_boxes[gt_triplets[valid_gt_inds, 0]]) + overlaps_obj = boxlist_iou(prediction[anchor_pairs[:, 1]], gt_boxes[gt_triplets[valid_gt_inds, 1]]) + overlaps = torch.min(overlaps_sub, overlaps_obj) + + for j in range(min(len(anchor_pairs), len(gt_relations))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal pair that covers the best covered gt pair + pair_ind = argmax_overlaps[gt_ind] + # record the co-iou coverage of this gt pair + _gt_overlaps[j] = overlaps[pair_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal pair and the gt pair as used + overlaps[pair_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = torch.cat(gt_overlaps, dim=0) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } diff --git a/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py b/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c26048b361ddd41b6e82d4bb9d5ead745f6bb07 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/voc/__init__.py @@ -0,0 +1,16 @@ +import logging + +from .voc_eval import do_voc_evaluation + + +def voc_evaluation(dataset, predictions, output_folder, box_only, **_): + logger = logging.getLogger("maskrcnn_benchmark.inference") + if box_only: + logger.warning("voc evaluation doesn't support box_only, ignored.") + logger.info("performing voc evaluation, ignored iou_types.") + return do_voc_evaluation( + dataset=dataset, + predictions=predictions, + output_folder=output_folder, + logger=logger, + ) diff --git a/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ac54d768d458861ca994dc7de1fa37f166ed012b --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/evaluation/voc/voc_eval.py @@ -0,0 +1,216 @@ +# A modification version from chainercv repository. +# (See https://github.com/chainer/chainercv/blob/master/chainercv/evaluations/eval_detection_voc.py) +from __future__ import division + +import os +from collections import defaultdict +import numpy as np +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou + + +def do_voc_evaluation(dataset, predictions, output_folder, logger): + # TODO need to make the use_07_metric format available + # for the user to choose + pred_boxlists = [] + gt_boxlists = [] + for image_id, prediction in enumerate(predictions): + img_info = dataset.get_img_info(image_id) + if len(prediction) == 0: + continue + image_width = img_info["width"] + image_height = img_info["height"] + prediction = prediction.resize((image_width, image_height)) + pred_boxlists.append(prediction) + + gt_boxlist = dataset.get_groundtruth(image_id) + gt_boxlists.append(gt_boxlist) + result = eval_detection_voc( + pred_boxlists=pred_boxlists, + gt_boxlists=gt_boxlists, + iou_thresh=0.5, + use_07_metric=True, + ) + result_str = "mAP: {:.4f}\n".format(result["map"]) + for i, ap in enumerate(result["ap"]): + if i == 0: # skip background + continue + result_str += "{:<16}: {:.4f}\n".format( + dataset.map_class_id_to_class_name(i), ap + ) + logger.info(result_str) + if output_folder: + with open(os.path.join(output_folder, "result.txt"), "w") as fid: + fid.write(result_str) + return result + + +def eval_detection_voc(pred_boxlists, gt_boxlists, iou_thresh=0.5, use_07_metric=False): + """Evaluate on voc dataset. + Args: + pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields. + gt_boxlists(list[BoxList]): ground truth boxlist, has labels field. + iou_thresh: iou thresh + use_07_metric: boolean + Returns: + dict represents the results + """ + assert len(gt_boxlists) == len( + pred_boxlists + ), "Length of gt and pred lists need to be same." + prec, rec = calc_detection_voc_prec_rec( + pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, iou_thresh=iou_thresh + ) + ap = calc_detection_voc_ap(prec, rec, use_07_metric=use_07_metric) + return {"ap": ap, "map": np.nanmean(ap)} + + +def calc_detection_voc_prec_rec(gt_boxlists, pred_boxlists, iou_thresh=0.5): + """Calculate precision and recall based on evaluation code of PASCAL VOC. + This function calculates precision and recall of + predicted bounding boxes obtained from a dataset which has :math:`N` + images. + The code is based on the evaluation code used in PASCAL VOC Challenge. + """ + n_pos = defaultdict(int) + score = defaultdict(list) + match = defaultdict(list) + for gt_boxlist, pred_boxlist in zip(gt_boxlists, pred_boxlists): + pred_bbox = pred_boxlist.bbox.numpy() + pred_label = pred_boxlist.get_field("labels").numpy() + pred_score = pred_boxlist.get_field("scores").numpy() + gt_bbox = gt_boxlist.bbox.numpy() + gt_label = gt_boxlist.get_field("labels").numpy() + gt_difficult = gt_boxlist.get_field("difficult").numpy() + + for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)): + pred_mask_l = pred_label == l + pred_bbox_l = pred_bbox[pred_mask_l] + pred_score_l = pred_score[pred_mask_l] + # sort by score + order = pred_score_l.argsort()[::-1] + pred_bbox_l = pred_bbox_l[order] + pred_score_l = pred_score_l[order] + + gt_mask_l = gt_label == l + gt_bbox_l = gt_bbox[gt_mask_l] + gt_difficult_l = gt_difficult[gt_mask_l] + + n_pos[l] += np.logical_not(gt_difficult_l).sum() + score[l].extend(pred_score_l) + + if len(pred_bbox_l) == 0: + continue + if len(gt_bbox_l) == 0: + match[l].extend((0,) * pred_bbox_l.shape[0]) + continue + + # VOC evaluation follows integer typed bounding boxes. + pred_bbox_l = pred_bbox_l.copy() + pred_bbox_l[:, 2:] += 1 + gt_bbox_l = gt_bbox_l.copy() + gt_bbox_l[:, 2:] += 1 + iou = boxlist_iou( + BoxList(pred_bbox_l, gt_boxlist.size), + BoxList(gt_bbox_l, gt_boxlist.size), + ).numpy() + gt_index = iou.argmax(axis=1) + # set -1 if there is no matching ground truth + gt_index[iou.max(axis=1) < iou_thresh] = -1 + del iou + + selec = np.zeros(gt_bbox_l.shape[0], dtype=bool) + for gt_idx in gt_index: + if gt_idx >= 0: + if gt_difficult_l[gt_idx]: + match[l].append(-1) + else: + if not selec[gt_idx]: + match[l].append(1) + else: + match[l].append(0) + selec[gt_idx] = True + else: + match[l].append(0) + + n_fg_class = max(n_pos.keys()) + 1 + prec = [None] * n_fg_class + rec = [None] * n_fg_class + + for l in n_pos.keys(): + score_l = np.array(score[l]) + match_l = np.array(match[l], dtype=np.int8) + + order = score_l.argsort()[::-1] + match_l = match_l[order] + + tp = np.cumsum(match_l == 1) + fp = np.cumsum(match_l == 0) + + # If an element of fp + tp is 0, + # the corresponding element of prec[l] is nan. + prec[l] = tp / (fp + tp) + # If n_pos[l] is 0, rec[l] is None. + if n_pos[l] > 0: + rec[l] = tp / n_pos[l] + + return prec, rec + + +def calc_detection_voc_ap(prec, rec, use_07_metric=False): + """Calculate average precisions based on evaluation code of PASCAL VOC. + This function calculates average precisions + from given precisions and recalls. + The code is based on the evaluation code used in PASCAL VOC Challenge. + Args: + prec (list of numpy.array): A list of arrays. + :obj:`prec[l]` indicates precision for class :math:`l`. + If :obj:`prec[l]` is :obj:`None`, this function returns + :obj:`numpy.nan` for class :math:`l`. + rec (list of numpy.array): A list of arrays. + :obj:`rec[l]` indicates recall for class :math:`l`. + If :obj:`rec[l]` is :obj:`None`, this function returns + :obj:`numpy.nan` for class :math:`l`. + use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric + for calculating average precision. The default value is + :obj:`False`. + Returns: + ~numpy.ndarray: + This function returns an array of average precisions. + The :math:`l`-th value corresponds to the average precision + for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is + :obj:`None`, the corresponding value is set to :obj:`numpy.nan`. + """ + + n_fg_class = len(prec) + ap = np.empty(n_fg_class) + for l in range(n_fg_class): + if prec[l] is None or rec[l] is None: + ap[l] = np.nan + continue + + if use_07_metric: + # 11 point metric + ap[l] = 0 + for t in np.arange(0.0, 1.1, 0.1): + if np.sum(rec[l] >= t) == 0: + p = 0 + else: + p = np.max(np.nan_to_num(prec[l])[rec[l] >= t]) + ap[l] += p / 11 + else: + # correct AP calculation + # first append sentinel values at the end + mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0])) + mrec = np.concatenate(([0], rec[l], [1])) + + mpre = np.maximum.accumulate(mpre[::-1])[::-1] + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + + return ap diff --git a/maskrcnn_benchmark/data/datasets/flickr.py b/maskrcnn_benchmark/data/datasets/flickr.py new file mode 100644 index 0000000000000000000000000000000000000000..fe71a932182f0cb88385e990c7f0c22342ef5fbf --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/flickr.py @@ -0,0 +1,8 @@ +import torch +import torchvision +import torch.utils.data as data +from maskrcnn_benchmark.data.datasets.modulated_coco import ModulatedDataset + + +class FlickrDataset(ModulatedDataset): + pass diff --git a/maskrcnn_benchmark/data/datasets/gqa.py b/maskrcnn_benchmark/data/datasets/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..98d906cf9c9cb7e4d5d2ad17923398b25f11d9f6 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/gqa.py @@ -0,0 +1,91 @@ +import json +from pathlib import Path + +import torch +import torchvision + +from .modulated_coco import ConvertCocoPolysToMask, ModulatedDataset + + +class GQADataset(ModulatedDataset): + pass + + +class GQAQuestionAnswering(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms, return_masks, return_tokens, tokenizer, ann_folder): + super(GQAQuestionAnswering, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer) + with open(ann_folder / "gqa_answer2id.json", "r") as f: + self.answer2id = json.load(f) + with open(ann_folder / "gqa_answer2id_by_type.json", "r") as f: + self.answer2id_by_type = json.load(f) + self.type2id = {"obj": 0, "attr": 1, "rel": 2, "global": 3, "cat": 4} + + def __getitem__(self, idx): + img, target = super(GQAQuestionAnswering, self).__getitem__(idx) + image_id = self.ids[idx] + coco_img = self.coco.loadImgs(image_id)[0] + caption = coco_img["caption"] + dataset_name = coco_img["dataset_name"] + questionId = coco_img["questionId"] + target = {"image_id": image_id, "annotations": target, "caption": caption} + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + target["dataset_name"] = dataset_name + target["questionId"] = questionId + + if coco_img["answer"] not in self.answer2id: + answer = "unknown" + else: + answer = coco_img["answer"] + + target["answer"] = torch.as_tensor(self.answer2id[answer], dtype=torch.long) + target["answer_type"] = torch.as_tensor(self.type2id[coco_img["question_type"]], dtype=torch.long) + + if coco_img["answer"] not in self.answer2id_by_type["answer_attr"]: + answer = "unknown" + else: + answer = coco_img["answer"] + target["answer_attr"] = torch.as_tensor( + self.answer2id_by_type["answer_attr"][answer] if coco_img["question_type"] == "attr" else -100, + dtype=torch.long, + ) + + if coco_img["answer"] not in self.answer2id_by_type["answer_global"]: + answer = "unknown" + else: + answer = coco_img["answer"] + target["answer_global"] = torch.as_tensor( + self.answer2id_by_type["answer_global"][answer] if coco_img["question_type"] == "global" else -100, + dtype=torch.long, + ) + + if coco_img["answer"] not in self.answer2id_by_type["answer_rel"]: + answer = "unknown" + else: + answer = coco_img["answer"] + target["answer_rel"] = torch.as_tensor( + self.answer2id_by_type["answer_rel"][answer] if coco_img["question_type"] == "rel" else -100, + dtype=torch.long, + ) + + if coco_img["answer"] not in self.answer2id_by_type["answer_cat"]: + answer = "unknown" + else: + answer = coco_img["answer"] + target["answer_cat"] = torch.as_tensor( + self.answer2id_by_type["answer_cat"][answer] if coco_img["question_type"] == "cat" else -100, + dtype=torch.long, + ) + + if coco_img["answer"] not in self.answer2id_by_type["answer_obj"]: + answer = "unknown" + else: + answer = coco_img["answer"] + target["answer_obj"] = torch.as_tensor( + self.answer2id_by_type["answer_obj"][answer] if coco_img["question_type"] == "obj" else -100, + dtype=torch.long, + ) + return img, target diff --git a/maskrcnn_benchmark/data/datasets/imagenet.py b/maskrcnn_benchmark/data/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..723ea7dcc89fc3cb2bc68664e3ede90a0083b3b3 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/imagenet.py @@ -0,0 +1,63 @@ +import os +import os.path +import json +from PIL import Image + +import torch.utils.data as data + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + +class ImageNet(data.Dataset): + """ ImageNet + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + """ + + def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None): + + + self.root = root + self.transform = transforms + + meta_file = os.path.join(root, ann_file) + assert os.path.exists(meta_file), 'meta file %s under root %s not found' % (os.path.basename(meta_file), root) + + with open(meta_file, 'r') as f: + meta = json.load(f) + + self.classes = meta['classes'] + self.class_to_idx = meta['class_to_idx'] + self.samples = meta['samples'] + self.num_sample = len(self.samples) + self.allsamples = self.samples + + def select_class(self, cls): + new_samples = [sample for sample in self.allsamples if sample[-1] in cls] + self.samples = new_samples + self.num_sample = len(self.samples) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + img_path, target = self.samples[index] + sample = pil_loader(self.root + '/' + img_path) + if self.transform is not None: + sample = self.transform(sample) + + return sample, target, index + + def __len__(self): + return len(self.samples) \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/list_dataset.py b/maskrcnn_benchmark/data/datasets/list_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a4f47fc08c8317ade1a762cf4070b6d16a3edf --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/list_dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Simple dataset class that wraps a list of path names +""" + +from PIL import Image + +from maskrcnn_benchmark.structures.bounding_box import BoxList + + +class ListDataset(object): + def __init__(self, image_lists, transforms=None): + self.image_lists = image_lists + self.transforms = transforms + + def __getitem__(self, item): + img = Image.open(self.image_lists[item]).convert("RGB") + + # dummy target + w, h = img.size + target = BoxList([[0, 0, w, h]], img.size, mode="xyxy") + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.image_lists) + + def get_img_info(self, item): + """ + Return the image dimensions for the image, without + loading and pre-processing it + """ + pass diff --git a/maskrcnn_benchmark/data/datasets/lvis.py b/maskrcnn_benchmark/data/datasets/lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..753bcbc836a855f967403b78f4e843a86ce77e39 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/lvis.py @@ -0,0 +1,268 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json +import os +import time +from collections import defaultdict + +import pycocotools.mask as mask_utils +import torchvision +from PIL import Image + +# from .coco import ConvertCocoPolysToMask, make_coco_transforms +from .modulated_coco import ConvertCocoPolysToMask + + +def _isArrayLike(obj): + return hasattr(obj, "__iter__") and hasattr(obj, "__len__") + + +class LVIS: + def __init__(self, annotation_path=None): + """Class for reading and visualizing annotations. + Args: + annotation_path (str): location of annotation file + """ + self.anns = {} + self.cats = {} + self.imgs = {} + self.img_ann_map = defaultdict(list) + self.cat_img_map = defaultdict(list) + self.dataset = {} + + if annotation_path is not None: + print("Loading annotations.") + + tic = time.time() + self.dataset = self._load_json(annotation_path) + print("Done (t={:0.2f}s)".format(time.time() - tic)) + + assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset)) + self._create_index() + + def _load_json(self, path): + with open(path, "r") as f: + return json.load(f) + + def _create_index(self): + print("Creating index.") + + self.img_ann_map = defaultdict(list) + self.cat_img_map = defaultdict(list) + + self.anns = {} + self.cats = {} + self.imgs = {} + + for ann in self.dataset["annotations"]: + self.img_ann_map[ann["image_id"]].append(ann) + self.anns[ann["id"]] = ann + + for img in self.dataset["images"]: + self.imgs[img["id"]] = img + + for cat in self.dataset["categories"]: + self.cats[cat["id"]] = cat + + for ann in self.dataset["annotations"]: + self.cat_img_map[ann["category_id"]].append(ann["image_id"]) + + print("Index created.") + + def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None): + """Get ann ids that satisfy given filter conditions. + Args: + img_ids (int array): get anns for given imgs + cat_ids (int array): get anns for given cats + area_rng (float array): get anns for a given area range. e.g [0, inf] + Returns: + ids (int array): integer array of ann ids + """ + if img_ids is not None: + img_ids = img_ids if _isArrayLike(img_ids) else [img_ids] + if cat_ids is not None: + cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids] + anns = [] + if img_ids is not None: + for img_id in img_ids: + anns.extend(self.img_ann_map[img_id]) + else: + anns = self.dataset["annotations"] + + # return early if no more filtering required + if cat_ids is None and area_rng is None: + return [_ann["id"] for _ann in anns] + + cat_ids = set(cat_ids) + + if area_rng is None: + area_rng = [0, float("inf")] + + ann_ids = [ + _ann["id"] + for _ann in anns + if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1] + ] + return ann_ids + + def get_cat_ids(self): + """Get all category ids. + Returns: + ids (int array): integer array of category ids + """ + return list(self.cats.keys()) + + def get_img_ids(self): + """Get all img ids. + Returns: + ids (int array): integer array of image ids + """ + return list(self.imgs.keys()) + + def _load_helper(self, _dict, ids): + if ids is None: + return list(_dict.values()) + elif _isArrayLike(ids): + return [_dict[id] for id in ids] + else: + return [_dict[ids]] + + def load_anns(self, ids=None): + """Load anns with the specified ids. If ids=None load all anns. + Args: + ids (int array): integer array of annotation ids + Returns: + anns (dict array) : loaded annotation objects + """ + return self._load_helper(self.anns, ids) + + def load_cats(self, ids): + """Load categories with the specified ids. If ids=None load all + categories. + Args: + ids (int array): integer array of category ids + Returns: + cats (dict array) : loaded category dicts + """ + return self._load_helper(self.cats, ids) + + def load_imgs(self, ids): + """Load categories with the specified ids. If ids=None load all images. + Args: + ids (int array): integer array of image ids + Returns: + imgs (dict array) : loaded image dicts + """ + return self._load_helper(self.imgs, ids) + + def download(self, save_dir, img_ids=None): + """Download images from mscoco.org server. + Args: + save_dir (str): dir to save downloaded images + img_ids (int array): img ids of images to download + """ + imgs = self.load_imgs(img_ids) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + for img in imgs: + file_name = os.path.join(save_dir, img["file_name"]) + if not os.path.exists(file_name): + from urllib.request import urlretrieve + + urlretrieve(img["coco_url"], file_name) + + def ann_to_rle(self, ann): + """Convert annotation which can be polygons, uncompressed RLE to RLE. + Args: + ann (dict) : annotation object + Returns: + ann (rle) + """ + img_data = self.imgs[ann["image_id"]] + h, w = img_data["height"], img_data["width"] + segm = ann["segmentation"] + if isinstance(segm, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_utils.frPyObjects(segm, h, w) + rle = mask_utils.merge(rles) + elif isinstance(segm["counts"], list): + # uncompressed RLE + rle = mask_utils.frPyObjects(segm, h, w) + else: + # rle + rle = ann["segmentation"] + return rle + + def ann_to_mask(self, ann): + """Convert annotation which can be polygons, uncompressed RLE, or RLE + to binary mask. + Args: + ann (dict) : annotation object + Returns: + binary mask (numpy 2D array) + """ + rle = self.ann_to_rle(ann) + return mask_utils.decode(rle) + + +class LvisDetectionBase(torchvision.datasets.VisionDataset): + def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): + super(LvisDetectionBase, self).__init__(root, transforms, transform, target_transform) + self.lvis = LVIS(annFile) + self.ids = list(sorted(self.lvis.imgs.keys())) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + lvis = self.lvis + img_id = self.ids[index] + ann_ids = lvis.get_ann_ids(img_ids=img_id) + target = lvis.load_anns(ann_ids) + + path = "/".join(self.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:]) + + img = Image.open(os.path.join(self.root, path)).convert("RGB") + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + + def __len__(self): + return len(self.ids) + + +class LvisDetection(LvisDetectionBase): + def __init__(self, img_folder, ann_file, transforms, return_masks=False, **kwargs): + super(LvisDetection, self).__init__(img_folder, ann_file) + self.ann_file = ann_file + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + + def __getitem__(self, idx): + img, target = super(LvisDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = {"image_id": image_id, "annotations": target} + img, target = self.prepare(img, target) + if self._transforms is not None: + img = self._transforms(img) + return img, target, idx + + def get_raw_image(self, idx): + img, target = super(LvisDetection, self).__getitem__(idx) + return img + + def categories(self): + id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]} + all_cats = sorted(list(id2cat.keys())) + categories = {} + for l in list(all_cats): + categories[l] = id2cat[l]['name'] + return categories \ No newline at end of file diff --git a/maskrcnn_benchmark/data/datasets/mixed.py b/maskrcnn_benchmark/data/datasets/mixed.py new file mode 100644 index 0000000000000000000000000000000000000000..3aec54451233173b3e9de107593c87feeb8a3691 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/mixed.py @@ -0,0 +1,145 @@ +import os +import os.path +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import torch +from maskrcnn_benchmark.structures.bounding_box import BoxList + +from PIL import Image, ImageDraw +from torchvision.datasets.vision import VisionDataset + +from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation + + +class CustomCocoDetection(VisionDataset): + """Coco-style dataset imported from TorchVision. + It is modified to handle several image sources + + Args: + root_coco (string): Path to the coco images + root_vg (string): Path to the vg images + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__( + self, + root_coco: str, + root_vg: str, + annFile: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + ) -> None: + super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform) + from pycocotools.coco import COCO + + self.coco = COCO(annFile) + self.ids = list(sorted(self.coco.imgs.keys())) + + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + ids.append(img_id) + self.ids = ids + + self.root_coco = root_coco + self.root_vg = root_vg + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + img_info = coco.loadImgs(img_id)[0] + path = img_info["file_name"] + dataset = img_info["data_source"] + + cur_root = self.root_coco if dataset == "coco" else self.root_vg + img = Image.open(os.path.join(cur_root, path)).convert("RGB") + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.ids) + + +class MixedDataset(CustomCocoDetection): + """Same as the modulated detection dataset, except with multiple img sources""" + + def __init__(self, + img_folder_coco, + img_folder_vg, + ann_file, + transforms, + return_masks, + return_tokens, + tokenizer=None, + disable_clip_to_image=False, + no_mask_for_gold=False, + max_query_len=256, + **kwargs): + super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file) + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self.disable_clip_to_image = disable_clip_to_image + self.no_mask_for_gold = no_mask_for_gold + + def __getitem__(self, idx): + img, target = super(MixedDataset, self).__getitem__(idx) + + image_id = self.ids[idx] + caption = self.coco.loadImgs(image_id)[0]["caption"] + anno = {"image_id": image_id, "annotations": target, "caption": caption} + anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] + if self.no_mask_for_gold: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + + img, anno = self.prepare(img, anno) + + # convert to BoxList (bboxes, labels) + boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xyxy") + classes = anno["labels"] + target.add_field("labels", classes) + if not self.disable_clip_to_image: + num_boxes = len(boxes) + target = target.clip_to_image(remove_empty=True) + assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!" + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data diff --git a/maskrcnn_benchmark/data/datasets/mixup.py b/maskrcnn_benchmark/data/datasets/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..110775727526137e5f9af7a85619f6e268b9cdbd --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/mixup.py @@ -0,0 +1,124 @@ +"""Mixup detection dataset wrapper.""" +from __future__ import absolute_import +import numpy as np +import torch +import torch.utils.data as data + + +class MixupDetection(data.Dataset): + """Detection dataset wrapper that performs mixup for normal dataset. + Parameters + ---------- + dataset : mx.gluon.data.Dataset + Gluon dataset object. + mixup : callable random generator, e.g. np.random.uniform + A random mixup ratio sampler, preferably a random generator from numpy.random + A random float will be sampled each time with mixup(*args). + Use None to disable. + *args : list + Additional arguments for mixup random sampler. + """ + def __init__(self, dataset, mixup=None, preproc=None, *args): + super().__init__(dataset.input_dim) + self._dataset = dataset + self.preproc = preproc + self._mixup = mixup + self._mixup_args = args + + def set_mixup(self, mixup=None, *args): + """Set mixup random sampler, use None to disable. + Parameters + ---------- + mixup : callable random generator, e.g. np.random.uniform + A random mixup ratio sampler, preferably a random generator from numpy.random + A random float will be sampled each time with mixup(*args) + *args : list + Additional arguments for mixup random sampler. + """ + self._mixup = mixup + self._mixup_args = args + + def __len__(self): + return len(self._dataset) + + @Dataset.resize_getitem + def __getitem__(self, idx): + self._dataset._input_dim = self.input_dim + # first image + img1, label1, _, _= self._dataset.pull_item(idx) + lambd = 1 + + # draw a random lambda ratio from distribution + if self._mixup is not None: + lambd = max(0, min(1, self._mixup(*self._mixup_args))) + + if lambd >= 1: + weights1 = np.ones((label1.shape[0], 1)) + label1 = np.hstack((label1, weights1)) + height, width, _ = img1.shape + img_info = (width, height) + if self.preproc is not None: + img_o, target_o = self.preproc(img1, label1, self.input_dim) + return img_o, target_o, img_info, idx + + # second image + idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx))) + img2, label2, _, _ = self._dataset.pull_item(idx2) + + # mixup two images + height = max(img1.shape[0], img2.shape[0]) + width = max(img1.shape[1], img2.shape[1]) + mix_img = np.zeros((height, width, 3),dtype=np.float32) + mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype(np.float32) * lambd + mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype(np.float32) * (1. - lambd) + mix_img = mix_img.astype(np.uint8) + + y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd))) + y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1. - lambd))) + mix_label = np.vstack((y1, y2)) + if self.preproc is not None: + mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim) + + img_info = (width, height) + + return mix_img, padded_labels, img_info , idx + + def pull_item(self, idx): + self._dataset._input_dim = self.input_dim + # first image + img1, label1, _, _= self._dataset.pull_item(idx) + lambd = 1 + + # draw a random lambda ratio from distribution + if self._mixup is not None: + lambd = max(0, min(1, self._mixup(*self._mixup_args))) + + if lambd >= 1: + weights1 = np.ones((label1.shape[0], 1)) + label1 = np.hstack((label1, weights1)) + height, width, _ = img1.shape + img_info = (width, height) + if self.preproc is not None: + img_o, target_o = self.preproc(img1, label1, self.input_dim) + return img_o, target_o, img_info, idx + + # second image + idx2 = int(np.random.choice(np.delete(np.arange(len(self)), idx))) + img2, label2 = self._dataset.pull_item(idx2) + + # mixup two images + height = max(img1.shape[0], img2.shape[0]) + width = max(img1.shape[1], img2.shape[1]) + mix_img = np.zeros((height, width, 3),dtype=np.float32) + mix_img[:img1.shape[0], :img1.shape[1], :] = img1.astype(np.float32) * lambd + mix_img[:img2.shape[0], :img2.shape[1], :] += img2.astype(np.float32) * (1. - lambd) + mix_img = mix_img.astype(np.uint8) + + y1 = np.hstack((label1, np.full((label1.shape[0], 1), lambd))) + y2 = np.hstack((label2, np.full((label2.shape[0], 1), 1. - lambd))) + mix_label = np.vstack((y1, y2)) + if self.preproc is not None: + mix_img, padded_labels = self.preproc(mix_img, mix_label, self.input_dim) + + img_info = (width, height) + return mix_img, padded_labels, img_info , idx diff --git a/maskrcnn_benchmark/data/datasets/modulated_coco.py b/maskrcnn_benchmark/data/datasets/modulated_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..23f6d3610a1231bb0ae0c99affe7374b9551df96 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/modulated_coco.py @@ -0,0 +1,654 @@ +import logging +import os +import os.path +import math +from PIL import Image, ImageDraw + +import random +import numpy as np + +import torch +import torchvision +import torch.utils.data as data +from pycocotools import mask as coco_mask + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask +from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation +from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od +import pdb +import json + +class CocoGrounding(torchvision.datasets.CocoDetection): + def __init__(self, + img_folder, + ann_file, + transforms, + return_masks, + return_tokens, + is_train=False, + tokenizer=None, + disable_shuffle=False, + add_detection_prompt=False, + one_hot=False, + disable_clip_to_image=False, + no_minus_one_for_one_hot=False, + separation_tokens=" ", + few_shot=0, + no_mask_for_od=False, + override_category=None, + use_caption_prompt=False, + caption_prompt=None, + max_query_len=256, + special_safeguard_for_coco_grounding=False, + random_sample_negative=-1, + **kwargs + ): + super(CocoGrounding, self).__init__(img_folder, ann_file) + self.ids = sorted(self.ids) + + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + ids.append(img_id) + + self.ids = ids + + if few_shot: + ids = [] + # cats_freq = [few_shot]*len(self.coco.cats.keys()) + cats_freq = [few_shot]*max(list(self.coco.cats.keys())) + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level + is_needed = sum([cats_freq[c-1]>0 for c in cat]) + if is_needed: + ids.append(img_id) + for c in cat: + cats_freq[c-1] -= 1 + # print(cat, cats_freq) + self.ids = ids + + + + self.json_category_id_to_contiguous_id = { + v: i + 1 for i, v in enumerate(self.coco.getCatIds()) + } + self.contiguous_category_id_to_json_id = { + v: k for k, v in self.json_category_id_to_contiguous_id.items() + } + + if override_category is not None: + self.coco.dataset["categories"] = override_category + self.use_caption_prompt = use_caption_prompt + self.caption_prompt = caption_prompt + self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding + self.random_sample_negative = random_sample_negative + self.ind_to_class = self.categories(no_background=False) + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) + self.tokenizer = tokenizer + self.is_train = is_train + + self.ind_to_class = self.categories(no_background=False) + + self.disable_shuffle = disable_shuffle + self.add_detection_prompt = add_detection_prompt + self.one_hot = one_hot + self.no_minus_one_for_one_hot = no_minus_one_for_one_hot + + self.disable_clip_to_image = disable_clip_to_image + self.separation_tokens = separation_tokens + self.no_mask_for_od = no_mask_for_od + self.return_masks = return_masks + + def categories(self, no_background=True): + categories = self.coco.dataset["categories"] + label_list = {} + for index, i in enumerate(categories): + # assert(index + 1 == i["id"]) + if not no_background or (i["name"] != "__background__" and i['id'] != 0): + label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] + return label_list + + def get_box_mask(self, rect, img_size, mode="poly"): + assert mode=="poly", "Only support poly mask right now!" + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + + def __getitem__(self, idx): + img, tgt = super(CocoGrounding, self).__getitem__(idx) + image_id = self.ids[idx] + tgt = [obj for obj in tgt if obj["iscrowd"] == 0] + boxes = [obj["bbox"] for obj in tgt] + boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") + classes = [obj["category_id"] for obj in tgt] + classes = [self.json_category_id_to_contiguous_id[c] for c in classes] + classes = torch.tensor(classes) + target.add_field("labels", classes) + + if self.return_masks: + masks = [] + is_box_mask = [] + for obj, bbox in zip(tgt, target.bbox): + if "segmentation" in obj: + masks.append(obj["segmentation"]) + is_box_mask.append(0) + else: + masks.append(self.get_box_mask(bbox, img.size, mode="poly")) + is_box_mask.append(1) + masks = SegmentationMask(masks, img.size, mode="poly") + is_box_mask = torch.tensor(is_box_mask) + target.add_field("masks", masks) + target.add_field("is_box_mask", is_box_mask) + + if not self.disable_clip_to_image: + target = target.clip_to_image(remove_empty=True) + + if self.special_safeguard_for_coco_grounding: + # Intended for LVIS + assert(not self.use_caption_prompt) + + original_box_num = len(target) + target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens + if len(target) < original_box_num: + print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) + + annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( + target=target, + image_id=image_id, + ind_to_class=self.ind_to_class, + disable_shuffle=self.disable_shuffle, + add_detection_prompt=False, + add_detection_prompt_advanced=False, + random_sample_negative=self.random_sample_negative, + control_probabilities=(0.0, 0.0, 1.0, 0.0), # always try to add a lot of negatives + restricted_negative_list=None, + separation_tokens=self.separation_tokens, + max_num_labels=-1, + positive_caption_length=positive_caption_length, + tokenizer=self.tokenizer, + max_seq_length=self.max_query_len-2 + ) + else: + # Intended for COCO / ODinW + annotations, caption, greenlight_span_for_masked_lm_objective = convert_od_to_grounding_simple( + target=target, + image_id=image_id, + ind_to_class=self.ind_to_class, + disable_shuffle=self.disable_shuffle, + add_detection_prompt=self.add_detection_prompt, + separation_tokens=self.separation_tokens, + caption_prompt=self.caption_prompt if self.use_caption_prompt else None, + ) + + anno = {"image_id": image_id, "annotations": annotations, "caption": caption} + anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective + if self.no_mask_for_od: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + img, anno = self.prepare(img, anno, box_format="xyxy") + + # for equivalence check + if self.one_hot: + logging.info("using one hot for equivalence check.") + one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float) + text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64) + # create one hot mapping + for ii, cls in enumerate(classes): + if self.no_minus_one_for_one_hot: + one_hot_map[ii, cls] = 1.0 + else: + one_hot_map[ii, cls - 1] = 1.0 + if self.no_minus_one_for_one_hot: + text_mask[:] = 1 + else: + text_mask[:len(self.ind_to_class)] = 1 + anno["positive_map"] = one_hot_map + anno["text_mask"] = text_mask + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + sanity_check_target_after_processing(target) + + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data + + +class ModulatedDataset(torchvision.datasets.CocoDetection): + def __init__(self, + img_folder, + ann_file, + transforms, + return_masks, + return_tokens, + is_train=False, + tokenizer=None, + disable_clip_to_image=False, + no_mask_for_gold=False, + max_query_len=256, + **kwargs): + super(ModulatedDataset, self).__init__(img_folder, ann_file) + self.ids = sorted(self.ids) + + ids = [] + for img_id in self.ids: + if isinstance(img_id, str): + ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) + else: + ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = self.coco.loadAnns(ann_ids) + if has_valid_annotation(anno): + ids.append(img_id) + self.ids = ids + + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) + self.is_train = is_train + self.disable_clip_to_image = disable_clip_to_image + self.no_mask_for_gold = no_mask_for_gold + + def __getitem__(self, idx): + img, target = super(ModulatedDataset, self).__getitem__(idx) + image_id = self.ids[idx] + coco_img = self.coco.loadImgs(image_id)[0] + caption = coco_img["caption"] + dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None + anno = {"image_id": image_id, "annotations": target, "caption": caption} + + # This dataset is used for Flickr & Mixed, so the sequence is maskable + anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] + if self.no_mask_for_gold: + anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) + img, anno = self.prepare(img, anno) + + # convert to BoxList (bboxes, labels) + boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes + target = BoxList(boxes, img.size, mode="xyxy") + classes = anno["labels"] + target.add_field("labels", classes) + if self.prepare.return_masks: + target.add_field("masks", anno.pop("masks")) + target.add_field("is_box_mask", anno.pop("is_box_mask")) + if not self.disable_clip_to_image: + num_boxes = len(target.bbox) + target = target.clip_to_image(remove_empty=True) + assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!" + + # Check if bboxes are correct + # draw = ImageDraw.Draw(img) + # boxes = target.bbox + # for box in boxes: + # draw.rectangle([box[0], box[1], box[2], box[3]]) + # img.save('OUTPUT/images/{}.jpg'.format(idx)) + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in anno: + target.add_field(ann, anno[ann]) + + target.add_field("dataset_name", dataset_name) + for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: + if extra_key in coco_img: + target.add_field(extra_key, coco_img[extra_key]) + + if "tokens_positive_eval" in coco_img and not self.is_train: + tokenized = self.prepare.tokenizer(caption, return_tensors="pt") + target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"])) + target.add_field("nb_eval", len(target.get_field("positive_map_eval"))) + + sanity_check_target_after_processing(target) + return img, target, idx + + def get_img_info(self, index): + img_id = self.id_to_img_map[index] + img_data = self.coco.imgs[img_id] + return img_data + + +class CocoDetection(data.Dataset): + """`MS Coco Detection `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, annFile, transform=None, target_transform=None): + from pycocotools.coco import COCO + self.root = root + self.coco = COCO(annFile) + self.ids = list(self.coco.imgs.keys()) + self.transform = transform + self.target_transform = target_transform + + def __getitem__(self, index, return_meta=False): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + if isinstance(img_id, str): + img_id = [img_id] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + meta = coco.loadImgs(img_id)[0] + path = meta['file_name'] + img = pil_loader(os.path.join(self.root, path)) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + if return_meta: + return img, target, meta + else: + return img, target + + def __len__(self): + return len(self.ids) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256): + self.return_masks = return_masks + self.return_tokens = return_tokens + self.tokenizer = tokenizer + self.max_query_len = max_query_len + + def get_box_mask(self, rect, img_size, mode="poly"): + assert mode=="poly", "Only support poly mask right now!" + x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] + return [[x1, y1, x1, y2, x2, y2, x2, y1]] + + def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + caption = target["caption"] if "caption" in target else None + label_to_positions = target.get("label_to_positions", {}) + + greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None) + + anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + if box_format == "xywh": + boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1 + boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1 + boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1 + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + masks = [] + is_box_mask = [] + for obj, bbox in zip(anno, boxes): + if "segmentation" in obj: + masks.append(obj["segmentation"]) + is_box_mask.append(0) + else: + masks.append(self.get_box_mask(bbox, image.size, mode='poly')) + is_box_mask.append(1) + masks = SegmentationMask(masks, image.size, mode='poly') + is_box_mask = torch.tensor(is_box_mask) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + isfinal = None + if anno and "isfinal" in anno[0]: + isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float) + + tokens_positive = [] if self.return_tokens else None + if self.return_tokens and anno and "tokens" in anno[0]: + tokens_positive = [obj["tokens"] for obj in anno] + elif self.return_tokens and anno and "tokens_positive" in anno[0]: + tokens_positive = [obj["tokens_positive"] for obj in anno] + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + is_box_mask = is_box_mask[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if caption is not None: + target["caption"] = caption + if self.return_masks: + target["masks"] = masks + target["is_box_mask"] = is_box_mask + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + if tokens_positive is not None: + target["tokens_positive"] = [] + + for i, k in enumerate(keep): + if k or ignore_box_screen: + target["tokens_positive"].append(tokens_positive[i]) + + if isfinal is not None: + target["isfinal"] = isfinal + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + if self.return_tokens and self.tokenizer is not None: + if not ignore_box_screen: + assert len(target["boxes"]) == len(target["tokens_positive"]) + tokenized = self.tokenizer(caption, return_tensors="pt", + max_length=self.max_query_len, + truncation=True) + target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"]) + target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized) + target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions) + + original_od_label = [] + for obj in anno: + original_od_label.append( + obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100 + target["original_od_label"] = torch.as_tensor(original_od_label) + + return image, target + +def create_greenlight_map(tok_list, tokenized): + # An example tok_list: + # [(0, 5), (10, 13), (-1, -1, -1)] + # The last one is a special indicator.. + + greenlight_map = torch.zeros(256, dtype=torch.float) + for item in tok_list: + if len(item) != 2: + assert(len(item) == 3) + # Make everything unmakable + greenlight_map[:] = -1 + break + + beg, end = item + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + greenlight_map[beg_pos: end_pos + 1].fill_(1) + return greenlight_map + + +def create_positive_map_for_od_labels(tokenized, label_to_positions): + """construct a map such that positive_map[i] = j, where j is the object detection label of the token i""" + """ + {3: [1: 5)} + 256 : -1 3 3 3 3 -1 .. 8 8 .. + the woman in the garden + -1 -1 -1 -1 -1 + """ + positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match + keys = list(label_to_positions.keys()) + for j, key in enumerate(keys): + tok_list = label_to_positions[key] + # one label only mapps to one location + beg, end = tok_list + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + assert beg_pos is not None and end_pos is not None + positive_map[beg_pos: end_pos + 1].fill_(key) + return positive_map + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +def create_positive_map(tokenized, tokens_positive): + """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" + positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) + + for j, tok_list in enumerate(tokens_positive): + for (beg, end) in tok_list: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + positive_map[j, beg_pos: end_pos + 1].fill_(1) + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def pil_loader(path, retry=5): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + ri = 0 + while ri < retry: + try: + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + except: + ri += 1 diff --git a/maskrcnn_benchmark/data/datasets/object365.py b/maskrcnn_benchmark/data/datasets/object365.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9bb4aabe13237b9fad229b310be8b50e31727b --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/object365.py @@ -0,0 +1,8 @@ +import torch +import torchvision +import torch.utils.data as data +from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV + + +class Object365DetectionTSV(CocoDetectionTSV): + pass diff --git a/maskrcnn_benchmark/data/datasets/od_to_grounding.py b/maskrcnn_benchmark/data/datasets/od_to_grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..b93aace9ea6e08a0ae8e7d1b8f87729dfcd84bbc --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/od_to_grounding.py @@ -0,0 +1,375 @@ +import numpy as np +import random +import re +import torch +import pdb +import logging + + +def clean_name(name): + name = re.sub(r"\(.*\)", "", name) + name = re.sub(r"_", " ", name) + name = re.sub(r" ", " ", name) + return name + + +def sanity_check_target_after_processing(target): + assert(len(target.bbox) == len(target.extra_fields["boxes"])) + + +def convert_od_to_grounding_simple( + target, + image_id, + ind_to_class, + disable_shuffle=True, + add_detection_prompt=False, + separation_tokens=" ", + caption_prompt=None): + """ + Convert object detection data into grounding data format, on the fly. + ind_to_class: {0: "__background__", 1 : "person" ...}, contiguous id + """ + + def generate_sentence_from_labels(positive_label_list, negative_label_list, disable_shuffle=True): + label_to_positions = {} + label_list = negative_label_list + positive_label_list + if not disable_shuffle: + random.shuffle(label_list) + assert (caption_prompt is None), "Should not specify caption_prompt when shuffle is enabled!!" # avoid potential bug + + if add_detection_prompt: + pheso_caption = "object detection : " + else: + pheso_caption = "" + + + for index, label in enumerate(label_list): + if caption_prompt is not None: + pheso_caption += caption_prompt[index]['prefix'] + + start_index = len(pheso_caption) + if caption_prompt is not None: + pheso_caption += clean_name(caption_prompt[index]['name']) + else: + pheso_caption += clean_name(ind_to_class[label]) # NOTE: slight change... + end_index = len(pheso_caption) + + if caption_prompt is not None: + pheso_caption += caption_prompt[index]['suffix'] + + # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17 + # label_to_positions: {4: (0, 3), 17: (4, 7)} + label_to_positions[label] = [start_index, end_index] + + if index != len(label_list) - 1: + pheso_caption += separation_tokens + + return label_to_positions, pheso_caption + + label_list = list(sorted(ind_to_class.keys())) # do not include the background + label_to_positions, pheso_caption = generate_sentence_from_labels( + positive_label_list=label_list, + negative_label_list=[], + disable_shuffle=disable_shuffle + ) + + new_target = [] + + ''' + Convert into: + {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} + tokens_positive is the char position + ''' + areas = target.area() + greenlight_span_for_masked_lm_objective = [] + for i in range(len(target)): + new_target_i = {} + new_target_i["area"] = areas[i] + new_target_i["iscrowd"] = 0 + new_target_i["image_id"] = image_id + new_target_i["category_id"] = target.extra_fields["labels"][i].item() + new_target_i["id"] = None + new_target_i['bbox'] = target.bbox[i].numpy().tolist() + + label_i = target.extra_fields["labels"][i].item() + + if label_i in label_to_positions: # NOTE: Only add those that actually appear in the final caption + new_target_i["tokens_positive"] = [label_to_positions[label_i]] + new_target.append(new_target_i) + greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i]) + + return new_target, pheso_caption, greenlight_span_for_masked_lm_objective + + +def check_for_positive_overflow(target, ind_to_class, tokenizer, max_seq_length=256): + # NOTE: Only call this function for OD data; DO NOT USE IT FOR GROUNDING DATA + # NOTE: called only in coco_dt + + # Check if we have too many positive labels + # generate a caption by appending the positive labels + positive_label_set = set() + for i in range(len(target)): + label_i = target.extra_fields["labels"][i].item() + positive_label_set.add(label_i) + positive_label_list = list(positive_label_set) + + # random shuffule so we can sample different annotations at different epochs + random.shuffle(positive_label_list) + + kept_lables = [] + length = 0 + + for index, label in enumerate(positive_label_list): + + label_text = clean_name(ind_to_class[label]) + ". " # "dog. " + + tokenized = tokenizer.tokenize(label_text) + + length += len(tokenized) + + if length > max_seq_length: + break + else: + kept_lables.append(label) + + ## filter boxes + keep_box_index = [] + for i in range(len(target)): + label_i = target.extra_fields["labels"][i].item() + if label_i in kept_lables: + keep_box_index.append(i) + + keep_box_index = torch.LongTensor(keep_box_index) + + target = target[keep_box_index] ## filter boxes + + return target, length + + +def convert_object_detection_to_grounding_optimized_for_od( + target, + image_id, + ind_to_class, + disable_shuffle, + add_detection_prompt, + add_detection_prompt_advanced, + random_sample_negative, + control_probabilities, + restricted_negative_list=None, + separation_tokens=" ", + max_num_labels=-1, + max_seq_length=256, + tokenizer=None, + positive_caption_length=0 +): + ''' + ind_to_class: {0: "__background__", 1 : "person" ...} + target: + + restricted_negative_list : for datasets with restricted negatives, sample only the negatives + + Convert object detection data into grounding data format, on the fly. + + Control options: + 1. add_detection_prompt: add "object detection : " to the front of the prompt + 2. num_negatives: randomly sampled negative classes + 3. num_positives: how many positives to keep (-1 means do not cut any) + + Probabilities to generate the control options: + + a. probability_one_negative: only give one negative class to mimic evaluation + b. probability_one_positive: only give one positive class to mimic evaluation + c. probability_full: add both all positive and all negatives + d. other: + randomly sample some negatives and some positives + The below control options are independent of each other: + - probability_random_negative: probability of randomly sample X negatives + - probability_random_positive: probability of randomly sample some positives + ''' + if restricted_negative_list is None: + valid_negative_indexes = list(ind_to_class.keys()) + else: + valid_negative_indexes = restricted_negative_list + + def generate_senetence_given_labels( + positive_label_list, + negative_label_list, + prompt_engineer_version="v2", + disable_shuffle=False, + positive_question_probability=0.6, + negative_question_probability=0.8, + full_question_probability=0.5): + + ''' + v3: with simple prompt such as "there are", "are there?" + v4: try to merge some are there / there are together, to avoid sequence being too long + ''' + + label_to_positions = {} + + assert (prompt_engineer_version == "v2") + num_negatives = len(negative_label_list) + num_positives = len(positive_label_list) + label_list = negative_label_list + positive_label_list + if not disable_shuffle: + random.shuffle(label_list) + + if add_detection_prompt: + if add_detection_prompt_advanced and (num_negatives == 0 or num_positives == 0) and not disable_shuffle: + pheso_caption = "object detection query : " + else: + pheso_caption = "object detection : " + else: + pheso_caption = "" + + for index, label in enumerate(label_list): + + start_index = len(pheso_caption) + + pheso_caption += clean_name(ind_to_class[label]) # NOTE: slight change... + end_index = len(pheso_caption) + + # e.g.: pheso_caption = "cat dog", where cat is label 4, and dog is label 17 + # label_to_positions: {4: (0, 3), 17: (4, 7)} + label_to_positions[label] = [start_index, end_index] + + if index != len(label_list) - 1: + pheso_caption += separation_tokens + + return label_to_positions, pheso_caption + + if disable_shuffle: + label_list = list(sorted(ind_to_class.keys()))[1:] # do not include the background + label_to_positions, pheso_caption = generate_senetence_given_labels( + positive_label_list=label_list, + negative_label_list=[], + disable_shuffle=True) + # print(label_to_positions, pheso_caption) + else: + positive_label_set = set() + for i in range(len(target)): + label_i = target.extra_fields["labels"][i].item() + positive_label_set.add(label_i) + + full_positive = len(positive_label_set) + if max_num_labels <= 0: + full_negative = random_sample_negative + else: + full_negative = max(min(max_num_labels-full_positive, random_sample_negative), 0) + + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + num_negatives, num_positives = generate_control_options_given_probabilities( + control_probabilities=control_probabilities, + full_positive=full_positive, + full_negative=full_negative) + # num_positives not used + + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + for i in np.random.choice(valid_negative_indexes, size=num_negatives, replace=False): + # label_sets.add(i) + if i not in positive_label_set: + negative_label_list.add(i) + + # Keep all positives; ignoring num_positives + positive_label_list = list(positive_label_set) + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) # e.g.: [17, 1, 13] where each number is the class name + random.shuffle(negative_label_list) + + # Do a pre-screen. If we cannot afford this many negatives, we will sample less + negative_max_length = max_seq_length - positive_caption_length + screened_negative_label_list = [] + for negative_label in negative_label_list: + label_text = clean_name(ind_to_class[negative_label]) + ". " # "dog. " + + tokenized = tokenizer.tokenize(label_text) + + negative_max_length -= len(tokenized) + + if negative_max_length > 0: + screened_negative_label_list.append(negative_label) # keep this negative + else: + break + negative_label_list = screened_negative_label_list + + label_to_positions, pheso_caption = generate_senetence_given_labels( + positive_label_list=positive_label_list, + negative_label_list=negative_label_list) + + new_target = [] + + ''' + Convert into: + {'area': 10506.0, 'iscrowd': 0, 'image_id': 571335, 'category_id': 1, 'id': 2999421, 'bbox': [221, 319, 103, 102], 'tokens_positive': [[0, 3]]} + tokens_positive is the char position + ''' + areas = target.area() + greenlight_span_for_masked_lm_objective = [] + for i in range(len(target)): + new_target_i = {} + new_target_i["area"] = areas[i] + new_target_i["iscrowd"] = 0 + new_target_i["image_id"] = image_id + new_target_i["category_id"] = target.extra_fields["labels"][i].item() + new_target_i["id"] = None + new_target_i['bbox'] = target.bbox[i].numpy().tolist() + + label_i = target.extra_fields["labels"][i].item() + new_target_i["original_od_label"] = label_i + + if label_i in label_to_positions: # NOTE: Only add those that actually appear in the final caption + new_target_i["tokens_positive"] = [label_to_positions[label_i]] + new_target.append(new_target_i) + greenlight_span_for_masked_lm_objective.append(label_to_positions[label_i]) + + return new_target, pheso_caption, greenlight_span_for_masked_lm_objective, label_to_positions + + +def generate_control_options_given_probabilities( + control_probabilities, + full_positive, + full_negative): + + # The function was originally designed to perform data augmentation by randomly dropping negative and positive classes. Later, we decided to only consider dropping negative classes. So the returned 'num_positives' by this function will be ignored. + + outer_prob = random.random() + + probability_one_negative = control_probabilities[0] + probability_one_positive = control_probabilities[1] + probability_full = control_probabilities[2] + probability_drop_positive = control_probabilities[3] + + assert(probability_drop_positive == 0) + + if outer_prob < probability_one_negative: + # a. probability_one_negative: only give one negative class to mimic evaluation (10%) + num_negatives = 1 + num_positives = 0 + elif outer_prob < probability_one_positive + probability_one_negative: + # b. probability_one_positive: only give one positive class to mimic evaluation (10%) + num_negatives = 0 + num_positives = 1 + elif outer_prob < probability_full + probability_one_positive + probability_one_negative: + # c. probability_full: add both all positive and all negatives (20%) + num_negatives = full_negative + num_positives = full_positive + else: + if random.random() < 1.0: # - probability_random_negative: probability of randomly sample X negatives (100%) + num_negatives = np.random.choice(max(1, full_negative)) + 1 # mininum 1 + else: + num_negatives = full_negative # Full + + if random.random() < probability_drop_positive: # + num_positives = np.random.choice(max(1, full_positive)) + 1 + else: + num_positives = full_positive # Full + + return num_negatives, num_positives diff --git a/maskrcnn_benchmark/data/datasets/phrasecut.py b/maskrcnn_benchmark/data/datasets/phrasecut.py new file mode 100644 index 0000000000000000000000000000000000000000..2a68262d2372c69ba9e64535014770ce4be98189 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/phrasecut.py @@ -0,0 +1,8 @@ +import torch +import torchvision +import torch.utils.data as data +from maskrcnn_benchmark.data.datasets.modulated_coco import ModulatedDataset + + +class PhrasecutDetection(ModulatedDataset): + pass diff --git a/maskrcnn_benchmark/data/datasets/pseudo_data.py b/maskrcnn_benchmark/data/datasets/pseudo_data.py new file mode 100644 index 0000000000000000000000000000000000000000..70f2ac3e78feed8740f4f9aeec7bf57695f555b3 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/pseudo_data.py @@ -0,0 +1,228 @@ +import torch +import torch.distributed as dist +import time +from torchvision.ops import nms +import random +import numpy as np +from PIL import Image, ImageDraw +import pdb +from maskrcnn_benchmark.structures.bounding_box import BoxList +from .modulated_coco import ConvertCocoPolysToMask +from .tsv import ODTSVDataset, TSVYamlDataset +from .od_to_grounding import sanity_check_target_after_processing +from copy import deepcopy + +class PseudoData(TSVYamlDataset): + def __init__(self, + yaml_file, + transforms, + return_tokens, + return_masks, + tokenizer, + caption_min_box=1, + replace_clean_label=False, + further_screen=False, + caption_conf=0.5, + caption_nms=-1, + pack_random_caption_number=0, + inference_caption=False, + sample_negative_for_grounding_data=-1, + random_pack_prob=-1.0, + no_random_pack_probability=0.0, + safeguard_positive_caption=True, + mlm_obj_for_only_positive=False, + caption_format_version="v1", + local_debug=False, + max_query_len=256, + diver_box_for_vqa=False, + **kwargs + ): + super(PseudoData, self).__init__(yaml_file, None, replace_clean_label) + self.yaml_file = yaml_file + self._transforms = transforms + self.max_query_len = max_query_len + self.prepare = ConvertCocoPolysToMask(return_masks=return_masks, + return_tokens=return_tokens, + tokenizer=tokenizer, + max_query_len=max_query_len) + self.diver_box_for_vqa = diver_box_for_vqa + if "qa" in self.yaml_file: + assert(self.diver_box_for_vqa) # must diver box + self.tokenizer = tokenizer + self.caption_min_box = caption_min_box + self.replace_clean_label = replace_clean_label + self.further_screen = further_screen + self.pack_random_caption_number = pack_random_caption_number + self.caption_format_version = caption_format_version + + self.caption_conf = caption_conf + self.caption_nms = caption_nms + self.inference_caption = inference_caption + self.sample_negative_for_grounding_data = sample_negative_for_grounding_data + self.random_pack_prob = random_pack_prob + self.no_random_pack_probability = no_random_pack_probability + self.safeguard_positive_caption = safeguard_positive_caption + self.mlm_obj_for_only_positive = mlm_obj_for_only_positive + self.local_debug = local_debug + try: + self.rank = dist.get_rank() + except: + self.rank = 0 + + def __len__(self): + return super(PseudoData, self).__len__() + + @staticmethod + def check_for_overlap(range1, range2): + if range1[0] > range2[1] or range2[0] > range1[1]: + return False + return True + + def divert_boxes(self, anno): + # first get answer start and end + answer_start = len(anno['text']) + 1 # +1 for the space + answer_end = len(anno["caption"]) + + question = anno["caption"][:answer_start] # get the question + + mask_start = len(question) + # add the mask token + mask_token = self.tokenizer.mask_token + if mask_token is None: + mask_token = 'answer' + question += mask_token + mask_end = len(question) + + # divert the box + for i in range(len(anno["bboxes"])): + # check over lap + for j in range(len(anno["tokens_positive"][i])): + if self.check_for_overlap(anno["tokens_positive"][i][j], [answer_start, answer_end]): + # if overlap, then divert the box to the mask token + anno["tokens_positive"][i][j] = [mask_start, mask_end] + + anno["caption"] = question + return question, anno + + def __getitem__(self, idx): + img, anno, _, scale = super(PseudoData, self).__getitem__(idx) + if self.inference_caption: + caption = None + if isinstance(anno, list): + caption = anno[0]["caption"] # inference mode for bing + anno = [] + elif len(anno) == 1: + caption = anno["caption"] # inference mode for googlecc + anno = [] + else: + caption = " ".join(anno["captions"]) + anno = [] + else: + if self.caption_format_version == "v2": + anno = self.convert_anno_from_yiling_to_ours(anno) + + if self.further_screen: + conf = self.caption_conf + nms_thre = self.caption_nms + + bboxes = torch.as_tensor(anno["bboxes"]).float() + scores = torch.as_tensor(anno["scores"]) + tokens_positive = anno["tokens_positive"] + + keep = scores > conf + scores = scores[keep] + bboxes = bboxes[keep] + tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]] + + assert (len(tokens_positive) == len(bboxes) == len(scores)) + + if len(bboxes) < self.caption_min_box: # Retry triggered! + return self[np.random.choice(len(self))] + + if nms_thre > 0: + keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre) + scores = scores[keep] + bboxes = bboxes[keep] + tokens_positive = [tokens_positive[i] for i in keep] + assert (len(tokens_positive) == len(bboxes) == len(scores)) + + # Write back + anno["bboxes"] = bboxes.tolist() + anno["scores"] = scores.tolist() + anno["tokens_positive"] = tokens_positive + + boxes = torch.as_tensor(anno["bboxes"]) + + if len(boxes) < self.caption_min_box: # Retry triggered! + return self[np.random.choice(len(self))] + + target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy") + target = target.clip_to_image(remove_empty=True) + + if self.diver_box_for_vqa: + caption, anno = self.divert_boxes(anno=anno) # will change caption and "tokens_positive" + + caption = anno["caption"] + + greenlight_span_for_masked_lm_objective = [(0, len(caption))] + + new_anno = [] + areas = target.area() + for i in range(len(target)): + new_anno_i = {} + new_anno_i["area"] = areas[i] + new_anno_i["iscrowd"] = 0 + new_anno_i["image_id"] = idx + new_anno_i["category_id"] = 1 # following vg and others + new_anno_i["id"] = None + new_anno_i['bbox'] = target.bbox[i].numpy().tolist() + new_anno_i["tokens_positive"] = anno["tokens_positive"][i] + new_anno.append(new_anno_i) + anno = new_anno + + annotations = {"image_id": idx, "annotations": anno, "caption": caption} + annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective + img, annotations = self.prepare(img, annotations, box_format="xyxy") + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # add additional property + for ann in annotations: + target.add_field(ann, annotations[ann]) + + # This is the real image_id + image_id = self.get_img_id(idx) + # Can insert additional field into target if needed + + sanity_check_target_after_processing(target) + + return img, target, idx + + def convert_anno_from_yiling_to_ours(self, anno): + flatterned_bboxes = [] + flatterned_tokens_positive = [] + flatterned_bboxes_scores = [] + for i in range(len(anno["bboxes"])): + # i is the index for entity + for j in range(len(anno["bboxes"][i])): + # j is the index for each box + flatterned_bboxes.append(anno["bboxes"][i][j]) + flatterned_tokens_positive.append( + anno["tokens_positive"][i]) # Assume this box corresponds to all the token_spans for this entity + flatterned_bboxes_scores.append(anno["scores"][i][j]) + anno["bboxes"] = flatterned_bboxes + anno["tokens_positive"] = flatterned_tokens_positive + anno["scores"] = flatterned_bboxes_scores + return anno + + def get_raw_image(self, idx): + image, *_ = super(PseudoData, self).__getitem__(idx) + return image + + def get_img_id(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv.seek(line_no) + img_id = row[0] + return img_id diff --git a/maskrcnn_benchmark/data/datasets/refexp.py b/maskrcnn_benchmark/data/datasets/refexp.py new file mode 100644 index 0000000000000000000000000000000000000000..a63015aff6919f1c2ea97382bc319f92b742f76a --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/refexp.py @@ -0,0 +1,88 @@ +import copy +from collections import defaultdict +from pathlib import Path + +import torch +import torch.utils.data + +import maskrcnn_benchmark.utils.dist as dist +from maskrcnn_benchmark.layers.set_loss import generalized_box_iou + +from .modulated_coco import ModulatedDataset + + +class RefExpDataset(ModulatedDataset): + pass + + +class RefExpEvaluator(object): + def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5): + assert isinstance(k, (list, tuple)) + refexp_gt = copy.deepcopy(refexp_gt) + self.refexp_gt = refexp_gt + self.iou_types = iou_types + self.img_ids = self.refexp_gt.imgs.keys() + self.predictions = {} + self.k = k + self.thresh_iou = thresh_iou + + def accumulate(self): + pass + + def update(self, predictions): + self.predictions.update(predictions) + + def synchronize_between_processes(self): + all_predictions = dist.all_gather(self.predictions) + merged_predictions = {} + for p in all_predictions: + merged_predictions.update(p) + self.predictions = merged_predictions + + def summarize(self): + if dist.is_main_process(): + dataset2score = { + "refcoco": {k: 0.0 for k in self.k}, + "refcoco+": {k: 0.0 for k in self.k}, + "refcocog": {k: 0.0 for k in self.k}, + } + dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0} + for image_id in self.img_ids: + ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id) + assert len(ann_ids) == 1 + img_info = self.refexp_gt.loadImgs(image_id)[0] + + target = self.refexp_gt.loadAnns(ann_ids[0]) + prediction = self.predictions[image_id] + assert prediction is not None + sorted_scores_boxes = sorted( + zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True + ) + sorted_scores, sorted_boxes = zip(*sorted_scores_boxes) + sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes]) + target_bbox = target[0]["bbox"] + converted_bbox = [ + target_bbox[0], + target_bbox[1], + target_bbox[2] + target_bbox[0], + target_bbox[3] + target_bbox[1], + ] + giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4)) + for k in self.k: + if max(giou[:k]) >= self.thresh_iou: + dataset2score[img_info["dataset_name"]][k] += 1.0 + dataset2count[img_info["dataset_name"]] += 1.0 + + for key, value in dataset2score.items(): + for k in self.k: + try: + value[k] /= dataset2count[key] + except: + pass + results = {} + for key, value in dataset2score.items(): + results[key] = sorted([v for k, v in value.items()]) + print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n") + + return results + return None diff --git a/maskrcnn_benchmark/data/datasets/tsv.py b/maskrcnn_benchmark/data/datasets/tsv.py new file mode 100644 index 0000000000000000000000000000000000000000..64b92fb16631aae21bad47c1569b582ea0b6431e --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/tsv.py @@ -0,0 +1,420 @@ +import os +import os.path as op +import json +# import logging +import base64 +import yaml +import errno +import io +import math +from PIL import Image, ImageDraw + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from .box_label_loader import LabelLoader + + +def load_linelist_file(linelist_file): + if linelist_file is not None: + line_list = [] + with open(linelist_file, 'r') as fp: + for i in fp: + line_list.append(int(i.strip())) + return line_list + + +def img_from_base64(imagestring): + try: + img = Image.open(io.BytesIO(base64.b64decode(imagestring))) + return img.convert('RGB') + except ValueError: + return None + + +def load_from_yaml_file(yaml_file): + with open(yaml_file, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) + + +def find_file_path_in_yaml(fname, root): + if fname is not None: + if op.isfile(fname): + return fname + elif op.isfile(op.join(root, fname)): + return op.join(root, fname) + else: + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) + ) + + +def create_lineidx(filein, idxout): + idxout_tmp = idxout + '.tmp' + with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: + fsize = os.fstat(tsvin.fileno()).st_size + fpos = 0 + while fpos != fsize: + tsvout.write(str(fpos) + "\n") + tsvin.readline() + fpos = tsvin.tell() + os.rename(idxout_tmp, idxout) + + +def read_to_character(fp, c): + result = [] + while True: + s = fp.read(32) + assert s != '' + if c in s: + result.append(s[: s.index(c)]) + break + else: + result.append(s) + return ''.join(result) + + +class TSVFile(object): + def __init__(self, tsv_file, generate_lineidx=False): + self.tsv_file = tsv_file + self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' + self._fp = None + self._lineidx = None + # the process always keeps the process which opens the file. + # If the pid is not equal to the currrent pid, we will re-open the file. + self.pid = None + # generate lineidx if not exist + if not op.isfile(self.lineidx) and generate_lineidx: + create_lineidx(self.tsv_file, self.lineidx) + + def __del__(self): + if self._fp: + self._fp.close() + + def __str__(self): + return "TSVFile(tsv_file='{}')".format(self.tsv_file) + + def __repr__(self): + return str(self) + + def num_rows(self): + self._ensure_lineidx_loaded() + return len(self._lineidx) + + def seek(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + try: + pos = self._lineidx[idx] + except: + # logging.info('{}-{}'.format(self.tsv_file, idx)) + raise + self._fp.seek(pos) + return [s.strip() for s in self._fp.readline().split('\t')] + + def seek_first_column(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + pos = self._lineidx[idx] + self._fp.seek(pos) + return read_to_character(self._fp, '\t') + + def get_key(self, idx): + return self.seek_first_column(idx) + + def __getitem__(self, index): + return self.seek(index) + + def __len__(self): + return self.num_rows() + + def _ensure_lineidx_loaded(self): + if self._lineidx is None: + # logging.info('loading lineidx: {}'.format(self.lineidx)) + with open(self.lineidx, 'r') as fp: + self._lineidx = [int(i.strip()) for i in fp.readlines()] + + def _ensure_tsv_opened(self): + if self._fp is None: + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + if self.pid != os.getpid(): + # logging.info('re-open {} because the process id changed'.format(self.tsv_file)) + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + +class CompositeTSVFile(): + def __init__(self, file_list, seq_file, root='.'): + if isinstance(file_list, str): + self.file_list = load_list_file(file_list) + else: + assert isinstance(file_list, list) + self.file_list = file_list + + self.seq_file = seq_file + self.root = root + self.initialized = False + self.initialize() + + def get_key(self, index): + idx_source, idx_row = self.seq[index] + k = self.tsvs[idx_source].get_key(idx_row) + return '_'.join([self.file_list[idx_source], k]) + + def num_rows(self): + return len(self.seq) + + def __getitem__(self, index): + idx_source, idx_row = self.seq[index] + return self.tsvs[idx_source].seek(idx_row) + + def __len__(self): + return len(self.seq) + + def initialize(self): + ''' + this function has to be called in init function if cache_policy is + enabled. Thus, let's always call it in init funciton to make it simple. + ''' + if self.initialized: + return + self.seq = [] + with open(self.seq_file, 'r') as fp: + for line in fp: + parts = line.strip().split('\t') + self.seq.append([int(parts[0]), int(parts[1])]) + self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] + self.initialized = True + + +def load_list_file(fname): + with open(fname, 'r') as fp: + lines = fp.readlines() + result = [line.strip() for line in lines] + if len(result) > 0 and result[-1] == '': + result = result[:-1] + return result + + +class TSVDataset(object): + def __init__(self, img_file, label_file=None, hw_file=None, + linelist_file=None, imageid2idx_file=None): + """Constructor. + Args: + img_file: Image file with image key and base64 encoded image str. + label_file: An optional label file with image key and label information. + A label_file is required for training and optional for testing. + hw_file: An optional file with image key and image height/width info. + linelist_file: An optional file with a list of line indexes to load samples. + It is useful to select a subset of samples or duplicate samples. + """ + self.img_file = img_file + self.label_file = label_file + self.hw_file = hw_file + self.linelist_file = linelist_file + + self.img_tsv = TSVFile(img_file) + self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True) + self.hw_tsv = None if hw_file is None else TSVFile(hw_file) + self.line_list = load_linelist_file(linelist_file) + self.imageid2idx = None + if imageid2idx_file is not None: + self.imageid2idx = json.load(open(imageid2idx_file, 'r')) + + self.transforms = None + + def __len__(self): + if self.line_list is None: + if self.imageid2idx is not None: + assert self.label_tsv is not None, "label_tsv is None!!!" + return self.label_tsv.num_rows() + return self.img_tsv.num_rows() + else: + return len(self.line_list) + + def __getitem__(self, idx): + img = self.get_image(idx) + img_size = img.size # w, h + annotations = self.get_annotations(idx) + # print(idx, annotations) + target = self.get_target_from_annotations(annotations, img_size, idx) + img, target = self.apply_transforms(img, target) + + if self.transforms is None: + return img, target, idx, 1.0 + else: + new_img_size = img.shape[1:] + scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1])) + return img, target, idx, scale + + def get_line_no(self, idx): + return idx if self.line_list is None else self.line_list[idx] + + def get_image(self, idx): + line_no = self.get_line_no(idx) + if self.imageid2idx is not None: + assert self.label_tsv is not None, "label_tsv is None!!!" + row = self.label_tsv.seek(line_no) + annotations = json.loads(row[1]) + imageid = annotations["img_id"] + line_no = self.imageid2idx[imageid] + row = self.img_tsv.seek(line_no) + # use -1 to support old format with multiple columns. + img = img_from_base64(row[-1]) + return img + + def get_annotations(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv.seek(line_no) + annotations = json.loads(row[1]) + return annotations + else: + return [] + + def get_target_from_annotations(self, annotations, img_size, idx): + # This function will be overwritten by each dataset to + # decode the labels to specific formats for each task. + return annotations + + def apply_transforms(self, image, target=None): + # This function will be overwritten by each dataset to + # apply transforms to image and targets. + return image, target + + def get_img_info(self, idx): + if self.imageid2idx is not None: + assert self.label_tsv is not None, "label_tsv is None!!!" + line_no = self.get_line_no(idx) + row = self.label_tsv.seek(line_no) + annotations = json.loads(row[1]) + return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])} + + if self.hw_tsv is not None: + line_no = self.get_line_no(idx) + row = self.hw_tsv.seek(line_no) + try: + # json string format with "height" and "width" being the keys + data = json.loads(row[1]) + if type(data) == list: + return data[0] + elif type(data) == dict: + return data + except ValueError: + # list of strings representing height and width in order + hw_str = row[1].split(' ') + hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} + return hw_dict + + def get_img_key(self, idx): + line_no = self.get_line_no(idx) + # based on the overhead of reading each row. + if self.imageid2idx is not None: + assert self.label_tsv is not None, "label_tsv is None!!!" + row = self.label_tsv.seek(line_no) + annotations = json.loads(row[1]) + return annotations["img_id"] + + if self.hw_tsv: + return self.hw_tsv.seek(line_no)[0] + elif self.label_tsv: + return self.label_tsv.seek(line_no)[0] + else: + return self.img_tsv.seek(line_no)[0] + + +class TSVYamlDataset(TSVDataset): + """ TSVDataset taking a Yaml file for easy function call + """ + + def __init__(self, yaml_file, root=None, replace_clean_label=False): + print("Reading {}".format(yaml_file)) + self.cfg = load_from_yaml_file(yaml_file) + if root: + self.root = root + else: + self.root = op.dirname(yaml_file) + img_file = find_file_path_in_yaml(self.cfg['img'], self.root) + label_file = find_file_path_in_yaml(self.cfg.get('label', None), + self.root) + hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + imageid2idx_file = find_file_path_in_yaml(self.cfg.get('imageid2idx', None), + self.root) + + if replace_clean_label: + assert ("raw_label" in label_file) + label_file = label_file.replace("raw_label", "clean_label") + + super(TSVYamlDataset, self).__init__( + img_file, label_file, hw_file, linelist_file, imageid2idx_file) + + +class ODTSVDataset(TSVYamlDataset): + """ + Generic TSV dataset format for Object Detection. + """ + + def __init__(self, yaml_file, extra_fields=(), transforms=None, + is_load_label=True, **kwargs): + if yaml_file is None: + return + super(ODTSVDataset, self).__init__(yaml_file) + + self.transforms = transforms + self.is_load_label = is_load_label + self.attribute_on = False + # self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False + + if self.is_load_label: + # construct maps + jsondict_file = find_file_path_in_yaml( + self.cfg.get("labelmap", None), self.root + ) + if jsondict_file is None: + jsondict_file = find_file_path_in_yaml( + self.cfg.get("jsondict", None), self.root + ) + if "json" in jsondict_file: + jsondict = json.load(open(jsondict_file, 'r')) + if "label_to_idx" not in jsondict: + jsondict = {'label_to_idx': jsondict} + elif "tsv" in jsondict_file: + label_to_idx = {} + counter = 1 + with open(jsondict_file) as f: + for line in f: + label_to_idx[line.strip()] = counter + counter += 1 + jsondict = {'label_to_idx': label_to_idx} + else: + assert (0) + + self.labelmap = {} + self.class_to_ind = jsondict['label_to_idx'] + self.class_to_ind['__background__'] = 0 + self.ind_to_class = {v: k for k, v in self.class_to_ind.items()} + self.labelmap['class_to_ind'] = self.class_to_ind + + if self.attribute_on: + self.attribute_to_ind = jsondict['attribute_to_idx'] + self.attribute_to_ind['__no_attribute__'] = 0 + self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()} + self.labelmap['attribute_to_ind'] = self.attribute_to_ind + + self.label_loader = LabelLoader( + labelmap=self.labelmap, + extra_fields=extra_fields, + ) + + def get_target_from_annotations(self, annotations, img_size, idx): + if isinstance(annotations, list): + annotations = {"objects": annotations} + if self.is_load_label: + return self.label_loader(annotations['objects'], img_size) + + def apply_transforms(self, img, target=None): + if self.transforms is not None: + img, target = self.transforms(img, target) + return img, target diff --git a/maskrcnn_benchmark/data/datasets/vg.py b/maskrcnn_benchmark/data/datasets/vg.py new file mode 100644 index 0000000000000000000000000000000000000000..c94eacd3ee75346ba06a61efdb0f28ae53b82501 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/vg.py @@ -0,0 +1,267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import collections +import json +import os.path as op + +import numpy as np +import torch + +from .tsv import TSVYamlDataset, find_file_path_in_yaml +from .box_label_loader import BoxLabelLoader +from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV + + +class VGDetectionTSV(CocoDetectionTSV): + pass + + +def sort_key_by_val(dic): + sorted_dic = sorted(dic.items(), key=lambda kv: kv[1]) + return [kv[0] for kv in sorted_dic] + + +def bbox_overlaps(anchors, gt_boxes): + """ + anchors: (N, 4) ndarray of float + gt_boxes: (K, 4) ndarray of float + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = anchors.size(0) + K = gt_boxes.size(0) + + gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * + (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K) + + anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) * + (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1) + + boxes = anchors.view(N, 1, 4).expand(N, K, 4) + query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) + + iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - + torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1) + iw[iw < 0] = 0 + + ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - + torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1) + ih[ih < 0] = 0 + + ua = anchors_area + gt_boxes_area - (iw * ih) + overlaps = iw * ih / ua + + return overlaps + + +# VG data loader for Danfei Xu's Scene graph focused format. +# todo: if ordering of classes, attributes, relations changed +# todo make sure to re-write the obj_classes.txt/rel_classes.txt files + +def _box_filter(boxes, must_overlap=False): + """ Only include boxes that overlap as possible relations. + If no overlapping boxes, use all of them.""" + overlaps = bbox_overlaps(boxes, boxes).numpy() > 0 + np.fill_diagonal(overlaps, 0) + + all_possib = np.ones_like(overlaps, dtype=np.bool) + np.fill_diagonal(all_possib, 0) + + if must_overlap: + possible_boxes = np.column_stack(np.where(overlaps)) + + if possible_boxes.size == 0: + possible_boxes = np.column_stack(np.where(all_possib)) + else: + possible_boxes = np.column_stack(np.where(all_possib)) + return possible_boxes + + +class VGTSVDataset(TSVYamlDataset): + """ + Generic TSV dataset format for Object Detection. + """ + + def __init__(self, yaml_file, extra_fields=None, transforms=None, + is_load_label=True, filter_duplicate_rels=True, + relation_on=False, cv2_output=False, **kwargs): + if extra_fields is None: + extra_fields = [] + self.transforms = transforms + self.is_load_label = is_load_label + self.relation_on = relation_on + super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output) + + ignore_attrs = self.cfg.get("ignore_attrs", None) + # construct those maps + jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root) + jsondict = json.load(open(jsondict_file, 'r')) + + # self.linelist_file + if 'train' in op.basename(self.linelist_file): + self.split = "train" + elif 'test' in op.basename(self.linelist_file) \ + or 'val' in op.basename(self.linelist_file) \ + or 'valid' in op.basename(self.linelist_file): + self.split = "test" + else: + raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file)) + self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train' + + self.class_to_ind = jsondict['label_to_idx'] + self.ind_to_class = jsondict['idx_to_label'] + self.class_to_ind['__background__'] = 0 + self.ind_to_class['0'] = '__background__' + self.classes = sort_key_by_val(self.class_to_ind) + assert (all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))])) + + # writing obj classes to disk for Neural Motif model building. + obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt" + if not op.isfile(obj_classes_out_fn): + with open(obj_classes_out_fn, 'w') as f: + for item in self.classes: + f.write("%s\n" % item) + + self.attribute_to_ind = jsondict['attribute_to_idx'] + self.ind_to_attribute = jsondict['idx_to_attribute'] + self.attribute_to_ind['__no_attribute__'] = 0 + self.ind_to_attribute['0'] = '__no_attribute__' + self.attributes = sort_key_by_val(self.attribute_to_ind) + assert (all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))])) + + self.relation_to_ind = jsondict['predicate_to_idx'] + self.ind_to_relation = jsondict['idx_to_predicate'] + self.relation_to_ind['__no_relation__'] = 0 + self.ind_to_relation['0'] = '__no_relation__' + self.relations = sort_key_by_val(self.relation_to_ind) + assert (all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))])) + + # writing rel classes to disk for Neural Motif Model building. + rel_classes_out_fn = op.splitext(self.label_file)[0] + '.rel_classes.txt' + if not op.isfile(rel_classes_out_fn): + with open(rel_classes_out_fn, 'w') as f: + for item in self.relations: + f.write("%s\n" % item) + + # label map: minus one because we will add one in BoxLabelLoader + self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()} + labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root) + # self.labelmap_dec = load_labelmap_file(labelmap_file) + if self.is_load_label: + self.label_loader = BoxLabelLoader( + labelmap=self.labelmap, + extra_fields=extra_fields, + ignore_attrs=ignore_attrs + ) + + # get frequency prior for relations + if self.relation_on: + self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy" + if self.split == 'train' and not op.exists(self.freq_prior_file): + print("Computing frequency prior matrix...") + fg_matrix, bg_matrix = self._get_freq_prior() + prob_matrix = fg_matrix.astype(np.float32) + prob_matrix[:, :, 0] = bg_matrix + prob_matrix[:, :, 0] += 1 + prob_matrix /= np.sum(prob_matrix, 2)[:, :, None] + np.save(self.freq_prior_file, prob_matrix) + + def _get_freq_prior(self, must_overlap=False): + fg_matrix = np.zeros(( + len(self.classes), + len(self.classes), + len(self.relations) + ), dtype=np.int64) + + bg_matrix = np.zeros(( + len(self.classes), + len(self.classes), + ), dtype=np.int64) + + for ex_ind in range(self.__len__()): + target = self.get_groundtruth(ex_ind) + gt_classes = target.get_field('labels').numpy() + gt_relations = target.get_field('relation_labels').numpy() + gt_boxes = target.bbox + + # For the foreground, we'll just look at everything + try: + o1o2 = gt_classes[gt_relations[:, :2]] + for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]): + fg_matrix[o1, o2, gtr] += 1 + + # For the background, get all of the things that overlap. + o1o2_total = gt_classes[np.array( + _box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)] + for (o1, o2) in o1o2_total: + bg_matrix[o1, o2] += 1 + except IndexError as e: + assert len(gt_relations) == 0 + + if ex_ind % 20 == 0: + print("processing {}/{}".format(ex_ind, self.__len__())) + + return fg_matrix, bg_matrix + + def relation_loader(self, relation_triplets, target): + # relation_triplets [list of tuples]: M*3 + # target: BoxList from label_loader + if self.filter_duplicate_rels: + # Filter out dupes! + assert self.split == 'train' + all_rel_sets = collections.defaultdict(list) + for (o0, o1, r) in relation_triplets: + all_rel_sets[(o0, o1)].append(r) + relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()] + + # get M*M pred_labels + relations = torch.zeros([len(target), len(target)], dtype=torch.int64) + for i in range(len(relation_triplets)): + subj_id = relation_triplets[i][0] + obj_id = relation_triplets[i][1] + pred = relation_triplets[i][2] + relations[subj_id, obj_id] = int(pred) + + relation_triplets = torch.tensor(relation_triplets) + target.add_field("relation_labels", relation_triplets) + target.add_field("pred_labels", relations) + return target + + def get_target_from_annotations(self, annotations, img_size, idx): + if self.is_load_label and annotations: + target = self.label_loader(annotations['objects'], img_size) + # make sure no boxes are removed + assert (len(annotations['objects']) == len(target)) + if self.split in ["val", "test"]: + # add the difficult field + target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32)) + # load relations + if self.relation_on: + target = self.relation_loader(annotations["relations"], target) + return target + + def get_groundtruth(self, idx, call=False): + # similar to __getitem__ but without transform + img = self.get_image(idx) + if self.cv2_output: + img_size = img.shape[:2][::-1] # h, w -> w, h + else: + img_size = img.size # w, h + annotations = self.get_annotations(idx) + target = self.get_target_from_annotations(annotations, img_size, idx) + if call: + return img, target, annotations + else: + return target + + def apply_transforms(self, img, target=None): + if self.transforms is not None: + img, target = self.transforms(img, target) + return img, target + + def map_class_id_to_class_name(self, class_id): + return self.classes[class_id] + + def map_attribute_id_to_attribute_name(self, attribute_id): + return self.attributes[attribute_id] + + def map_relation_id_to_relation_name(self, relation_id): + return self.relations[relation_id] diff --git a/maskrcnn_benchmark/data/datasets/voc.py b/maskrcnn_benchmark/data/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..4288ba5231dc513ec1e33fe952527ba6433fd199 --- /dev/null +++ b/maskrcnn_benchmark/data/datasets/voc.py @@ -0,0 +1,134 @@ +import os + +import torch +import torch.utils.data +from PIL import Image +import sys + +if sys.version_info[0] == 2: + import xml.etree.cElementTree as ET +else: + import xml.etree.ElementTree as ET + + +from maskrcnn_benchmark.structures.bounding_box import BoxList + + +class PascalVOCDataset(torch.utils.data.Dataset): + + CLASSES = ( + "__background__ ", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", + ) + + def __init__(self, data_dir, split, use_difficult=False, transforms=None): + self.root = data_dir + self.image_set = split + self.keep_difficult = use_difficult + self.transforms = transforms + + self._annopath = os.path.join(self.root, "Annotations", "%s.xml") + self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg") + self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt") + + with open(self._imgsetpath % self.image_set) as f: + self.ids = f.readlines() + self.ids = [x.strip("\n") for x in self.ids] + self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} + + cls = PascalVOCDataset.CLASSES + self.class_to_ind = dict(zip(cls, range(len(cls)))) + + def __getitem__(self, index): + img_id = self.ids[index] + img = Image.open(self._imgpath % img_id).convert("RGB") + + target = self.get_groundtruth(index) + target = target.clip_to_image(remove_empty=True) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target, index + + def __len__(self): + return len(self.ids) + + def get_groundtruth(self, index): + img_id = self.ids[index] + anno = ET.parse(self._annopath % img_id).getroot() + anno = self._preprocess_annotation(anno) + + height, width = anno["im_info"] + target = BoxList(anno["boxes"], (width, height), mode="xyxy") + target.add_field("labels", anno["labels"]) + target.add_field("difficult", anno["difficult"]) + return target + + def _preprocess_annotation(self, target): + boxes = [] + gt_classes = [] + difficult_boxes = [] + TO_REMOVE = 1 + + for obj in target.iter("object"): + difficult = int(obj.find("difficult").text) == 1 + if not self.keep_difficult and difficult: + continue + name = obj.find("name").text.lower().strip() + bb = obj.find("bndbox") + # Make pixel indexes 0-based + # Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211" + box = [ + bb.find("xmin").text, + bb.find("ymin").text, + bb.find("xmax").text, + bb.find("ymax").text, + ] + bndbox = tuple( + map(lambda x: x - TO_REMOVE, list(map(int, box))) + ) + + boxes.append(bndbox) + gt_classes.append(self.class_to_ind[name]) + difficult_boxes.append(difficult) + + size = target.find("size") + im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) + + res = { + "boxes": torch.tensor(boxes, dtype=torch.float32), + "labels": torch.tensor(gt_classes), + "difficult": torch.tensor(difficult_boxes), + "im_info": im_info, + } + return res + + def get_img_info(self, index): + img_id = self.ids[index] + anno = ET.parse(self._annopath % img_id).getroot() + size = anno.find("size") + im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) + return {"height": im_info[0], "width": im_info[1]} + + def map_class_id_to_class_name(self, class_id): + return PascalVOCDataset.CLASSES[class_id] diff --git a/maskrcnn_benchmark/data/samplers/__init__.py b/maskrcnn_benchmark/data/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f891498f3d66c08a4840de0b12fb03b6834ba4c8 --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .distributed import DistributedSampler +from .grouped_batch_sampler import GroupedBatchSampler +from .iteration_based_batch_sampler import IterationBasedBatchSampler + +__all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"] diff --git a/maskrcnn_benchmark/data/samplers/distributed.py b/maskrcnn_benchmark/data/samplers/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2aa926f61243e77a9e959ef36826c854467fc5 --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/distributed.py @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Code is copy-pasted exactly as in torch.utils.data.distributed. +# FIXME remove this once c10d fixes the bug it has +import math +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler + +from maskrcnn_benchmark.utils.comm import shared_random_seed + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, use_random=False): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.use_random = use_random + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch + _seed = self.epoch + if self.use_random: + _seed = int(shared_random_seed()) + g = torch.Generator() + g.manual_seed(_seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6985b9ccef6d9a7353e11817a904d309395b82 --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import itertools + +import torch +from torch.utils.data.sampler import BatchSampler +from torch.utils.data.sampler import Sampler + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that elements from the same group should appear in groups of batch_size. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_uneven (bool): If ``True``, the sampler will drop the batches whose + size is less than ``batch_size`` + + """ + + def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): + if not isinstance(sampler, Sampler): + raise ValueError( + "sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = torch.as_tensor(group_ids) + assert self.group_ids.dim() == 1 + self.batch_size = batch_size + self.drop_uneven = drop_uneven + + self.groups = torch.unique(self.group_ids).sort(0)[0] + + self._can_reuse_batches = False + + def _prepare_batches(self): + dataset_size = len(self.group_ids) + # get the sampled indices from the sampler + sampled_ids = torch.as_tensor(list(self.sampler)) + # potentially not all elements of the dataset were sampled + # by the sampler (e.g., DistributedSampler). + # construct a tensor which contains -1 if the element was + # not sampled, and a non-negative number indicating the + # order where the element was sampled. + # for example. if sampled_ids = [3, 1] and dataset_size = 5, + # the order is [-1, 1, -1, 0, -1] + order = torch.full((dataset_size,), -1, dtype=torch.int64) + order[sampled_ids] = torch.arange(len(sampled_ids)) + + # get a mask with the elements that were sampled + mask = order >= 0 + + # find the elements that belong to each individual cluster + clusters = [(self.group_ids == i) & mask for i in self.groups] + # get relative order of the elements inside each cluster + # that follows the order from the sampler + relative_order = [order[cluster] for cluster in clusters] + # with the relative order, find the absolute order in the + # sampled space + permutation_ids = [s[s.sort()[1]] for s in relative_order] + # permute each cluster so that they follow the order from + # the sampler + permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] + + # splits each cluster in batch_size, and merge as a list of tensors + splits = [c.split(self.batch_size) for c in permuted_clusters] + merged = tuple(itertools.chain.from_iterable(splits)) + + # now each batch internally has the right order, but + # they are grouped by clusters. Find the permutation between + # different batches that brings them as close as possible to + # the order that we have in the sampler. For that, we will consider the + # ordering as coming from the first element of each batch, and sort + # correspondingly + first_element_of_batch = [t[0].item() for t in merged] + # get and inverse mapping from sampled indices and the position where + # they occur (as returned by the sampler) + inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} + # from the first element in each batch, get a relative ordering + first_index_of_batch = torch.as_tensor( + [inv_sampled_ids_map[s] for s in first_element_of_batch] + ) + + # permute the batches so that they approximately follow the order + # from the sampler + permutation_order = first_index_of_batch.sort(0)[1].tolist() + # finally, permute the batches + batches = [merged[i].tolist() for i in permutation_order] + + if self.drop_uneven: + kept = [] + for batch in batches: + if len(batch) == self.batch_size: + kept.append(batch) + batches = kept + return batches + + def __iter__(self): + if self._can_reuse_batches: + batches = self._batches + self._can_reuse_batches = False + else: + batches = self._prepare_batches() + self._batches = batches + return iter(batches) + + def __len__(self): + if not hasattr(self, "_batches"): + self._batches = self._prepare_batches() + self._can_reuse_batches = True + return len(self._batches) diff --git a/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..431693eecd2e474dacdbc9eb805dbe2b092234cc --- /dev/null +++ b/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch.utils.data.sampler import BatchSampler + + +class IterationBasedBatchSampler(BatchSampler): + """ + Wraps a BatchSampler, resampling from it until + a specified number of iterations have been sampled + """ + + def __init__(self, batch_sampler, num_iterations, start_iter=0): + self.batch_sampler = batch_sampler + self.num_iterations = num_iterations + self.start_iter = start_iter + + def __iter__(self): + iteration = self.start_iter + while iteration <= self.num_iterations: + # if the underlying sampler has a set_epoch method, like + # DistributedSampler, used for making each process see + # a different split of the dataset, then set it + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) + for batch in self.batch_sampler: + iteration += 1 + if iteration > self.num_iterations: + break + yield batch + + def __len__(self): + return self.num_iterations diff --git a/maskrcnn_benchmark/data/transforms/__init__.py b/maskrcnn_benchmark/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94ce850056fdd7ed45f416bc4ead90f3f7da0073 --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .transforms import Compose +from .transforms import Resize +from .transforms import RandomHorizontalFlip +from .transforms import ToTensor +from .transforms import Normalize + +from .build import build_transforms diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py new file mode 100644 index 0000000000000000000000000000000000000000..9f66c092e4ca0229b7bd5607c84e7be9ce52eb9f --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from . import transforms as T + + +def build_transforms(cfg, is_train=True): + if is_train: + if len(cfg.AUGMENT.MULT_MIN_SIZE_TRAIN)>0: + min_size = cfg.AUGMENT.MULT_MIN_SIZE_TRAIN + else: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + flip_horizontal_prob = cfg.AUGMENT.FLIP_PROB_TRAIN + flip_vertical_prob = cfg.AUGMENT.VERTICAL_FLIP_PROB_TRAIN + brightness = cfg.AUGMENT.BRIGHTNESS + contrast = cfg.AUGMENT.CONTRAST + saturation = cfg.AUGMENT.SATURATION + hue = cfg.AUGMENT.HUE + + crop_prob = cfg.AUGMENT.CROP_PROB + min_ious = cfg.AUGMENT.CROP_MIN_IOUS + min_crop_size = cfg.AUGMENT.CROP_MIN_SIZE + + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + flip_horizontal_prob = 0.0 + + fix_res = cfg.INPUT.FIX_RES + if cfg.INPUT.FORMAT is not '': + input_format = cfg.INPUT.FORMAT + elif cfg.INPUT.TO_BGR255: + input_format = 'bgr255' + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, format=input_format + ) + + transform = T.Compose( + [ + T.Resize(min_size, max_size, restrict=fix_res), + T.RandomHorizontalFlip(flip_horizontal_prob), + T.ToTensor(), + normalize_transform, + ] + ) + return transform diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3698aec52d9df8bcd9bb73cf2a80294c917898bb --- /dev/null +++ b/maskrcnn_benchmark/data/transforms/transforms.py @@ -0,0 +1,385 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import cv2 +import random +import numpy as np +import math +import torch +import torchvision +from torchvision.transforms import functional as F + +from maskrcnn_benchmark.structures.bounding_box import BoxList + +def matrix_iou(a, b, relative=False): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + if relative: + ious = area_i / (area_b[:, np.newaxis]+1e-12) + else: + ious = area_i / (area_a[:, np.newaxis] + area_b - area_i+1e-12) + return ious + + +class RACompose(object): + def __init__(self, pre_transforms, rand_transforms, post_transforms, concurrent=2): + self.preprocess = pre_transforms + self.transforms = post_transforms + self.rand_transforms = rand_transforms + self.concurrent = concurrent + + def __call__(self, image, target): + for t in self.preprocess: + image, target = t(image, target) + for t in random.choices(self.rand_transforms, k=self.concurrent): + image = np.array(image) + image, target = t(image, target) + for t in self.transforms: + image, target = t(image, target) + + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.preprocess: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\nRandom select {0} from: (".format(self.concurrent) + for t in self.rand_transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += ")\nThen, apply:" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target=None): + for t in self.transforms: + image, target = t(image, target) + if target is None: + return image + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class Resize(object): + def __init__(self, min_size, max_size, restrict=False): + if not isinstance(min_size, (list, tuple)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + self.restrict = restrict + + # modified from torchvision to add support for max size + def get_size(self, image_size): + w, h = image_size + size = random.choice(self.min_size) + max_size = self.max_size + if self.restrict: + return (size, max_size) + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def __call__(self, image, target): + if isinstance(image, np.ndarray): + image_size = self.get_size(image.shape[:2]) + image = cv2.resize(image, image_size) + new_size = image_size + else: + image = F.resize(image, self.get_size(image.size)) + new_size = image.size + if target is not None: + target = target.resize(new_size) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + if isinstance(image, np.ndarray): + image = np.fliplr(image) + else: + image = F.hflip(image) + if target is not None: + target = target.transpose(0) + return image, target + + +class RandomVerticalFlip(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + if isinstance(image, np.ndarray): + image = np.flipud(image) + else: + image = F.vflip(image) + target = target.transpose(1) + return image, target + +class ToTensor(object): + def __call__(self, image, target): + return F.to_tensor(image), target + + +class Normalize(object): + def __init__(self, mean, std, format='rgb'): + self.mean = mean + self.std = std + self.format = format.lower() + + def __call__(self, image, target): + if 'bgr' in self.format: + image = image[[2, 1, 0]] + if '255' in self.format: + image = image * 255 + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target + + +class ColorJitter(object): + def __init__(self, + brightness=0.0, + contrast=0.0, + saturation=0.0, + hue=0.0, + ): + self.color_jitter = torchvision.transforms.ColorJitter( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue,) + + def __call__(self, image, target): + image = self.color_jitter(image) + return image, target + + +class RandomCrop(object): + def __init__(self, prob=0.5, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3): + # 1: return ori img + self.prob = prob + self.sample_mode = (1, *min_ious, 0) + self.min_crop_size = min_crop_size + + def __call__(self, img, target): + if random.random() > self.prob: + return img, target + + h, w, c = img.shape + boxes = target.bbox.numpy() + labels = target.get_field('labels') + + while True: + mode = random.choice(self.sample_mode) + if mode == 1: + return img, target + + min_iou = mode + + new_w = random.uniform(self.min_crop_size * w, w) + new_h = random.uniform(self.min_crop_size * h, h) + + # h / w in [0.5, 2] + if new_h / new_w < 0.5 or new_h / new_w > 2: + continue + + left = random.uniform(0, w - new_w) + top = random.uniform(0, h - new_h) + + patch = np.array([left, top, left + new_w, top + new_h]) + overlaps = matrix_iou(patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) + if overlaps.min() < min_iou: + continue + + # center of boxes should inside the crop img + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * ( center[:, 1] < patch[3]) + if not mask.any(): + continue + + boxes = boxes[mask] + labels = labels[mask] + + # adjust boxes + img = img[int(patch[1]):int(patch[3]), int(patch[0]):int(patch[2])] + + boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) + boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) + boxes -= np.tile(patch[:2], 2) + + new_target = BoxList(boxes, (img.shape[1], img.shape[0]), mode='xyxy') + new_target.add_field('labels', labels) + return img, new_target + + +class RandomAffine(object): + def __init__(self, prob=0.5, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), + borderValue=(127.5, 127.5, 127.5)): + self.prob = prob + self.degrees = degrees + self.translate = translate + self.scale = scale + self.shear = shear + self.borderValue = borderValue + + def __call__(self, img, targets=None): + if random.random() > self.prob: + return img, targets + # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) + # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 + + border = 0 # width of added border (optional) + #height = max(img.shape[0], img.shape[1]) + border * 2 + height, width, _ = img.shape + bbox = targets.bbox + + # Rotation and Scale + R = np.eye(3) + a = random.random() * (self.degrees[1] - self.degrees[0]) + self.degrees[0] + # a += random.choice([-180, -90, 0, 90]) # 90deg rotations added to small rotations + s = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0] + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) + + # Translation + T = np.eye(3) + T[0, 2] = (random.random() * 2 - 1) * self.translate[0] * img.shape[0] + border # x translation (pixels) + T[1, 2] = (random.random() * 2 - 1) * self.translate[1] * img.shape[1] + border # y translation (pixels) + + # Shear + S = np.eye(3) + S[0, 1] = math.tan((random.random() * (self.shear[1] - self.shear[0]) + self.shear[0]) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan((random.random() * (self.shear[1] - self.shear[0]) + self.shear[0]) * math.pi / 180) # y shear (deg) + + M = S @ T @ R # Combined rotation matrix. ORDER IS IMPORTANT HERE!! + imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR, + borderValue=self.borderValue) # BGR order borderValue + + # Return warped points also + if targets: + n = bbox.shape[0] + points = bbox[:, 0:4] + area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1]) + + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = (xy @ M.T)[:, :2].reshape(n, 8) + + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + + # apply angle-based reduction + radians = a * math.pi / 180 + reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 + x = (xy[:, 2] + xy[:, 0]) / 2 + y = (xy[:, 3] + xy[:, 1]) / 2 + w = (xy[:, 2] - xy[:, 0]) * reduction + h = (xy[:, 3] - xy[:, 1]) * reduction + xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T + + # reject warped points outside of image + x1 = np.clip(xy[:,0], 0, width) + y1 = np.clip(xy[:,1], 0, height) + x2 = np.clip(xy[:,2], 0, width) + y2 = np.clip(xy[:,3], 0, height) + new_bbox = np.concatenate((x1, y1, x2, y2)).reshape(4, n).T + targets.bbox = torch.as_tensor(new_bbox, dtype=torch.float32) + + return imw, targets + + +class RandomErasing: + def __init__(self, prob=0.5, era_l=0.02, era_h=1/3, min_aspect=0.3, + mode='const', max_count=1, max_overlap=0.3, max_value=255): + self.prob = prob + self.era_l = era_l + self.era_h = era_h + self.min_aspect = min_aspect + self.min_count = 1 + self.max_count = max_count + self.max_overlap = max_overlap + self.max_value = max_value + self.mode = mode.lower() + assert self.mode in ['const', 'rand', 'pixel'], 'invalid erase mode: %s' % self.mode + + def _get_pixels(self, patch_size): + if self.mode == 'pixel': + return np.random.random(patch_size)*self.max_value + elif self.mode == 'rand': + return np.random.random((1, 1, patch_size[-1]))*self.max_value + else: + return np.zeros((1, 1, patch_size[-1])) + + def __call__(self, image, target): + if random.random() > self.prob: + return image, target + ih, iw, ic = image.shape + ia = ih * iw + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) + erase_boxes = [] + for _ in range(count): + for try_idx in range(10): + erase_area = random.uniform(self.era_l, self.era_h) * ia / count + aspect_ratio = math.exp(random.uniform(math.log(self.min_aspect), math.log(1/self.min_aspect))) + eh = int(round(math.sqrt(erase_area * aspect_ratio))) + ew = int(round(math.sqrt(erase_area / aspect_ratio))) + if eh < ih and ew < iw: + x = random.randint(0, iw - ew) + y = random.randint(0, ih - eh) + image[y:y+eh, x:x+ew, :] = self._get_pixels((eh, ew, ic)) + erase_boxes.append([x,y,x+ew,y+eh]) + break + + if target is not None and len(erase_boxes)>0: + boxes = target.bbox.numpy() + labels = target.get_field('labels') + overlap = matrix_iou(np.array(erase_boxes), boxes, relative=True) + mask = overlap.max(axis=0) 1: + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + + reduced_losses = {} + for k, v in zip(loss_names, all_losses): + if k not in reduced_losses: + reduced_losses[k] = v / len(all_loss_dict) + reduced_losses[k] += v / len(all_loss_dict) + + return reduced_losses + + +def do_train( + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, +): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training") + meters = MetricLogger(delimiter=" ") + max_iter = min(len(task_loader) for task_loader in data_loader) + start_iter = arguments["iteration"] + model.train() + start_training_time = time.time() + end = time.time() + for iteration, task_loader in enumerate(zip(*data_loader), start_iter): + data_time = time.time() - end + iteration = iteration + 1 + arguments["iteration"] = iteration + + all_task_loss_dict = [] + for task, (images, targets, _) in enumerate(task_loader, 1): + if all(len(target) < 1 for target in targets): + logger.warning('Sampled all negative batches, skip') + continue + + images = images.to(device) + targets = [target.to(device) for target in targets] + + loss_dict = model(images, targets, task) + all_task_loss_dict.append(loss_dict) + + losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(all_task_loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + scheduler.step() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if iteration % 20 == 0 or iteration == max_iter: + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + if iteration % checkpoint_period == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + if iteration == max_iter: + checkpointer.save("model_final", **arguments) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/maskrcnn_benchmark/engine/evolution.py b/maskrcnn_benchmark/engine/evolution.py new file mode 100644 index 0000000000000000000000000000000000000000..40b41c3e6550de19e6d06d722121ff59e205f6ce --- /dev/null +++ b/maskrcnn_benchmark/engine/evolution.py @@ -0,0 +1,357 @@ + +import time +import pickle +import logging +import os +import numpy as np +import torch +import torch.nn as nn + + +from collections import OrderedDict +from yaml import safe_dump +from yacs.config import load_cfg, CfgNode#, _to_dict +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus +from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name +from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather +from maskrcnn_benchmark.data.datasets.evaluation import evaluate +from maskrcnn_benchmark.utils.flops import profile + + +choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x)) + + +def gather_candidates(all_candidates): + all_candidates = all_gather(all_candidates) + all_candidates = [cand for candidates in all_candidates for cand in candidates] + return list(set(all_candidates)) + + +def gather_stats(all_candidates): + all_candidates = all_gather(all_candidates) + reduced_statcs = {} + for candidates in all_candidates: + reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists + return reduced_statcs + + +def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE): + model.eval() + results_dict = {} + cpu_device = torch.device("cpu") + for _, batch in enumerate(data_loader): + images, targets, image_ids = batch + with torch.no_grad(): + output = model(images.to(device), rngs=rngs) + output = [o.to(cpu_device) for o in output] + results_dict.update( + {img_id: result for img_id, result in zip(image_ids, output)} + ) + return results_dict + + +def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500): + for name, param in model.named_buffers(): + if 'running_mean' in name: + nn.init.constant_(param, 0) + if 'running_var' in name: + nn.init.constant_(param, 1) + + model.train() + for iteration, (images, targets, _) in enumerate(data_loader, 1): + images = images.to(device) + targets = [target.to(device) for target in targets] + with torch.no_grad(): + loss_dict = model(images, targets, rngs) + if iteration >= max_iter: + break + + return model + + +def inference( + model, + rngs, + data_loader, + iou_types=("bbox",), + box_only=False, + device="cuda", + expected_results=(), + expected_results_sigma_tol=4, + output_folder=None, +): + + # convert to a torch.device for efficiency + device = torch.device(device) + dataset = data_loader.dataset + predictions = compute_on_dataset(model, rngs, data_loader, device) + # wait for all processes to complete before measuring the time + synchronize() + + predictions = _accumulate_predictions_from_multiple_gpus(predictions) + if not is_main_process(): + return + + extra_args = dict( + box_only=box_only, + iou_types=iou_types, + expected_results=expected_results, + expected_results_sigma_tol=expected_results_sigma_tol, + ) + + return evaluate(dataset=dataset, + predictions=predictions, + output_folder=output_folder, + **extra_args) + + +def fitness(cfg, model, rngs, val_loaders): + iou_types = ("bbox",) + if cfg.MODEL.MASK_ON: + iou_types = iou_types + ("segm",) + for data_loader_val in val_loaders: + results = inference( + model, + rngs, + data_loader_val, + iou_types=iou_types, + box_only=False, + device=cfg.MODEL.DEVICE, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + ) + synchronize() + + return results + + +class EvolutionTrainer(object): + def __init__(self, cfg, model, flops_limit=None, is_distributed=True): + + self.log_dir = cfg.OUTPUT_DIR + self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth') + self.is_distributed = is_distributed + + self.states = model.module.mix_nums if is_distributed else model.mix_nums + self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict())) + self.flops_limit = flops_limit + self.model = model + + self.candidates = [] + self.vis_dict = {} + + self.max_epochs = cfg.SEARCH.MAX_EPOCH + self.select_num = cfg.SEARCH.SELECT_NUM + self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size() + self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size() + self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size() + self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size() + + self.keep_top_k = {self.select_num:[], 50:[]} + self.epoch=0 + self.cfg = cfg + + def save_checkpoint(self): + if not is_main_process(): + return + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + info = {} + info['candidates'] = self.candidates + info['vis_dict'] = self.vis_dict + info['keep_top_k'] = self.keep_top_k + info['epoch'] = self.epoch + torch.save(info, self.checkpoint_name) + print('Save checkpoint to', self.checkpoint_name) + + def load_checkpoint(self): + if not os.path.exists(self.checkpoint_name): + return False + info = torch.load(self.checkpoint_name) + self.candidates = info['candidates'] + self.vis_dict = info['vis_dict'] + self.keep_top_k = info['keep_top_k'] + self.epoch = info['epoch'] + print('Load checkpoint from', self.checkpoint_name) + return True + + def legal(self, cand): + assert isinstance(cand,tuple) and len(cand)==len(self.states) + if cand in self.vis_dict: + return False + + if self.flops_limit is not None: + net = self.model.module.backbone if self.is_distributed else self.model.backbone + inp = (1, 3, 224, 224) + flops, params = profile(net, inp, extra_args={'paths': list(cand)}) + flops = flops/1e6 + print('flops:',flops) + if flops>self.flops_limit: + return False + + return True + + def update_top_k(self, candidates, *, k, key, reverse=False): + assert k in self.keep_top_k + # print('select ......') + t = self.keep_top_k[k] + t += candidates + t.sort(key=key,reverse=reverse) + self.keep_top_k[k]=t[:k] + + def eval_candidates(self, train_loader, val_loader): + for cand in self.candidates: + t0 = time.time() + + # load back supernet state dict + self.model.load_state_dict(self.supernet_state_dict) + # bn_statistic + model = bn_statistic(self.model, list(cand), train_loader) + # fitness + evals = fitness(cfg, model, list(cand), val_loader) + + if is_main_process(): + acc = evals[0].results['bbox']['AP'] + self.vis_dict[cand] = acc + print('candiate ', cand) + print('time: {}s'.format(time.time() - t0)) + print('acc ', acc) + + def stack_random_cand(self, random_func, *, batchsize=10): + while True: + cands = [random_func() for _ in range(batchsize)] + for cand in cands: + yield cand + + def random_can(self, num): + # print('random select ........') + candidates = [] + cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states)) + while len(candidates)0: + cand = next(cand_iter) + if not self.legal(cand): + continue + res.append(cand) + #print('mutation {}/{}'.format(len(res),mutation_num)) + max_iters-=1 + + # print('mutation_num = {}'.format(len(res))) + return res + + def get_crossover(self, k, crossover_num): + assert k in self.keep_top_k + # print('crossover ......') + res = [] + iter = 0 + max_iters = 10 * crossover_num + + def random_func(): + p1=choice(self.keep_top_k[k]) + p2=choice(self.keep_top_k[k]) + return tuple(choice([i,j]) for i,j in zip(p1,p2)) + + cand_iter = self.stack_random_cand(random_func) + while len(res)0: + cand = next(cand_iter) + if not self.legal(cand): + continue + res.append(cand) + #print('crossover {}/{}'.format(len(res),crossover_num)) + max_iters-=1 + + # print('crossover_num = {}'.format(len(res))) + return res + + def train(self, train_loader, val_loader): + logger = logging.getLogger("maskrcnn_benchmark.evolution") + + if not self.load_checkpoint(): + self.candidates = gather_candidates(self.random_can(self.population_num)) + + while self.epoch self.confidence_threshold, + and returns the predictions in descending order of score + + Arguments: + predictions (BoxList): the result of the computation by the model. + It should contain the field `scores`. + + Returns: + prediction (BoxList): the detected objects. Additional information + of the detection properties can be found in the fields of + the BoxList via `prediction.fields()` + """ + + scores = predictions.get_field("scores") + labels = predictions.get_field("labels").tolist() + thresh = scores.clone() + for i,lb in enumerate(labels): + if isinstance(self.confidence_threshold, float): + thresh[i] = self.confidence_threshold + elif len(self.confidence_threshold)==1: + thresh[i] = self.confidence_threshold[0] + else: + thresh[i] = self.confidence_threshold[lb-1] + keep = torch.nonzero(scores > thresh).squeeze(1) + predictions = predictions[keep] + + if self.exclude_region: + exlude = BoxList(self.exclude_region, predictions.size) + iou = boxlist_iou(exlude, predictions) + keep = torch.nonzero(torch.sum(iou>0.5, dim=0)==0).squeeze(1) + if len(keep)>0: + predictions = predictions[keep] + + scores = predictions.get_field("scores") + _, idx = scores.sort(0, descending=True) + return predictions[idx] + + def compute_colors_for_labels(self, labels): + """ + Simple function that adds fixed colors depending on the class + """ + colors = (30*(labels[:, None] -1)+1)*self.palette + colors = (colors % 255).numpy().astype("uint8") + return colors + + def overlay_boxes(self, image, predictions): + """ + Adds the predicted boxes on top of the image + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `labels`. + """ + labels = predictions.get_field("labels") + boxes = predictions.bbox + + colors = self.compute_colors_for_labels(labels).tolist() + + for box, color in zip(boxes, colors): + box = box.to(torch.int64) + top_left, bottom_right = box[:2].tolist(), box[2:].tolist() + image = cv2.rectangle( + image, tuple(top_left), tuple(bottom_right), tuple(color), 2) + + return image + + def overlay_scores(self, image, predictions): + """ + Adds the predicted boxes on top of the image + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `labels`. + """ + scores = predictions.get_field("scores") + boxes = predictions.bbox + + for box, score in zip(boxes, scores): + box = box.to(torch.int64) + image = cv2.putText(image, '%.3f'%score, + (box[0], (box[1]+box[3])/2), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (255,255,255), 1) + + return image + + def overlay_cboxes(self, image, predictions): + """ + Adds the predicted boxes on top of the image + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `labels`. + """ + scores = predictions.get_field("scores") + boxes = predictions.bbox + for box, score in zip(boxes, scores): + box = box.to(torch.int64) + top_left, bottom_right = box[:2].tolist(), box[2:].tolist() + image = cv2.rectangle( + image, tuple(top_left), tuple(bottom_right), (255,0,0), 2) + image = cv2.putText(image, '%.3f'%score, + (box[0], (box[1]+box[3])/2), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, + (255,0,0), 1) + return image + + def overlay_centers(self, image, predictions): + """ + Adds the predicted boxes on top of the image + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `labels`. + """ + centers = predictions.get_field("centers") + + for cord in centers: + cord = cord.to(torch.int64) + image = cv2.circle(image, (cord[0].item(),cord[1].item()), + 2, (255,0,0), 20) + + return image + + def overlay_count(self, image, predictions): + """ + Adds the predicted boxes on top of the image + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `labels`. + """ + if isinstance(predictions, int): + count = predictions + else: + count = len(predictions) + image = cv2.putText(image, 'Count: %d'%count, (0,100), cv2.FONT_HERSHEY_SIMPLEX, 3, (255,0,0), 3) + + return image + + def overlay_mask(self, image, predictions): + """ + Adds the instances contours for each predicted object. + Each label has a different color. + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `mask` and `labels`. + """ + masks = predictions.get_field("mask").numpy() + labels = predictions.get_field("labels") + + colors = self.compute_colors_for_labels(labels).tolist() + + for mask, color in zip(masks, colors): + thresh = mask[0, :, :, None].astype(np.uint8) + contours, hierarchy = cv2_util.findContours( + thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + image = cv2.drawContours(image, contours, -1, color, 3) + + composite = image + + return composite + + def overlay_keypoints(self, image, predictions): + keypoints = predictions.get_field("keypoints") + kps = keypoints.keypoints + scores = keypoints.get_field("logits") + kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy() + for region in kps: + image = vis_keypoints(image, region.transpose((1, 0)), + names=keypoints.NAMES, connections=keypoints.CONNECTIONS) + return image + + def create_mask_montage(self, image, predictions): + """ + Create a montage showing the probability heatmaps for each one one of the + detected objects + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `mask`. + """ + masks = predictions.get_field("mask") + masks_per_dim = self.masks_per_dim + masks = L.interpolate( + masks.float(), scale_factor=1 / masks_per_dim + ).byte() + height, width = masks.shape[-2:] + max_masks = masks_per_dim ** 2 + masks = masks[:max_masks] + # handle case where we have less detections than max_masks + if len(masks) < max_masks: + masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8) + masks_padded[: len(masks)] = masks + masks = masks_padded + masks = masks.reshape(masks_per_dim, masks_per_dim, height, width) + result = torch.zeros( + (masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8 + ) + for y in range(masks_per_dim): + start_y = y * height + end_y = (y + 1) * height + for x in range(masks_per_dim): + start_x = x * width + end_x = (x + 1) * width + result[start_y:end_y, start_x:end_x] = masks[y, x] + return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET) + + def overlay_class_names(self, image, predictions, names=None): + """ + Adds detected class names and scores in the positions defined by the + top-left corner of the predicted bounding box + + Arguments: + image (np.ndarray): an image as returned by OpenCV + predictions (BoxList): the result of the computation by the model. + It should contain the field `scores` and `labels`. + """ + scores = predictions.get_field("scores").tolist() + labels = predictions.get_field("labels").tolist() + if names: + labels = [names[i-1] for i in labels] + else: + labels = [self.CATEGORIES[i] for i in labels] + boxes = predictions.bbox + + template = "{}: {:.2f}" + for box, score, label in zip(boxes, scores, labels): + x, y = box[:2] + s = template.format(label, score) + cv2.putText( + image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1 + ) + + return image + +def vis_keypoints(img, kps, kp_thresh=0, alpha=0.7, names=None, connections=None): + """Visualizes keypoints (adapted from vis_one_image). + kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob). + """ + + dataset_keypoints = names + kp_lines = connections + + # simple rainbow color map implementation + blue_red_ratio = 0.8 + gx = lambda x: (6-2*blue_red_ratio)*x + blue_red_ratio + colors = [[256*max(0, (3-abs(gx(i)-4)-abs(gx(i)-5))/2), + 256*max(0, (3-abs(gx(i)-2)-abs(gx(i)-4))/2), + 256*max(0, (3-abs(gx(i)-1)-abs(gx(i)-2))/2),] for i in np.linspace(0, 1, len(kp_lines) + 2)] + + # Perform the drawing on a copy of the image, to allow for blending. + kp_mask = np.copy(img) + + # Draw mid shoulder / mid hip first for better visualization. + mid_shoulder = ( + kps[:2, dataset_keypoints.index('right_shoulder')] + + kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0 + sc_mid_shoulder = np.minimum( + kps[2, dataset_keypoints.index('right_shoulder')], + kps[2, dataset_keypoints.index('left_shoulder')]) + nose_idx = dataset_keypoints.index('nose') + if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh: + cv2.line( + kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]), + color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA) + + if 'right_hip' in names and 'left_hip' in names: + mid_hip = ( + kps[:2, dataset_keypoints.index('right_hip')] + + kps[:2, dataset_keypoints.index('left_hip')]) / 2.0 + sc_mid_hip = np.minimum( + kps[2, dataset_keypoints.index('right_hip')], + kps[2, dataset_keypoints.index('left_hip')]) + if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: + cv2.line( + kp_mask, tuple(mid_shoulder), tuple(mid_hip), + color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA) + + # Draw the keypoints. + for l in range(len(kp_lines)): + i1 = kp_lines[l][0] + i2 = kp_lines[l][1] + p1 = kps[0, i1], kps[1, i1] + p2 = kps[0, i2], kps[1, i2] + if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: + cv2.line( + kp_mask, p1, p2, + color=colors[l], thickness=2, lineType=cv2.LINE_AA) + if kps[2, i1] > kp_thresh: + cv2.circle( + kp_mask, p1, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + if kps[2, i2] > kp_thresh: + cv2.circle( + kp_mask, p2, + radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) + + # Blend the keypoints. + return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) \ No newline at end of file diff --git a/maskrcnn_benchmark/engine/predictor_glip.py b/maskrcnn_benchmark/engine/predictor_glip.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdfcc24b6abf711a77217d03c83bad7d6c6f442 --- /dev/null +++ b/maskrcnn_benchmark/engine/predictor_glip.py @@ -0,0 +1,471 @@ +import cv2 +import torch +import re +import numpy as np +from typing import List, Union +import nltk +import inflect +from transformers import AutoTokenizer +from torchvision import transforms as T +import pdb +from maskrcnn_benchmark.modeling.detector import build_detection_model +from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer +from maskrcnn_benchmark.structures.image_list import to_image_list +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark import layers as L +from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker +from maskrcnn_benchmark.utils import cv2_util + +engine = inflect.engine() +nltk.download('punkt') +nltk.download('averaged_perceptron_tagger') + +import timeit + + +class GLIPDemo(object): + def __init__(self, + cfg, + confidence_threshold=0.7, + min_image_size=None, + show_mask_heatmaps=False, + masks_per_dim=5, + load_model=True + ): + self.cfg = cfg.clone() + if load_model: + self.model = build_detection_model(cfg) + self.model.eval() + self.device = torch.device(cfg.MODEL.DEVICE) + self.model.to(self.device) + self.min_image_size = min_image_size + self.show_mask_heatmaps = show_mask_heatmaps + self.masks_per_dim = masks_per_dim + + save_dir = cfg.OUTPUT_DIR + if load_model: + checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir) + _ = checkpointer.load(cfg.MODEL.WEIGHT) + + self.transforms = self.build_transform() + + # used to make colors for each tokens + mask_threshold = -1 if show_mask_heatmaps else 0.5 + self.masker = Masker(threshold=mask_threshold, padding=1) + self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + self.cpu_device = torch.device("cpu") + self.confidence_threshold = confidence_threshold + + self.tokenizer = self.build_tokenizer() + + def build_transform(self): + """ + Creates a basic transformation that was used to train the models + """ + cfg = self.cfg + + # we are loading images with OpenCV, so we don't need to convert them + # to BGR, they are already! So all we need to do is to normalize + # by 255 if we want to convert to BGR255 format, or flip the channels + # if we want it to be in RGB in [0-1] range. + if cfg.INPUT.TO_BGR255: + to_bgr_transform = T.Lambda(lambda x: x * 255) + else: + to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) + + normalize_transform = T.Normalize( + mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD + ) + + transform = T.Compose( + [ + T.ToPILImage(), + T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x, + T.ToTensor(), + to_bgr_transform, + normalize_transform, + ] + ) + return transform + + def build_tokenizer(self): + cfg = self.cfg + tokenizer = None + if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": + from transformers import CLIPTokenizerFast + if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True, mask_token='ðŁĴij') + else: + tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True) + return tokenizer + + def run_ner(self, caption): + noun_phrases = find_noun_phrases(caption) + noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] + noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] + relevant_phrases = noun_phrases + labels = noun_phrases + self.entities = labels + + tokens_positive = [] + + for entity, label in zip(relevant_phrases, labels): + try: + # search all occurrences and mark them as different entities + for m in re.finditer(entity, caption.lower()): + tokens_positive.append([[m.start(), m.end()]]) + except: + print("noun entities:", noun_phrases) + print("entity:", entity) + print("caption:", caption.lower()) + + return tokens_positive + + def inference(self, original_image, original_caption): + predictions = self.compute_prediction(original_image, original_caption) + top_predictions = self._post_process_fixed_thresh(predictions) + return top_predictions + + def run_on_web_image(self, + original_image, + original_caption, + thresh=0.5, + custom_entity=None, + alpha=0.0): + predictions = self.compute_prediction(original_image, original_caption, custom_entity) + top_predictions = self._post_process(predictions, thresh) + + result = original_image.copy() + if self.show_mask_heatmaps: + return self.create_mask_montage(result, top_predictions) + result = self.overlay_boxes(result, top_predictions) + result = self.overlay_entity_names(result, top_predictions) + if self.cfg.MODEL.MASK_ON: + result = self.overlay_mask(result, top_predictions) + return result, top_predictions + + def visualize_with_predictions(self, + original_image, + predictions, + thresh=0.5, + alpha=0.0, + box_pixel=3, + text_size=1, + text_pixel=2, + text_offset=10, + text_offset_original=4, + color=255): + self.color = color + height, width = original_image.shape[:-1] + predictions = predictions.resize((width, height)) + top_predictions = self._post_process(predictions, thresh) + + result = original_image.copy() + if self.show_mask_heatmaps: + return self.create_mask_montage(result, top_predictions) + result = self.overlay_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel) + result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel, + text_offset=text_offset, text_offset_original=text_offset_original) + if self.cfg.MODEL.MASK_ON: + result = self.overlay_mask(result, top_predictions) + return result, top_predictions + + def compute_prediction(self, original_image, original_caption, custom_entity=None): + # image + image = self.transforms(original_image) + image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) + image_list = image_list.to(self.device) + # caption + if isinstance(original_caption, list): + # we directly provided a list of category names + caption_string = "" + tokens_positive = [] + seperation_tokens = " . " + for word in original_caption: + tokens_positive.append([len(caption_string), len(caption_string) + len(word)]) + caption_string += word + caption_string += seperation_tokens + + tokenized = self.tokenizer([caption_string], return_tensors="pt") + tokens_positive = [tokens_positive] + + original_caption = caption_string + print(tokens_positive) + else: + tokenized = self.tokenizer([original_caption], return_tensors="pt") + if custom_entity is None: + tokens_positive = self.run_ner(original_caption) + print(tokens_positive) + # process positive map + positive_map = create_positive_map(tokenized, tokens_positive) + + if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": + plus = 1 + else: + plus = 0 + + positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus) + self.plus = plus + self.positive_map_label_to_token = positive_map_label_to_token + tic = timeit.time.perf_counter() + + # compute predictions + with torch.no_grad(): + predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token) + predictions = [o.to(self.cpu_device) for o in predictions] + print("inference time per image: {}".format(timeit.time.perf_counter() - tic)) + + # always single image is passed at a time + prediction = predictions[0] + + # reshape prediction (a BoxList) into the original image size + height, width = original_image.shape[:-1] + prediction = prediction.resize((width, height)) + + if prediction.has_field("mask"): + # if we have masks, paste the masks in the right position + # in the image, as defined by the bounding boxes + masks = prediction.get_field("mask") + # always single image is passed at a time + masks = self.masker([masks], [prediction])[0] + prediction.add_field("mask", masks) + + return prediction + + def _post_process_fixed_thresh(self, predictions): + scores = predictions.get_field("scores") + labels = predictions.get_field("labels").tolist() + thresh = scores.clone() + for i, lb in enumerate(labels): + if isinstance(self.confidence_threshold, float): + thresh[i] = self.confidence_threshold + elif len(self.confidence_threshold) == 1: + thresh[i] = self.confidence_threshold[0] + else: + thresh[i] = self.confidence_threshold[lb - 1] + keep = torch.nonzero(scores > thresh).squeeze(1) + predictions = predictions[keep] + + scores = predictions.get_field("scores") + _, idx = scores.sort(0, descending=True) + return predictions[idx] + + def _post_process(self, predictions, threshold=0.5): + scores = predictions.get_field("scores") + labels = predictions.get_field("labels").tolist() + thresh = scores.clone() + for i, lb in enumerate(labels): + if isinstance(self.confidence_threshold, float): + thresh[i] = threshold + elif len(self.confidence_threshold) == 1: + thresh[i] = threshold + else: + thresh[i] = self.confidence_threshold[lb - 1] + keep = torch.nonzero(scores > thresh).squeeze(1) + predictions = predictions[keep] + + scores = predictions.get_field("scores") + _, idx = scores.sort(0, descending=True) + return predictions[idx] + + def compute_colors_for_labels(self, labels): + """ + Simple function that adds fixed colors depending on the class + """ + colors = (300 * (labels[:, None] - 1) + 1) * self.palette + colors = (colors % 255).numpy().astype("uint8") + try: + colors = (colors * 0 + self.color).astype("uint8") + except: + pass + return colors + + def overlay_boxes(self, image, predictions, alpha=0.5, box_pixel=3): + labels = predictions.get_field("labels") + boxes = predictions.bbox + + colors = self.compute_colors_for_labels(labels).tolist() + new_image = image.copy() + for box, color in zip(boxes, colors): + box = box.to(torch.int64) + top_left, bottom_right = box[:2].tolist(), box[2:].tolist() + new_image = cv2.rectangle( + new_image, tuple(top_left), tuple(bottom_right), tuple(color), box_pixel) + + # Following line overlays transparent rectangle over the image + image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0) + + return image + + def overlay_scores(self, image, predictions): + scores = predictions.get_field("scores") + boxes = predictions.bbox + + for box, score in zip(boxes, scores): + box = box.to(torch.int64) + image = cv2.putText(image, '%.3f' % score, + (int(box[0]), int((box[1] + box[3]) / 2)), + cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA) + + return image + + def overlay_entity_names(self, image, predictions, names=None, text_size=0.7, text_pixel=2, text_offset=10, + text_offset_original=4): + scores = predictions.get_field("scores").tolist() + labels = predictions.get_field("labels").tolist() + new_labels = [] + if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": + plus = 1 + else: + plus = 0 + self.plus = plus + if self.entities and self.plus: + for i in labels: + if i <= len(self.entities): + new_labels.append(self.entities[i - self.plus]) + else: + new_labels.append('object') + # labels = [self.entities[i - self.plus] for i in labels ] + else: + new_labels = ['object' for i in labels] + boxes = predictions.bbox + + template = "{}:{:.2f}" + previous_locations = [] + for box, score, label in zip(boxes, scores, new_labels): + x, y = box[:2] + s = template.format(label, score).replace("_", " ").replace("(", "").replace(")", "") + for x_prev, y_prev in previous_locations: + if abs(x - x_prev) < abs(text_offset) and abs(y - y_prev) < abs(text_offset): + y -= text_offset + + cv2.putText( + image, s, (int(x), int(y) - text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size, + (255, 255, 255), text_pixel, cv2.LINE_AA + ) + previous_locations.append((int(x), int(y))) + + return image + + def overlay_mask(self, image, predictions): + masks = predictions.get_field("mask").numpy() + labels = predictions.get_field("labels") + + colors = self.compute_colors_for_labels(labels).tolist() + + # import pdb + # pdb.set_trace() + # masks = masks > 0.1 + + for mask, color in zip(masks, colors): + thresh = mask[0, :, :, None].astype(np.uint8) + contours, hierarchy = cv2_util.findContours( + thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + image = cv2.drawContours(image, contours, -1, color, 2) + + composite = image + + return composite + + def create_mask_montage(self, image, predictions): + masks = predictions.get_field("mask") + masks_per_dim = self.masks_per_dim + masks = L.interpolate( + masks.float(), scale_factor=1 / masks_per_dim + ).byte() + height, width = masks.shape[-2:] + max_masks = masks_per_dim ** 2 + masks = masks[:max_masks] + # handle case where we have less detections than max_masks + if len(masks) < max_masks: + masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8) + masks_padded[: len(masks)] = masks + masks = masks_padded + masks = masks.reshape(masks_per_dim, masks_per_dim, height, width) + result = torch.zeros( + (masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8 + ) + for y in range(masks_per_dim): + start_y = y * height + end_y = (y + 1) * height + for x in range(masks_per_dim): + start_x = x * width + end_x = (x + 1) * width + result[start_y:end_y, start_x:end_x] = masks[y, x] + + return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None + + +def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0): + positive_map_label_to_token = {} + for i in range(len(positive_map)): + positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist() + return positive_map_label_to_token + + +def create_positive_map(tokenized, tokens_positive): + """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" + positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) + + for j, tok_list in enumerate(tokens_positive): + for (beg, end) in tok_list: + try: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + except Exception as e: + print("beg:", beg, "end:", end) + print("token_positive:", tokens_positive) + # print("beg_pos:", beg_pos, "end_pos:", end_pos) + raise e + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + positive_map[j, beg_pos: end_pos + 1].fill_(1) + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def find_noun_phrases(caption: str) -> List[str]: + caption = caption.lower() + tokens = nltk.word_tokenize(caption) + pos_tags = nltk.pos_tag(tokens) + + grammar = "NP: {
?*+}" + cp = nltk.RegexpParser(grammar) + result = cp.parse(pos_tags) + + noun_phrases = list() + for subtree in result.subtrees(): + if subtree.label() == 'NP': + noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) + + return noun_phrases + + +def remove_punctuation(text: str) -> str: + punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', + '\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' + ] + for p in punct: + text = text.replace(p, '') + return text.strip() diff --git a/maskrcnn_benchmark/engine/singlepath_trainer.py b/maskrcnn_benchmark/engine/singlepath_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..303f151d11e5d2a4ab7c849cfb213e175f3469c4 --- /dev/null +++ b/maskrcnn_benchmark/engine/singlepath_trainer.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import datetime +import logging +import time +import random +import torch +import torch.distributed as dist +from maskrcnn_benchmark.utils.comm import get_world_size, synchronize, broadcast_data +from maskrcnn_benchmark.utils.metric_logger import MetricLogger +from maskrcnn_benchmark.utils.ema import ModelEma + + +def reduce_loss_dict(loss_dict): + """ + Reduce the loss dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + loss_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return loss_dict + with torch.no_grad(): + loss_names = [] + all_losses = [] + for k in sorted(loss_dict.keys()): + loss_names.append(k) + all_losses.append(loss_dict[k]) + all_losses = torch.stack(all_losses, dim=0) + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} + return reduced_losses + + +def do_train( + cfg, + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, + rngs=None +): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training") + meters = MetricLogger(delimiter=" ") + max_iter = len(data_loader) + start_iter = arguments["iteration"] + model.train() + model_ema = None + if cfg.SOLVER.MODEL_EMA>0: + model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) + start_training_time = time.time() + end = time.time() + + for iteration, (images, targets, _) in enumerate(data_loader, start_iter): + + if any(len(target) < 1 for target in targets): + logger.error("Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) + continue + data_time = time.time() - end + iteration = iteration + 1 + arguments["iteration"] = iteration + + images = images.to(device) + targets = [target.to(device) for target in targets] + + # synchronize rngs + if rngs is None: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + mix_nums = model.module.mix_nums + else: + mix_nums = model.mix_nums + rngs = [random.randint(0, mix-1) for mix in mix_nums] + rngs = broadcast_data(rngs) + + for param in model.parameters(): + param.requires_grad = False + loss_dict = model(images, targets, rngs) + + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + scheduler.step() + + if model_ema is not None: + model_ema.update(model) + arguments["model_ema"] = model_ema.state_dict() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if iteration % 20 == 0 or iteration == max_iter: + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + if iteration % checkpoint_period == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + if iteration == max_iter: + if model_ema is not None: + model.load_state_dict(model_ema.state_dict()) + checkpointer.save("model_final", **arguments) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/maskrcnn_benchmark/engine/stage_trainer.py b/maskrcnn_benchmark/engine/stage_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4c925d48bb8fae7ac76afd18bc5ea23a9491827c --- /dev/null +++ b/maskrcnn_benchmark/engine/stage_trainer.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import datetime +import logging +import time + +import torch +import torch.distributed as dist + +from maskrcnn_benchmark.utils.comm import get_world_size +from maskrcnn_benchmark.utils.metric_logger import MetricLogger + + +def reduce_loss_dict(all_loss_dict): + """ + Reduce the loss dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + loss_dict, after reduction. + """ + world_size = get_world_size() + with torch.no_grad(): + loss_names = [] + all_losses = [] + for loss_dict in all_loss_dict: + for k in sorted(loss_dict.keys()): + loss_names.append(k) + all_losses.append(loss_dict[k]) + all_losses = torch.stack(all_losses, dim=0) + if world_size > 1: + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + + reduced_losses = {} + for k, v in zip(loss_names, all_losses): + if k not in reduced_losses: + reduced_losses[k] = v / len(all_loss_dict) + reduced_losses[k] += v / len(all_loss_dict) + + return reduced_losses + + +def do_train( + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, +): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training") + meters = MetricLogger(delimiter=" ") + epoch_per_stage = arguments['epoch_per_stage'] + max_iter = sum(len(stage_loader) * epoch_per_stage[si] for si, stage_loader in enumerate(data_loader)) + max_iter += epoch_per_stage[-1] * min(len(stage_loader) for stage_loader in data_loader) + model.train() + start_training_time = time.time() + end = time.time() + + for stage_i, stage_loader in enumerate(data_loader): + for ep in range(epoch_per_stage[stage_i]): + start_iter = arguments["iteration"] + for iteration, (images, targets, _) in enumerate(stage_loader, start_iter): + data_time = time.time() - end + iteration = iteration + 1 + arguments["iteration"] = iteration + + scheduler[stage_i].step() + + all_stage_loss_dict = [] + images = images.to(device) + targets = [target.to(device) for target in targets] + loss_dict = model(images, targets, stage_i) + all_stage_loss_dict.append(loss_dict) + + losses = sum(loss for loss_dict in all_stage_loss_dict for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(all_stage_loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if iteration % 20 == 0 or iteration == max_iter: + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + if iteration % checkpoint_period == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + if iteration == max_iter: + checkpointer.save("model_final", **arguments) + + for ep in range(epoch_per_stage[-1]): + start_iter = arguments["iteration"] + for iteration, stage_loader in enumerate(zip(*data_loader), start_iter): + data_time = time.time() - end + iteration = iteration + 1 + arguments["iteration"] = iteration + + scheduler[-1].step() + + all_task_loss_dict = [] + for stage_i, (images, targets, _) in enumerate(stage_loader): + images = images.to(device) + targets = [target.to(device) for target in targets] + loss_dict = model(images, targets, stage_i) + all_task_loss_dict.append(loss_dict) + + losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(all_task_loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if iteration % 20 == 0 or iteration == max_iter: + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + if iteration % checkpoint_period == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + if iteration == max_iter: + checkpointer.save("model_final", **arguments) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2939a0443ba5612303087d750d49d724c890855c --- /dev/null +++ b/maskrcnn_benchmark/engine/trainer.py @@ -0,0 +1,360 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import datetime +import logging +import sys +import os +import math +import time + +import torch +import torch.distributed as dist + +from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank +from maskrcnn_benchmark.utils.metric_logger import MetricLogger +from maskrcnn_benchmark.utils.ema import ModelEma +from maskrcnn_benchmark.utils.amp import autocast, GradScaler +from maskrcnn_benchmark.data.datasets.evaluation import evaluate +from .inference import inference +import pdb + +def reduce_loss_dict(loss_dict): + """ + Reduce the loss dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + loss_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return loss_dict + with torch.no_grad(): + loss_names = [] + all_losses = [] + for k in sorted(loss_dict.keys()): + loss_names.append(k) + all_losses.append(loss_dict[k]) + all_losses = torch.stack(all_losses, dim=0) + dist.reduce(all_losses, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + all_losses /= world_size + reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} + return reduced_losses + + +def do_train( + cfg, + model, + data_loader, + optimizer, + scheduler, + checkpointer, + device, + checkpoint_period, + arguments, + val_data_loader=None, + meters=None, + zero_shot=False +): + logger = logging.getLogger("maskrcnn_benchmark.trainer") + logger.info("Start training") + # meters = MetricLogger(delimiter=" ") + max_iter = len(data_loader) + start_iter = arguments["iteration"] + model.train() + model_ema = None + if cfg.SOLVER.MODEL_EMA > 0: + model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) + start_training_time = time.time() + end = time.time() + + if cfg.SOLVER.USE_AMP: + scaler = GradScaler() + + global_rank = get_rank() + + if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1: + checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH + + if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1: + print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH ) + + if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: + patience_counter = 0 + previous_best = 0.0 + + # Adapt the weight decay + if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): + milestone_target = 0 + for i, milstone in enumerate(list(scheduler.milestones)): + if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: + milestone_target = i+1 + for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): + nnegative = sum(len(target) < 1 for target in targets) + nsample = len(targets) + if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH: + logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'. + format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH)) + continue + + data_time = time.time() - end + iteration = iteration + 1 + arguments["iteration"] = iteration + + images = images.to(device) + captions = None + try: + targets = [target.to(device) for target in targets] + captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] + except: + pass + # Freeze language backbone + if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: + if hasattr(model, "module"): + model.module.language_backbone.eval() + else: + model.language_backbone.eval() + + if cfg.SOLVER.USE_AMP: + with autocast(): + if len(captions) > 0: + loss_dict = model(images, targets, captions, positive_map, greenlight_map = greenlight_map) + else: + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + + # save checkpoints for further debug if nan happens + # loss_value = losses.item() + # if not math.isfinite(loss_value): + # logging.error(f'=> loss is {loss_value}, stopping training') + # logging.error("Losses are : {}".format(loss_dict)) + # time_str = time.strftime('%Y-%m-%d-%H-%M') + # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') + # logging.info(f'=> save error state to {fname}') + # dict_to_save = { + # 'x': images, + # 'y': targets, + # 'loss': losses, + # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + # } + # if len(captions) > 0: + # dict_to_save['captions'] = captions + # dict_to_save['positive_map'] = positive_map + # torch.save( + # dict_to_save, + # fname + # ) + + + if torch.isnan(losses) or torch.isinf(losses): + logging.error("NaN encountered, ignoring") + losses[losses != losses] = 0 + optimizer.zero_grad() + scaler.scale(losses).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + else: + if len(captions) > 0: + loss_dict = model(images, targets, captions, positive_map) + else: + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + + # loss_value = losses.item() + # if not math.isfinite(loss_value): + # logging.error(f'=> loss is {loss_value}, stopping training') + # time_str = time.strftime('%Y-%m-%d-%H-%M') + # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') + # logging.info(f'=> save error state to {fname}') + # dict_to_save = { + # 'x': images, + # 'y': targets, + # 'loss': losses, + # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + # } + # if len(captions) > 0: + # dict_to_save['captions'] = captions + # dict_to_save['positive_map'] = positive_map + # torch.save( + # dict_to_save, + # fname + # ) + + + if torch.isnan(losses) or torch.isinf(losses): + losses[losses != losses] = 0 + optimizer.zero_grad() + losses.backward() + optimizer.step() + scheduler.step() + + # Adapt the weight decay: only support multiStepLR + if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): + if milestone_target < len(scheduler.milestones): + next_milestone = list(scheduler.milestones)[milestone_target] + else: + next_milestone = float('inf') + if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: + gamma = scheduler.gamma + logger.info("Drop the weight decay by {}!".format(gamma)) + for param in optimizer.param_groups: + if 'weight_decay' in param: + param['weight_decay'] *= gamma + # move the target forward + milestone_target += 1 + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = reduce_loss_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + meters.update(loss=losses_reduced, **loss_dict_reduced) + if model_ema is not None: + model_ema.update(model) + arguments["model_ema"] = model_ema.state_dict() + + batch_time = time.time() - end + end = time.time() + meters.update(time=batch_time, data=data_time) + eta_seconds = meters.time.global_avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if iteration % 20 == 0 or iteration == max_iter: + # if iteration % 1 == 0 or iteration == max_iter: + #logger.info( + if global_rank <= 0: + print( + meters.delimiter.join( + [ + "eta: {eta}", + "iter: {iter}", + "{meters}", + "lr: {lr:.6f}", + "wd: {wd:.6f}", + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + iter=iteration, + meters=str(meters), + lr=optimizer.param_groups[0]["lr"], + wd=optimizer.param_groups[0]["weight_decay"], + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): + if is_main_process(): + print("Evaluating") + eval_result = 0.0 + model.eval() + if cfg.SOLVER.TEST_WITH_INFERENCE: + with torch.no_grad(): + try: + _model = model.module + except: + _model = model + _result = inference( + model = _model, + data_loader = val_data_loader, + dataset_name="val", + device=device, + expected_results=cfg.TEST.EXPECTED_RESULTS, + expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, + output_folder=None, + cfg=cfg, + verbose=False + ) + if is_main_process(): + eval_result = _result[0].results['bbox']['AP'] + else: + results_dict = {} + cpu_device = torch.device("cpu") + for i, batch in enumerate(val_data_loader): + images, targets, image_ids, positive_map, *_ = batch + with torch.no_grad(): + images = images.to(device) + if positive_map is None: + output = model(images) + else: + captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] + output = model(images, captions, positive_map) + output = [o.to(cpu_device) for o in output] + results_dict.update( + {img_id: result for img_id, result in zip(image_ids, output)} + ) + all_predictions = all_gather(results_dict) + if is_main_process(): + predictions = {} + for p in all_predictions: + predictions.update(p) + predictions = [predictions[i] for i in list(sorted(predictions.keys()))] + eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, + box_only=cfg.DATASETS.CLASS_AGNOSTIC) + if cfg.DATASETS.CLASS_AGNOSTIC: + eval_result = eval_result.results['box_proposal']['AR@100'] + else: + eval_result = eval_result.results['bbox']['AP'] + model.train() + + if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR: + model_ema.ema.eval() + results_dict = {} + cpu_device = torch.device("cpu") + for i, batch in enumerate(val_data_loader): + images, targets, image_ids, positive_map, positive_map_eval = batch + with torch.no_grad(): + images = images.to(device) + if positive_map is None: + output = model_ema.ema(images) + else: + captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] + output = model_ema.ema(images, captions, positive_map) + output = [o.to(cpu_device) for o in output] + results_dict.update( + {img_id: result for img_id, result in zip(image_ids, output)} + ) + all_predictions = all_gather(results_dict) + if is_main_process(): + predictions = {} + for p in all_predictions: + predictions.update(p) + predictions = [predictions[i] for i in list(sorted(predictions.keys()))] + eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, + box_only=cfg.DATASETS.CLASS_AGNOSTIC) + if cfg.DATASETS.CLASS_AGNOSTIC: + eval_result = eval_result.results['box_proposal']['AR@100'] + else: + eval_result = eval_result.results['bbox']['AP'] + + arguments.update(eval_result=eval_result) + + if cfg.SOLVER.USE_AUTOSTEP: + eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0] + # print("Rank {} eval result gathered".format(cfg.local_rank), eval_result) + scheduler.step(eval_result) + + if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: + if eval_result < previous_best: + patience_counter += 1 + else: + patience_counter = 0 + previous_best = eval_result + checkpointer.save("model_best", **arguments) + print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result) + if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE: + if is_main_process(): + print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best)) + break + + if iteration % checkpoint_period == 0: + checkpointer.save("model_{:07d}".format(iteration), **arguments) + if iteration == max_iter: + checkpointer.save("model_final", **arguments) + break + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info( + "Total training time: {} ({:.4f} s / it)".format( + total_time_str, total_training_time / (max_iter) + ) + ) diff --git a/maskrcnn_benchmark/layers/__init__.py b/maskrcnn_benchmark/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d1db2e7b5328cd8231d8045e4bea0fc88dd934 --- /dev/null +++ b/maskrcnn_benchmark/layers/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .batch_norm import FrozenBatchNorm2d, NaiveSyncBatchNorm2d +from .misc import Conv2d, _NewEmptyTensorOp +from .misc import ConvTranspose2d +from .misc import DFConv2d +from .misc import interpolate +from .misc import Scale +from .nms import nms +from .nms import ml_nms +from .nms import soft_nms +from .roi_align import ROIAlign +from .roi_align import roi_align +from .roi_align import ROIAlignV2 +from .roi_pool import ROIPool +from .roi_pool import roi_pool +from .smooth_l1_loss import smooth_l1_loss +from .sigmoid_focal_loss import SigmoidFocalLoss, TokenSigmoidFocalLoss +from .iou_loss import IOULoss, IOUWHLoss +from .deform_conv import DeformConv, ModulatedDeformConv +from .dropblock import DropBlock2D, DropBlock3D +from .evonorm import EvoNorm2d +from .dyrelu import DYReLU, swish +from .se import SELayer, SEBlock +from .dyhead import DyHead +from .set_loss import HungarianMatcher, SetCriterion + +__all__ = ["nms", "ml_nms", "soft_nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", + "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "swish", + "FrozenBatchNorm2d", "NaiveSyncBatchNorm2d", "SigmoidFocalLoss", "TokenSigmoidFocalLoss", "IOULoss", + "IOUWHLoss", "Scale", "DeformConv", "ModulatedDeformConv", "DyHead", + "DropBlock2D", "DropBlock3D", "EvoNorm2d", "DYReLU", "SELayer", "SEBlock", + "HungarianMatcher", "SetCriterion", "ROIAlignV2", "_NewEmptyTensorOp"] diff --git a/maskrcnn_benchmark/layers/batch_norm.py b/maskrcnn_benchmark/layers/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2a83dadfa3aa52b3f854017a9fd71655c2a7c3 --- /dev/null +++ b/maskrcnn_benchmark/layers/batch_norm.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + +import torch.distributed as dist +import maskrcnn_benchmark.utils.comm as comm +from torch.autograd.function import Function + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters + are fixed + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def forward(self, x): + scale = self.weight * self.running_var.rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias + + +class AllReduce(Function): + @staticmethod + def forward(ctx, input): + input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] + # Use allgather instead of allreduce since I don't trust in-place operations .. + dist.all_gather(input_list, input, async_op=False) + inputs = torch.stack(input_list, dim=0) + return torch.sum(inputs, dim=0) + + @staticmethod + def backward(ctx, grad_output): + dist.all_reduce(grad_output, async_op=False) + return grad_output + + +class NaiveSyncBatchNorm2d(nn.BatchNorm2d): + """ + In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient + when the batch size on each worker is different. + (e.g., when scale augmentation is used, or when it is applied to mask head). + + This is a slower but correct alternative to `nn.SyncBatchNorm`. + + Note: + There isn't a single definition of Sync BatchNorm. + + When ``stats_mode==""``, this module computes overall statistics by using + statistics of each worker with equal weight. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (N, H, W). This mode does not support inputs with zero batch size. + + When ``stats_mode=="N"``, this module computes overall statistics by weighting + the statistics of each worker by their ``N``. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (H, W). It is slower than ``stats_mode==""``. + + Even though the result of this module may not be the true statistics of all samples, + it may still be reasonable because it might be preferrable to assign equal weights + to all workers, regardless of their (H, W) dimension, instead of putting larger weight + on larger images. From preliminary experiments, little difference is found between such + a simplified implementation and an accurate computation of overall mean & variance. + """ + + def __init__(self, *args, stats_mode="", **kwargs): + super().__init__(*args, **kwargs) + assert stats_mode in ["", "N"] + self._stats_mode = stats_mode + + def forward(self, input): + if comm.get_world_size() == 1 or not self.training: + return super().forward(input) + + B, C = input.shape[0], input.shape[1] + + mean = torch.mean(input, dim=[0, 2, 3]) + meansqr = torch.mean(input * input, dim=[0, 2, 3]) + + if self._stats_mode == "": + assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' + vec = torch.cat([mean, meansqr], dim=0) + vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) + mean, meansqr = torch.split(vec, C) + momentum = self.momentum + else: + if B == 0: + vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) + vec = vec + input.sum() # make sure there is gradient w.r.t input + else: + vec = torch.cat( + [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 + ) + vec = AllReduce.apply(vec * B) + + total_batch = vec[-1].detach() + momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 + total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero + mean, meansqr, _ = torch.split(vec / total_batch, C) + + var = meansqr - mean * mean + invstd = torch.rsqrt(var + self.eps) + scale = self.weight * invstd + bias = self.bias - mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + + self.running_mean += momentum * (mean.detach() - self.running_mean) + self.running_var += momentum * (var.detach() - self.running_var) + return input * scale + bias \ No newline at end of file diff --git a/maskrcnn_benchmark/layers/deform_conv.py b/maskrcnn_benchmark/layers/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d78dcc10db200b7d8dae1fb4de252ba0868628 --- /dev/null +++ b/maskrcnn_benchmark/layers/deform_conv.py @@ -0,0 +1,436 @@ +import torch +import math +from torch import nn +from torch.nn import init +from torch.nn.modules.utils import _pair +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd + +from maskrcnn_benchmark import _C + +class DeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64 + ): + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format( + input.dim())) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + DeformConvFunction._output_size(input, weight, ctx.padding, + ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + _C.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % + cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + _C.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + _C.deform_conv_backward_parameters( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step + ) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + "convolution input is too small (output would be {})".format( + 'x'.join(map(str, output_size)))) + return output_size + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1 + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty( + ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + _C.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + _C.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias + ) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - + (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - + (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False + ): + assert not bias + super(DeformConv, self).__init__() + self.with_bias = bias + + assert in_channels % groups == 0, \ + 'in_channels {} cannot be divisible by groups {}'.format( + in_channels, groups) + assert out_channels % groups == 0, \ + 'out_channels {} cannot be divisible by groups {}'.format( + out_channels, groups) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, + *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, input, offset): + return deform_conv(input, offset, self.weight, self.stride, + self.padding, self.dilation, self.groups, + self.deformable_groups) + + def __repr__(self): + return "".join([ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ]) + +class ModulatedDeformConv(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True + ): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + + self.weight = nn.Parameter(torch.Tensor( + out_channels, + in_channels // groups, + *self.kernel_size + )) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, input, offset, mask): + return modulated_deform_conv( + input, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.deformable_groups) + + def __repr__(self): + return "".join([ + "{}(".format(self.__class__.__name__), + "in_channels={}, ".format(self.in_channels), + "out_channels={}, ".format(self.out_channels), + "kernel_size={}, ".format(self.kernel_size), + "stride={}, ".format(self.stride), + "dilation={}, ".format(self.dilation), + "padding={}, ".format(self.padding), + "groups={}, ".format(self.groups), + "deformable_groups={}, ".format(self.deformable_groups), + "bias={})".format(self.with_bias), + ]) + +class ModulatedDeformConvPack(ModulatedDeformConv): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConvPack, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, deformable_groups, bias) + + self.conv_offset_mask = nn.Conv2d( + self.in_channels // self.groups, + self.deformable_groups * 3 * self.kernel_size[0] * + self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, input): + out = self.conv_offset_mask(input) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv( + input, offset, mask, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, self.deformable_groups) diff --git a/maskrcnn_benchmark/layers/deform_pool.py b/maskrcnn_benchmark/layers/deform_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..1202e1d657c45cba7c8fa34a2684ae13e957ca30 --- /dev/null +++ b/maskrcnn_benchmark/layers/deform_pool.py @@ -0,0 +1,423 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .deform_conv import DeformConv2d + +def add_conv(in_ch, out_ch, ksize, stride, leaky=True): + """ + Add a conv2d / batchnorm / leaky ReLU block. + Args: + in_ch (int): number of input channels of the convolution layer. + out_ch (int): number of output channels of the convolution layer. + ksize (int): kernel size of the convolution layer. + stride (int): stride of the convolution layer. + Returns: + stage (Sequential) : Sequential layers composing a convolution block. + """ + stage = nn.Sequential() + pad = (ksize - 1) // 2 + stage.add_module('conv', nn.Conv2d(in_channels=in_ch, + out_channels=out_ch, kernel_size=ksize, stride=stride, + padding=pad, bias=False)) + stage.add_module('batch_norm', nn.BatchNorm2d(out_ch)) + if leaky: + stage.add_module('leaky', nn.LeakyReLU(0.1)) + else: + stage.add_module('relu6', nn.ReLU6(inplace=True)) + return stage + + +class upsample(nn.Module): + __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name'] + + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): + super(upsample, self).__init__() + self.name = type(self).__name__ + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, input): + return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) + + def extra_repr(self): + if self.scale_factor is not None: + info = 'scale_factor=' + str(self.scale_factor) + else: + info = 'size=' + str(self.size) + info += ', mode=' + self.mode + return info + +class SPPLayer(nn.Module): + def __init__(self): + super(SPPLayer, self).__init__() + + def forward(self, x): + x_1 = x + x_2 = F.max_pool2d(x, 5, stride=1, padding=2) + x_3 = F.max_pool2d(x, 9, stride=1, padding=4) + x_4 = F.max_pool2d(x, 13, stride=1, padding=6) + out = torch.cat((x_1, x_2, x_3, x_4),dim=1) + return out + +class DropBlock(nn.Module): + def __init__(self, block_size=7, keep_prob=0.9): + super(DropBlock, self).__init__() + self.block_size = block_size + self.keep_prob = keep_prob + self.gamma = None + self.kernel_size = (block_size, block_size) + self.stride = (1, 1) + self.padding = (block_size//2, block_size//2) + + def reset(self, block_size, keep_prob): + self.block_size = block_size + self.keep_prob = keep_prob + self.gamma = None + self.kernel_size = (block_size, block_size) + self.stride = (1, 1) + self.padding = (block_size//2, block_size//2) + + def calculate_gamma(self, x): + return (1-self.keep_prob) * x.shape[-1]**2/ \ + (self.block_size**2 * (x.shape[-1] - self.block_size + 1)**2) + + def forward(self, x): + if (not self.training or self.keep_prob==1): #set keep_prob=1 to turn off dropblock + return x + if self.gamma is None: + self.gamma = self.calculate_gamma(x) + if x.type() == 'torch.cuda.HalfTensor': #TODO: not fully support for FP16 now + FP16 = True + x = x.float() + else: + FP16 = False + p = torch.ones_like(x) * (self.gamma) + mask = 1 - torch.nn.functional.max_pool2d(torch.bernoulli(p), + self.kernel_size, + self.stride, + self.padding) + + out = mask * x * (mask.numel()/mask.sum()) + + if FP16: + out = out.half() + return out + +class resblock(nn.Module): + """ + Sequential residual blocks each of which consists of \ + two convolution layers. + Args: + ch (int): number of input and output channels. + nblocks (int): number of residual blocks. + shortcut (bool): if True, residual tensor addition is enabled. + """ + def __init__(self, ch, nblocks=1, shortcut=True): + + super().__init__() + self.shortcut = shortcut + self.module_list = nn.ModuleList() + for i in range(nblocks): + resblock_one = nn.ModuleList() + resblock_one.append(add_conv(ch, ch//2, 1, 1)) + resblock_one.append(add_conv(ch//2, ch, 3, 1)) + self.module_list.append(resblock_one) + + def forward(self, x): + for module in self.module_list: + h = x + for res in module: + h = res(h) + x = x + h if self.shortcut else h + return x + + +class RFBblock(nn.Module): + def __init__(self,in_ch,residual=False): + super(RFBblock, self).__init__() + inter_c = in_ch // 4 + self.branch_0 = nn.Sequential( + nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0), + ) + self.branch_1 = nn.Sequential( + nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1) + ) + self.branch_2 = nn.Sequential( + nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1), + nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=2, padding=2) + ) + self.branch_3 = nn.Sequential( + nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=5, stride=1, padding=2), + nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=3, padding=3) + ) + self.residual= residual + + def forward(self,x): + x_0 = self.branch_0(x) + x_1 = self.branch_1(x) + x_2 = self.branch_2(x) + x_3 = self.branch_3(x) + out = torch.cat((x_0,x_1,x_2,x_3),1) + if self.residual: + out +=x + return out + + +class FeatureAdaption(nn.Module): + def __init__(self, in_ch, out_ch, n_anchors, rfb=False, sep=False): + super(FeatureAdaption, self).__init__() + if sep: + self.sep=True + else: + self.sep=False + self.conv_offset = nn.Conv2d(in_channels=2*n_anchors, + out_channels=2*9*n_anchors, groups = n_anchors, kernel_size=1,stride=1,padding=0) + self.dconv = DeformConv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, + padding=1, deformable_groups=n_anchors) + self.rfb=None + if rfb: + self.rfb = RFBblock(out_ch) + + def forward(self, input, wh_pred): + #The RFB block is added behind FeatureAdaption + #For mobilenet, we currently don't support rfb and FeatureAdaption + if self.sep: + return input + if self.rfb is not None: + input = self.rfb(input) + wh_pred_new = wh_pred.detach() + offset = self.conv_offset(wh_pred_new) + out = self.dconv(input, offset) + return out + + +class ASFFmobile(nn.Module): + def __init__(self, level, rfb=False, vis=False): + super(ASFFmobile, self).__init__() + self.level = level + self.dim = [512, 256, 128] + self.inter_dim = self.dim[self.level] + if level==0: + self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2, leaky=False) + self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False) + self.expand = add_conv(self.inter_dim, 1024, 3, 1, leaky=False) + elif level==1: + self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False) + self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False) + self.expand = add_conv(self.inter_dim, 512, 3, 1, leaky=False) + elif level==2: + self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False) + self.compress_level_1 = add_conv(256, self.inter_dim, 1, 1, leaky=False) + self.expand = add_conv(self.inter_dim, 256, 3, 1,leaky=False) + + compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory + + self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False) + self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False) + self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False) + + self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0) + self.vis= vis + + + def forward(self, x_level_0, x_level_1, x_level_2): + if self.level==0: + level_0_resized = x_level_0 + level_1_resized = self.stride_level_1(x_level_1) + + level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1) + level_2_resized = self.stride_level_2(level_2_downsampled_inter) + + elif self.level==1: + level_0_compressed = self.compress_level_0(x_level_0) + level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') + level_1_resized =x_level_1 + level_2_resized =self.stride_level_2(x_level_2) + elif self.level==2: + level_0_compressed = self.compress_level_0(x_level_0) + level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') + level_1_compressed = self.compress_level_1(x_level_1) + level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest') + level_2_resized =x_level_2 + + level_0_weight_v = self.weight_level_0(level_0_resized) + level_1_weight_v = self.weight_level_1(level_1_resized) + level_2_weight_v = self.weight_level_2(level_2_resized) + levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1) + levels_weight = self.weight_levels(levels_weight_v) + levels_weight = F.softmax(levels_weight, dim=1) + + fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \ + level_1_resized * levels_weight[:,1:2,:,:]+ \ + level_2_resized * levels_weight[:,2:,:,:] + + out = self.expand(fused_out_reduced) + + if self.vis: + return out, levels_weight, fused_out_reduced.sum(dim=1) + else: + return out + + +class ASFF(nn.Module): + def __init__(self, level, rfb=False, vis=False): + super(ASFF, self).__init__() + self.level = level + self.dim = [512, 256, 256] + self.inter_dim = self.dim[self.level] + if level==0: + self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2) + self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2) + self.expand = add_conv(self.inter_dim, 1024, 3, 1) + elif level==1: + self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1) + self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2) + self.expand = add_conv(self.inter_dim, 512, 3, 1) + elif level==2: + self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1) + self.expand = add_conv(self.inter_dim, 256, 3, 1) + + compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory + + self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1) + self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1) + self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1) + + self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0) + self.vis= vis + + + def forward(self, x_level_0, x_level_1, x_level_2): + if self.level==0: + level_0_resized = x_level_0 + level_1_resized = self.stride_level_1(x_level_1) + + level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1) + level_2_resized = self.stride_level_2(level_2_downsampled_inter) + + elif self.level==1: + level_0_compressed = self.compress_level_0(x_level_0) + level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') + level_1_resized =x_level_1 + level_2_resized =self.stride_level_2(x_level_2) + elif self.level==2: + level_0_compressed = self.compress_level_0(x_level_0) + level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') + level_1_resized =F.interpolate(x_level_1, scale_factor=2, mode='nearest') + level_2_resized =x_level_2 + + level_0_weight_v = self.weight_level_0(level_0_resized) + level_1_weight_v = self.weight_level_1(level_1_resized) + level_2_weight_v = self.weight_level_2(level_2_resized) + levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1) + levels_weight = self.weight_levels(levels_weight_v) + levels_weight = F.softmax(levels_weight, dim=1) + + fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \ + level_1_resized * levels_weight[:,1:2,:,:]+ \ + level_2_resized * levels_weight[:,2:,:,:] + + out = self.expand(fused_out_reduced) + + if self.vis: + return out, levels_weight, fused_out_reduced.sum(dim=1) + else: + return out + +def make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + +def add_sepconv(in_ch, out_ch, ksize, stride): + + stage = nn.Sequential() + pad = (ksize - 1) // 2 + stage.add_module('sepconv', nn.Conv2d(in_channels=in_ch, + out_channels=in_ch, kernel_size=ksize, stride=stride, + padding=pad, groups=in_ch, bias=False)) + stage.add_module('sepbn', nn.BatchNorm2d(in_ch)) + stage.add_module('seprelu6', nn.ReLU6(inplace=True)) + stage.add_module('ptconv', nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False)) + stage.add_module('ptbn', nn.BatchNorm2d(out_ch)) + stage.add_module('ptrelu6', nn.ReLU6(inplace=True)) + return stage + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + +class ressepblock(nn.Module): + def __init__(self, ch, out_ch, in_ch=None, shortcut=True): + + super().__init__() + self.shortcut = shortcut + self.module_list = nn.ModuleList() + in_ch = ch//2 if in_ch==None else in_ch + resblock_one = nn.ModuleList() + resblock_one.append(add_conv(ch, in_ch, 1, 1, leaky=False)) + resblock_one.append(add_conv(in_ch, out_ch, 3, 1,leaky=False)) + self.module_list.append(resblock_one) + + def forward(self, x): + for module in self.module_list: + h = x + for res in module: + h = res(h) + x = x + h if self.shortcut else h + return x + diff --git a/maskrcnn_benchmark/layers/dropblock.py b/maskrcnn_benchmark/layers/dropblock.py new file mode 100644 index 0000000000000000000000000000000000000000..3210b99ec5d82d65e448363315df28c4c5f2d239 --- /dev/null +++ b/maskrcnn_benchmark/layers/dropblock.py @@ -0,0 +1,146 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class DropBlock2D(nn.Module): + r"""Randomly zeroes 2D spatial blocks of the input tensor. + + As described in the paper + `DropBlock: A regularization method for convolutional networks`_ , + dropping whole blocks of feature map allows to remove semantic + information as compared to regular dropout. + + Args: + drop_prob (float): probability of an element to be dropped. + block_size (int): size of the block to drop + + Shape: + - Input: `(N, C, H, W)` + - Output: `(N, C, H, W)` + + .. _DropBlock: A regularization method for convolutional networks: + https://arxiv.org/abs/1810.12890 + + """ + + def __init__(self, drop_prob, block_size): + super(DropBlock2D, self).__init__() + + self.drop_prob = drop_prob + self.block_size = block_size + + def forward(self, x): + # shape: (bsize, channels, height, width) + + assert x.dim() == 4, \ + "Expected input with 4 dimensions (bsize, channels, height, width)" + + if not self.training or self.drop_prob == 0.: + return x + else: + # get gamma value + gamma = self._compute_gamma(x) + + # sample mask + mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() + + # place mask on input device + mask = mask.to(x.device) + + # compute block mask + block_mask = self._compute_block_mask(mask) + + # apply block mask + out = x * block_mask[:, None, :, :] + + # scale output + out = out * block_mask.numel() / block_mask.sum() + + return out + + def _compute_block_mask(self, mask): + block_mask = F.max_pool2d(input=mask[:, None, :, :], + kernel_size=(self.block_size, self.block_size), + stride=(1, 1), + padding=self.block_size // 2) + + if self.block_size % 2 == 0: + block_mask = block_mask[:, :, :-1, :-1] + + block_mask = 1 - block_mask.squeeze(1) + + return block_mask + + def _compute_gamma(self, x): + return self.drop_prob / (self.block_size ** 2) + + +class DropBlock3D(DropBlock2D): + r"""Randomly zeroes 3D spatial blocks of the input tensor. + + An extension to the concept described in the paper + `DropBlock: A regularization method for convolutional networks`_ , + dropping whole blocks of feature map allows to remove semantic + information as compared to regular dropout. + + Args: + drop_prob (float): probability of an element to be dropped. + block_size (int): size of the block to drop + + Shape: + - Input: `(N, C, D, H, W)` + - Output: `(N, C, D, H, W)` + + .. _DropBlock: A regularization method for convolutional networks: + https://arxiv.org/abs/1810.12890 + + """ + + def __init__(self, drop_prob, block_size): + super(DropBlock3D, self).__init__(drop_prob, block_size) + + def forward(self, x): + # shape: (bsize, channels, depth, height, width) + + assert x.dim() == 5, \ + "Expected input with 5 dimensions (bsize, channels, depth, height, width)" + + if not self.training or self.drop_prob == 0.: + return x + else: + # get gamma value + gamma = self._compute_gamma(x) + + # sample mask + mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() + + # place mask on input device + mask = mask.to(x.device) + + # compute block mask + block_mask = self._compute_block_mask(mask) + + # apply block mask + out = x * block_mask[:, None, :, :, :] + + # scale output + out = out * block_mask.numel() / block_mask.sum() + + return out + + def _compute_block_mask(self, mask): + block_mask = F.max_pool3d(input=mask[:, None, :, :, :], + kernel_size=(self.block_size, self.block_size, self.block_size), + stride=(1, 1, 1), + padding=self.block_size // 2) + + if self.block_size % 2 == 0: + block_mask = block_mask[:, :, :-1, :-1, :-1] + + block_mask = 1 - block_mask.squeeze(1) + + return block_mask + + def _compute_gamma(self, x): + return self.drop_prob / (self.block_size ** 3) \ No newline at end of file diff --git a/maskrcnn_benchmark/layers/dyhead.py b/maskrcnn_benchmark/layers/dyhead.py new file mode 100644 index 0000000000000000000000000000000000000000..91fa88cb0beaef03e6459d671de843496ebe27f4 --- /dev/null +++ b/maskrcnn_benchmark/layers/dyhead.py @@ -0,0 +1,151 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from .deform_conv import ModulatedDeformConv +from .dyrelu import h_sigmoid, DYReLU + + +class Conv3x3Norm(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + stride, + deformable=False, + use_gn=False): + super(Conv3x3Norm, self).__init__() + + if deformable: + self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) + else: + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) + + if use_gn: + self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels) + else: + self.bn = None + + def forward(self, input, **kwargs): + x = self.conv(input, **kwargs) + if self.bn: + x = self.bn(x) + return x + + +class DyConv(nn.Module): + def __init__(self, + in_channels=256, + out_channels=256, + conv_func=Conv3x3Norm, + use_dyfuse=True, + use_dyrelu=False, + use_deform=False + ): + super(DyConv, self).__init__() + + self.DyConv = nn.ModuleList() + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 2)) + + if use_dyfuse: + self.AttnConv = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, 1, kernel_size=1), + nn.ReLU(inplace=True)) + self.h_sigmoid = h_sigmoid() + else: + self.AttnConv = None + + if use_dyrelu: + self.relu = DYReLU(in_channels, out_channels) + else: + self.relu = nn.ReLU() + + if use_deform: + self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) + else: + self.offset = None + + self.init_weights() + + def init_weights(self): + for m in self.DyConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + if self.AttnConv is not None: + for m in self.AttnConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + next_x = [] + for level, feature in enumerate(x): + + conv_args = dict() + if self.offset is not None: + offset_mask = self.offset(feature) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, 18:, :, :].sigmoid() + conv_args = dict(offset=offset, mask=mask) + + temp_fea = [self.DyConv[1](feature, **conv_args)] + + if level > 0: + temp_fea.append(self.DyConv[2](x[level - 1], **conv_args)) + if level < len(x) - 1: + temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args), + size=[feature.size(2), feature.size(3)])) + mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) + + if self.AttnConv is not None: + attn_fea = [] + res_fea = [] + for fea in temp_fea: + res_fea.append(fea) + attn_fea.append(self.AttnConv(fea)) + + res_fea = torch.stack(res_fea) + spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) + + mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) + + next_x.append(self.relu(mean_fea)) + + return next_x + + +class DyHead(nn.Module): + def __init__(self, cfg, in_channels): + super(DyHead, self).__init__() + self.cfg = cfg + channels = cfg.MODEL.DYHEAD.CHANNELS + use_gn = cfg.MODEL.DYHEAD.USE_GN + use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU + use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE + use_deform = cfg.MODEL.DYHEAD.USE_DFCONV + + conv_func = lambda i,o,s : Conv3x3Norm(i,o,s,deformable=use_deform,use_gn=use_gn) + + dyhead_tower = [] + for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): + dyhead_tower.append( + DyConv( + in_channels if i == 0 else channels, + channels, + conv_func=conv_func, + use_dyrelu=use_dyrelu, + use_dyfuse=use_dyfuse, + use_deform=use_deform + ) + ) + + self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) + + def forward(self, x): + dyhead_tower = self.dyhead_tower(x) + return dyhead_tower \ No newline at end of file diff --git a/maskrcnn_benchmark/layers/dyrelu.py b/maskrcnn_benchmark/layers/dyrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..070b2e99df0f473faec0ef5914e1d385fda8e4f4 --- /dev/null +++ b/maskrcnn_benchmark/layers/dyrelu.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class h_swish(nn.Module): + def __init__(self, inplace=False): + super(h_swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True, h_max=1): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + self.h_max = h_max + + def forward(self, x): + return self.relu(x + 3) * self.h_max / 6 + + +class DYReLU(nn.Module): + def __init__(self, inp, oup, reduction=4, lambda_a=1.0, K2=True, use_bias=True, use_spatial=False, + init_a=[1.0, 0.0], init_b=[0.0, 0.0]): + super(DYReLU, self).__init__() + self.oup = oup + self.lambda_a = lambda_a * 2 + self.K2 = K2 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.use_bias = use_bias + if K2: + self.exp = 4 if use_bias else 2 + else: + self.exp = 2 if use_bias else 1 + self.init_a = init_a + self.init_b = init_b + + # determine squeeze + if reduction == 4: + squeeze = inp // reduction + else: + squeeze = _make_divisible(inp // reduction, 4) + # print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze)) + # print('init_a: {}, init_b: {}'.format(self.init_a, self.init_b)) + + self.fc = nn.Sequential( + nn.Linear(inp, squeeze), + nn.ReLU(inplace=True), + nn.Linear(squeeze, oup * self.exp), + h_sigmoid() + ) + if use_spatial: + self.spa = nn.Sequential( + nn.Conv2d(inp, 1, kernel_size=1), + nn.BatchNorm2d(1), + ) + else: + self.spa = None + + def forward(self, x): + if isinstance(x, list): + x_in = x[0] + x_out = x[1] + else: + x_in = x + x_out = x + b, c, h, w = x_in.size() + y = self.avg_pool(x_in).view(b, c) + y = self.fc(y).view(b, self.oup * self.exp, 1, 1) + if self.exp == 4: + a1, b1, a2, b2 = torch.split(y, self.oup, dim=1) + a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0] # 1.0 + a2 = (a2 - 0.5) * self.lambda_a + self.init_a[1] + + b1 = b1 - 0.5 + self.init_b[0] + b2 = b2 - 0.5 + self.init_b[1] + out = torch.max(x_out * a1 + b1, x_out * a2 + b2) + elif self.exp == 2: + if self.use_bias: # bias but not PL + a1, b1 = torch.split(y, self.oup, dim=1) + a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0] # 1.0 + b1 = b1 - 0.5 + self.init_b[0] + out = x_out * a1 + b1 + + else: + a1, a2 = torch.split(y, self.oup, dim=1) + a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0] # 1.0 + a2 = (a2 - 0.5) * self.lambda_a + self.init_a[1] + out = torch.max(x_out * a1, x_out * a2) + + elif self.exp == 1: + a1 = y + a1 = (a1 - 0.5) * self.lambda_a + self.init_a[0] # 1.0 + out = x_out * a1 + + if self.spa: + ys = self.spa(x_in).view(b, -1) + ys = F.softmax(ys, dim=1).view(b, 1, h, w) * h * w + ys = F.hardtanh(ys, 0, 3, inplace=True)/3 + out = out * ys + + return out diff --git a/maskrcnn_benchmark/layers/evonorm.py b/maskrcnn_benchmark/layers/evonorm.py new file mode 100644 index 0000000000000000000000000000000000000000..058c0990427d0070fbb218e54119109536e80b66 --- /dev/null +++ b/maskrcnn_benchmark/layers/evonorm.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +class EvoNorm2d(nn.Module): + __constants__ = ['num_features', 'eps', 'nonlinearity'] + + def __init__(self, num_features, eps=1e-5, nonlinearity=True, group=32): + super(EvoNorm2d, self).__init__() + + self.num_features = num_features + self.eps = eps + self.nonlinearity = nonlinearity + self.group = group + + self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + if self.nonlinearity: + self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.nonlinearity: + nn.init.ones_(self.v) + + def group_std(self, x, groups=32): + N, C, H, W = x.shape + x = torch.reshape(x, (N, groups, C // groups, H, W)) + std = torch.std(x, (3, 4), keepdim=True) + return torch.reshape(std + self.eps, (N, C, 1, 1)) + + def forward(self, x): + if self.nonlinearity: + num = x * torch.sigmoid(self.v * x) + return num / self.group_std(x, self.group) * self.weight + self.bias + else: + return x * self.weight + self.bias \ No newline at end of file diff --git a/maskrcnn_benchmark/layers/iou_loss.py b/maskrcnn_benchmark/layers/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..80703b20bfd0d443e66ee089f2d3653330238dbe --- /dev/null +++ b/maskrcnn_benchmark/layers/iou_loss.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class IOULoss(nn.Module): + def __init__(self, loss_type="iou"): + super(IOULoss, self).__init__() + self.loss_type = loss_type + + def forward(self, pred, target, weight=None): + pred_left = pred[:, 0] + pred_top = pred[:, 1] + pred_right = pred[:, 2] + pred_bottom = pred[:, 3] + + target_left = target[:, 0] + target_top = target[:, 1] + target_right = target[:, 2] + target_bottom = target[:, 3] + + target_area = (target_left + target_right) * \ + (target_top + target_bottom) + pred_area = (pred_left + pred_right) * \ + (pred_top + pred_bottom) + + w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) + g_w_intersect = torch.max(pred_left, target_left) + torch.max( + pred_right, target_right) + h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) + g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) + ac_uion = g_w_intersect * g_h_intersect + 1e-7 + area_intersect = w_intersect * h_intersect + area_union = target_area + pred_area - area_intersect + ious = (area_intersect + 1.0) / (area_union + 1.0) + gious = ious - (ac_uion - area_union) / ac_uion + if self.loss_type == 'iou': + losses = -torch.log(ious) + elif self.loss_type == 'linear_iou': + losses = 1 - ious + elif self.loss_type == 'giou': + losses = 1 - gious + else: + raise NotImplementedError + + if weight is not None and weight.sum() > 0: + return (losses * weight).sum() + else: + assert losses.numel() != 0 + return losses.sum() + + +class IOUWHLoss(nn.Module): # used for anchor guiding + def __init__(self, reduction='none'): + super(IOUWHLoss, self).__init__() + self.reduction = reduction + + def forward(self, pred, target): + orig_shape = pred.shape + pred = pred.view(-1, 4) + target = target.view(-1, 4) + target[:, :2] = 0 + tl = torch.max((target[:, :2] - pred[:, 2:] / 2), + (target[:, :2] - target[:, 2:] / 2)) + + br = torch.min((target[:, :2] + pred[:, 2:] / 2), + (target[:, :2] + target[:, 2:] / 2)) + + area_p = torch.prod(pred[:, 2:], 1) + area_g = torch.prod(target[:, 2:], 1) + + en = (tl < br).type(tl.type()).prod(dim=1) + area_i = torch.prod(br - tl, 1) * en + U = area_p + area_g - area_i + 1e-16 + iou = area_i / U + + loss = 1 - iou ** 2 + if self.reduction == 'mean': + loss = loss.mean() + elif self.reduction == 'sum': + loss = loss.sum() + + return loss diff --git a/maskrcnn_benchmark/layers/misc.py b/maskrcnn_benchmark/layers/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fe175249c6494e19e3d077478025ef9e8335306d --- /dev/null +++ b/maskrcnn_benchmark/layers/misc.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +helper class that supports empty tensors on some nn functions. + +Ideally, add support directly in PyTorch to empty tensors in +those functions. + +This can be removed once https://github.com/pytorch/pytorch/issues/12013 +is implemented +""" + +import math +import torch +from torch.nn.modules.utils import _ntuple + + +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class Conv2d(torch.nn.Conv2d): + def forward(self, x): + if x.numel() > 0: + return super(Conv2d, self).forward(x) + # get output shape + + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +class ConvTranspose2d(torch.nn.ConvTranspose2d): + def forward(self, x): + if x.numel() > 0: + return super(ConvTranspose2d, self).forward(x) + # get output shape + + output_shape = [ + (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op + for i, p, di, k, d, op in zip( + x.shape[-2:], + self.padding, + self.dilation, + self.kernel_size, + self.stride, + self.output_padding, + ) + ] + output_shape = [x.shape[0], self.bias.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +class BatchNorm2d(torch.nn.BatchNorm2d): + def forward(self, x): + if x.numel() > 0: + return super(BatchNorm2d, self).forward(x) + # get output shape + output_shape = x.shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + def _check_size_scale_factor(dim): + if size is None and scale_factor is None: + raise ValueError("either size or scale_factor should be defined") + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + if ( + scale_factor is not None + and isinstance(scale_factor, tuple) + and len(scale_factor) != dim + ): + raise ValueError( + "scale_factor shape must match input shape. " + "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) + ) + + def _output_size(dim): + _check_size_scale_factor(dim) + if size is not None: + return size + scale_factors = _ntuple(dim)(scale_factor) + # math.floor might return float in py2.7 + return [ + int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) + ] + + output_shape = tuple(_output_size(2)) + output_shape = input.shape[:-2] + output_shape + return _NewEmptyTensorOp.apply(input, output_shape) + + +class Scale(torch.nn.Module): + def __init__(self, init_value=1.0): + super(Scale, self).__init__() + self.scale = torch.nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class DFConv2d(torch.nn.Module): + """Deformable convolutional layer""" + def __init__( + self, + in_channels, + out_channels, + with_modulated_dcn=True, + kernel_size=3, + stride=1, + groups=1, + padding=1, + dilation=1, + deformable_groups=1, + bias=False + ): + super(DFConv2d, self).__init__() + if isinstance(kernel_size, (list, tuple)): + assert len(kernel_size) == 2 + offset_base_channels = kernel_size[0] * kernel_size[1] + else: + offset_base_channels = kernel_size * kernel_size + if with_modulated_dcn: + from maskrcnn_benchmark.layers import ModulatedDeformConv + offset_channels = offset_base_channels * 3 #default: 27 + conv_block = ModulatedDeformConv + else: + from maskrcnn_benchmark.layers import DeformConv + offset_channels = offset_base_channels * 2 #default: 18 + conv_block = DeformConv + self.offset = Conv2d( + in_channels, + deformable_groups * offset_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=1, + dilation=dilation + ) + for l in [self.offset, ]: + torch.nn.init.kaiming_uniform_(l.weight, a=1) + torch.nn.init.constant_(l.bias, 0.) + self.conv = conv_block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + bias=bias + ) + self.with_modulated_dcn = with_modulated_dcn + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.offset_base_channels = offset_base_channels + + def forward(self, x): + if x.numel() > 0: + if not self.with_modulated_dcn: + offset = self.offset(x) + x = self.conv(x, offset) + else: + offset_mask = self.offset(x) + split_point = self.offset_base_channels * 2 + offset = offset_mask[:, :split_point, :, :] + mask = offset_mask[:, split_point:, :, :].sigmoid() + x = self.conv(x, offset, mask) + return x + # get output shape + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], + self.padding, + self.dilation, + self.kernel_size, + self.stride + ) + ] + output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) diff --git a/maskrcnn_benchmark/layers/nms.py b/maskrcnn_benchmark/layers/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..12e81ad4a2183b5fca497d33f0d34d5fcc0d4ea1 --- /dev/null +++ b/maskrcnn_benchmark/layers/nms.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from maskrcnn_benchmark import _C + +try: + import torchvision + from torchvision.ops import nms +except: + nms = _C.nms + +ml_nms = _C.ml_nms +soft_nms = _C.soft_nms + +# nms.__doc__ = """ +# This function performs Non-maximum suppresion""" diff --git a/maskrcnn_benchmark/layers/roi_align.py b/maskrcnn_benchmark/layers/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..247397098aaa7bf71bcb652af5a6664f86265ce9 --- /dev/null +++ b/maskrcnn_benchmark/layers/roi_align.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from maskrcnn_benchmark import _C + +class _ROIAlign(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): + ctx.save_for_backward(roi) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.size() + output = _C.roi_align_forward( + input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_align_backward( + grad_output, + rois, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + sampling_ratio, + ) + return grad_input, None, None, None, None + +try: + import torchvision + from torchvision.ops import roi_align +except: + roi_align = _ROIAlign.apply + +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(ROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + return roi_align( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr + +class ROIAlignV2(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(ROIAlignV2, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + return torchvision.ops.roi_align( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, aligned=True + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr diff --git a/maskrcnn_benchmark/layers/roi_pool.py b/maskrcnn_benchmark/layers/roi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..b69efd75f17326a4fd8f306570624fd5ff4ef9b6 --- /dev/null +++ b/maskrcnn_benchmark/layers/roi_pool.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from maskrcnn_benchmark import _C + + +class _ROIPool(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + output, argmax = _C.roi_pool_forward( + input, roi, spatial_scale, output_size[0], output_size[1] + ) + ctx.save_for_backward(input, roi, argmax) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, argmax = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + grad_input = _C.roi_pool_backward( + grad_output, + input, + rois, + argmax, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + ) + return grad_input, None, None, None + + +roi_pool = _ROIPool.apply + + +class ROIPool(nn.Module): + def __init__(self, output_size, spatial_scale): + super(ROIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" + return tmpstr diff --git a/maskrcnn_benchmark/layers/se.py b/maskrcnn_benchmark/layers/se.py new file mode 100644 index 0000000000000000000000000000000000000000..f10d09217270c14001fec2795b20d14dd5b73586 --- /dev/null +++ b/maskrcnn_benchmark/layers/se.py @@ -0,0 +1,52 @@ +from torch import nn + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class SEBlock(nn.Module): + def __init__(self, channels, reduction=16, + use_conv=True, mid_activation=nn.ReLU(inplace=True), out_activation=nn.Sigmoid()): + super(SEBlock, self).__init__() + self.use_conv = use_conv + mid_channels = channels // reduction + + self.pool = nn.AdaptiveAvgPool2d(output_size=1) + if use_conv: + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, bias=True) + else: + self.fc1 = nn.Linear(channels, mid_channels) + self.activ = mid_activation + if use_conv: + self.conv2 = nn.Conv2d(mid_channels, channels, kernel_size=1, bias=True) + else: + self.fc2 = nn.Linear(mid_channels, channels) + self.sigmoid = out_activation + + def forward(self, x): + w = self.pool(x) + if not self.use_conv: + w = w.view(x.size(0), -1) + w = self.conv1(w) if self.use_conv else self.fc1(w) + w = self.activ(w) + w = self.conv2(w) if self.use_conv else self.fc2(w) + w = self.sigmoid(w) + if not self.use_conv: + w = w.unsqueeze(2).unsqueeze(3) + x = x * w + return x \ No newline at end of file diff --git a/maskrcnn_benchmark/layers/set_loss.py b/maskrcnn_benchmark/layers/set_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8f504d813708e93b6c31efa61f98d0c8fa2cf0e9 --- /dev/null +++ b/maskrcnn_benchmark/layers/set_loss.py @@ -0,0 +1,371 @@ +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torch import nn + +from scipy.optimize import linear_sum_assignment +from torch.cuda.amp import custom_fwd, custom_bwd + + +def box_area(boxes): + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + #assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + #assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, gamma: float = 2, reduction: str = "none"): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss + + +sigmoid_focal_loss_jit = torch.jit.script( + sigmoid_focal_loss +) # type: torch.jit.ScriptModule + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, + use_focal: bool = False, focal_loss_alpha: float = 0.25, focal_loss_gamma: float = 2.0, + **kwargs): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.use_focal = use_focal + if self.use_focal: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + @custom_fwd(cast_inputs=torch.float32) + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + if self.use_focal: + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + else: + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes_xyxy"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + if self.use_focal: + # Compute the classification cost. + alpha = self.focal_loss_alpha + gamma = self.focal_loss_gamma + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + else: + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + image_size_out = torch.cat([v["image_size_xyxy"].unsqueeze(0) for v in targets]) + image_size_out = image_size_out.unsqueeze(1).repeat(1, num_queries, 1).flatten(0, 1) + image_size_tgt = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) + + out_bbox_ = out_bbox / image_size_out + tgt_bbox_ = tgt_bbox / image_size_tgt + cost_bbox = torch.cdist(out_bbox_, tgt_bbox_, p=1) + + # Compute the giou cost betwen boxes + # cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + cost_giou = -generalized_box_iou(out_bbox, tgt_bbox) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + C[torch.isnan(C)] = 0.0 + C[torch.isinf(C)] = 0.0 + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +class SetCriterion(nn.Module): + """ + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, + use_focal, focal_loss_alpha=0.25, focal_loss_gamma=2.0): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + self.use_focal = use_focal + if self.use_focal: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + else: + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=False): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + if self.use_focal: + src_logits = src_logits.flatten(0, 1) + # prepare one_hot target. + target_classes = target_classes.flatten(0, 1) + pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] + labels = torch.zeros_like(src_logits) + labels[pos_inds, target_classes[pos_inds]] = 1 + # comp focal loss. + class_loss = sigmoid_focal_loss_jit( + src_logits, + labels, + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) / num_boxes + losses = {'loss_ce': class_loss} + else: + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + loss_giou = 1 - torch.diag(generalized_box_iou(src_boxes, target_boxes)) + losses['loss_giou'] = loss_giou.sum() / num_boxes + + image_size = torch.cat([v["image_size_xyxy_tgt"] for v in targets]) + src_boxes_ = src_boxes / image_size + target_boxes_ = target_boxes / image_size + + loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='none') + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'boxes': self.loss_boxes, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, outputs, targets, *argrs, **kwargs): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if dist.is_available() and dist.is_initialized(): + torch.distributed.all_reduce(num_boxes) + word_size = dist.get_world_size() + else: + word_size = 1 + num_boxes = torch.clamp(num_boxes / word_size, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + diff --git a/maskrcnn_benchmark/layers/sigmoid_focal_loss.py b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2de1bb4c3a003ec705f721e7f1db04baf2e8b268 --- /dev/null +++ b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py @@ -0,0 +1,197 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from maskrcnn_benchmark import _C + + +# TODO: Use JIT to replace CUDA implementation in the future. +class _SigmoidFocalLoss(Function): + @staticmethod + def forward(ctx, logits, targets, gamma, alpha): + ctx.save_for_backward(logits, targets) + num_classes = logits.shape[1] + ctx.num_classes = num_classes + ctx.gamma = gamma + ctx.alpha = alpha + + losses = _C.sigmoid_focalloss_forward( + logits, targets, num_classes, gamma, alpha + ) + return losses + + @staticmethod + @once_differentiable + def backward(ctx, d_loss): + logits, targets = ctx.saved_tensors + num_classes = ctx.num_classes + gamma = ctx.gamma + alpha = ctx.alpha + d_loss = d_loss.contiguous() + d_logits = _C.sigmoid_focalloss_backward( + logits, targets, d_loss, num_classes, gamma, alpha + ) + return d_logits, None, None, None, None + + +sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply + + +def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): + num_classes = logits.shape[1] + dtype = targets.dtype + device = targets.device + class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0) + + t = targets.unsqueeze(1) + p = torch.sigmoid(logits) + term1 = (1 - p) ** gamma * torch.log(p) + term2 = p ** gamma * torch.log(1 - p) + return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha) + + +class SigmoidFocalLoss(nn.Module): + def __init__(self, gamma, alpha): + super(SigmoidFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + + def forward(self, logits, targets): + if logits.is_cuda: + loss_func = sigmoid_focal_loss_cuda + else: + loss_func = sigmoid_focal_loss_cpu + + loss = loss_func(logits, targets, self.gamma, self.alpha) + return loss.sum() + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "gamma=" + str(self.gamma) + tmpstr += ", alpha=" + str(self.alpha) + tmpstr += ")" + return tmpstr + + +def token_sigmoid_softmax_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None): + # Another modification is that because we use the cross entropy version, there is no frequent or not frequent class. + # So we temporarily retired the design of alpha. + + assert (targets.dim() == 3) + assert (pred_logits.dim() == 3) # batch x from x to + + # reprocess target to become probability map ready for softmax + targets = targets.float() + target_num = targets.sum(-1) + 1e-8 # numerical stability + targets = targets / target_num.unsqueeze(-1) # T(x) + + if text_mask is not None: + # reserve the last token for non object + assert (text_mask.dim() == 2) + text_mask[:, -1] = 1 + text_mask = (text_mask > 0).unsqueeze(1).repeat(1, pred_logits.size(1), 1) # copy along the image channel + pred_logits = pred_logits.masked_fill(~text_mask, -1000000) # softmax + + out_prob = pred_logits.softmax(-1) + + filled_targets = targets.clone() + filled_targets[filled_targets == 0] = 1.0 + + weight = torch.clamp(targets - out_prob, min=0.001) / filled_targets + weight = torch.pow(weight, gamma) # weight = torch.pow(torch.clamp(target - out_prob, min=0.01), gamma) + + loss_ce = - targets * weight * pred_logits.log_softmax( + -1) # only those positives with positive target_sim will have losses. + return loss_ce + + +def token_sigmoid_binary_focal_loss_v2(pred_logits, targets, alpha, gamma, text_mask=None): + assert (targets.dim() == 3) + assert (pred_logits.dim() == 3) # batch x from x to + + if text_mask is not None: + assert (text_mask.dim() == 2) + + # We convert everything into binary + out_prob = pred_logits.sigmoid() + out_prob_neg_pos = torch.stack([1 - out_prob, out_prob], dim=-1) + 1e-8 # batch x boxes x 256 x 2 + weight = torch.pow(-out_prob_neg_pos + 1.0, gamma) + + focal_zero = - weight[:, :, :, 0] * torch.log(out_prob_neg_pos[:, :, :, 0]) * ( + 1 - alpha) # negative class + focal_one = - weight[:, :, :, 1] * torch.log(out_prob_neg_pos[:, :, :, 1]) * alpha # positive class + focal = torch.stack([focal_zero, focal_one], dim=-1) + loss_ce = torch.gather(focal, index=targets.long().unsqueeze(-1), dim=-1) + return loss_ce + + +def token_sigmoid_binary_focal_loss(pred_logits, targets, alpha, gamma, text_mask=None): + # binary version of focal loss + # copied from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor with the reduction option applied. + """ + assert (targets.dim() == 3) + assert (pred_logits.dim() == 3) # batch x from x to + + bs, n, _ = pred_logits.shape + if text_mask is not None: + assert (text_mask.dim() == 2) + text_mask = (text_mask > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, pred_logits.size(1), 1) # copy along the image channel dimension + pred_logits = torch.masked_select(pred_logits, text_mask) + targets = torch.masked_select(targets, text_mask) + + # print(pred_logits.shape) + # print(targets.shape) + + p = torch.sigmoid(pred_logits) + ce_loss = F.binary_cross_entropy_with_logits(pred_logits, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss + + +class TokenSigmoidFocalLoss(nn.Module): + def __init__(self, alpha, gamma): + super(TokenSigmoidFocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, logits, targets, text_masks=None, version="binary", **kwargs): + if version == "binary": + loss_func = token_sigmoid_binary_focal_loss + elif version == "softmax": + loss_func = token_sigmoid_softmax_focal_loss + elif version == "binaryv2": + loss_func = token_sigmoid_binary_focal_loss_v2 + else: + raise NotImplementedError + loss = loss_func(logits, targets, self.alpha, self.gamma, text_masks, **kwargs) + return loss.sum() + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "gamma=" + str(self.gamma) + tmpstr += ", alpha=" + str(self.alpha) + tmpstr += ")" + return tmpstr diff --git a/maskrcnn_benchmark/layers/smooth_l1_loss.py b/maskrcnn_benchmark/layers/smooth_l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f2866f6c15f4f301f18181179b5fb835d4d6b7e8 --- /dev/null +++ b/maskrcnn_benchmark/layers/smooth_l1_loss.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +# TODO maybe push this to nn? +def smooth_l1_loss(input, target, beta=1. / 9, size_average=True): + """ + very similar to the smooth_l1_loss from pytorch, but with + the extra beta parameter + """ + n = torch.abs(input - target) + cond = n < beta + loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) + if size_average: + return loss.mean() + return loss.sum() diff --git a/maskrcnn_benchmark/modeling/.DS_Store b/maskrcnn_benchmark/modeling/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..57ad02856e1a722be1c2932bec7fda4b93bc20b9 Binary files /dev/null and b/maskrcnn_benchmark/modeling/.DS_Store differ diff --git a/maskrcnn_benchmark/modeling/__init__.py b/maskrcnn_benchmark/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/backbone/__init__.py b/maskrcnn_benchmark/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8583983b31e2a7084858ae3d8bb1bb881978e0f --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/__init__.py @@ -0,0 +1,239 @@ +from collections import OrderedDict + +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform +from maskrcnn_benchmark.layers import DropBlock2D, DyHead +from . import fpn as fpn_module +from . import bifpn +from . import resnet +from . import efficientnet +from . import efficientdet +from . import swint +from . import swint_v2 +from . import swint_vl +from . import swint_v2_vl + + +@registry.BACKBONES.register("R-50-C4") +@registry.BACKBONES.register("R-50-C5") +@registry.BACKBONES.register("R-101-C4") +@registry.BACKBONES.register("R-101-C5") +def build_resnet_backbone(cfg): + body = resnet.ResNet(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +@registry.BACKBONES.register("R-50-RETINANET") +@registry.BACKBONES.register("R-101-RETINANET") +def build_resnet_c5_backbone(cfg): + body = resnet.ResNet(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +@registry.BACKBONES.register("SWINT-FPN-RETINANET") +def build_retinanet_swint_fpn_backbone(cfg): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + if cfg.MODEL.SWINT.VERSION == "v1": + body = swint.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "v2": + body = swint_v2.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "vl": + body = swint_vl.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "v2_vl": + body = swint_v2_vl.build_swint_backbone(cfg) + + in_channels_stages = cfg.MODEL.SWINT.OUT_CHANNELS + out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + in_channels_p6p7 = out_channels + fpn = fpn_module.FPN( + in_channels_list=[ + 0, + in_channels_stages[-3], + in_channels_stages[-2], + in_channels_stages[-1], + ], + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels), + drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None, + use_spp=cfg.MODEL.FPN.USE_SPP, + use_pan=cfg.MODEL.FPN.USE_PAN, + return_swint_feature_before_fusion=cfg.MODEL.FPN.RETURN_SWINT_FEATURE_BEFORE_FUSION + ) + if cfg.MODEL.FPN.USE_DYHEAD: + dyhead = DyHead(cfg, out_channels) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)])) + else: + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + return model + + +@registry.BACKBONES.register("SWINT-FPN") +def build_swint_fpn_backbone(cfg): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + if cfg.MODEL.SWINT.VERSION == "v1": + body = swint.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "v2": + body = swint_v2.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "vl": + body = swint_vl.build_swint_backbone(cfg) + elif cfg.MODEL.SWINT.VERSION == "v2_vl": + body = swint_v2_vl.build_swint_backbone(cfg) + + in_channels_stages = cfg.MODEL.SWINT.OUT_CHANNELS + out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + fpn = fpn_module.FPN( + in_channels_list=[ + in_channels_stages[-4], + in_channels_stages[-3], + in_channels_stages[-2], + in_channels_stages[-1], + ], + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelMaxPool(), + drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None, + use_spp=cfg.MODEL.FPN.USE_SPP, + use_pan=cfg.MODEL.FPN.USE_PAN + ) + if cfg.MODEL.FPN.USE_DYHEAD: + dyhead = DyHead(cfg, out_channels) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)])) + else: + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + return model + + +@registry.BACKBONES.register("CVT-FPN-RETINANET") +def build_retinanet_cvt_fpn_backbone(cfg): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + body = cvt.build_cvt_backbone(cfg) + in_channels_stages = cfg.MODEL.SPEC.DIM_EMBED + out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + in_channels_p6p7 = out_channels + fpn = fpn_module.FPN( + in_channels_list=[ + 0, + in_channels_stages[-3], + in_channels_stages[-2], + in_channels_stages[-1], + ], + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels), + drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None, + use_spp=cfg.MODEL.FPN.USE_SPP, + use_pan=cfg.MODEL.FPN.USE_PAN + ) + if cfg.MODEL.FPN.USE_DYHEAD: + dyhead = DyHead(cfg, out_channels) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn), ("dyhead", dyhead)])) + else: + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + return model + + +@registry.BACKBONES.register("EFFICIENT7-FPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT7-FPN-FCOS") +@registry.BACKBONES.register("EFFICIENT5-FPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT5-FPN-FCOS") +@registry.BACKBONES.register("EFFICIENT3-FPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT3-FPN-FCOS") +def build_eff_fpn_p6p7_backbone(cfg): + version = cfg.MODEL.BACKBONE.CONV_BODY.split('-')[0] + version = version.replace('EFFICIENT', 'b') + body = efficientnet.get_efficientnet(cfg, version) + in_channels_stage = body.out_channels + out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + in_channels_p6p7 = out_channels + in_channels_stage[0] = 0 + fpn = fpn_module.FPN( + in_channels_list=in_channels_stage, + out_channels=out_channels, + conv_block=conv_with_kaiming_uniform( + cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU + ), + top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels), + drop_block=DropBlock2D(cfg.MODEL.FPN.DROP_PROB, cfg.MODEL.FPN.DROP_SIZE) if cfg.MODEL.FPN.DROP_BLOCK else None, + use_spp=cfg.MODEL.FPN.USE_SPP, + use_pan=cfg.MODEL.FPN.USE_PAN + ) + model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) + return model + + +@registry.BACKBONES.register("EFFICIENT7-BIFPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT7-BIFPN-FCOS") +@registry.BACKBONES.register("EFFICIENT5-BIFPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT5-BIFPN-FCOS") +@registry.BACKBONES.register("EFFICIENT3-BIFPN-RETINANET") +@registry.BACKBONES.register("EFFICIENT3-BIFPN-FCOS") +def build_eff_fpn_p6p7_backbone(cfg): + version = cfg.MODEL.BACKBONE.CONV_BODY.split('-')[0] + version = version.replace('EFFICIENT', 'b') + body = efficientnet.get_efficientnet(cfg, version) + in_channels_stage = body.out_channels + out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + bifpns = nn.ModuleList() + for i in range(cfg.MODEL.BIFPN.NUM_REPEATS): + first_time = (i==0) + fpn = bifpn.BiFPN( + in_channels_list=in_channels_stage[1:], + out_channels=out_channels, + first_time=first_time, + attention=cfg.MODEL.BIFPN.USE_ATTENTION + ) + bifpns.append(fpn) + model = nn.Sequential(OrderedDict([("body", body), ("bifpn", bifpns)])) + return model + + +@registry.BACKBONES.register("EFFICIENT-DET") +def build_efficientdet_backbone(cfg): + efficientdet.g_simple_padding = True + compound = cfg.MODEL.BACKBONE.EFFICIENT_DET_COMPOUND + start_from = cfg.MODEL.BACKBONE.EFFICIENT_DET_START_FROM + model = efficientdet.EffNetFPN( + compound_coef=compound, + start_from=start_from, + ) + if cfg.MODEL.BACKBONE.USE_SYNCBN: + import torch + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + return model + + +def build_backbone(cfg): + assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \ + "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format( + cfg.MODEL.BACKBONE.CONV_BODY + ) + return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg) diff --git a/maskrcnn_benchmark/modeling/backbone/bifpn.py b/maskrcnn_benchmark/modeling/backbone/bifpn.py new file mode 100644 index 0000000000000000000000000000000000000000..8689c1c32e61f7984559eb78e7f3e7828b3c2abc --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/bifpn.py @@ -0,0 +1,273 @@ +import torch.nn as nn +import torch + +from maskrcnn_benchmark.layers import swish + + +class BiFPN(nn.Module): + def __init__(self, in_channels_list, out_channels, first_time=False, epsilon=1e-4, attention=True): + super(BiFPN, self).__init__() + self.epsilon = epsilon + # Conv layers + self.conv6_up = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv5_up = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv4_up = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv3_up = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv4_down = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv5_down = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv6_down = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.conv7_down = nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), + nn.Conv2d(out_channels, out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + + # Feature scaling layers + self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') + + self.p4_downsample = nn.MaxPool2d(3, 2) + self.p5_downsample = nn.MaxPool2d(3, 2) + self.p6_downsample = nn.MaxPool2d(3, 2) + self.p7_downsample = nn.MaxPool2d(3, 2) + + self.swish = swish() + + self.first_time = first_time + if self.first_time: + self.p5_down_channel = nn.Sequential( + nn.Conv2d(in_channels_list[2], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.p4_down_channel = nn.Sequential( + nn.Conv2d(in_channels_list[1], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.p3_down_channel = nn.Sequential( + nn.Conv2d(in_channels_list[0], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + + self.p5_to_p6 = nn.Sequential( + nn.Conv2d(in_channels_list[2], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + nn.MaxPool2d(3, 2) + ) + self.p6_to_p7 = nn.Sequential( + nn.MaxPool2d(3, 2) + ) + + self.p4_down_channel_2 = nn.Sequential( + nn.Conv2d(in_channels_list[1], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + self.p5_down_channel_2 = nn.Sequential( + nn.Conv2d(in_channels_list[2], out_channels, 1), + nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), + ) + + # Weight + self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p6_w1_relu = nn.ReLU() + self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p5_w1_relu = nn.ReLU() + self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p4_w1_relu = nn.ReLU() + self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p3_w1_relu = nn.ReLU() + + self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p4_w2_relu = nn.ReLU() + self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p5_w2_relu = nn.ReLU() + self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p6_w2_relu = nn.ReLU() + self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p7_w2_relu = nn.ReLU() + + self.attention = attention + + def forward(self, inputs): + """ + illustration of a minimal bifpn unit + P7_0 -------------------------> P7_2 --------> + |-------------| ↑ + ↓ | + P6_0 ---------> P6_1 ---------> P6_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P5_0 ---------> P5_1 ---------> P5_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P4_0 ---------> P4_1 ---------> P4_2 --------> + |-------------|--------------↑ ↑ + |--------------↓ | + P3_0 -------------------------> P3_2 --------> + """ + + # downsample channels using same-padding conv2d to target phase's if not the same + # judge: same phase as target, + # if same, pass; + # elif earlier phase, downsample to target phase's by pooling + # elif later phase, upsample to target phase's by nearest interpolation + + if self.attention: + p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs) + else: + p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + def _forward_fast_attention(self, inputs): + if self.first_time: + p3, p4, p5 = inputs[-3:] + + p6_in = self.p5_to_p6(p5) + p7_in = self.p6_to_p7(p6_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + # P7_0 to P7_2 + + # Weights for P6_0 and P7_0 to P6_1 + p6_w1 = self.p6_w1_relu(self.p6_w1) + weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) + + # Weights for P5_0 and P6_1 to P5_1 + p5_w1 = self.p5_w1_relu(self.p5_w1) + weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) + # Connections for P5_0 and P6_1 to P5_1 respectively + p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) + + # Weights for P4_0 and P5_1 to P4_1 + p4_w1 = self.p4_w1_relu(self.p4_w1) + weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) + # Connections for P4_0 and P5_1 to P4_1 respectively + p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) + + # Weights for P3_0 and P4_1 to P3_2 + p3_w1 = self.p3_w1_relu(self.p3_w1) + weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Weights for P4_0, P4_1 and P3_2 to P4_2 + p4_w2 = self.p4_w2_relu(self.p4_w2) + weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down( + self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))) + + # Weights for P5_0, P5_1 and P4_2 to P5_2 + p5_w2 = self.p5_w2_relu(self.p5_w2) + weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down( + self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))) + + # Weights for P6_0, P6_1 and P5_2 to P6_2 + p6_w2 = self.p6_w2_relu(self.p6_w2) + weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down( + self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))) + + # Weights for P7_0 and P6_2 to P7_2 + p7_w2 = self.p7_w2_relu(self.p7_w2) + weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + def _forward(self, inputs): + if self.first_time: + p3, p4, p5 = inputs + + p6_in = self.p5_to_p6(p5) + p7_in = self.p6_to_p7(p6_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + # P7_0 to P7_2 + + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in))) + + # Connections for P5_0 and P6_1 to P5_1 respectively + p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up))) + + # Connections for P4_0 and P5_1 to P4_1 respectively + p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up))) + + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down( + self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) + + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down( + self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) + + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down( + self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) + + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/blocks.py b/maskrcnn_benchmark/modeling/backbone/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..eab3b74a2e129abe07fb5d30776db77c46a648dd --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/blocks.py @@ -0,0 +1,266 @@ +import torch.nn as nn +from .ops import * + + +class stem(nn.Module): + num_layer = 1 + + def __init__(self, conv, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d): + super(stem, self).__init__() + + self.conv1 = conv(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + return out + + +class basic(nn.Module): + expansion = 1 + num_layer = 2 + + def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): + super(basic, self).__init__() + midplanes = planes if midplanes is None else midplanes + self.conv1 = conv(inplanes, midplanes, stride) + self.bn1 = norm_layer(midplanes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv(midplanes, planes) + self.bn2 = norm_layer(planes) + if stride!=1 or inplanes!=planes*self.expansion: + self.downsample = nn.Sequential( + conv1x1(inplanes, planes, stride), + norm_layer(planes), + ) + else: + self.downsample = None + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class bottleneck(nn.Module): + expansion = 4 + num_layer = 3 + + def __init__(self, conv, inplanes, planes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): + super(bottleneck, self).__init__() + midplanes = planes if midplanes is None else midplanes + self.conv1 = conv1x1(inplanes, midplanes) + self.bn1 = norm_layer(midplanes) + self.conv2 = conv(midplanes, midplanes, stride) + self.bn2 = norm_layer(midplanes) + self.conv3 = conv1x1(midplanes, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + if stride!=1 or inplanes!=planes*self.expansion: + self.downsample = nn.Sequential( + conv1x1(inplanes, planes*self.expansion, stride), + norm_layer(planes*self.expansion), + ) + else: + self.downsample = None + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class invert(nn.Module): + def __init__(self, conv, inp, oup, stride=1, expand_ratio=1, norm_layer=nn.BatchNorm2d): + super(invert, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + conv(hidden_dim, hidden_dim, stride), + norm_layer(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + norm_layer(hidden_dim), + nn.ReLU6(inplace=True), + # dw + conv(hidden_dim, hidden_dim, stride), + norm_layer(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +invert2 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=2, **kwargs) +invert3 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=3, **kwargs) +invert4 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=4, **kwargs) +invert6 = lambda op, inp, outp, stride, **kwargs: invert(op, inp, outp, stride, expand_ratio=6, **kwargs) + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + # flatten + x = x.view(batchsize, -1, height, width) + return x + + +class shuffle(nn.Module): + expansion = 1 + num_layer = 3 + + def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): + super(shuffle, self).__init__() + inplanes = inplanes // 2 if stride == 1 else inplanes + midplanes = outplanes // 2 if midplanes is None else midplanes + rightoutplanes = outplanes - inplanes + if stride == 2: + self.left_branch = nn.Sequential( + # dw + conv(inplanes, inplanes, stride), + norm_layer(inplanes), + # pw-linear + conv1x1(inplanes, inplanes), + norm_layer(inplanes), + nn.ReLU(inplace=True), + ) + + self.right_branch = nn.Sequential( + # pw + conv1x1(inplanes, midplanes), + norm_layer(midplanes), + nn.ReLU(inplace=True), + # dw + conv(midplanes, midplanes, stride), + norm_layer(midplanes), + # pw-linear + conv1x1(midplanes, rightoutplanes), + norm_layer(rightoutplanes), + nn.ReLU(inplace=True), + ) + + self.reduce = stride==2 + + def forward(self, x): + if self.reduce: + out = torch.cat((self.left_branch(x), self.right_branch(x)), 1) + else: + x1 = x[:, :(x.shape[1]//2), :, :] + x2 = x[:, (x.shape[1]//2):, :, :] + out = torch.cat((x1, self.right_branch(x2)), 1) + + return channel_shuffle(out, 2) + + +class shufflex(nn.Module): + expansion = 1 + num_layer = 3 + + def __init__(self, conv, inplanes, outplanes, stride=1, midplanes=None, norm_layer=nn.BatchNorm2d): + super(shufflex, self).__init__() + inplanes = inplanes // 2 if stride == 1 else inplanes + midplanes = outplanes // 2 if midplanes is None else midplanes + rightoutplanes = outplanes - inplanes + if stride==2: + self.left_branch = nn.Sequential( + # dw + conv(inplanes, inplanes, stride), + norm_layer(inplanes), + # pw-linear + conv1x1(inplanes, inplanes), + norm_layer(inplanes), + nn.ReLU(inplace=True), + ) + + self.right_branch = nn.Sequential( + # dw + conv(inplanes, inplanes, stride), + norm_layer(inplanes), + # pw-linear + conv1x1(inplanes, midplanes), + norm_layer(midplanes), + nn.ReLU(inplace=True), + # dw + conv(midplanes, midplanes, 1), + norm_layer(midplanes), + # pw-linear + conv1x1(midplanes, midplanes), + norm_layer(midplanes), + nn.ReLU(inplace=True), + # dw + conv(midplanes, midplanes, 1), + norm_layer(midplanes), + # pw-linear + conv1x1(midplanes, rightoutplanes), + norm_layer(rightoutplanes), + nn.ReLU(inplace=True), + ) + + self.reduce = stride==2 + + def forward(self, x): + if self.reduce: + out = torch.cat((self.left_branch(x), self.right_branch(x)), 1) + else: + x1 = x[:, :(x.shape[1] // 2), :, :] + x2 = x[:, (x.shape[1] // 2):, :, :] + out = torch.cat((x1, self.right_branch(x2)), 1) + + return channel_shuffle(out, 2) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/efficientdet.py b/maskrcnn_benchmark/modeling/backbone/efficientdet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5666815cd2e94c954929bd38786e89b4c19d89 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/efficientdet.py @@ -0,0 +1,1882 @@ +import torch +import re +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import logging +import cv2 +import math +import itertools +import collections +from torchvision.ops import nms + + +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + +# https://stackoverflow.com/a/18348004 +# Change namedtuple defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + +# in the old version, g_simple_padding = False, which tries to align +# tensorflow's implementation, which is not required here. +g_simple_padding = True +class MaxPool2dStaticSamePadding(nn.Module): + """ + created by Zylo117 + The real keras/tensorflow MaxPool2d with same padding + """ + + def __init__(self, kernel_size, stride): + super().__init__() + if g_simple_padding: + self.pool = nn.MaxPool2d(kernel_size, stride, + padding=(kernel_size-1)//2) + else: + assert ValueError() + self.pool = nn.MaxPool2d(kernel_size, stride) + self.stride = self.pool.stride + self.kernel_size = self.pool.kernel_size + + if isinstance(self.stride, int): + self.stride = [self.stride] * 2 + elif len(self.stride) == 1: + self.stride = [self.stride[0]] * 2 + + if isinstance(self.kernel_size, int): + self.kernel_size = [self.kernel_size] * 2 + elif len(self.kernel_size) == 1: + self.kernel_size = [self.kernel_size[0]] * 2 + + def forward(self, x): + if g_simple_padding: + return self.pool(x) + else: + assert ValueError() + h, w = x.shape[-2:] + + h_step = math.ceil(w / self.stride[1]) + v_step = math.ceil(h / self.stride[0]) + h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) + v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) + + extra_h = h_cover_len - w + extra_v = v_cover_len - h + + left = extra_h // 2 + right = extra_h - left + top = extra_v // 2 + bottom = extra_v - top + + x = F.pad(x, [left, right, top, bottom]) + + x = self.pool(x) + return x + +class Conv2dStaticSamePadding(nn.Module): + """ + created by Zylo117 + The real keras/tensorflow conv2d with same padding + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): + super().__init__() + if g_simple_padding: + assert kernel_size % 2 == 1 + assert dilation == 1 + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, + bias=bias, + groups=groups, + padding=(kernel_size - 1) // 2) + self.stride = self.conv.stride + if isinstance(self.stride, int): + self.stride = [self.stride] * 2 + elif len(self.stride) == 1: + self.stride = [self.stride[0]] * 2 + else: + self.stride = list(self.stride) + else: + assert ValueError() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, + bias=bias, groups=groups) + self.stride = self.conv.stride + self.kernel_size = self.conv.kernel_size + self.dilation = self.conv.dilation + + if isinstance(self.stride, int): + self.stride = [self.stride] * 2 + elif len(self.stride) == 1: + self.stride = [self.stride[0]] * 2 + + if isinstance(self.kernel_size, int): + self.kernel_size = [self.kernel_size] * 2 + elif len(self.kernel_size) == 1: + self.kernel_size = [self.kernel_size[0]] * 2 + + def forward(self, x): + if g_simple_padding: + return self.conv(x) + else: + assert ValueError() + h, w = x.shape[-2:] + + h_step = math.ceil(w / self.stride[1]) + v_step = math.ceil(h / self.stride[0]) + h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) + v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) + + extra_h = h_cover_len - w + extra_v = v_cover_len - h + + left = extra_h // 2 + right = extra_h - left + top = extra_v // 2 + bottom = extra_v - top + + x = F.pad(x, [left, right, top, bottom]) + + x = self.conv(x) + return x + +class SeparableConvBlock(nn.Module): + """ + created by Zylo117 + """ + + def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False): + super(SeparableConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + + # Q: whether separate conv + # share bias between depthwise_conv and pointwise_conv + # or just pointwise_conv apply bias. + # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias. + + self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels, + kernel_size=3, stride=1, groups=in_channels, bias=False) + self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1) + + self.norm = norm + if self.norm: + # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow + self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3) + + self.activation = activation + if self.activation: + self.swish = MemoryEfficientSwish() if not onnx_export else Swish() + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + + if self.norm: + x = self.bn(x) + + if self.activation: + x = self.swish(x) + + return x + + +class BiFPN(nn.Module): + """ + modified by Zylo117 + """ + + def __init__(self, num_channels, conv_channels, first_time=False, + epsilon=1e-4, onnx_export=False, attention=True, + adaptive_up=False): + """ + + Args: + num_channels: + conv_channels: + first_time: whether the input comes directly from the efficientnet, + if True, downchannel it first, and downsample P5 to generate P6 then P7 + epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon + onnx_export: if True, use Swish instead of MemoryEfficientSwish + """ + super(BiFPN, self).__init__() + self.epsilon = epsilon + # Conv layers + self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) + self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) + + # Feature scaling layers + self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') + self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') + + self.adaptive_up = adaptive_up + + self.p4_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p5_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p6_downsample = MaxPool2dStaticSamePadding(3, 2) + self.p7_downsample = MaxPool2dStaticSamePadding(3, 2) + + self.swish = MemoryEfficientSwish() if not onnx_export else Swish() + + self.first_time = first_time + if self.first_time: + self.p5_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + self.p4_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + self.p3_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[0], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + + if len(conv_channels) == 3: + self.p5_to_p6 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + MaxPool2dStaticSamePadding(3, 2) + ) + else: + assert len(conv_channels) == 4 + self.p6_down_channel = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[3], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + + self.p6_to_p7 = nn.Sequential( + MaxPool2dStaticSamePadding(3, 2) + ) + + self.p4_down_channel_2 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + self.p5_down_channel_2 = nn.Sequential( + Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), + nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), + ) + + # Weight + self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p6_w1_relu = nn.ReLU() + self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p5_w1_relu = nn.ReLU() + self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p4_w1_relu = nn.ReLU() + self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p3_w1_relu = nn.ReLU() + + self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p4_w2_relu = nn.ReLU() + self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p5_w2_relu = nn.ReLU() + self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) + self.p6_w2_relu = nn.ReLU() + self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) + self.p7_w2_relu = nn.ReLU() + + self.attention = attention + + def forward(self, inputs): + """ + illustration of a minimal bifpn unit + P7_0 -------------------------> P7_2 --------> + |-------------| ↑ + ↓ | + P6_0 ---------> P6_1 ---------> P6_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P5_0 ---------> P5_1 ---------> P5_2 --------> + |-------------|--------------↑ ↑ + ↓ | + P4_0 ---------> P4_1 ---------> P4_2 --------> + |-------------|--------------↑ ↑ + |--------------↓ | + P3_0 -------------------------> P3_2 --------> + """ + + # downsample channels using same-padding conv2d to target phase's if not the same + # judge: same phase as target, + # if same, pass; + # elif earlier phase, downsample to target phase's by pooling + # elif later phase, upsample to target phase's by nearest interpolation + if self.attention: + p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs) + else: + p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + def _forward_fast_attention(self, inputs): + if self.first_time: + if len(inputs) == 3: + p3, p4, p5 = inputs + p6_in = self.p5_to_p6(p5) + else: + p3, p4, p5, p6 = inputs + p6_in = self.p6_down_channel(p6) + + p7_in = self.p6_to_p7(p6_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + # P7_0 to P7_2 + + if not self.adaptive_up: + # Weights for P6_0 and P7_0 to P6_1 + p6_w1 = self.p6_w1_relu(self.p6_w1) + weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) + + # Weights for P5_0 and P6_0 to P5_1 + p5_w1 = self.p5_w1_relu(self.p5_w1) + weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) + # Connections for P5_0 and P6_0 to P5_1 respectively + p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) + + # Weights for P4_0 and P5_0 to P4_1 + p4_w1 = self.p4_w1_relu(self.p4_w1) + weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) + # Connections for P4_0 and P5_0 to P4_1 respectively + p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) + + # Weights for P3_0 and P4_1 to P3_2 + p3_w1 = self.p3_w1_relu(self.p3_w1) + weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) + else: + # Weights for P6_0 and P7_0 to P6_1 + p6_w1 = self.p6_w1_relu(self.p6_w1) + weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_upsample = nn.Upsample(size=p6_in.shape[-2:]) + p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * p6_upsample(p7_in))) + + # Weights for P5_0 and P6_0 to P5_1 + p5_w1 = self.p5_w1_relu(self.p5_w1) + weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) + # Connections for P5_0 and P6_0 to P5_1 respectively + p5_upsample = nn.Upsample(size=p5_in.shape[-2:]) + p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * p5_upsample(p6_up))) + + # Weights for P4_0 and P5_0 to P4_1 + p4_w1 = self.p4_w1_relu(self.p4_w1) + weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) + # Connections for P4_0 and P5_0 to P4_1 respectively + p4_upsample = nn.Upsample(size=p4_in.shape[-2:]) + p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * p4_upsample(p5_up))) + + # Weights for P3_0 and P4_1 to P3_2 + p3_w1 = self.p3_w1_relu(self.p3_w1) + weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) + p3_upsample = nn.Upsample(size=p3_in.shape[-2:]) + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Weights for P4_0, P4_1 and P3_2 to P4_2 + p4_w2 = self.p4_w2_relu(self.p4_w2) + weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down( + self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))) + + # Weights for P5_0, P5_1 and P4_2 to P5_2 + p5_w2 = self.p5_w2_relu(self.p5_w2) + weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down( + self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))) + + # Weights for P6_0, P6_1 and P5_2 to P6_2 + p6_w2 = self.p6_w2_relu(self.p6_w2) + weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down( + self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))) + + # Weights for P7_0 and P6_2 to P7_2 + p7_w2 = self.p7_w2_relu(self.p7_w2) + weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + def _forward(self, inputs): + if self.first_time: + p3, p4, p5 = inputs + + p6_in = self.p5_to_p6(p5) + p7_in = self.p6_to_p7(p6_in) + + p3_in = self.p3_down_channel(p3) + p4_in = self.p4_down_channel(p4) + p5_in = self.p5_down_channel(p5) + + else: + # P3_0, P4_0, P5_0, P6_0 and P7_0 + p3_in, p4_in, p5_in, p6_in, p7_in = inputs + + # P7_0 to P7_2 + + # Connections for P6_0 and P7_0 to P6_1 respectively + p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in))) + + # Connections for P5_0 and P6_0 to P5_1 respectively + p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up))) + + # Connections for P4_0 and P5_0 to P4_1 respectively + p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up))) + + # Connections for P3_0 and P4_1 to P3_2 respectively + p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up))) + + if self.first_time: + p4_in = self.p4_down_channel_2(p4) + p5_in = self.p5_down_channel_2(p5) + + # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively + p4_out = self.conv4_down( + self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) + + # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively + p5_out = self.conv5_down( + self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) + + # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively + p6_out = self.conv6_down( + self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) + + # Connections for P7_0 and P6_2 to P7_2 + p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out))) + + return p3_out, p4_out, p5_out, p6_out, p7_out + + +class Regressor(nn.Module): + """ + modified by Zylo117 + """ + + def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False): + super(Regressor, self).__init__() + self.num_layers = num_layers + self.num_layers = num_layers + + self.conv_list = nn.ModuleList( + [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) + self.bn_list = nn.ModuleList( + [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in + range(5)]) + self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False) + self.swish = MemoryEfficientSwish() if not onnx_export else Swish() + + def forward(self, inputs): + feats = [] + for feat, bn_list in zip(inputs, self.bn_list): + for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): + feat = conv(feat) + feat = bn(feat) + feat = self.swish(feat) + feat = self.header(feat) + feat = feat.permute(0, 2, 3, 1) + feat = feat.contiguous().view(feat.shape[0], -1, 4) + + feats.append(feat) + + feats = torch.cat(feats, dim=1) + + return feats + +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + if torch._C._get_tracing_state(): + return x * torch.sigmoid(x) + return SwishImplementation.apply(x) + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + +class Classifier(nn.Module): + """ + modified by Zylo117 + """ + + def __init__(self, in_channels, num_anchors, num_classes, num_layers, + onnx_export=False, prior_prob=0.01): + super(Classifier, self).__init__() + self.num_anchors = num_anchors + self.num_classes = num_classes + self.num_layers = num_layers + self.conv_list = nn.ModuleList( + [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) + self.bn_list = nn.ModuleList( + [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in + range(5)]) + self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False) + + prior_prob = prior_prob + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.normal_(self.header.pointwise_conv.conv.weight, std=0.01) + torch.nn.init.constant_(self.header.pointwise_conv.conv.bias, bias_value) + + self.swish = MemoryEfficientSwish() if not onnx_export else Swish() + + def forward(self, inputs): + feats = [] + for feat, bn_list in zip(inputs, self.bn_list): + for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): + feat = conv(feat) + feat = bn(feat) + feat = self.swish(feat) + feat = self.header(feat) + + feat = feat.permute(0, 2, 3, 1) + feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors, + self.num_classes) + feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes) + + feats.append(feat) + + feats = torch.cat(feats, dim=1) + #feats = feats.sigmoid() + + return feats + +class Conv2dDynamicSamePadding(nn.Conv2d): + """ 2D Convolutions like TensorFlow, for a dynamic image size """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + raise ValueError('tend to be deprecated') + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +#TODO: it seems like the standard conv layer is good enough with proper padding +# parameters. +def get_same_padding_conv2d(image_size=None): + """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. """ + if image_size is None: + raise ValueError('not validated') + return Conv2dDynamicSamePadding + else: + from functools import partial + return partial(Conv2dStaticSamePadding, image_size=image_size) + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + +def drop_connect(inputs, p, training): + """ Drop connect. """ + if not training: return inputs + batch_size = inputs.shape[0] + keep_prob = 1 - p + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + +class MBConvBlock(nn.Module): + """ + Mobile Inverted Residual Bottleneck Block + + Args: + block_args (namedtuple): BlockArgs, see above + global_params (namedtuple): GlobalParam, see above + + Attributes: + has_se (bool): Whether the block contains a Squeeze and Excitation layer. + """ + + def __init__(self, block_args, global_params): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # skip connection and drop connect + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Expansion phase + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + if isinstance(s, (tuple, list)) and all([s0 == s[0] for s0 in s]): + s = s[0] + self._depthwise_conv = Conv2d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Squeeze and Excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Output phase + final_oup = self._block_args.output_filters + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """ + :param inputs: input tensor + :param drop_connect_rate: drop connect rate (float, between 0 and 1) + :return: output of block + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + +def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, + drop_connect_rate=0.2, image_size=None, num_classes=1000): + """ Creates a efficientnet model. """ + + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + # data_format='channels_last', # removed, this is always true in PyTorch + num_classes=num_classes, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None, + image_size=image_size, + ) + + return blocks_args, global_params + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: %s' % model_name) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + +url_map = { + 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth', +} + +url_map_advprop = { + 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth', +} + +def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): + """ Loads pretrained weights, and downloads if loading for the first time. """ + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + from torch.utils import model_zoo + state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu')) + # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth') + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + print(ret) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + res = model.load_state_dict(state_dict, strict=False) + assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' + print('Loaded pretrained weights for {}'.format(model_name)) + +class EfficientNet(nn.Module): + """ + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods + + Args: + blocks_args (list): A list of BlockArgs to construct blocks + global_params (namedtuple): A set of GlobalParams shared between blocks + + Example: + model = EfficientNet.from_pretrained('efficientnet-b0') + + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ + bs = inputs.size(0) + # Convolution layers + x = self.extract_features(inputs) + + # Pooling and final linear layer + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, override_params=None): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + return cls(blocks_args, global_params) + + @classmethod + def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3): + model = cls.from_name(model_name, override_params={'num_classes': num_classes}) + if load_weights: + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) + out_channels = round_filters(32, model._global_params) + model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + return model + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) + +class EfficientNetD(nn.Module): + """ + modified by Zylo117 + """ + + def __init__(self, compound_coef, load_weights=False): + super().__init__() + model = EfficientNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights) + del model._conv_head + del model._bn1 + del model._avg_pooling + del model._dropout + del model._fc + self.model = model + + def forward(self, x): + x = self.model._conv_stem(x) + x = self.model._bn0(x) + x = self.model._swish(x) + feature_maps = [] + + # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea, + # try recording stride changing when creating efficientnet, + # and then apply it here. + last_x = None + for idx, block in enumerate(self.model._blocks): + drop_connect_rate = self.model._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self.model._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + if tuple(block._depthwise_conv.stride) == (2, 2): + feature_maps.append(last_x) + elif idx == len(self.model._blocks) - 1: + feature_maps.append(x) + last_x = x + del last_x + return feature_maps[1:] + +class Anchors(nn.Module): + """ + adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117 + """ + + def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs): + super().__init__() + from qd.qd_common import print_frame_info + print_frame_info() + self.anchor_scale = anchor_scale + + if pyramid_levels is None: + self.pyramid_levels = [3, 4, 5, 6, 7] + + self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels]) + self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) + self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) + + self.buffer = {} + + @torch.no_grad() + def forward(self, image, dtype=torch.float32, features=None): + """Generates multiscale anchor boxes. + + Args: + image_size: integer number of input image size. The input image has the + same dimension for width and height. The image_size should be divided by + the largest feature stride 2^max_level. + anchor_scale: float number representing the scale of size of the base + anchor to the feature stride 2^level. + anchor_configs: a dictionary with keys as the levels of anchors and + values as a list of anchor configuration. + + Returns: + anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all + feature levels. + Raises: + ValueError: input size must be the multiple of largest feature stride. + """ + image_shape = image.shape[2:] + anchor_key = self.get_key('anchor', image_shape) + stride_idx_key = self.get_key('anchor_stride_index', image_shape) + + if anchor_key in self.buffer: + return {'stride_idx': self.buffer[stride_idx_key].detach(), + 'anchor': self.buffer[anchor_key].detach()} + + if dtype == torch.float16: + dtype = np.float16 + else: + dtype = np.float32 + + boxes_all = [] + all_idx_strides = [] + for idx_stride, stride in enumerate(self.strides): + boxes_level = [] + for scale, ratio in itertools.product(self.scales, self.ratios): + if features is not None: + f_h, f_w = features[idx_stride].shape[-2:] + x = np.arange(stride / 2, stride * f_w, stride) + y = np.arange(stride / 2, stride * f_h, stride) + else: + if image_shape[1] % stride != 0: + x_max = stride * ((image_shape[1] + stride - 1) // stride) + y_max = stride * ((image_shape[0] + stride - 1) // stride) + else: + x_max = image_shape[1] + y_max = image_shape[0] + x = np.arange(stride / 2, x_max, stride) + y = np.arange(stride / 2, y_max, stride) + xv, yv = np.meshgrid(x, y) + xv = xv.reshape(-1) + yv = yv.reshape(-1) + + base_anchor_size = self.anchor_scale * stride * scale + anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0 + anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0 + # y1,x1,y2,x2 + boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2, + yv + anchor_size_y_2, xv + anchor_size_x_2)) + boxes = np.swapaxes(boxes, 0, 1) + boxes_level.append(np.expand_dims(boxes, axis=1)) + # concat anchors on the same level to the reshape NxAx4 + boxes_level = np.concatenate(boxes_level, axis=1) + boxes_level = boxes_level.reshape([-1, 4]) + idx_strides = torch.tensor([idx_stride] * len(boxes_level)) + all_idx_strides.append(idx_strides) + boxes_all.append(boxes_level) + + anchor_boxes = np.vstack(boxes_all) + anchor_stride_indices = torch.cat(all_idx_strides).to(image.device) + + self.buffer[stride_idx_key] = anchor_stride_indices + + anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device) + anchor_boxes = anchor_boxes.unsqueeze(0) + + # save it for later use to reduce overhead + self.buffer[anchor_key] = anchor_boxes + + return {'stride_idx': self.buffer[stride_idx_key], + 'anchor': self.buffer[anchor_key]} + + def get_key(self, hint, image_shape): + return '{}_{}'.format(hint, '_'.join(map(str, image_shape))) + +class EffNetFPN(nn.Module): + def __init__(self, compound_coef=0, start_from=3): + super().__init__() + + self.backbone_net = EfficientNetD(EfficientDetBackbone.backbone_compound_coef[compound_coef], + load_weights=False) + if start_from == 3: + conv_channel_coef = EfficientDetBackbone.conv_channel_coef[compound_coef] + else: + conv_channel_coef = EfficientDetBackbone.conv_channel_coef2345[compound_coef] + self.bifpn = nn.Sequential( + *[BiFPN(EfficientDetBackbone.fpn_num_filters[compound_coef], + conv_channel_coef, + True if _ == 0 else False, + attention=True if compound_coef < 6 else False, + adaptive_up=True) + for _ in range(EfficientDetBackbone.fpn_cell_repeats[compound_coef])]) + + self.out_channels = EfficientDetBackbone.fpn_num_filters[compound_coef] + + self.start_from = start_from + assert self.start_from in [2, 3] + + def forward(self, inputs): + if self.start_from == 3: + _, p3, p4, p5 = self.backbone_net(inputs) + + features = (p3, p4, p5) + features = self.bifpn(features) + return features + else: + p2, p3, p4, p5 = self.backbone_net(inputs) + features = (p2, p3, p4, p5) + features = self.bifpn(features) + return features + +class EfficientDetBackbone(nn.Module): + backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6] + fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384] + conv_channel_coef = { + # the channels of P3/P4/P5. + 0: [40, 112, 320], + 1: [40, 112, 320], + 2: [48, 120, 352], + 3: [48, 136, 384], + 4: [56, 160, 448], + 5: [64, 176, 512], + 6: [72, 200, 576], + 7: [72, 200, 576], + } + conv_channel_coef2345 = { + # the channels of P2/P3/P4/P5. + 0: [24, 40, 112, 320], + # to be determined for the following + 1: [24, 40, 112, 320], + 2: [24, 48, 120, 352], + 3: [32, 48, 136, 384], + 4: [32, 56, 160, 448], + 5: [40, 64, 176, 512], + 6: [72, 200], + 7: [72, 200], + } + fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8] + def __init__(self, num_classes=80, compound_coef=0, load_weights=False, + prior_prob=0.01, **kwargs): + super(EfficientDetBackbone, self).__init__() + self.compound_coef = compound_coef + + self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] + self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5] + self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.] + self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) + self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) + + num_anchors = len(self.aspect_ratios) * self.num_scales + + self.bifpn = nn.Sequential( + *[BiFPN(self.fpn_num_filters[self.compound_coef], + self.conv_channel_coef[compound_coef], + True if _ == 0 else False, + attention=True if compound_coef < 6 else False, + adaptive_up=kwargs.get('adaptive_up')) + for _ in range(self.fpn_cell_repeats[compound_coef])]) + + self.num_classes = num_classes + self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, + num_layers=self.box_class_repeats[self.compound_coef]) + self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, + num_classes=num_classes, + num_layers=self.box_class_repeats[self.compound_coef], + prior_prob=prior_prob) + anchor_scale = self.anchor_scale[compound_coef] + if kwargs.get('anchor_scale'): + anchor_scale = kwargs.pop('anchor_scale') + if 'anchor_scale' in kwargs: + del kwargs['anchor_scale'] + self.anchors = Anchors(anchor_scale=anchor_scale, **kwargs) + + self.backbone_net = EfficientNetD(self.backbone_compound_coef[compound_coef], load_weights) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def forward(self, inputs): + _, p3, p4, p5 = self.backbone_net(inputs) + + features = (p3, p4, p5) + features = self.bifpn(features) + + regression = self.regressor(features) + classification = self.classifier(features) + anchors = self.anchors(inputs, inputs.dtype, features=features) + + return features, regression, classification, anchors + + def init_backbone(self, path): + state_dict = torch.load(path) + try: + ret = self.load_state_dict(state_dict, strict=False) + print(ret) + except RuntimeError as e: + print('Ignoring ' + str(e) + '"') + +def init_weights(model): + for name, module in model.named_modules(): + is_conv_layer = isinstance(module, nn.Conv2d) + + if is_conv_layer: + nn.init.kaiming_uniform_(module.weight.data) + + if module.bias is not None: + module.bias.data.zero_() + +def calc_iou(a, b): + # a(anchor) [boxes, (y1, x1, y2, x2)] + # b(gt, coco-style) [boxes, (x1, y1, x2, y2)] + + area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) + iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0]) + ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1]) + iw = torch.clamp(iw, min=0) + ih = torch.clamp(ih, min=0) + ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih + ua = torch.clamp(ua, min=1e-8) + intersection = iw * ih + IoU = intersection / ua + + return IoU + +class BBoxTransform(nn.Module): + def forward(self, anchors, regression): + """ + decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py + + Args: + anchors: [batchsize, boxes, (y1, x1, y2, x2)] + regression: [batchsize, boxes, (dy, dx, dh, dw)] + + Returns: + + """ + y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2 + x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2 + ha = anchors[..., 2] - anchors[..., 0] + wa = anchors[..., 3] - anchors[..., 1] + + w = regression[..., 3].exp() * wa + h = regression[..., 2].exp() * ha + + y_centers = regression[..., 0] * ha + y_centers_a + x_centers = regression[..., 1] * wa + x_centers_a + + ymin = y_centers - h / 2. + xmin = x_centers - w / 2. + ymax = y_centers + h / 2. + xmax = x_centers + w / 2. + if len(anchors.shape) == 3: + return torch.stack([xmin, ymin, xmax, ymax], dim=2) + else: + return torch.stack([xmin, ymin, xmax, ymax], dim=1) + + +class ClipBoxes(nn.Module): + + def __init__(self): + super(ClipBoxes, self).__init__() + + def forward(self, boxes, img): + batch_size, num_channels, height, width = img.shape + + boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) + boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) + + boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1) + boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1) + + return boxes + +def postprocess2(x, anchors, regression, classification, + transformed_anchors, threshold, iou_threshold, max_box): + anchors = anchors['anchor'] + all_above_th = classification > threshold + out = [] + num_image = x.shape[0] + num_class = classification.shape[-1] + + #classification = classification.cpu() + #transformed_anchors = transformed_anchors.cpu() + #all_above_th = all_above_th.cpu() + max_box_pre_nms = 1000 + for i in range(num_image): + all_rois = [] + all_class_ids = [] + all_scores = [] + for c in range(num_class): + above_th = all_above_th[i, :, c].nonzero() + if len(above_th) == 0: + continue + above_prob = classification[i, above_th, c].squeeze(1) + if len(above_th) > max_box_pre_nms: + _, idx = above_prob.topk(max_box_pre_nms) + above_th = above_th[idx] + above_prob = above_prob[idx] + transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1) + from torchvision.ops import nms + nms_idx = nms(transformed_anchors_per, above_prob, iou_threshold=iou_threshold) + if len(nms_idx) > 0: + all_rois.append(transformed_anchors_per[nms_idx]) + ids = torch.tensor([c] * len(nms_idx)) + all_class_ids.append(ids) + all_scores.append(above_prob[nms_idx]) + + if len(all_rois) > 0: + rois = torch.cat(all_rois) + class_ids = torch.cat(all_class_ids) + scores = torch.cat(all_scores) + if len(scores) > max_box: + _, idx = torch.topk(scores, max_box) + rois = rois[idx, :] + class_ids = class_ids[idx] + scores = scores[idx] + out.append({ + 'rois': rois, + 'class_ids': class_ids, + 'scores': scores, + }) + else: + out.append({ + 'rois': [], + 'class_ids': [], + 'scores': [], + }) + + return out + +def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold): + anchors = anchors['anchor'] + transformed_anchors = regressBoxes(anchors, regression) + transformed_anchors = clipBoxes(transformed_anchors, x) + scores = torch.max(classification, dim=2, keepdim=True)[0] + scores_over_thresh = (scores > threshold)[:, :, 0] + out = [] + for i in range(x.shape[0]): + if scores_over_thresh.sum() == 0: + out.append({ + 'rois': [], + 'class_ids': [], + 'scores': [], + }) + continue + + classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0) + transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...] + scores_per = scores[i, scores_over_thresh[i, :], ...] + from torchvision.ops import nms + anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold) + + if anchors_nms_idx.shape[0] != 0: + scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0) + boxes_ = transformed_anchors_per[anchors_nms_idx, :] + + out.append({ + 'rois': boxes_, + 'class_ids': classes_, + 'scores': scores_, + }) + else: + out.append({ + 'rois': [], + 'class_ids': [], + 'scores': [], + }) + + return out + +def display(preds, imgs, obj_list, imshow=True, imwrite=False): + for i in range(len(imgs)): + if len(preds[i]['rois']) == 0: + continue + + for j in range(len(preds[i]['rois'])): + (x1, y1, x2, y2) = preds[i]['rois'][j].detach().cpu().numpy().astype(np.int) + logging.info((x1, y1, x2, y2)) + cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) + #obj = obj_list[preds[i]['class_ids'][j]] + #score = float(preds[i]['scores'][j]) + + #cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), + #(x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, + #(255, 255, 0), 1) + #break + if imshow: + cv2.imshow('image', imgs[i]) + cv2.waitKey(0) + +def calculate_focal_loss2(classification, target_list, alpha, gamma): + from maskrcnn_benchmark.layers.sigmoid_focal_loss import sigmoid_focal_loss_cuda + cls_loss = sigmoid_focal_loss_cuda(classification, target_list.int(), gamma, alpha) + return cls_loss + +def calculate_focal_loss(classification, targets, alpha, gamma): + classification = classification.sigmoid() + device = classification.device + alpha_factor = torch.ones_like(targets) * alpha + alpha_factor = alpha_factor.to(device) + + alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) + focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) + focal_weight = alpha_factor * torch.pow(focal_weight, gamma) + + bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) + + cls_loss = focal_weight * bce + + zeros = torch.zeros_like(cls_loss) + zeros = zeros.to(device) + cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) + return cls_loss.mean() + +def calculate_giou(pred, gt): + ax1, ay1, ax2, ay2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3] + bx1, by1, bx2, by2 = gt[:, 0], gt[:, 1], gt[:, 2], gt[:, 3] + a = (ax2 - ax1) * (ay2 - ay1) + b = (bx2 - bx1) * (by2 - by1) + max_x1, _ = torch.max(torch.stack([ax1, bx1], dim=1), dim=1) + max_y1, _ = torch.max(torch.stack([ay1, by1], dim=1), dim=1) + min_x2, _ = torch.min(torch.stack([ax2, bx2], dim=1), dim=1) + min_y2, _ = torch.min(torch.stack([ay2, by2], dim=1), dim=1) + inter = (min_x2 > max_x1) * (min_y2 > max_y1) + inter = inter * (min_x2 - max_x1) * (min_y2 - max_y1) + + min_x1, _ = torch.min(torch.stack([ax1, bx1], dim=1), dim=1) + min_y1, _ = torch.min(torch.stack([ay1, by1], dim=1), dim=1) + max_x2, _ = torch.max(torch.stack([ax2, bx2], dim=1), dim=1) + max_y2, _ = torch.max(torch.stack([ay2, by2], dim=1), dim=1) + cover = (max_x2 - min_x1) * (max_y2 - min_y1) + union = a + b - inter + iou = inter / (union + 1e-5) + giou = iou - (cover - union) / (cover + 1e-5) + return giou + +class FocalLoss(nn.Module): + def __init__(self, alpha=0.25, gamma=2., cls_loss_type='FL', smooth_bce_pos=0.99, + smooth_bce_neg=0.01, + reg_loss_type='L1', + at_least_1_assgin=False, + neg_iou_th=0.4, + pos_iou_th=0.5, + cls_weight=1., + reg_weight=1., + ): + super(FocalLoss, self).__init__() + from qd.qd_common import print_frame_info + print_frame_info() + self.iter = 0 + self.reg_loss_type = reg_loss_type + self.regressBoxes = BBoxTransform() + if cls_loss_type == 'FL': + from qd.layers.loss import FocalLossWithLogitsNegLoss + self.cls_loss = FocalLossWithLogitsNegLoss(alpha, gamma) + elif cls_loss_type == 'BCE': + from qd.qd_pytorch import BCEWithLogitsNegLoss + self.cls_loss = BCEWithLogitsNegLoss(reduction='sum') + elif cls_loss_type == 'SmoothBCE': + from qd.layers.loss import SmoothBCEWithLogitsNegLoss + self.cls_loss = SmoothBCEWithLogitsNegLoss( + pos=smooth_bce_pos, neg=smooth_bce_neg) + elif cls_loss_type == 'SmoothFL': + from qd.layers.loss import FocalSmoothBCEWithLogitsNegLoss + self.cls_loss = FocalSmoothBCEWithLogitsNegLoss( + alpha=alpha, gamma=2., + pos=smooth_bce_pos, neg=smooth_bce_neg) + else: + raise NotImplementedError(cls_loss_type) + self.at_least_1_assgin = at_least_1_assgin + + self.gt_total = 0 + self.gt_saved_by_at_least = 0 + + self.neg_iou_th = neg_iou_th + self.pos_iou_th = pos_iou_th + + self.cls_weight = cls_weight + self.reg_weight = reg_weight + + self.buf = {} + + def forward(self, classifications, regressions, anchor_info, annotations, **kwargs): + debug = (self.iter % 100) == 0 + self.iter += 1 + if debug: + from collections import defaultdict + debug_info = defaultdict(list) + + batch_size = classifications.shape[0] + classification_losses = [] + regression_losses = [] + anchors = anchor_info['anchor'] + anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is + dtype = anchors.dtype + + anchor_widths = anchor[:, 3] - anchor[:, 1] + anchor_heights = anchor[:, 2] - anchor[:, 0] + anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths + anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights + + #anchor_widths = anchor[:, 2] - anchor[:, 0] + #anchor_heights = anchor[:, 3] - anchor[:, 1] + #anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths + #anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights + device = classifications.device + + for j in range(batch_size): + + classification = classifications[j, :, :] + regression = regressions[j, :, :] + + bbox_annotation = annotations[j] + bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] + + #classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) + + if bbox_annotation.shape[0] == 0: + #cls_loss = calculate_focal_loss2(classification, + #torch.zeros(len(classification)), alpha, + #gamma) + #cls_loss = cls_loss.mean() + cls_loss = torch.tensor(0).to(dtype).to(device) + regression_losses.append(torch.tensor(0).to(dtype).to(device)) + classification_losses.append(cls_loss) + continue + + IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4]) + + IoU_max, IoU_argmax = torch.max(IoU, dim=1) + if self.at_least_1_assgin: + iou_max_gt, iou_argmax_gt = torch.max(IoU, dim=0) + curr_saved = (iou_max_gt < self.pos_iou_th).sum() + self.gt_saved_by_at_least += curr_saved + self.gt_total += len(iou_argmax_gt) + IoU_max[iou_argmax_gt] = 1. + IoU_argmax[iou_argmax_gt] = torch.arange(len(iou_argmax_gt)).to(device) + + # compute the loss for classification + targets = torch.ones_like(classification) * -1 + targets = targets.to(device) + + targets[torch.lt(IoU_max, self.neg_iou_th), :] = 0 + + positive_indices = torch.ge(IoU_max, self.pos_iou_th) + + num_positive_anchors = positive_indices.sum() + + assigned_annotations = bbox_annotation[IoU_argmax, :] + + targets[positive_indices, :] = 0 + targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 + + if debug: + if num_positive_anchors > 0: + debug_info['pos_conf'].append(classification[ + positive_indices, + assigned_annotations[positive_indices, 4].long()].mean()) + debug_info['neg_conf'].append(classification[targets == 0].mean()) + stride_idx = anchor_info['stride_idx'] + positive_stride_idx = stride_idx[positive_indices] + pos_count_each_stride = torch.tensor( + [(positive_stride_idx == i).sum() for i in range(5)]) + if 'cum_pos_count_each_stride' not in self.buf: + self.buf['cum_pos_count_each_stride'] = pos_count_each_stride + else: + cum_pos_count_each_stride = self.buf['cum_pos_count_each_stride'] + cum_pos_count_each_stride += pos_count_each_stride + self.buf['cum_pos_count_each_stride'] = cum_pos_count_each_stride + + #cls_loss = calculate_focal_loss(classification, targets, alpha, + #gamma) + cls_loss = self.cls_loss(classification, targets) + + cls_loss = cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0) + assert cls_loss == cls_loss + classification_losses.append(cls_loss) + + if positive_indices.sum() > 0: + assigned_annotations = assigned_annotations[positive_indices, :] + if self.reg_loss_type == 'L1': + anchor_widths_pi = anchor_widths[positive_indices] + anchor_heights_pi = anchor_heights[positive_indices] + anchor_ctr_x_pi = anchor_ctr_x[positive_indices] + anchor_ctr_y_pi = anchor_ctr_y[positive_indices] + + gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] + gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] + gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths + gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights + + # efficientdet style + gt_widths = torch.clamp(gt_widths, min=1) + gt_heights = torch.clamp(gt_heights, min=1) + + targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi + targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi + targets_dw = torch.log(gt_widths / anchor_widths_pi) + targets_dh = torch.log(gt_heights / anchor_heights_pi) + + targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw)) + targets = targets.t() + + regression_diff = torch.abs(targets - regression[positive_indices, :]) + + regression_loss = torch.where( + torch.le(regression_diff, 1.0 / 9.0), + 0.5 * 9.0 * torch.pow(regression_diff, 2), + regression_diff - 0.5 / 9.0 + ).mean() + elif self.reg_loss_type == 'GIOU': + curr_regression = regression[positive_indices, :] + curr_anchors = anchor[positive_indices] + curr_pred_xyxy = self.regressBoxes(curr_anchors, + curr_regression) + regression_loss = 1.- calculate_giou(curr_pred_xyxy, assigned_annotations) + regression_loss = regression_loss.mean() + assert regression_loss == regression_loss + else: + raise NotImplementedError + regression_losses.append(regression_loss) + else: + if torch.cuda.is_available(): + regression_losses.append(torch.tensor(0).to(dtype).cuda()) + else: + regression_losses.append(torch.tensor(0).to(dtype)) + if debug: + if len(debug_info) > 0: + logging.info('pos = {}; neg = {}, saved_ratio = {}/{}={:.1f}, ' + 'stride_info = {}' + .format( + torch.tensor(debug_info['pos_conf']).mean(), + torch.tensor(debug_info['neg_conf']).mean(), + self.gt_saved_by_at_least, + self.gt_total, + 1. * self.gt_saved_by_at_least / self.gt_total, + self.buf['cum_pos_count_each_stride'], + )) + return self.cls_weight * torch.stack(classification_losses).mean(dim=0, keepdim=True), \ + self.reg_weight * torch.stack(regression_losses).mean(dim=0, keepdim=True) + +class ModelWithLoss(nn.Module): + def __init__(self, model, criterion): + super().__init__() + self.criterion = criterion + self.module = model + + def forward(self, *args): + if len(args) == 2: + imgs, annotations = args + elif len(args) == 1: + imgs, annotations = args[0][:2] + _, regression, classification, anchors = self.module(imgs) + cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations) + return {'cls_loss': cls_loss, 'reg_loss': reg_loss} + +class TorchVisionNMS(nn.Module): + def __init__(self, iou_threshold): + super().__init__() + self.iou_threshold = iou_threshold + + def forward(self, box, prob): + nms_idx = nms(box, prob, iou_threshold=self.iou_threshold) + return nms_idx + +class PostProcess(nn.Module): + def __init__(self, iou_threshold): + super().__init__() + self.nms = TorchVisionNMS(iou_threshold) + + def forward(self, x, anchors, regression, + classification, + transformed_anchors, threshold, max_box): + all_above_th = classification > threshold + out = [] + num_image = x.shape[0] + num_class = classification.shape[-1] + + #classification = classification.cpu() + #transformed_anchors = transformed_anchors.cpu() + #all_above_th = all_above_th.cpu() + max_box_pre_nms = 1000 + for i in range(num_image): + all_rois = [] + all_class_ids = [] + all_scores = [] + for c in range(num_class): + above_th = all_above_th[i, :, c].nonzero() + if len(above_th) == 0: + continue + above_prob = classification[i, above_th, c].squeeze(1) + if len(above_th) > max_box_pre_nms: + _, idx = above_prob.topk(max_box_pre_nms) + above_th = above_th[idx] + above_prob = above_prob[idx] + transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1) + nms_idx = self.nms(transformed_anchors_per, above_prob) + if len(nms_idx) > 0: + all_rois.append(transformed_anchors_per[nms_idx]) + ids = torch.tensor([c] * len(nms_idx)) + all_class_ids.append(ids) + all_scores.append(above_prob[nms_idx]) + + if len(all_rois) > 0: + rois = torch.cat(all_rois) + class_ids = torch.cat(all_class_ids) + scores = torch.cat(all_scores) + if len(scores) > max_box: + _, idx = torch.topk(scores, max_box) + rois = rois[idx, :] + class_ids = class_ids[idx] + scores = scores[idx] + out.append({ + 'rois': rois, + 'class_ids': class_ids, + 'scores': scores, + }) + else: + out.append({ + 'rois': [], + 'class_ids': [], + 'scores': [], + }) + + return out + +class InferenceModel(nn.Module): + def __init__(self, model): + super().__init__() + self.module = model + + self.regressBoxes = BBoxTransform() + self.clipBoxes = ClipBoxes() + self.threshold = 0.01 + self.nms_threshold = 0.5 + self.max_box = 100 + self.debug = False + self.post_process = PostProcess(self.nms_threshold) + + def forward(self, sample): + features, regression, classification, anchor_info = self.module(sample['image']) + anchors = anchor_info['anchor'] + classification = classification.sigmoid() + transformed_anchors = self.regressBoxes(anchors, regression) + transformed_anchors = self.clipBoxes(transformed_anchors, sample['image']) + + preds = self.post_process(sample['image'], anchors, regression, + classification, transformed_anchors, + self.threshold, self.max_box) + + if self.debug: + logging.info('debugging') + imgs = sample['image'] + imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() + imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8) + imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs] + display(preds, imgs, list(map(str, range(80)))) + + for p, s in zip(preds, sample['scale']): + if len(p['rois']) > 0: + p['rois'] /= s + return preds + diff --git a/maskrcnn_benchmark/modeling/backbone/efficientnet.py b/maskrcnn_benchmark/modeling/backbone/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a124b0e0672c723f44d75b63d2434fcfe0f52c --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/efficientnet.py @@ -0,0 +1,691 @@ +""" + EfficientNet for ImageNet-1K, implemented in PyTorch. + Original papers: + - 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,' https://arxiv.org/abs/1905.11946, + - 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665. +""" + +import os +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from maskrcnn_benchmark.layers import SEBlock, swish + + +def round_channels(channels, + divisor=8): + """ + Round weighted channel number (make divisible operation). + + Parameters: + ---------- + channels : int or float + Original number of channels. + divisor : int, default 8 + Alignment value. + + Returns + ------- + int + Weighted number of channels. + """ + rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) + if float(rounded_channels) < 0.9 * channels: + rounded_channels += divisor + return rounded_channels + + +def calc_tf_padding(x, + kernel_size, + stride=1, + dilation=1): + """ + Calculate TF-same like padding size. + + Parameters: + ---------- + x : tensor + Input tensor. + kernel_size : int + Convolution window size. + stride : int, default 1 + Strides of the convolution. + dilation : int, default 1 + Dilation value for convolution layer. + + Returns + ------- + tuple of 4 int + The size of the padding. + """ + height, width = x.size()[2:] + oh = math.ceil(height / stride) + ow = math.ceil(width / stride) + pad_h = max((oh - 1) * stride + (kernel_size - 1) * dilation + 1 - height, 0) + pad_w = max((ow - 1) * stride + (kernel_size - 1) * dilation + 1 - width, 0) + return pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2 + + +class ConvBlock(nn.Module): + """ + Standard convolution block with Batch normalization and activation. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple/list of 2 int + Convolution window size. + stride : int or tuple/list of 2 int + Strides of the convolution. + padding : int, or tuple/list of 2 int, or tuple/list of 4 int + Padding value for convolution layer. + dilation : int or tuple/list of 2 int, default 1 + Dilation value for convolution layer. + groups : int, default 1 + Number of groups. + bias : bool, default False + Whether the layer uses a bias vector. + use_bn : bool, default True + Whether to use BatchNorm layer. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + activation : function or str or None, default nn.ReLU(inplace=True) + Activation function or name of activation function. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=1, + groups=1, + bias=False, + use_bn=True, + bn_eps=1e-5, + activation=nn.ReLU(inplace=True)): + super(ConvBlock, self).__init__() + self.activate = (activation is not None) + self.use_bn = use_bn + self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4)) + + if self.use_pad: + self.pad = nn.ZeroPad2d(padding=padding) + padding = 0 + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + if self.use_bn: + self.bn = nn.BatchNorm2d( + num_features=out_channels, + eps=bn_eps) + if self.activate: + self.activ = activation + + def forward(self, x): + if self.use_pad: + x = self.pad(x) + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + if self.activate: + x = self.activ(x) + return x + + +def conv1x1_block(in_channels, + out_channels, + stride=1, + padding=0, + groups=1, + bias=False, + use_bn=True, + bn_eps=1e-5, + activation=nn.ReLU(inplace=True)): + """ + 1x1 version of the standard convolution block. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + stride : int or tuple/list of 2 int, default 1 + Strides of the convolution. + padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 0 + Padding value for convolution layer. + groups : int, default 1 + Number of groups. + bias : bool, default False + Whether the layer uses a bias vector. + use_bn : bool, default True + Whether to use BatchNorm layer. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + activation : function or str or None, default nn.ReLU(inplace=True) + Activation function or name of activation function. + """ + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding, + groups=groups, + bias=bias, + use_bn=use_bn, + bn_eps=bn_eps, + activation=activation) + + +def conv3x3_block(in_channels, + out_channels, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=False, + use_bn=True, + bn_eps=1e-5, + activation=nn.ReLU(inplace=True)): + """ + 3x3 version of the standard convolution block. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + stride : int or tuple/list of 2 int, default 1 + Strides of the convolution. + padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1 + Padding value for convolution layer. + dilation : int or tuple/list of 2 int, default 1 + Dilation value for convolution layer. + groups : int, default 1 + Number of groups. + bias : bool, default False + Whether the layer uses a bias vector. + use_bn : bool, default True + Whether to use BatchNorm layer. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + activation : function or str or None, default nn.ReLU(inplace=True) + Activation function or name of activation function. + """ + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + use_bn=use_bn, + bn_eps=bn_eps, + activation=activation) + + +def dwconv3x3_block(in_channels, + out_channels, + stride=1, + padding=1, + dilation=1, + bias=False, + bn_eps=1e-5, + activation=nn.ReLU(inplace=True)): + """ + 3x3 depthwise version of the standard convolution block. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + stride : int or tuple/list of 2 int, default 1 + Strides of the convolution. + padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1 + Padding value for convolution layer. + dilation : int or tuple/list of 2 int, default 1 + Dilation value for convolution layer. + bias : bool, default False + Whether the layer uses a bias vector. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + activation : function or str or None, default nn.ReLU(inplace=True) + Activation function or name of activation function. + """ + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=out_channels, + bias=bias, + use_bn=True, + bn_eps=bn_eps, + activation=activation) + + +def dwconv5x5_block(in_channels, + out_channels, + stride=1, + padding=2, + dilation=1, + bias=False, + bn_eps=1e-5, + activation=nn.ReLU(inplace=True)): + """ + 5x5 depthwise version of the standard convolution block. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + stride : int or tuple/list of 2 int, default 1 + Strides of the convolution. + padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 2 + Padding value for convolution layer. + dilation : int or tuple/list of 2 int, default 1 + Dilation value for convolution layer. + bias : bool, default False + Whether the layer uses a bias vector. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + activation : function or str or None, default nn.ReLU(inplace=True) + Activation function or name of activation function. + """ + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=5, + stride=stride, + padding=padding, + dilation=dilation, + groups=out_channels, + bias=bias, + use_bn=True, + bn_eps=bn_eps, + activation=activation) + + +class EffiDwsConvUnit(nn.Module): + """ + EfficientNet specific depthwise separable convolution block/unit with BatchNorms and activations at each convolution + layers. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + stride : int or tuple/list of 2 int + Strides of the second convolution layer. + bn_eps : float + Small float added to variance in Batch norm. + activation : str + Name of activation function. + tf_mode : bool + Whether to use TF-like mode. + """ + def __init__(self, + in_channels, + out_channels, + stride, + bn_eps, + activation, + tf_mode): + super(EffiDwsConvUnit, self).__init__() + self.tf_mode = tf_mode + self.residual = (in_channels == out_channels) and (stride == 1) + + self.dw_conv = dwconv3x3_block( + in_channels=in_channels, + out_channels=in_channels, + padding=(0 if tf_mode else 1), + bn_eps=bn_eps, + activation=activation) + self.se = SEBlock( + channels=in_channels, + reduction=4, + mid_activation=activation) + self.pw_conv = conv1x1_block( + in_channels=in_channels, + out_channels=out_channels, + bn_eps=bn_eps, + activation=None) + + def forward(self, x): + if self.residual: + identity = x + if self.tf_mode: + x = F.pad(x, pad=calc_tf_padding(x, kernel_size=3)) + x = self.dw_conv(x) + x = self.se(x) + x = self.pw_conv(x) + if self.residual: + x = x + identity + return x + + +class EffiInvResUnit(nn.Module): + """ + EfficientNet inverted residual unit. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int or tuple/list of 2 int + Convolution window size. + stride : int or tuple/list of 2 int + Strides of the second convolution layer. + exp_factor : int + Factor for expansion of channels. + se_factor : int + SE reduction factor for each unit. + bn_eps : float + Small float added to variance in Batch norm. + activation : str + Name of activation function. + tf_mode : bool + Whether to use TF-like mode. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + exp_factor, + se_factor, + bn_eps, + activation, + tf_mode): + super(EffiInvResUnit, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.tf_mode = tf_mode + self.residual = (in_channels == out_channels) and (stride == 1) + self.use_se = se_factor > 0 + mid_channels = in_channels * exp_factor + dwconv_block_fn = dwconv3x3_block if kernel_size == 3 else (dwconv5x5_block if kernel_size == 5 else None) + + self.conv1 = conv1x1_block( + in_channels=in_channels, + out_channels=mid_channels, + bn_eps=bn_eps, + activation=activation) + self.conv2 = dwconv_block_fn( + in_channels=mid_channels, + out_channels=mid_channels, + stride=stride, + padding=(0 if tf_mode else (kernel_size // 2)), + bn_eps=bn_eps, + activation=activation) + if self.use_se: + self.se = SEBlock( + channels=mid_channels, + reduction=(exp_factor * se_factor), + mid_activation=activation) + self.conv3 = conv1x1_block( + in_channels=mid_channels, + out_channels=out_channels, + bn_eps=bn_eps, + activation=None) + + def forward(self, x): + if self.residual: + identity = x + x = self.conv1(x) + if self.tf_mode: + x = F.pad(x, pad=calc_tf_padding(x, kernel_size=self.kernel_size, stride=self.stride)) + x = self.conv2(x) + if self.use_se: + x = self.se(x) + x = self.conv3(x) + if self.residual: + x = x + identity + return x + + +class EffiInitBlock(nn.Module): + """ + EfficientNet specific initial block. + + Parameters: + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + bn_eps : float + Small float added to variance in Batch norm. + activation : str + Name of activation function. + tf_mode : bool + Whether to use TF-like mode. + """ + + def __init__(self, + in_channels, + out_channels, + bn_eps, + activation, + tf_mode): + super(EffiInitBlock, self).__init__() + self.tf_mode = tf_mode + + self.conv = conv3x3_block( + in_channels=in_channels, + out_channels=out_channels, + stride=2, + padding=(0 if tf_mode else 1), + bn_eps=bn_eps, + activation=activation) + + def forward(self, x): + if self.tf_mode: + x = F.pad(x, pad=calc_tf_padding(x, kernel_size=3, stride=2)) + x = self.conv(x) + return x + + +class EfficientNet(nn.Module): + """ + EfficientNet model from 'EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks,' + https://arxiv.org/abs/1905.11946. + + Parameters: + ---------- + channels : list of list of int + Number of output channels for each unit. + init_block_channels : int + Number of output channels for initial unit. + final_block_channels : int + Number of output channels for the final block of the feature extractor. + kernel_sizes : list of list of int + Number of kernel sizes for each unit. + strides_per_stage : list int + Stride value for the first unit of each stage. + expansion_factors : list of list of int + Number of expansion factors for each unit. + dropout_rate : float, default 0.2 + Fraction of the input units to drop. Must be a number between 0 and 1. + tf_mode : bool, default False + Whether to use TF-like mode. + bn_eps : float, default 1e-5 + Small float added to variance in Batch norm. + in_channels : int, default 3 + Number of input channels. + in_size : tuple of two ints, default (224, 224) + Spatial size of the expected input image. + num_classes : int, default 1000 + Number of classification classes. + """ + def __init__(self, + cfg, + channels, + init_block_channels, + kernel_sizes, + strides_per_stage, + expansion_factors, + tf_mode=False, + bn_eps=1e-5, + in_channels=3): + super(EfficientNet, self).__init__() + activation = swish() + + self.out_channels = [] + self.features = nn.Sequential() + self.stages = [] + stem = EffiInitBlock( + in_channels=in_channels, + out_channels=init_block_channels, + bn_eps=bn_eps, + activation=activation, + tf_mode=tf_mode) + self.features.add_module("init_block", stem) + self.stages.append(stem) + + in_channels = init_block_channels + for i, channels_per_stage in enumerate(channels): + kernel_sizes_per_stage = kernel_sizes[i] + expansion_factors_per_stage = expansion_factors[i] + stage = nn.Sequential() + for j, out_channels in enumerate(channels_per_stage): + kernel_size = kernel_sizes_per_stage[j] + expansion_factor = expansion_factors_per_stage[j] + stride = strides_per_stage[i] if (j == 0) else 1 + if i == 0: + stage.add_module("unit{}".format(j + 1), EffiDwsConvUnit( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bn_eps=bn_eps, + activation=activation, + tf_mode=tf_mode)) + else: + stage.add_module("unit{}".format(j + 1), EffiInvResUnit( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + exp_factor=expansion_factor, + se_factor=4, + bn_eps=bn_eps, + activation=activation, + tf_mode=tf_mode)) + in_channels = out_channels + if i>0: + self.out_channels.append(out_channels) + self.features.add_module("stage{}".format(i + 1), stage) + self.stages.append(stage) + # Optionally freeze (requires_grad=False) parts of the backbone + self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) + + def _freeze_backbone(self, freeze_at): + if freeze_at < 0: + return + for stage_index in range(freeze_at): + m = self.stages[stage_index] + for p in m.parameters(): + p.requires_grad = False + + def forward(self, x): + res = [] + for i, stage in enumerate(self.stages): + x = stage(x) + if i>1: + res.append(x) + return res + + +def get_efficientnet(cfg, version, tf_mode = True, bn_eps=1e-5, **kwargs): + if version == "b0": + depth_factor = 1.0 + width_factor = 1.0 + elif version == "b1": + depth_factor = 1.1 + width_factor = 1.0 + elif version == "b2": + depth_factor = 1.2 + width_factor = 1.1 + elif version == "b3": + depth_factor = 1.4 + width_factor = 1.2 + elif version == "b4": + depth_factor = 1.8 + width_factor = 1.4 + elif version == "b5": + depth_factor = 2.2 + width_factor = 1.6 + elif version == "b6": + depth_factor = 2.6 + width_factor = 1.8 + elif version == "b7": + depth_factor = 3.1 + width_factor = 2.0 + elif version == "b8": + depth_factor = 3.6 + width_factor = 2.2 + else: + raise ValueError("Unsupported EfficientNet version {}".format(version)) + + init_block_channels = 32 + layers = [1, 2, 2, 3, 3, 4, 1] + downsample = [1, 1, 1, 1, 0, 1, 0] + channels_per_layers = [16, 24, 40, 80, 112, 192, 320] + expansion_factors_per_layers = [1, 6, 6, 6, 6, 6, 6] + kernel_sizes_per_layers = [3, 3, 5, 3, 5, 5, 3] + strides_per_stage = [1, 2, 2, 2, 1, 2, 1] + + layers = [int(math.ceil(li * depth_factor)) for li in layers] + channels_per_layers = [round_channels(ci * width_factor) for ci in channels_per_layers] + + from functools import reduce + channels = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], + zip(channels_per_layers, layers, downsample), []) + kernel_sizes = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], + zip(kernel_sizes_per_layers, layers, downsample), []) + expansion_factors = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], + zip(expansion_factors_per_layers, layers, downsample), []) + strides_per_stage = reduce(lambda x, y: x + [[y[0]] * y[1]] if y[2] != 0 else x[:-1] + [x[-1] + [y[0]] * y[1]], + zip(strides_per_stage, layers, downsample), []) + strides_per_stage = [si[0] for si in strides_per_stage] + + init_block_channels = round_channels(init_block_channels * width_factor) + + net = EfficientNet( + cfg, + channels=channels, + init_block_channels=init_block_channels, + kernel_sizes=kernel_sizes, + strides_per_stage=strides_per_stage, + expansion_factors=expansion_factors, + tf_mode=tf_mode, + bn_eps=bn_eps, + **kwargs) + + return net diff --git a/maskrcnn_benchmark/modeling/backbone/fbnet.py b/maskrcnn_benchmark/modeling/backbone/fbnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc2823f3bd2f06cc86b3e1bb597fb20f219817d --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/fbnet.py @@ -0,0 +1,536 @@ +""" +FBNet model builder +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy +import logging +import math +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d, SyncBatchNorm +from maskrcnn_benchmark.layers import Conv2d, interpolate +from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d +from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp + + +logger = logging.getLogger(__name__) + + +def _py2_round(x): + return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5) + + +def _get_divisible_by(num, divisible_by, min_val): + ret = int(num) + if divisible_by > 0 and num % divisible_by != 0: + ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by) + return ret + + +class Identity(nn.Module): + def __init__(self, C_in, C_out, stride): + super(Identity, self).__init__() + self.conv = ( + ConvBNRelu( + C_in, + C_out, + kernel=1, + stride=stride, + pad=0, + no_bias=1, + use_relu="relu", + bn_type="bn", + ) + if C_in != C_out or stride != 1 + else None + ) + + def forward(self, x): + if self.conv: + out = self.conv(x) + else: + out = x + return out + + +class CascadeConv3x3(nn.Sequential): + def __init__(self, C_in, C_out, stride): + assert stride in [1, 2] + ops = [ + Conv2d(C_in, C_in, 3, stride, 1, bias=False), + BatchNorm2d(C_in), + nn.ReLU(inplace=True), + Conv2d(C_in, C_out, 3, 1, 1, bias=False), + BatchNorm2d(C_out), + ] + super(CascadeConv3x3, self).__init__(*ops) + self.res_connect = (stride == 1) and (C_in == C_out) + + def forward(self, x): + y = super(CascadeConv3x3, self).forward(x) + if self.res_connect: + y += x + return y + + +class Shift(nn.Module): + def __init__(self, C, kernel_size, stride, padding): + super(Shift, self).__init__() + self.C = C + kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32) + ch_idx = 0 + + assert stride in [1, 2] + self.stride = stride + self.padding = padding + self.kernel_size = kernel_size + self.dilation = 1 + + hks = kernel_size // 2 + ksq = kernel_size ** 2 + + for i in range(kernel_size): + for j in range(kernel_size): + if i == hks and j == hks: + num_ch = C // ksq + C % ksq + else: + num_ch = C // ksq + kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1 + ch_idx += num_ch + + self.register_parameter("bias", None) + self.kernel = nn.Parameter(kernel, requires_grad=False) + + def forward(self, x): + if x.numel() > 0: + return nn.functional.conv2d( + x, + self.kernel, + self.bias, + (self.stride, self.stride), + (self.padding, self.padding), + self.dilation, + self.C, # groups + ) + + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], + (self.padding, self.dilation), + (self.dilation, self.dilation), + (self.kernel_size, self.kernel_size), + (self.stride, self.stride), + ) + ] + output_shape = [x.shape[0], self.C] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + +class ShiftBlock5x5(nn.Sequential): + def __init__(self, C_in, C_out, expansion, stride): + assert stride in [1, 2] + self.res_connect = (stride == 1) and (C_in == C_out) + + C_mid = _get_divisible_by(C_in * expansion, 8, 8) + + ops = [ + # pw + Conv2d(C_in, C_mid, 1, 1, 0, bias=False), + BatchNorm2d(C_mid), + nn.ReLU(inplace=True), + # shift + Shift(C_mid, 5, stride, 2), + # pw-linear + Conv2d(C_mid, C_out, 1, 1, 0, bias=False), + BatchNorm2d(C_out), + ] + super(ShiftBlock5x5, self).__init__(*ops) + + def forward(self, x): + y = super(ShiftBlock5x5, self).forward(x) + if self.res_connect: + y += x + return y + + +class ChannelShuffle(nn.Module): + def __init__(self, groups): + super(ChannelShuffle, self).__init__() + self.groups = groups + + def forward(self, x): + """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" + N, C, H, W = x.size() + g = self.groups + assert C % g == 0, "Incompatible group size {} for input channel {}".format( + g, C + ) + return ( + x.view(N, g, int(C / g), H, W) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(N, C, H, W) + ) + + +class ConvBNRelu(nn.Sequential): + def __init__( + self, + input_depth, + output_depth, + kernel, + stride, + pad, + no_bias, + use_relu, + bn_type, + group=1, + *args, + **kwargs + ): + super(ConvBNRelu, self).__init__() + + assert use_relu in ["relu", None] + if isinstance(bn_type, (list, tuple)): + assert len(bn_type) == 2 + assert bn_type[0] == "gn" + gn_group = bn_type[1] + bn_type = bn_type[0] + assert bn_type in ["bn", "nsbn", "sbn", "af", "gn", None] + assert stride in [1, 2, 4] + + op = Conv2d( + input_depth, + output_depth, + kernel_size=kernel, + stride=stride, + padding=pad, + bias=not no_bias, + groups=group, + *args, + **kwargs + ) + nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu") + if op.bias is not None: + nn.init.constant_(op.bias, 0.0) + self.add_module("conv", op) + + if bn_type == "bn": + bn_op = BatchNorm2d(output_depth) + elif bn_type == "sbn": + bn_op = SyncBatchNorm(output_depth) + elif bn_type == "nsbn": + bn_op = NaiveSyncBatchNorm2d(output_depth) + elif bn_type == "gn": + bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth) + elif bn_type == "af": + bn_op = FrozenBatchNorm2d(output_depth) + if bn_type is not None: + self.add_module("bn", bn_op) + + if use_relu == "relu": + self.add_module("relu", nn.ReLU(inplace=True)) + + +class SEModule(nn.Module): + reduction = 4 + + def __init__(self, C): + super(SEModule, self).__init__() + mid = max(C // self.reduction, 8) + conv1 = Conv2d(C, mid, 1, 1, 0) + conv2 = Conv2d(mid, C, 1, 1, 0) + + self.op = nn.Sequential( + nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid() + ) + + def forward(self, x): + return x * self.op(x) + + +class Upsample(nn.Module): + def __init__(self, scale_factor, mode, align_corners=None): + super(Upsample, self).__init__() + self.scale = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + return interpolate( + x, scale_factor=self.scale, mode=self.mode, + align_corners=self.align_corners + ) + + +def _get_upsample_op(stride): + assert ( + stride in [1, 2, 4] + or stride in [-1, -2, -4] + or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride)) + ) + + scales = stride + ret = None + if isinstance(stride, tuple) or stride < 0: + scales = [-x for x in stride] if isinstance(stride, tuple) else -stride + stride = 1 + ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None) + + return ret, stride + + +class IRFBlock(nn.Module): + def __init__( + self, + input_depth, + output_depth, + expansion, + stride, + bn_type="bn", + kernel=3, + width_divisor=1, + shuffle_type=None, + pw_group=1, + se=False, + cdw=False, + dw_skip_bn=False, + dw_skip_relu=False, + ): + super(IRFBlock, self).__init__() + + assert kernel in [1, 3, 5, 7], kernel + + self.use_res_connect = stride == 1 and input_depth == output_depth + self.output_depth = output_depth + + mid_depth = int(input_depth * expansion) + mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor) + + # pw + self.pw = ConvBNRelu( + input_depth, + mid_depth, + kernel=1, + stride=1, + pad=0, + no_bias=1, + use_relu="relu", + bn_type=bn_type, + group=pw_group, + ) + + # negative stride to do upsampling + self.upscale, stride = _get_upsample_op(stride) + + # dw + if kernel == 1: + self.dw = nn.Sequential() + elif cdw: + dw1 = ConvBNRelu( + mid_depth, + mid_depth, + kernel=kernel, + stride=stride, + pad=(kernel // 2), + group=mid_depth, + no_bias=1, + use_relu="relu", + bn_type=bn_type, + ) + dw2 = ConvBNRelu( + mid_depth, + mid_depth, + kernel=kernel, + stride=1, + pad=(kernel // 2), + group=mid_depth, + no_bias=1, + use_relu="relu" if not dw_skip_relu else None, + bn_type=bn_type if not dw_skip_bn else None, + ) + self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)])) + else: + self.dw = ConvBNRelu( + mid_depth, + mid_depth, + kernel=kernel, + stride=stride, + pad=(kernel // 2), + group=mid_depth, + no_bias=1, + use_relu="relu" if not dw_skip_relu else None, + bn_type=bn_type if not dw_skip_bn else None, + ) + + # pw-linear + self.pwl = ConvBNRelu( + mid_depth, + output_depth, + kernel=1, + stride=1, + pad=0, + no_bias=1, + use_relu=None, + bn_type=bn_type, + group=pw_group, + ) + + self.shuffle_type = shuffle_type + if shuffle_type is not None: + self.shuffle = ChannelShuffle(pw_group) + + self.se4 = SEModule(output_depth) if se else nn.Sequential() + + self.output_depth = output_depth + + def forward(self, x): + y = self.pw(x) + if self.shuffle_type == "mid": + y = self.shuffle(y) + if self.upscale is not None: + y = self.upscale(y) + y = self.dw(y) + y = self.pwl(y) + if self.use_res_connect: + y += x + y = self.se4(y) + return y + + + +skip = lambda C_in, C_out, stride, **kwargs: Identity( + C_in, C_out, stride +) +basic_block = lambda C_in, C_out, stride, **kwargs: CascadeConv3x3( + C_in, C_out, stride +) +# layer search 2 +ir_k3_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=3, **kwargs +) +ir_k3_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=3, **kwargs +) +ir_k3_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=3, **kwargs +) +ir_k3_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs +) +ir_k5_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=5, **kwargs +) +ir_k5_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=5, **kwargs +) +ir_k5_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=5, **kwargs +) +ir_k5_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs +) +# layer search se +ir_k3_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=3, se=True, **kwargs +) +ir_k3_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=3, se=True, **kwargs +) +ir_k3_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=3, se=True, **kwargs +) +ir_k3_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, + C_out, + 4, + stride, + kernel=3, + shuffle_type=mid, + pw_group=4, + se=True, + **kwargs +) +ir_k5_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=5, se=True, **kwargs +) +ir_k5_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=5, se=True, **kwargs +) +ir_k5_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=5, se=True, **kwargs +) +ir_k5_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, + C_out, + 4, + stride, + kernel=5, + shuffle_type="mid", + pw_group=4, + se=True, + **kwargs +) +# layer search 3 (in addition to layer search 2) +ir_k3_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs +) +ir_k5_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs +) +ir_k3_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, + C_out, + 1, + stride, + kernel=3, + shuffle_type="mid", + pw_group=2, + se=True, + **kwargs +) +ir_k5_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, + C_out, + 1, + stride, + kernel=5, + shuffle_type="mid", + pw_group=2, + se=True, + **kwargs +) +# layer search 4 (in addition to layer search 3) +ir_k33_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs +) +ir_k33_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs +) +ir_k33_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs +) +# layer search 5 (in addition to layer search 4) +ir_k7_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=7, **kwargs +) +ir_k7_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=7, **kwargs +) +ir_k7_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=7, **kwargs +) +ir_k7_sep_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs +) +ir_k7_sep_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs +) +ir_k7_sep_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock( + C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs +) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/fpn.py b/maskrcnn_benchmark/modeling/backbone/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..90bd853325190618d82addd46ac0d08f44742aa7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/fpn.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +class FPN(nn.Module): + """ + Module that adds FPN on top of a list of feature maps. + The feature maps are currently supposed to be in increasing depth + order, and must be consecutive + """ + + def __init__( + self, in_channels_list, out_channels, conv_block, top_blocks=None, drop_block=None, use_spp=False, use_pan=False, + return_swint_feature_before_fusion=False + ): + """ + Arguments: + in_channels_list (list[int]): number of channels for each feature map that + will be fed + out_channels (int): number of channels of the FPN representation + top_blocks (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + FPN output, and the result will extend the result list + """ + super(FPN, self).__init__() + self.inner_blocks = [] + self.layer_blocks = [] + self.pan_blocks = [] if use_pan else None + self.spp_block = SPPLayer() if use_spp else None + self.return_swint_feature_before_fusion = return_swint_feature_before_fusion + for idx, in_channels in enumerate(in_channels_list, 1): + inner_block = "fpn_inner{}".format(idx) + layer_block = "fpn_layer{}".format(idx) + + if in_channels == 0: + continue + if idx==len(in_channels_list) and use_spp: + in_channels = in_channels*4 + inner_block_module = conv_block(in_channels, out_channels, 1) + layer_block_module = conv_block(out_channels, out_channels, 3, 1) + self.add_module(inner_block, inner_block_module) + self.add_module(layer_block, layer_block_module) + self.inner_blocks.append(inner_block) + self.layer_blocks.append(layer_block) + + if use_pan: + pan_in_block = "pan_in_layer{}".format(idx) + pan_in_block_module = conv_block(out_channels, out_channels, 3, 2) + self.add_module(pan_in_block, pan_in_block_module) + pan_out_block = "pan_out_layer{}".format(idx) + pan_out_block_module = conv_block(out_channels, out_channels, 3, 1) + self.add_module(pan_out_block, pan_out_block_module) + self.pan_blocks.append([pan_in_block, pan_out_block]) + + self.top_blocks = top_blocks + self.drop_block = drop_block + + def forward(self, x): + """ + Arguments: + x (list[Tensor]): feature maps for each feature level. + Returns: + results (tuple[Tensor]): feature maps after FPN layers. + They are ordered from highest resolution first. + """ + if type(x) is tuple: + # for the case of VL backbone + x, x_text = x[0], x[1] + # print([v.shape for v in x]) + swint_feature_c4 = None + if self.return_swint_feature_before_fusion: + # TODO: here we only return last single scale feature map before the backbone fusion, should be more flexible + swint_feature_c4 = x[-2] + + if self.spp_block: + last_inner = getattr(self, self.inner_blocks[-1])(self.spp_block(x[-1])) + else: + last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) + results = [] + results.append(getattr(self, self.layer_blocks[-1])(last_inner)) + for feature, inner_block, layer_block in zip( + x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] + ): + if not inner_block: + continue + inner_lateral = getattr(self, inner_block)(feature) + + if inner_lateral.shape[-2:] != last_inner.shape[-2:]: + # TODO: could also give size instead of + inner_top_down = F.interpolate(last_inner, size=inner_lateral.shape[-2:], mode="nearest") + else: + inner_top_down = last_inner + + # TODO use size instead of scale to make it robust to different sizes + # inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], + # mode='bilinear', align_corners=False) + last_inner = inner_lateral + inner_top_down + if self.drop_block and self.training: + results.insert(0, getattr(self, layer_block)(self.drop_block(last_inner))) + else: + results.insert(0, getattr(self, layer_block)(last_inner)) + + if self.pan_blocks: + pan_results = [] + last_outer = results[0] + pan_results.append(last_outer) + for outer_top_down, pan_block in zip(results[1:], self.pan_blocks): + + if self.drop_block and self.training: + pan_lateral = getattr(self, pan_block[0])(self.drop_block(last_outer)) + else: + pan_lateral = getattr(self, pan_block[0])(last_outer) + + last_outer = getattr(self, pan_block[1])(pan_lateral + outer_top_down) + pan_results.append(last_outer) + results = pan_results + + if isinstance(self.top_blocks, LastLevelP6P7): + last_results = self.top_blocks(x[-1], results[-1]) + results.extend(last_results) + elif isinstance(self.top_blocks, LastLevelMaxPool): + last_results = self.top_blocks(results[-1]) + results.extend(last_results) + + try: + return tuple(results), x_text, swint_feature_c4 + except NameError as e: + return tuple(results) + + +class LastLevelMaxPool(nn.Module): + def forward(self, x): + return [F.max_pool2d(x, 1, 2, 0)] + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7. + """ + def __init__(self, in_channels, out_channels): + super(LastLevelP6P7, self).__init__() + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + nn.init.kaiming_uniform_(module.weight, a=1) + nn.init.constant_(module.bias, 0) + self.use_P5 = in_channels == out_channels + + def forward(self, c5, p5): + x = p5 if self.use_P5 else c5 + p6 = self.p6(x) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +class SPPLayer(nn.Module): + def __init__(self): + super(SPPLayer, self).__init__() + + def forward(self, x): + x_1 = x + x_2 = F.max_pool2d(x, 5, stride=1, padding=2) + x_3 = F.max_pool2d(x, 9, stride=1, padding=4) + x_4 = F.max_pool2d(x, 13, stride=1, padding=6) + out = torch.cat((x_1, x_2, x_3, x_4),dim=1) + return out \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/mixer.py b/maskrcnn_benchmark/modeling/backbone/mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4782d50863a4da9070285a9cd3093db4fbcf6f8 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/mixer.py @@ -0,0 +1,23 @@ +import torch +from torch import nn + +class MixedOperationRandom(nn.Module): + def __init__(self, search_ops): + super(MixedOperationRandom, self).__init__() + self.ops = nn.ModuleList(search_ops) + self.num_ops = len(search_ops) + + def forward(self, x, x_path=None): + if x_path is None: + output = sum(op(x) for op in self.ops) / self.num_ops + else: + assert isinstance(x_path, (int, float)) and 0 <= x_path < self.num_ops or isinstance(x_path, torch.Tensor) + if isinstance(x_path, (int, float)): + x_path = int(x_path) + assert 0 <= x_path < self.num_ops + output = self.ops[x_path](x) + elif isinstance(x_path, torch.Tensor): + assert x_path.size(0) == x.size(0), 'batch_size should match length of y_idx' + output = torch.cat([self.ops[int(x_path[i].item())](x.narrow(0, i, 1)) + for i in range(x.size(0))], dim=0) + return output \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/ops.py b/maskrcnn_benchmark/modeling/backbone/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..23c36ccebb57d3207d97e32babd84e21b65a2ec2 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/ops.py @@ -0,0 +1,71 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv7x7(in_planes, out_planes, stride=1, groups=1, dilation=1): + """7x7 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, + padding=3*dilation, groups=groups, bias=False, dilation=dilation) + + +def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1): + """5x5 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, + padding=2*dilation, groups=groups, bias=False, dilation=dilation) + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def maxpool(**kwargs): + return nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + +def avgpool(**kwargs): + return nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + +def dropout(prob): + return nn.Dropout(prob) + + +conv3x3sep = lambda i, o, s=1: conv3x3(i, o, s, groups=i) +conv3x3g2 = lambda i, o, s=1: conv3x3(i, o, s, groups=2) +conv3x3g4 = lambda i, o, s=1: conv3x3(i, o, s, groups=4) +conv3x3g8 = lambda i, o, s=1: conv3x3(i, o, s, groups=8) +conv3x3dw = lambda i, o, s=1: conv3x3(i, o, s, groups=i) + +conv3x3d2 = lambda i, o, s=1: conv3x3(i, o, s, dilation=2) +conv3x3d3 = lambda i, o, s=1: conv3x3(i, o, s, dilation=3) +conv3x3d4 = lambda i, o, s=1: conv3x3(i, o, s, dilation=4) + + +conv5x5sep = lambda i, o, s=1: conv5x5(i, o, s, groups=i) +conv5x5g2 = lambda i, o, s=1: conv5x5(i, o, s, groups=2) +conv5x5g4 = lambda i, o, s=1: conv5x5(i, o, s, groups=4) +conv5x5g8 = lambda i, o, s=1: conv5x5(i, o, s, groups=8) +conv5x5dw = lambda i, o, s=1: conv5x5(i, o, s, groups=i) + + +conv5x5d2 = lambda i, o, s=1: conv5x5(i, o, s, dilation=2) +conv5x5d3 = lambda i, o, s=1: conv5x5(i, o, s, dilation=3) +conv5x5d4 = lambda i, o, s=1: conv5x5(i, o, s, dilation=4) + +conv7x7sep = lambda i, o, s=1: conv7x7(i, o, s, groups=i) +conv7x7g2 = lambda i, o, s=1: conv7x7(i, o, s, groups=2) +conv7x7g4 = lambda i, o, s=1: conv7x7(i, o, s, groups=4) +conv7x7g8 = lambda i, o, s=1: conv7x7(i, o, s, groups=8) +conv7x7dw = lambda i, o, s=1: conv7x7(i, o, s, groups=i) + +conv7x7d2 = lambda i, o, s=1: conv7x7(i, o, s, dilation=2) +conv7x7d3 = lambda i, o, s=1: conv7x7(i, o, s, dilation=3) +conv7x7d4 = lambda i, o, s=1: conv7x7(i, o, s, dilation=4) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/resnet.py b/maskrcnn_benchmark/modeling/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f27a1edc470a0c49778369abceeefbeecc792f85 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/resnet.py @@ -0,0 +1,643 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Variant of the resnet module that takes cfg as an argument. +Example usage. Strings may be specified in the config file. + model = ResNet( + "StemWithFixedBatchNorm", + "BottleneckWithFixedBatchNorm", + "ResNet50StagesTo4", + ) +OR: + model = ResNet( + "StemWithGN", + "BottleneckWithGN", + "ResNet50StagesTo4", + ) +Custom implementations may be written in user code and hooked in via the +`register_*` functions. +""" +from collections import namedtuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import BatchNorm2d, SyncBatchNorm + +from maskrcnn_benchmark.layers import FrozenBatchNorm2d, NaiveSyncBatchNorm2d +from maskrcnn_benchmark.layers import Conv2d, DFConv2d, SELayer +from maskrcnn_benchmark.modeling.make_layers import group_norm +from maskrcnn_benchmark.utils.registry import Registry + + +# ResNet stage specification +StageSpec = namedtuple( + "StageSpec", + [ + "index", # Index of the stage, eg 1, 2, ..,. 5 + "block_count", # Number of residual blocks in the stage + "return_features", # True => return the last feature map from this stage + ], +) + +# ----------------------------------------------------------------------------- +# Standard ResNet models +# ----------------------------------------------------------------------------- +# ResNet-50 (including all stages) +ResNet50StagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, False), (4, 3, True)) +) +# ResNet-50 up to stage 4 (excludes stage 5) +ResNet50StagesTo4 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) +) +# ResNet-101 (including all stages) +ResNet101StagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True)) +) +# ResNet-101 up to stage 4 (excludes stage 5) +ResNet101StagesTo4 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True)) +) +# ResNet-50-FPN (including all stages) +ResNet50FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 6, True), (4, 3, True)) +) +# ResNet-101-FPN (including all stages) +ResNet101FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 4, True), (3, 23, True), (4, 3, True)) +) +# ResNet-152-FPN (including all stages) +ResNet152FPNStagesTo5 = tuple( + StageSpec(index=i, block_count=c, return_features=r) + for (i, c, r) in ((1, 3, True), (2, 8, True), (3, 36, True), (4, 3, True)) +) + +class ResNet(nn.Module): + def __init__(self, cfg): + super(ResNet, self).__init__() + + # If we want to use the cfg in forward(), then we should make a copy + # of it and store it for later use: + # self.cfg = cfg.clone() + + # Translate string names to implementations + norm_level = None + stem_module = _STEM_MODULES[cfg.MODEL.RESNETS.STEM_FUNC] + stage_specs = _STAGE_SPECS[cfg.MODEL.BACKBONE.CONV_BODY] + transformation_module = _TRANSFORMATION_MODULES[cfg.MODEL.RESNETS.TRANS_FUNC] + + if cfg.MODEL.BACKBONE.USE_BN: + stem_module = StemWithBatchNorm + transformation_module = BottleneckWithBatchNorm + norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL + elif cfg.MODEL.BACKBONE.USE_NSYNCBN: + stem_module = StemWithNaiveSyncBatchNorm + transformation_module = BottleneckWithNaiveSyncBatchNorm + norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL + elif cfg.MODEL.BACKBONE.USE_SYNCBN: + stem_module = StemWithSyncBatchNorm + transformation_module = BottleneckWithSyncBatchNorm + norm_level = cfg.MODEL.BACKBONE.NORM_LEVEL + + # Construct the stem module + self.stem = stem_module(cfg) + + # Constuct the specified ResNet stages + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + stage2_bottleneck_channels = num_groups * width_per_group + stage2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + with_se = cfg.MODEL.RESNETS.WITH_SE + + self.stages = [] + self.out_channels = [] + self.return_features = {} + for stage_spec in stage_specs: + name = "layer" + str(stage_spec.index) + stage2_relative_factor = 2 ** (stage_spec.index - 1) + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + out_channels = stage2_out_channels * stage2_relative_factor + stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index - 1] + if cfg.MODEL.RESNETS.USE_AVG_DOWN: + avg_down_stride = 1 if stage_spec.index==1 else 2 + else: + avg_down_stride = 0 + module = _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + stage_spec.block_count, + num_groups, + cfg.MODEL.RESNETS.STRIDE_IN_1X1, + first_stride=int(stage_spec.index > 1) + 1, + dcn_config={ + "stage_with_dcn": stage_with_dcn, + "with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN, + "deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS, + }, + norm_level=norm_level, + with_se=with_se, + avg_down_stride=avg_down_stride + ) + in_channels = out_channels + self.add_module(name, module) + self.stages.append(name) + self.out_channels.append(out_channels) + self.return_features[name] = stage_spec.return_features + + # Optionally freeze (requires_grad=False) parts of the backbone + self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) + + def _freeze_backbone(self, freeze_at): + if freeze_at < 0: + return + for stage_index in range(freeze_at): + if stage_index == 0: + m = self.stem # stage 0 is the stem + else: + m = getattr(self, "layer" + str(stage_index)) + for p in m.parameters(): + p.requires_grad = False + + def forward(self, x): + outputs = [] + x = self.stem(x) + for stage_name in self.stages: + x = getattr(self, stage_name)(x) + if self.return_features[stage_name]: + outputs.append(x) + return outputs + + +class ResNetHead(nn.Module): + def __init__( + self, + block_module, + stages, + num_groups=1, + width_per_group=64, + stride_in_1x1=True, + stride_init=None, + res2_out_channels=256, + dilation=1, + dcn_config=None + ): + super(ResNetHead, self).__init__() + + stage2_relative_factor = 2 ** (stages[0].index - 1) + stage2_bottleneck_channels = num_groups * width_per_group + out_channels = res2_out_channels * stage2_relative_factor + in_channels = out_channels // 2 + bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor + + block_module = _TRANSFORMATION_MODULES[block_module] + + self.stages = [] + stride = stride_init + for stage in stages: + name = "layer" + str(stage.index) + if not stride: + stride = int(stage.index > 1) + 1 + module = _make_stage( + block_module, + in_channels, + bottleneck_channels, + out_channels, + stage.block_count, + num_groups, + stride_in_1x1, + first_stride=stride, + dilation=dilation, + dcn_config=dcn_config + ) + stride = None + self.add_module(name, module) + self.stages.append(name) + self.out_channels = out_channels + + def forward(self, x): + for stage in self.stages: + x = getattr(self, stage)(x) + return x + + +def _make_stage( + transformation_module, + in_channels, + bottleneck_channels, + out_channels, + block_count, + num_groups, + stride_in_1x1, + first_stride, + dilation=1, + dcn_config=None, + norm_level=None, + **kwargs +): + blocks = [] + stride = first_stride + for li in range(block_count): + if norm_level is not None: + layer_module = BottleneckWithFixedBatchNorm + if norm_level >= 1 and li == 0: + layer_module = transformation_module + if norm_level >= 2 and li == block_count - 1: + layer_module = transformation_module + if norm_level >= 3: + layer_module = transformation_module + else: + layer_module = transformation_module + + blocks.append( + layer_module( + in_channels, + bottleneck_channels, + out_channels, + num_groups, + stride_in_1x1, + stride, + dilation=dilation, + dcn_config=dcn_config, + **kwargs + ) + ) + stride = 1 + in_channels = out_channels + return nn.Sequential(*blocks) + + +class Bottleneck(nn.Module): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups, + stride_in_1x1, + stride, + dilation, + norm_func, + dcn_config, + with_se=False, + avg_down_stride=0, + ): + super(Bottleneck, self).__init__() + + self.downsample = None + if in_channels != out_channels: + down_stride = stride if dilation == 1 else 1 + if avg_down_stride>0: + self.downsample = nn.Sequential( + nn.AvgPool2d( + kernel_size=avg_down_stride, + stride=avg_down_stride, + ceil_mode=True, + count_include_pad=False + ), + nn.Conv2d( + in_channels, out_channels, + kernel_size=1, stride=1, bias=False + ), + norm_func(out_channels), + ) + else: + self.downsample = nn.Sequential( + Conv2d( + in_channels, out_channels, + kernel_size=1, stride=down_stride, bias=False + ), + norm_func(out_channels), + ) + for modules in [self.downsample,]: + for l in modules.modules(): + if isinstance(l, Conv2d): + nn.init.kaiming_uniform_(l.weight, a=1) + + if dilation > 1: + stride = 1 # reset to be 1 + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + ) + self.bn1 = norm_func(bottleneck_channels) + # TODO: specify init for the above + with_dcn = dcn_config.get("stage_with_dcn", False) + if with_dcn: + deformable_groups = dcn_config.get("deformable_groups", 1) + with_modulated_dcn = dcn_config.get("with_modulated_dcn", False) + self.conv2 = DFConv2d( + bottleneck_channels, + bottleneck_channels, + with_modulated_dcn=with_modulated_dcn, + kernel_size=3, + stride=stride_3x3, + groups=num_groups, + dilation=dilation, + deformable_groups=deformable_groups, + bias=False + ) + else: + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=dilation, + bias=False, + groups=num_groups, + dilation=dilation + ) + nn.init.kaiming_uniform_(self.conv2.weight, a=1) + + self.bn2 = norm_func(bottleneck_channels) + + self.conv3 = Conv2d( + bottleneck_channels, out_channels, kernel_size=1, bias=False + ) + self.bn3 = norm_func(out_channels) + + self.se = SELayer(out_channels) if with_se and not with_dcn else None + + for l in [self.conv1, self.conv3,]: + nn.init.kaiming_uniform_(l.weight, a=1) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu_(out) + + out = self.conv2(out) + out = self.bn2(out) + out = F.relu_(out) + + out0 = self.conv3(out) + out = self.bn3(out0) + + if self.se: + out = self.se(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = F.relu_(out) + + return out + + +class BaseStem(nn.Module): + def __init__(self, cfg, norm_func): + super(BaseStem, self).__init__() + + out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + self.stem_3x3 = cfg.MODEL.RESNETS.USE_STEM3X3 + + if self.stem_3x3: + self.conv1 = Conv2d( + 3, out_channels, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = norm_func(out_channels) + self.conv2 = Conv2d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn2 = norm_func(out_channels) + for l in [self.conv1, self.conv2]: + nn.init.kaiming_uniform_(l.weight, a=1) + else: + self.conv1 = Conv2d( + 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_func(out_channels) + + for l in [self.conv1,]: + nn.init.kaiming_uniform_(l.weight, a=1) + + def forward(self, x): + if self.stem_3x3: + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu_(x) + else: + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class BottleneckWithFixedBatchNorm(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config=None, + **kwargs + ): + super(BottleneckWithFixedBatchNorm, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=FrozenBatchNorm2d, + dcn_config=dcn_config, + **kwargs + ) + + +class StemWithFixedBatchNorm(BaseStem): + def __init__(self, cfg): + super(StemWithFixedBatchNorm, self).__init__( + cfg, norm_func=FrozenBatchNorm2d + ) + + +class BottleneckWithBatchNorm(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config=None, + **kwargs + ): + super(BottleneckWithBatchNorm, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=BatchNorm2d, + dcn_config=dcn_config, + **kwargs + ) + + +class StemWithBatchNorm(BaseStem): + def __init__(self, cfg): + super(StemWithBatchNorm, self).__init__( + cfg, norm_func=BatchNorm2d + ) + + +class BottleneckWithNaiveSyncBatchNorm(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config=None, + **kwargs + ): + super(BottleneckWithNaiveSyncBatchNorm, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=NaiveSyncBatchNorm2d, + dcn_config=dcn_config, + **kwargs + ) + + +class StemWithNaiveSyncBatchNorm(BaseStem): + def __init__(self, cfg): + super(StemWithNaiveSyncBatchNorm, self).__init__( + cfg, norm_func=NaiveSyncBatchNorm2d + ) + + +class BottleneckWithSyncBatchNorm(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config=None, + **kwargs + ): + super(BottleneckWithSyncBatchNorm, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=SyncBatchNorm, + dcn_config=dcn_config, + **kwargs + ) + + +class StemWithSyncBatchNorm(BaseStem): + def __init__(self, cfg): + super(StemWithSyncBatchNorm, self).__init__( + cfg, norm_func=SyncBatchNorm + ) + + +class BottleneckWithGN(Bottleneck): + def __init__( + self, + in_channels, + bottleneck_channels, + out_channels, + num_groups=1, + stride_in_1x1=True, + stride=1, + dilation=1, + dcn_config=None, + **kwargs + ): + super(BottleneckWithGN, self).__init__( + in_channels=in_channels, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + stride_in_1x1=stride_in_1x1, + stride=stride, + dilation=dilation, + norm_func=group_norm, + dcn_config=dcn_config, + **kwargs + ) + + +class StemWithGN(BaseStem): + def __init__(self, cfg): + super(StemWithGN, self).__init__(cfg, norm_func=group_norm) + + +_TRANSFORMATION_MODULES = Registry({ + "BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm, + "BottleneckWithGN": BottleneckWithGN, +}) + +_STEM_MODULES = Registry({ + "StemWithFixedBatchNorm": StemWithFixedBatchNorm, + "StemWithGN": StemWithGN, +}) + +_STAGE_SPECS = Registry({ + "R-50-C4": ResNet50StagesTo4, + "R-50-C5": ResNet50StagesTo5, + "R-50-RETINANET": ResNet50StagesTo5, + "R-101-C4": ResNet101StagesTo4, + "R-101-C5": ResNet101StagesTo5, + "R-101-RETINANET": ResNet101StagesTo5, + "R-50-FPN": ResNet50FPNStagesTo5, + "R-50-FPN-RETINANET": ResNet50FPNStagesTo5, + "R-50-FPN-FCOS": ResNet50FPNStagesTo5, + "R-101-FPN": ResNet101FPNStagesTo5, + "R-101-FPN-RETINANET": ResNet101FPNStagesTo5, + "R-101-FPN-FCOS": ResNet101FPNStagesTo5, + "R-152-FPN": ResNet152FPNStagesTo5, +}) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/swint.py b/maskrcnn_benchmark/modeling/backbone/swint.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a162b6d28f71837a8812b3d3dfb9526451df74 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/swint.py @@ -0,0 +1,650 @@ +# -------------------------------------------------------- +# Swin Transformer +# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False, + out_features=["stage2", "stage3", "stage4", "stage5"], + backbone_arch="SWINT-FPN-RETINANET"): + super(SwinTransformer, self).__init__() + + print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint) + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + + self.out_features = out_features + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + self._out_feature_strides = {} + self._out_feature_channels = {} + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1) + self.layers.append(layer) + + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + self._out_feature_channels[stage] = embed_dim * 2 ** i_layer + self._out_feature_strides[stage] = 4 * 2 ** i_layer + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in range(self.num_layers): + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + if i_layer == 0 and backbone_arch.endswith("RETINANET"): + layer = nn.Identity() + else: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + name = f'stage{i + 2}' + if name in self.out_features: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swint_backbone(cfg): + """ + Create a SwinT instance from config. + + Returns: + VoVNet: a :class:`VoVNet` instance. + """ + return SwinTransformer( + patch_size=4, + in_chans=3, + embed_dim=cfg.MODEL.SWINT.EMBED_DIM, + depths=cfg.MODEL.SWINT.DEPTHS, + num_heads=cfg.MODEL.SWINT.NUM_HEADS, + window_size=cfg.MODEL.SWINT.WINDOW_SIZE, + mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE, + norm_layer=nn.LayerNorm, + ape=cfg.MODEL.SWINT.APE, + patch_norm=True, + frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT, + backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY, + use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT, + out_features=cfg.MODEL.BACKBONE.OUT_FEATURES + ) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/swint_v2.py b/maskrcnn_benchmark/modeling/backbone/swint_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fe550c33f06880f67176cd03c918db06c25e6f --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/swint_v2.py @@ -0,0 +1,734 @@ +# -------------------------------------------------------- +# Swin Transformer +# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + self.gamma = 1.0 + if layer_scale: + self.gamma = nn.Parameter( + 1e-4*torch.ones(dim), requires_grad=True + ) + + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(self.gamma*x) + x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + layer_scale=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + layer_scale=layer_scale) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(patch_size=3, in_chans=dim, embed_dim=dim*2, + stride=2, padding=1, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +# class PatchEmbed(nn.Module): +# """ Image to Patch Embedding +# Args: +# patch_size (int): Patch token size. Default: 4. +# in_chans (int): Number of input image channels. Default: 3. +# embed_dim (int): Number of linear projection output channels. Default: 96. +# norm_layer (nn.Module, optional): Normalization layer. Default: None +# """ +# +# def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): +# super().__init__() +# patch_size = to_2tuple(patch_size) +# self.patch_size = patch_size +# +# self.in_chans = in_chans +# self.embed_dim = embed_dim +# +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) +# if norm_layer is not None: +# self.norm = norm_layer(embed_dim) +# else: +# self.norm = None +# +# def forward(self, x): +# """Forward function.""" +# # padding +# _, _, H, W = x.size() +# if W % self.patch_size[1] != 0: +# x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) +# if H % self.patch_size[0] != 0: +# x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) +# +# x = self.proj(x) # B C Wh Ww +# if self.norm is not None: +# Wh, Ww = x.size(2), x.size(3) +# x = x.flatten(2).transpose(1, 2) +# x = self.norm(x) +# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) +# +# return x + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__( + self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None + ): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding + ) + self.norm = norm_layer(embed_dim) if norm_layer else None + + def forward(self, x, H=None, W=None): + restore_hw = False + if H is None and W is None and len(x.size()) == 4: + _, _, H, W = x.size() + if W % self.patch_size != 0: + x = F.pad(x, (0, self.patch_size - W % self.patch_size)) + if H % self.patch_size != 0: + x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size)) + restore_hw = True + + if len(x.size()) == 3: + x = rearrange( + x, 'b (h w) c -> b c h w', + h=H, + w=W + ) + x = self.proj(x) # B C Wh Ww + B, C, Wh, Ww = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm: + x = self.norm(x) + + if restore_hw: + x = rearrange( + x, 'b (h w) c -> b c h w', + h=Wh, + w=Ww + ) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=7, + patch_padding=2, + patch_stride=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False, + layer_scale=False, + out_features=["stage2", "stage3", "stage4", "stage5"], + out_norm=True, + backbone_arch="SWINT-FPN-RETINANET"): + super(SwinTransformer, self).__init__() + + print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint) + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + + self.out_features = out_features + self.out_norm = out_norm + + # split image into non-overlapping patches + # self.patch_embed = PatchEmbed( + # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + # norm_layer=norm_layer if self.patch_norm else None) + self.patch_embed = ConvEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding, + norm_layer=norm_layer if self.patch_norm else None + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + self._out_feature_strides = {} + self._out_feature_channels = {} + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1, + layer_scale=layer_scale) + self.layers.append(layer) + + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + self._out_feature_channels[stage] = embed_dim * 2 ** i_layer + self._out_feature_strides[stage] = 4 * 2 ** i_layer + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + if self.out_norm: + for i_layer in range(self.num_layers): + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + if i_layer == 0 and backbone_arch.endswith("RETINANET"): + layer = nn.Identity() + else: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + name = f'stage{i + 2}' + if name in self.out_features: + if self.out_norm: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swint_backbone(cfg): + """ + Create a SwinT instance from config. + + Returns: + VoVNet: a :class:`VoVNet` instance. + """ + return SwinTransformer( + patch_size=7, + patch_padding=2, + patch_stride=4, + in_chans=3, + embed_dim=cfg.MODEL.SWINT.EMBED_DIM, + depths=cfg.MODEL.SWINT.DEPTHS, + num_heads=cfg.MODEL.SWINT.NUM_HEADS, + window_size=cfg.MODEL.SWINT.WINDOW_SIZE, + mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE, + norm_layer=nn.LayerNorm, + ape=cfg.MODEL.SWINT.APE, + patch_norm=True, + frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT, + backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY, + use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT, + layer_scale=cfg.MODEL.SWINT.LAYER_SCALE, + out_features=cfg.MODEL.BACKBONE.OUT_FEATURES, + out_norm=cfg.MODEL.SWINT.OUT_NORM, + ) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py b/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..008fda21f1d5c82661146e30a3ff3496579035e3 --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/swint_v2_vl.py @@ -0,0 +1,861 @@ +# -------------------------------------------------------- +# Swin Transformer +# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + ntext=None, dim_text=None): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + if ntext is not None: + self.qkv_text = nn.Linear(dim_text, dim * 3, bias=qkv_bias) + self.proj_text = nn.Linear(dim, dim_text) + + self.i2t_relative_position_bias = nn.Parameter( + torch.zeros(2, num_heads, ntext)) # (2, nH, ntext) + self.t2t_relative_position_bias = nn.Parameter( + torch.zeros(num_heads, ntext, ntext)) # (nH, ntext, ntext) + trunc_normal_(self.i2t_relative_position_bias, std=.02) + trunc_normal_(self.t2t_relative_position_bias, std=.02) + + def forward(self, x, mask=None, x_text=None, mask_text=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + x_text: input text features with shape of (B_text, N_text, C_text) + mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; TODO: support casual mask + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + if x_text is not None: + B_text, N_text, C_text = x_text.shape + nW = B_ // B_text # number of windows + assert B_text * nW == B_, "B_ is not a multiplier of B_text in window attention" + # notice that after qkv_text, the hidden dimension is C instead of C_text + qkv_text = self.qkv_text(x_text).reshape(B_text, N_text, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, + 1, 4) + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # image to text attention + attn_i2t = (q @ torch.repeat_interleave(k_text, nW, dim=0).transpose(-2, -1)) # B_, nH, N, N_text + # add image to text bias and text_mask + if mask_text is not None: + mask_and_i2t_bias = mask_text.view(B_text, 1, 1, N_text) + self.i2t_relative_position_bias[:1].expand( + B_text, -1, -1).unsqueeze(-2) # B_text, nH, 1, N_text + else: + mask_and_i2t_bias = self.i2t_relative_position_bias[:1].expand(B_text, -1, -1).unsqueeze( + -2) # B_text, nH, 1, N_text + attn_i2t = attn_i2t + torch.repeat_interleave(mask_and_i2t_bias, nW, dim=0) + + attn = torch.cat((attn, attn_i2t), dim=-1) # B_, nH, N, N+N_text + + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + if x_text is None: + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + else: + x = ( + attn @ torch.cat((v, torch.repeat_interleave(v_text, nW, dim=0)), dim=-2) + ).transpose(1, 2).reshape(B_, N, C) + + # compute attn_t2i + q_text = q_text * self.scale + + kv = qkv[1:].reshape(2, B_text, nW, self.num_heads, N, C // self.num_heads).transpose(2, 3) + k, v = kv[0].reshape(B_text, self.num_heads, nW * N, -1), kv[1].reshape(B_text, self.num_heads, nW * N, -1) + attn_t2i = (q_text @ k.transpose(-2, -1)) + mask_t2i = self.i2t_relative_position_bias[1:].expand(B_text, -1, -1).unsqueeze(-1) # B_text, nH, N_text, 1 + attn_t2i = attn_t2i + mask_t2i + + attn_t2t = (q_text @ k_text.transpose(-2, -1)) + # add relative positional bias + attn_t2t = attn_t2t + self.t2t_relative_position_bias.unsqueeze(0) + if mask_text is not None: + attn_t2t = attn_t2t + mask_text.view(B_text, 1, 1, N_text) + + attn_t = torch.cat((attn_t2i, attn_t2t), dim=-1) # B_text, nH, N_text, N+N_text + attn_t = self.softmax(attn_t) + attn_t = self.attn_drop(attn_t) + + x_text = ( + attn_t @ torch.cat((v, v_text), dim=-2) + ).transpose(1, 2).reshape(B_text, N_text, C) + + x_text = self.proj_text(x_text) + x_text = self.proj_drop(x_text) + + x = self.proj(x) + x = self.proj_drop(x) + return x, x_text + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False, ntext=None, dim_text=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + ntext=ntext, dim_text=dim_text + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + self.gamma = 1.0 + if layer_scale: + self.gamma = nn.Parameter( + 1e-4*torch.ones(dim), requires_grad=True + ) + + if dim_text is not None: + self.norm1_text = norm_layer(dim_text) + self.norm2_text = norm_layer(dim_text) + mlp_hidden_dim_text = int(dim_text * mlp_ratio) + self.mlp_text = Mlp(in_features=dim_text, hidden_features=mlp_hidden_dim_text, act_layer=act_layer, + drop=drop) + self.gamma_text = 1.0 + if layer_scale: + self.gamma_text = nn.Parameter( + 1e-4*torch.ones(dim_text), requires_grad=True + ) + + def forward(self, x, mask_matrix, x_text, mask_text): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + x_text: Input text feature, tensor size (B, L_text, C_text). L_text: Number of text tokens. + mask_text: text mask (vector right now). + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + if x_text is not None: + B, L_text, C_text = x_text.shape + shortcut_text = x_text + x_text = self.norm1_text(x_text) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, x_text = self.attn(x_windows, mask=attn_mask, x_text=x_text, + mask_text=mask_text) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(self.gamma*x) + x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x))) + + if x_text is not None: + x_text = shortcut_text + self.drop_path(self.gamma_text*x_text) + x_text = x_text + self.drop_path(self.gamma_text*self.mlp_text(self.norm2_text(x_text))) + + return x, x_text + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + layer_scale=False, + ntext=None, + dim_text=None): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + layer_scale=layer_scale, + ntext=ntext, + dim_text=dim_text) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(patch_size=3, in_chans=dim, embed_dim=dim*2, + stride=2, padding=1, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W, x_text=None, mask_text=None): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + x_text: input text features with shape of (B_text, N_text, C_text) + mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x, x_text = checkpoint.checkpoint(blk, x, attn_mask, x_text, mask_text) + else: + x, x_text = blk(x, attn_mask, x_text, mask_text) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww, x_text + else: + return x, H, W, x, H, W, x_text + + +# class PatchEmbed(nn.Module): +# """ Image to Patch Embedding +# Args: +# patch_size (int): Patch token size. Default: 4. +# in_chans (int): Number of input image channels. Default: 3. +# embed_dim (int): Number of linear projection output channels. Default: 96. +# norm_layer (nn.Module, optional): Normalization layer. Default: None +# """ +# +# def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): +# super().__init__() +# patch_size = to_2tuple(patch_size) +# self.patch_size = patch_size +# +# self.in_chans = in_chans +# self.embed_dim = embed_dim +# +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) +# if norm_layer is not None: +# self.norm = norm_layer(embed_dim) +# else: +# self.norm = None +# +# def forward(self, x): +# """Forward function.""" +# # padding +# _, _, H, W = x.size() +# if W % self.patch_size[1] != 0: +# x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) +# if H % self.patch_size[0] != 0: +# x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) +# +# x = self.proj(x) # B C Wh Ww +# if self.norm is not None: +# Wh, Ww = x.size(2), x.size(3) +# x = x.flatten(2).transpose(1, 2) +# x = self.norm(x) +# x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) +# +# return x + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__( + self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None + ): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding + ) + self.norm = norm_layer(embed_dim) if norm_layer else None + + def forward(self, x, H=None, W=None): + restore_hw = False + if H is None and W is None and len(x.size()) == 4: + _, _, H, W = x.size() + if W % self.patch_size != 0: + x = F.pad(x, (0, self.patch_size - W % self.patch_size)) + if H % self.patch_size != 0: + x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size)) + restore_hw = True + + if len(x.size()) == 3: + x = rearrange( + x, 'b (h w) c -> b c h w', + h=H, + w=W + ) + x = self.proj(x) # B C Wh Ww + B, C, Wh, Ww = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm: + x = self.norm(x) + + if restore_hw: + x = rearrange( + x, 'b (h w) c -> b c h w', + h=Wh, + w=Ww + ) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=7, + patch_padding=2, + patch_stride=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False, + layer_scale=False, + out_features=["stage2", "stage3", "stage4", "stage5"], + out_norm=True, + backbone_arch="SWINT-FPN-RETINANET", + max_query_len=None, + lang_dim=None): + super(SwinTransformer, self).__init__() + + print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint) + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + + self.out_features = out_features + self.out_norm = out_norm + + # split image into non-overlapping patches + # self.patch_embed = PatchEmbed( + # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + # norm_layer=norm_layer if self.patch_norm else None) + self.patch_embed = ConvEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding, + norm_layer=norm_layer if self.patch_norm else None + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + self._out_feature_strides = {} + self._out_feature_channels = {} + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + if i_layer < self.num_layers - 1: + ntext, dim_text = None, None + else: + ntext, dim_text = max_query_len, lang_dim + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1, + layer_scale=layer_scale, + ntext=ntext, + dim_text=dim_text + ) + self.layers.append(layer) + + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + self._out_feature_channels[stage] = embed_dim * 2 ** i_layer + self._out_feature_strides[stage] = 4 * 2 ** i_layer + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + if self.out_norm: + for i_layer in range(self.num_layers): + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + if i_layer == 0 and backbone_arch.endswith("RETINANET"): + layer = nn.Identity() + else: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, inputs): + """Forward function.""" + x = inputs["img"] + language_dict_features = inputs["lang"] + + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + x_text = language_dict_features['hidden'] + if "masks" in language_dict_features: + mask_text = 1.0 - language_dict_features["masks"] # (B, N_text) 0 means not to be masked out + mask_text.masked_fill_(mask_text.bool(), -float('inf')) + else: + mask_text = None + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + if i < self.num_layers - 1: + x_out, H, W, x, Wh, Ww, _ = layer(x, Wh, Ww, x_text=None, mask_text=None) + else: + x_out, H, W, x, Wh, Ww, x_text = layer(x, Wh, Ww, x_text=x_text, mask_text=mask_text) + name = f'stage{i + 2}' + if name in self.out_features: + if self.out_norm: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + # the backbone only update the "hidden" field, currently + language_dict_features['hidden'] = x_text + + return outs, language_dict_features + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swint_backbone(cfg): + """ + Create a SwinT instance from config. + + Returns: + VoVNet: a :class:`VoVNet` instance. + """ + return SwinTransformer( + patch_size=7, + patch_padding=2, + patch_stride=4, + in_chans=3, + embed_dim=cfg.MODEL.SWINT.EMBED_DIM, + depths=cfg.MODEL.SWINT.DEPTHS, + num_heads=cfg.MODEL.SWINT.NUM_HEADS, + window_size=cfg.MODEL.SWINT.WINDOW_SIZE, + mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE, + norm_layer=nn.LayerNorm, + ape=cfg.MODEL.SWINT.APE, + patch_norm=True, + frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT, + backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY, + use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT, + layer_scale=cfg.MODEL.SWINT.LAYER_SCALE, + out_features=cfg.MODEL.BACKBONE.OUT_FEATURES, + out_norm=cfg.MODEL.SWINT.OUT_NORM, + max_query_len=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, + lang_dim=cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM + ) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/backbone/swint_vl.py b/maskrcnn_benchmark/modeling/backbone/swint_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..97ed5705f727c26f0a5bbb21e95050d39a5348da --- /dev/null +++ b/maskrcnn_benchmark/modeling/backbone/swint_vl.py @@ -0,0 +1,774 @@ +# -------------------------------------------------------- +# Swin Transformer +# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., + ntext=None, dim_text=None): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + if ntext is not None: + self.qkv_text = nn.Linear(dim_text, dim * 3, bias=qkv_bias) + self.proj_text = nn.Linear(dim, dim_text) + + self.i2t_relative_position_bias = nn.Parameter( + torch.zeros(2, num_heads, ntext)) # (2, nH, ntext) + self.t2t_relative_position_bias = nn.Parameter( + torch.zeros(num_heads, ntext, ntext)) # (nH, ntext, ntext) + trunc_normal_(self.i2t_relative_position_bias, std=.02) + trunc_normal_(self.t2t_relative_position_bias, std=.02) + + def forward(self, x, mask=None, x_text=None, mask_text=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + x_text: input text features with shape of (B_text, N_text, C_text) + mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; TODO: support casual mask + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + if x_text is not None: + B_text, N_text, C_text = x_text.shape + nW = B_ // B_text # number of windows + assert B_text * nW == B_, "B_ is not a multiplier of B_text in window attention" + # notice that after qkv_text, the hidden dimension is C instead of C_text + qkv_text = self.qkv_text(x_text).reshape(B_text, N_text, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, + 1, 4) + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[ + 2] # make torchscript happy (cannot use tensor as tuple) + + # image to text attention + attn_i2t = (q @ torch.repeat_interleave(k_text, nW, dim=0).transpose(-2, -1)) # B_, nH, N, N_text + # add image to text bias and text_mask + if mask_text is not None: + mask_and_i2t_bias = mask_text.view(B_text, 1, 1, N_text) + self.i2t_relative_position_bias[:1].expand( + B_text, -1, -1).unsqueeze(-2) # B_text, nH, 1, N_text + else: + mask_and_i2t_bias = self.i2t_relative_position_bias[:1].expand(B_text, -1, -1).unsqueeze( + -2) # B_text, nH, 1, N_text + attn_i2t = attn_i2t + torch.repeat_interleave(mask_and_i2t_bias, nW, dim=0) + + attn = torch.cat((attn, attn_i2t), dim=-1) # B_, nH, N, N+N_text + + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + if x_text is None: + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + else: + x = ( + attn @ torch.cat((v, torch.repeat_interleave(v_text, nW, dim=0)), dim=-2) + ).transpose(1, 2).reshape(B_, N, C) + + # compute attn_t2i + q_text = q_text * self.scale + + kv = qkv[1:].reshape(2, B_text, nW, self.num_heads, N, C // self.num_heads).transpose(2, 3) + k, v = kv[0].reshape(B_text, self.num_heads, nW * N, -1), kv[1].reshape(B_text, self.num_heads, nW * N, -1) + attn_t2i = (q_text @ k.transpose(-2, -1)) + mask_t2i = self.i2t_relative_position_bias[1:].expand(B_text, -1, -1).unsqueeze(-1) # B_text, nH, N_text, 1 + attn_t2i = attn_t2i + mask_t2i + + attn_t2t = (q_text @ k_text.transpose(-2, -1)) + # add relative positional bias + attn_t2t = attn_t2t + self.t2t_relative_position_bias.unsqueeze(0) + if mask_text is not None: + attn_t2t = attn_t2t + mask_text.view(B_text, 1, 1, N_text) + + attn_t = torch.cat((attn_t2i, attn_t2t), dim=-1) # B_text, nH, N_text, N+N_text + attn_t = self.softmax(attn_t) + attn_t = self.attn_drop(attn_t) + + x_text = ( + attn_t @ torch.cat((v, v_text), dim=-2) + ).transpose(1, 2).reshape(B_text, N_text, C) + + x_text = self.proj_text(x_text) + x_text = self.proj_drop(x_text) + + x = self.proj(x) + x = self.proj_drop(x) + return x, x_text + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, ntext=None, dim_text=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + ntext=ntext, dim_text=dim_text + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + if dim_text is not None: + self.norm1_text = norm_layer(dim_text) + self.norm2_text = norm_layer(dim_text) + mlp_hidden_dim_text = int(dim_text * mlp_ratio) + self.mlp_text = Mlp(in_features=dim_text, hidden_features=mlp_hidden_dim_text, act_layer=act_layer, + drop=drop) + + def forward(self, x, mask_matrix, x_text, mask_text): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + x_text: Input text feature, tensor size (B, L_text, C_text). L_text: Number of text tokens. + mask_text: text mask (vector right now). + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + if x_text is not None: + B, L_text, C_text = x_text.shape + shortcut_text = x_text + x_text = self.norm1_text(x_text) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, x_text = self.attn(x_windows, mask=attn_mask, x_text=x_text, + mask_text=mask_text) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if x_text is not None: + x_text = shortcut_text + self.drop_path(x_text) + x_text = x_text + self.drop_path(self.mlp_text(self.norm2_text(x_text))) + + return x, x_text + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ntext=None, + dim_text=None): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ntext=ntext, + dim_text=dim_text) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W, x_text=None, mask_text=None): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + x_text: input text features with shape of (B_text, N_text, C_text) + mask_text: (0/-inf) mask with shape of (B_text, N_text) or None; + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x, x_text = checkpoint.checkpoint(blk, x, attn_mask, x_text, mask_text) + else: + x, x_text = blk(x, attn_mask, x_text, mask_text) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww, x_text + else: + return x, H, W, x, H, W, x_text + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + frozen_stages=-1, + use_checkpoint=False, + out_features=["stage2", "stage3", "stage4", "stage5"], + backbone_arch="SWINT-FPN-RETINANET", + max_query_len=None, + lang_dim=None): + super(SwinTransformer, self).__init__() + + print("VISION BACKBONE USE GRADIENT CHECKPOINTING: ", use_checkpoint) + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + + self.out_features = out_features + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + self._out_feature_strides = {} + self._out_feature_channels = {} + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + if i_layer < self.num_layers - 1: + ntext, dim_text = None, None + else: + ntext, dim_text = max_query_len, lang_dim + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint and i_layer > self.frozen_stages - 1, + ntext=ntext, + dim_text=dim_text + ) + self.layers.append(layer) + + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + self._out_feature_channels[stage] = embed_dim * 2 ** i_layer + self._out_feature_strides[stage] = 4 * 2 ** i_layer + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in range(self.num_layers): + stage = f'stage{i_layer + 2}' + if stage in self.out_features: + if i_layer == 0 and backbone_arch.endswith("RETINANET"): + layer = nn.Identity() + else: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def forward(self, inputs): + """Forward function.""" + x = inputs["img"] + language_dict_features = inputs["lang"] + + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + x_text = language_dict_features['hidden'] + if "masks" in language_dict_features: + mask_text = 1.0 - language_dict_features["masks"] # (B, N_text) 0 means not to be masked out + mask_text.masked_fill_(mask_text.bool(), -float('inf')) + else: + mask_text = None + + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + if i < self.num_layers - 1: + x_out, H, W, x, Wh, Ww, _ = layer(x, Wh, Ww, x_text=None, mask_text=None) + else: + x_out, H, W, x, Wh, Ww, x_text = layer(x, Wh, Ww, x_text=x_text, mask_text=mask_text) + name = f'stage{i + 2}' + if name in self.out_features: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + # the backbone only update the "hidden" field, currently + language_dict_features['hidden'] = x_text + + return outs, language_dict_features + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +def build_swint_backbone(cfg): + """ + Create a SwinT instance from config. + + Returns: + VoVNet: a :class:`VoVNet` instance. + """ + return SwinTransformer( + patch_size=4, + in_chans=3, + embed_dim=cfg.MODEL.SWINT.EMBED_DIM, + depths=cfg.MODEL.SWINT.DEPTHS, + num_heads=cfg.MODEL.SWINT.NUM_HEADS, + window_size=cfg.MODEL.SWINT.WINDOW_SIZE, + mlp_ratio=cfg.MODEL.SWINT.MLP_RATIO, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=cfg.MODEL.SWINT.DROP_PATH_RATE, + norm_layer=nn.LayerNorm, + ape=cfg.MODEL.SWINT.APE, + patch_norm=True, + frozen_stages=cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT, + backbone_arch=cfg.MODEL.BACKBONE.CONV_BODY, + use_checkpoint=cfg.MODEL.BACKBONE.USE_CHECKPOINT, + out_features=cfg.MODEL.BACKBONE.OUT_FEATURES, + max_query_len=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, + lang_dim=cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM + ) diff --git a/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..388c7ea8720a77bdc93718754798fcdeb43f6383 --- /dev/null +++ b/maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +class BalancedPositiveNegativeSampler(object): + """ + This class samples batches, ensuring that they contain a fixed proportion of positives + """ + + def __init__(self, batch_size_per_image, positive_fraction): + """ + Arguments: + batch_size_per_image (int): number of elements to be selected per image + positive_fraction (float): percentace of positive elements per batch + """ + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + + def __call__(self, matched_idxs): + """ + Arguments: + matched idxs: list of tensors containing -1, 0 or positive values. + Each tensor corresponds to a specific image. + -1 values are ignored, 0 are considered as negatives and > 0 as + positives. + + Returns: + pos_idx (list[tensor]) + neg_idx (list[tensor]) + + Returns two lists of binary masks for each image. + The first list contains the positive elements that were selected, + and the second list the negative example. + """ + pos_idx = [] + neg_idx = [] + for matched_idxs_per_image in matched_idxs: + positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) + negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) + + num_pos = int(self.batch_size_per_image * self.positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = self.batch_size_per_image - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] + perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] + + pos_idx_per_image = positive[perm1] + neg_idx_per_image = negative[perm2] + + # create binary mask from indices + pos_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.bool + ) + neg_idx_per_image_mask = torch.zeros_like( + matched_idxs_per_image, dtype=torch.bool + ) + pos_idx_per_image_mask[pos_idx_per_image] = 1 + neg_idx_per_image_mask[neg_idx_per_image] = 1 + + pos_idx.append(pos_idx_per_image_mask) + neg_idx.append(neg_idx_per_image_mask) + + return pos_idx, neg_idx diff --git a/maskrcnn_benchmark/modeling/box_coder.py b/maskrcnn_benchmark/modeling/box_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca39db1aa954be3482259797706ca12e56a77f1 --- /dev/null +++ b/maskrcnn_benchmark/modeling/box_coder.py @@ -0,0 +1,95 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math + +import torch + + +class BoxCoder(object): + """ + This class encodes and decodes a set of bounding boxes into + the representation used for training the regressors. + """ + + def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): + """ + Arguments: + weights (4-element tuple) + bbox_xform_clip (float) + """ + self.weights = weights + self.bbox_xform_clip = bbox_xform_clip + + def encode(self, reference_boxes, proposals): + """ + Encode a set of proposals with respect to some + reference boxes + + Arguments: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + """ + + TO_REMOVE = 1 # TODO remove + ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE + ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE + ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths + ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights + + gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE + gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE + gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths + gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights + + wx, wy, ww, wh = self.weights + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + + targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + return targets + + def decode(self, rel_codes, boxes): + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + + Arguments: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + """ + + boxes = boxes.to(rel_codes.dtype) + + TO_REMOVE = 1 # TODO remove + widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE + heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = rel_codes[:, 0::4] / wx + dy = rel_codes[:, 1::4] / wy + dw = rel_codes[:, 2::4] / ww + dh = rel_codes[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.bbox_xform_clip) + dh = torch.clamp(dh, max=self.bbox_xform_clip) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(rel_codes) + # x1 + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w + # y1 + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h + # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 + # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 + + return pred_boxes diff --git a/maskrcnn_benchmark/modeling/detector/__init__.py b/maskrcnn_benchmark/modeling/detector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf35abda3c725249b875078a85d76ede727f9d19 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/__init__.py @@ -0,0 +1,11 @@ +from .generalized_rcnn import GeneralizedRCNN +from .generalized_vl_rcnn import GeneralizedVLRCNN + +_DETECTION_META_ARCHITECTURES = {"GeneralizedRCNN": GeneralizedRCNN, + "GeneralizedVLRCNN": GeneralizedVLRCNN + } + + +def build_detection_model(cfg): + meta_arch = _DETECTION_META_ARCHITECTURES[cfg.MODEL.META_ARCHITECTURE] + return meta_arch(cfg) diff --git a/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7307722d5e281e52a73bff6fb76706445c11e810 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Implements the Generalized R-CNN framework +""" + +import torch +from torch import nn + +from maskrcnn_benchmark.structures.image_list import to_image_list + +from ..backbone import build_backbone +from ..rpn import build_rpn +from ..roi_heads import build_roi_heads + +import timeit + +class GeneralizedRCNN(nn.Module): + """ + Main class for Generalized R-CNN. Currently supports boxes and masks. + It consists of three main parts: + - backbone + - rpn + - heads: takes the features + the proposals from the RPN and computes + detections / masks from it. + """ + + def __init__(self, cfg): + super(GeneralizedRCNN, self).__init__() + + self.backbone = build_backbone(cfg) + self.rpn = build_rpn(cfg) + self.roi_heads = build_roi_heads(cfg) + self.DEBUG = cfg.MODEL.DEBUG + self.ONNX = cfg.MODEL.ONNX + self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE + self.freeze_fpn = cfg.MODEL.FPN.FREEZE + self.freeze_rpn = cfg.MODEL.RPN.FREEZE + + if cfg.MODEL.LINEAR_PROB: + assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" + if hasattr(self.backbone, 'fpn'): + assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" + self.linear_prob = cfg.MODEL.LINEAR_PROB + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(GeneralizedRCNN, self).train(mode) + if self.freeze_backbone: + self.backbone.body.eval() + for p in self.backbone.body.parameters(): + p.requires_grad = False + if self.freeze_fpn: + self.backbone.fpn.eval() + for p in self.backbone.fpn.parameters(): + p.requires_grad = False + if self.freeze_rpn: + self.rpn.eval() + for p in self.rpn.parameters(): + p.requires_grad = False + if self.linear_prob: + if self.rpn is not None: + for key, value in self.rpn.named_parameters(): + if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key): + value.requires_grad = False + if self.roi_heads is not None: + for key, value in self.roi_heads.named_parameters(): + if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key): + value.requires_grad = False + + def forward(self, images, targets=None): + """ + Arguments: + images (list[Tensor] or ImageList): images to be processed + targets (list[BoxList]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + if self.DEBUG: debug_info = {} + if self.DEBUG: debug_info['input_size'] = images[0].size() + if self.DEBUG: tic = timeit.time.perf_counter() + + if self.ONNX: + features = self.backbone(images) + else: + images = to_image_list(images) + features = self.backbone(images.tensors) + + if self.DEBUG: debug_info['feat_time'] = timeit.time.perf_counter() - tic + if self.DEBUG: debug_info['feat_size'] = [feat.size() for feat in features] + if self.DEBUG: tic = timeit.time.perf_counter() + + proposals, proposal_losses = self.rpn(images, features, targets) + + if self.DEBUG: debug_info['rpn_time'] = timeit.time.perf_counter() - tic + if self.DEBUG: debug_info['#rpn'] = [prop for prop in proposals] + if self.DEBUG: tic = timeit.time.perf_counter() + + if self.roi_heads: + x, result, detector_losses = self.roi_heads(features, proposals, targets) + else: + # RPN-only models don't have roi_heads + x = features + result = proposals + detector_losses = {} + + if self.DEBUG: debug_info['rcnn_time'] = timeit.time.perf_counter() - tic + if self.DEBUG: debug_info['#rcnn'] = result + if self.DEBUG: return result, debug_info + + if self.training: + losses = {} + losses.update(detector_losses) + losses.update(proposal_losses) + return losses + + return result \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py b/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..01f64c6ffc6272a222777cb4deb6b2ee3d715b23 --- /dev/null +++ b/maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py @@ -0,0 +1,466 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Implements the Generalized VL R-CNN framework +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from maskrcnn_benchmark.structures.image_list import to_image_list +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist + +from ..backbone import build_backbone +from ..rpn import build_rpn +from ..roi_heads import build_roi_heads + +from ..language_backbone import build_language_backbone +from transformers import AutoTokenizer + +import random +import timeit +import pdb +from copy import deepcopy + +def random_word(input_ids, mask_token_id, vocabs, padding_token_id, greenlight_map): + """ + greenlight_map, batch_size x 256 (seq_len): + 0 means this location cannot be calculated in the MLM loss + -1 means this location cannot be masked!! + 1 means this location can be masked and can be calculated in the MLM loss + """ + output_label = deepcopy(input_ids) + for j in range(input_ids.size(0)): + for i in range(input_ids.size(1)): + prob = random.random() + # mask token with probability + ratio = 0.15 + if greenlight_map is not None and greenlight_map[j,i] == -1: + output_label[j,i] = -100 + continue + + if (not input_ids[j,i] == padding_token_id) and prob < ratio: + prob /= ratio + + # 80% randomly change token to mask token + if prob < 0.8: + input_ids[j,i] = mask_token_id + + # 10% randomly change token to random token + elif prob < 0.9: + input_ids[j,i] = random.choice(vocabs) + + else: + # no masking token (will be ignored by loss function later) + output_label[j,i] = -100 + + if greenlight_map is not None and greenlight_map[j,i] != 1: + output_label[j,i] = -100 # If this location should not be masked + return input_ids, output_label + + +class GeneralizedVLRCNN(nn.Module): + """ + Main class for Generalized R-CNN. Currently supports boxes and masks. + It consists of three main parts: + - backbone + - rpn + - heads: takes the features + the proposals from the RPN and computes + detections / masks from it. + """ + + def __init__(self, cfg): + super(GeneralizedVLRCNN, self).__init__() + self.cfg = cfg + + # visual encoder + self.backbone = build_backbone(cfg) + + # language encoder + if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": + # self.tokenizer = build_tokenizer("clip") + from transformers import CLIPTokenizerFast + if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + print("Reuse token 'ðŁĴij' (token_id = 49404) for mask token!") + self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True, mask_token='ðŁĴij') + else: + self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True) + else: + self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) + self.tokenizer_vocab = self.tokenizer.get_vocab() + self.tokenizer_vocab_ids = [item for key, item in self.tokenizer_vocab.items()] + + self.language_backbone = build_language_backbone(cfg) + + self.rpn = build_rpn(cfg) + self.roi_heads = build_roi_heads(cfg) + self.DEBUG = cfg.MODEL.DEBUG + + self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE + self.freeze_fpn = cfg.MODEL.FPN.FREEZE + self.freeze_rpn = cfg.MODEL.RPN.FREEZE + self.add_linear_layer = cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER + + self.force_boxes = cfg.MODEL.RPN.FORCE_BOXES + + if cfg.MODEL.LINEAR_PROB: + assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" + if hasattr(self.backbone, 'fpn'): + assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" + self.linear_prob = cfg.MODEL.LINEAR_PROB + self.freeze_cls_logits = cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS + if cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + # disable cls_logits + if hasattr(self.rpn.head, 'cls_logits'): + for p in self.rpn.head.cls_logits.parameters(): + p.requires_grad = False + + self.freeze_language_backbone = self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE + if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: + for p in self.language_backbone.parameters(): + p.requires_grad = False + + self.use_mlm_loss = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS + self.mlm_loss_for_only_positives = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES + + if self.cfg.GLIPKNOW.KNOWLEDGE_FILE: + from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file + self.class_name_to_knowledge = load_from_yaml_file(self.cfg.GLIPKNOW.KNOWLEDGE_FILE) + self.class_name_list = sorted([k for k in self.class_name_to_knowledge]) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(GeneralizedVLRCNN, self).train(mode) + if self.freeze_backbone: + self.backbone.body.eval() + for p in self.backbone.body.parameters(): + p.requires_grad = False + if self.freeze_fpn: + self.backbone.fpn.eval() + for p in self.backbone.fpn.parameters(): + p.requires_grad = False + if self.freeze_rpn: + if hasattr(self.rpn, 'head'): + self.rpn.head.eval() + for p in self.rpn.parameters(): + p.requires_grad = False + if self.linear_prob: + if self.rpn is not None: + for key, value in self.rpn.named_parameters(): + if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key): + value.requires_grad = False + if self.roi_heads is not None: + for key, value in self.roi_heads.named_parameters(): + if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key): + value.requires_grad = False + if self.freeze_cls_logits: + if hasattr(self.rpn.head, 'cls_logits'): + self.rpn.head.cls_logits.eval() + for p in self.rpn.head.cls_logits.parameters(): + p.requires_grad = False + if self.add_linear_layer: + if self.rpn is not None: + for key, p in self.rpn.named_parameters(): + if 'tunable_linear' in key: + p.requires_grad = True + + if self.freeze_language_backbone: + self.language_backbone.eval() + for p in self.language_backbone.parameters(): + p.requires_grad = False + + def forward(self, + images, + targets=None, + captions=None, + positive_map=None, + greenlight_map=None): + """ + Arguments: + images (list[Tensor] or ImageList): images to be processed + targets (list[BoxList]): ground-truth boxes present in the image (optional) + + mask_black_list: batch x 256, indicates whether or not a certain token is maskable or not + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + images = to_image_list(images) + # batch_size = images.tensors.shape[0] + device = images.tensors.device + + + if self.cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT: + language_dict_features, positive_map = self._forward_language_parallel( + captions=captions, targets=targets, device=device, + positive_map=positive_map) + else: + # language embedding + language_dict_features = {} + if captions is not None: + #print(captions[0]) + tokenized = self.tokenizer.batch_encode_plus(captions, + max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, + padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", + return_special_tokens_mask=True, + return_tensors='pt', + truncation=True).to(device) + if self.use_mlm_loss: + if not self.mlm_loss_for_only_positives: + greenlight_map = None + input_ids, mlm_labels = random_word( + input_ids=tokenized.input_ids, + mask_token_id=self.tokenizer.mask_token_id, + vocabs=self.tokenizer_vocab_ids, + padding_token_id=self.tokenizer.pad_token_id, + greenlight_map=greenlight_map) + else: + input_ids = tokenized.input_ids + mlm_labels = None + + + tokenizer_input = {"input_ids": input_ids, + "attention_mask": tokenized.attention_mask} + + if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: + with torch.no_grad(): + language_dict_features = self.language_backbone(tokenizer_input) + else: + language_dict_features = self.language_backbone(tokenizer_input) + + # ONE HOT + if self.cfg.DATASETS.ONE_HOT: + new_masks = torch.zeros_like(language_dict_features['masks'], + device=language_dict_features['masks'].device) + new_masks[:, :self.cfg.MODEL.DYHEAD.NUM_CLASSES] = 1 + language_dict_features['masks'] = new_masks + + # MASK ALL SPECIAL TOKENS + if self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL: + language_dict_features["masks"] = 1 - tokenized.special_tokens_mask + + language_dict_features["mlm_labels"] = mlm_labels + + # visual embedding + swint_feature_c4 = None + if 'vl' in self.cfg.MODEL.SWINT.VERSION: + # the backbone only updates the "hidden" field in language_dict_features + inputs = {"img": images.tensors, "lang": language_dict_features} + visual_features, language_dict_features, swint_feature_c4 = self.backbone(inputs) + else: + visual_features = self.backbone(images.tensors) + + # rpn force boxes + if targets: + targets = [target.to(device) + for target in targets if target is not None] + + if self.force_boxes: + proposals = [] + for t in targets: + tb = t.copy_with_fields(["labels"]) + tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) + proposals.append(tb) + if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: + _, proposal_losses, fused_visual_features = self.rpn( + images, visual_features, targets, language_dict_features, + positive_map, captions, swint_feature_c4) + elif self.training: + null_loss = 0 + for key, param in self.rpn.named_parameters(): + null_loss += 0.0 * param.sum() + proposal_losses = {('rpn_null_loss', null_loss)} + else: + proposals, proposal_losses, fused_visual_features = self.rpn(images, visual_features, targets, language_dict_features, positive_map, + captions, swint_feature_c4) + if self.roi_heads: + if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"): + if self.training: + # "Only support VL mask head right now!!" + assert len(targets) == 1 and len(targets[0]) == len(positive_map), "shape match assert for mask head!!" + # Not necessary but as a safe guard: + # use the binary 0/1 positive map to replace the normalized positive map + targets[0].add_field("positive_map", positive_map) + # TODO: make sure that this use of language_dict_features is correct!! Its content should be changed in self.rpn + if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: + x, result, detector_losses = self.roi_heads( + fused_visual_features, proposals, targets, + language_dict_features=language_dict_features, + positive_map_label_to_token=positive_map if not self.training else None + ) + else: + x, result, detector_losses = self.roi_heads( + visual_features, proposals, targets, + language_dict_features=language_dict_features, + positive_map_label_to_token=positive_map if not self.training else None + ) + else: + # RPN-only models don't have roi_heads + x = visual_features + result = proposals + detector_losses = {} + + if self.training: + losses = {} + losses.update(detector_losses) + losses.update(proposal_losses) + return losses + + return result + + def _forward_language_parallel(self, captions=None, targets=None, + device=None, positive_map=None): + ktype = self.cfg.GLIPKNOW.KNOWLEDGE_TYPE + def _construct_captions_from_class_names(class_names): + captions = [] + for c in class_names: + try: + info = self.class_name_to_knowledge[c] + cap = info['clean_name'] + + # combine wiki and gpt3 knowledge + if self.cfg.GLIPKNOW.WIKI_AND_GPT3: + ktype = 'def_wiki' + know_seq = info[ktype] + + ktype = 'gpt3' + if ktype == 'gpt3' or type(info[ktype]) == list: + know_seq += ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ]) + + cap += ': ' + know_seq + + # only one knoweldge source is used + else: + if ktype and ktype in info and info[ktype]: + if ktype == 'gpt3' or type(info[ktype]) == list: + know_seq = ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ]) + else: + know_seq = info[ktype] + cap += ': ' + know_seq + except: + cap = c + print(f'cap {cap}, c {c}') + + + captions.append(cap) + return captions + + if self.training: + assert captions is None + assert targets is not None + + max_classes_per_batch = self.cfg.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN + if max_classes_per_batch >= len(self.class_name_list): + shuffled_class_names = self.class_name_list.copy() + random.shuffle(shuffled_class_names) + if max_classes_per_batch > len(shuffled_class_names): + shuffled_class_names.extend(shuffled_class_names[:max_classes_per_batch + -len(shuffled_class_names)]) + random.shuffle(shuffled_class_names) + else: + label_list = [] + label_to_idx = {} + for target_per_im in targets: + labels_per_im = target_per_im.get_field('label_names') + for label in labels_per_im: + if label not in label_to_idx: + label_to_idx[label] = len(label_list) + label_list.append(label) + + label_list = label_list[:max_classes_per_batch] + if len(label_list) < max_classes_per_batch: + all_neg_classes = [c for c in self.class_name_list if c not + in label_to_idx] + neg_label_list = random.sample(all_neg_classes, + max_classes_per_batch - len(label_list)) + label_list.extend(neg_label_list) + random.shuffle(label_list) + shuffled_class_names = label_list + + label_to_shuffled_idx = {l: i for i, l in + enumerate(shuffled_class_names)} + total_boxes = sum(len(t) for t in targets) + positive_map = torch.zeros((total_boxes, max_classes_per_batch+1), + device=device) + offset = 0 + for target_per_im in targets: + labels_per_im = target_per_im.get_field('label_names') + for label in labels_per_im: + j = label_to_shuffled_idx.get(label, -1) + if j >= 0: + positive_map[offset, j] = 1 + offset += 1 + captions = _construct_captions_from_class_names(shuffled_class_names) + captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719 + batch_size = len(targets) + + else: + assert captions is not None + batch_size = 1 + assert len(captions) == 1 + class_names = captions[0] + max_classes_per_batch = len(class_names) + captions = _construct_captions_from_class_names(class_names) + captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719 + + tokenized = self.tokenizer.batch_encode_plus(captions, + max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, + padding="longest", + return_special_tokens_mask=True, + return_tensors='pt', + truncation=True).to(device) + assert not self.use_mlm_loss + tokenizer_input = {"input_ids": tokenized.input_ids, + "attention_mask": tokenized.attention_mask} + + if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: + with torch.no_grad(): + language_dict_features = self.language_backbone(tokenizer_input) + else: + language_dict_features = self.language_backbone(tokenizer_input) + + assert not self.cfg.DATASETS.ONE_HOT + assert not self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL + + agg_type = self.cfg.GLIPKNOW.LAN_FEATURE_AGG_TYPE + agg_feats = language_dict_features['hidden'] + agg_emb = language_dict_features['embedded'] + if agg_type == 'first': + agg_feats = agg_feats[:, 0, :] + agg_emb = agg_emb[:, 0, :] + elif agg_type == 'mean': + attn_mask = language_dict_features['masks'] + seq_len = attn_mask.sum(-1).unsqueeze(-1).float() + agg_feats = agg_feats * attn_mask.unsqueeze(-1).float() + agg_feats = agg_feats.sum(1) / seq_len + agg_emb = agg_emb * attn_mask.unsqueeze(-1).float() + agg_emb = agg_emb.sum(1) / seq_len + else: + raise ValueError('not supported GLIPKNOW.LAN_FEATURE_AGG_TYPE: {}'.format(agg_type)) + + expanded_features = agg_feats.unsqueeze(0).repeat(batch_size, 1, 1) + expanded_embedding = agg_emb.unsqueeze(0).repeat(batch_size, 1, 1) + + lang_dict = {} + lang_dict["mlm_labels"] = None + lang_dict["aggregate"] = None + lang_dict["embedded"] = expanded_embedding + lang_dict['hidden'] = expanded_features + lang_dict["masks"] = torch.ones((batch_size, max_classes_per_batch+1), + device=device, dtype=language_dict_features['masks'].dtype) + # in GLIP setting, the token at the end of seqence is usually [PAD], and is masked out + # if [noobj] is not masked out, the loss sum is very big, as most + # anchors are matched to [noobj] + lang_dict["masks"][:,-1] = 0 + return lang_dict, positive_map + diff --git a/maskrcnn_benchmark/modeling/language_backbone/__init__.py b/maskrcnn_benchmark/modeling/language_backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f78d6ab1d5b2d59007bb4c042d0fc1a5a06253da --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/__init__.py @@ -0,0 +1,6 @@ +from .backbone import build_backbone as build_language_backbone +from .build import build_tokenizer + +from .hfpt_tokenizer import HFPTTokenizer +from .simple_tokenizer import SimpleTokenizer +from .clip_model import CLIPTransformer diff --git a/maskrcnn_benchmark/modeling/language_backbone/backbone.py b/maskrcnn_benchmark/modeling/language_backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..632b622092b52297c690cd9c0cebcef48b842e48 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/backbone.py @@ -0,0 +1,45 @@ +from collections import OrderedDict +import torch +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from . import bert_model +from . import rnn_model +from . import clip_model +from . import word_utils + + +@registry.LANGUAGE_BACKBONES.register("bert-base-uncased") +def build_bert_backbone(cfg): + body = bert_model.BertEncoder(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +@registry.LANGUAGE_BACKBONES.register("roberta-base") +def build_bert_backbone(cfg): + body = bert_model.BertEncoder(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +@registry.LANGUAGE_BACKBONES.register("rnn") +def build_rnn_backbone(cfg): + body = rnn_model.RNNEnoder(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +@registry.LANGUAGE_BACKBONES.register("clip") +def build_clip_backbone(cfg): + body = clip_model.CLIPTransformer(cfg) + model = nn.Sequential(OrderedDict([("body", body)])) + return model + + +def build_backbone(cfg): + assert cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in registry.LANGUAGE_BACKBONES, \ + "cfg.MODEL.LANGUAGE_BACKBONE.TYPE: {} is not registered in registry".format( + cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE + ) + return registry.LANGUAGE_BACKBONES[cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE](cfg) diff --git a/maskrcnn_benchmark/modeling/language_backbone/bert_model.py b/maskrcnn_benchmark/modeling/language_backbone/bert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4b69c54fc06ef600351da4addae354d971afb0e1 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/bert_model.py @@ -0,0 +1,79 @@ +from copy import deepcopy +import numpy as np +import torch +from torch import nn + +# from pytorch_pretrained_bert.modeling import BertModel +from transformers import BertConfig, RobertaConfig, RobertaModel, BertModel + + +class BertEncoder(nn.Module): + def __init__(self, cfg): + super(BertEncoder, self).__init__() + self.cfg = cfg + self.bert_name = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE + print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT) + + if self.bert_name == "bert-base-uncased": + config = BertConfig.from_pretrained(self.bert_name) + config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT + self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) + self.language_dim = 768 + elif self.bert_name == "roberta-base": + config = RobertaConfig.from_pretrained(self.bert_name) + config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT + self.model = RobertaModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config) + self.language_dim = 768 + else: + raise NotImplementedError + + self.num_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS + + def forward(self, x): + input = x["input_ids"] + mask = x["attention_mask"] + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + # with padding, always 256 + outputs = self.model( + input_ids=input, + attention_mask=mask, + output_hidden_states=True, + ) + # outputs has 13 layers, 1 input layer and 12 hidden layers + encoded_layers = outputs.hidden_states[1:] + features = None + features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1) + + # language embedding has shape [len(phrase), seq_len, language_dim] + features = features / self.num_layers + + embedded = features * mask.unsqueeze(-1).float() + aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float()) + + else: + # without padding, only consider positive_tokens + max_len = (input != 0).sum(1).max().item() + outputs = self.model( + input_ids=input[:, :max_len], + attention_mask=mask[:, :max_len], + output_hidden_states=True, + ) + # outputs has 13 layers, 1 input layer and 12 hidden layers + encoded_layers = outputs.hidden_states[1:] + + features = None + features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1) + # language embedding has shape [len(phrase), seq_len, language_dim] + features = features / self.num_layers + + embedded = features * mask[:, :max_len].unsqueeze(-1).float() + aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float()) + + ret = { + "aggregate": aggregate, + "embedded": embedded, + "masks": mask, + "hidden": encoded_layers[-1] + } + return ret diff --git a/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz b/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..e74ad860329b14ff6b53f3ae0b007bec308cc5af --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc496842c2d4b6e40b2bd1207a5ded6e425e6a7cf9c16afa86caa5d7d12df233 +size 1355337 diff --git a/maskrcnn_benchmark/modeling/language_backbone/build.py b/maskrcnn_benchmark/modeling/language_backbone/build.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fc534df7864869d89734b7ca48ba6d56fe5a58 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/build.py @@ -0,0 +1,18 @@ +from .simple_tokenizer import SimpleTokenizer + + +def build_tokenizer(tokenizer_name): + tokenizer = None + if tokenizer_name == 'clip': + tokenizer = SimpleTokenizer() + elif 'hf_' in tokenizer_name: + from .hfpt_tokenizer import HFPTTokenizer + + tokenizer = HFPTTokenizer(pt_name=tokenizer_name[3:]) + elif 'hfc_' in tokenizer_name: + from .hfpt_tokenizer import HFPTTokenizer + tokenizer = HFPTTokenizer(pt_name=tokenizer_name[4:]) + else: + raise ValueError('Unknown tokenizer') + + return tokenizer diff --git a/maskrcnn_benchmark/modeling/language_backbone/clip_model.py b/maskrcnn_benchmark/modeling/language_backbone/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..781f4f4ac5dabd7d232741fe88d40785ee2c1919 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/clip_model.py @@ -0,0 +1,200 @@ +from collections import OrderedDict +import logging +import os + +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from maskrcnn_benchmark.config import try_to_find + +from timm.models.layers import DropPath, trunc_normal_ + +logger = logging.getLogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + pdtype = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x.to(pdtype) + self.bias + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path: float = 0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ + if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] + + def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class CLIPTransformer(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.use_checkpoint = cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT + print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT) + + self.context_length = self.cfg.MODEL.CLIP.CONTEXT_LENGTH + self.width = self.cfg.MODEL.CLIP.WIDTH + self.layers = self.cfg.MODEL.CLIP.LAYERS + self.heads = self.cfg.MODEL.CLIP.HEADS + self.drop_path = self.cfg.MODEL.CLIP.DROP_PATH + self.vocab_size = self.cfg.MODEL.CLIP.VOCAB_SIZE + + self.token_embedding = nn.Embedding(self.vocab_size, self.width) + + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, self.width) + ) + + # attn_mask = self.build_attention_mask() + attn_mask = None + + dpr = [x.item() for x in torch.linspace(0, self.drop_path, self.layers)] # stochastic depth decay rule + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(self.width, self.heads, attn_mask, dpr[i]) + for i in range(self.layers) + ] + ) + + self.ln_final = LayerNorm(self.width) + + trunc_normal_(self.positional_embedding, std=.02) + # nn.init.normal_(self.token_embedding, std=.02) + trunc_normal_(self.token_embedding.weight, std=.02) + self.apply(self._init_weights) + + # loading pre-trained weight from our CLIP models + if len(self.cfg.MODEL.LANGUAGE_BACKBONE.WEIGHT) > 0: + self.init_weights(pretrained=try_to_find(self.cfg.MODEL.LANGUAGE_BACKBONE.WEIGHT), + pretrained_layers=['*']) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + nn.init.constant_(m.bias, 0) + + def resize_pos_embed_1d(self, posemb, shape_new): + # rescale the grid of position embeddings when loading from state_dict + ntok_old = posemb.shape[0] + if ntok_old > 1: + ntok_new = shape_new[0] + posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1).unsqueeze(dim=-1) + posemb_grid = F.interpolate(posemb_grid, size=[ntok_new, 1], mode='bilinear') + posemb_grid = posemb_grid.squeeze(dim=-1).permute(0, 2, 1).squeeze(dim=0) + posemb = posemb_grid + return posemb + + def init_weights(self, pretrained="", pretrained_layers=[], verbose=False): + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained, map_location="cpu") + logger.info(f'=> loading pretrained clip text model {pretrained}') + model_dict = self.state_dict() + + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] is '*' + ) + if need_init: + if k.startswith('text.') and k[5:] in model_dict.keys(): + need_init_state_dict[k[5:]] = v + + # notice the context length now changes from 77 to 256, so we need to resize the positional embedding + if "positional_embedding" in need_init_state_dict.keys(): + old_pos_embed = need_init_state_dict["positional_embedding"].float() + new_pos_embed = self.resize_pos_embed_1d(old_pos_embed, + (self.cfg.MODEL.CLIP.CONTEXT_LENGTH, old_pos_embed.shape[1])) + need_init_state_dict["positional_embedding"] = new_pos_embed + self.load_state_dict(need_init_state_dict, strict=True) + + @torch.jit.ignore + def no_weight_decay(self): + return { + 'positional_embedding', + 'token_embedding', + } + + def forward(self, text): + input = text["input_ids"] + mask = text["attention_mask"] + # get extended attention mask for nn.MultiHeadAttention + key_padding_mask = (1.0 - mask).to(torch.bool) + + x = self.token_embedding(input) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + + for resblock in self.resblocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(resblock, x, key_padding_mask) + else: + x = resblock(x, key_padding_mask) + + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_final(x) + + # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + + ret = { + "aggregate": x, + "embedded": x, + "masks": mask, + "hidden": x + } + + return ret \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..06dce89d75e3b91ee3405dd2e449b9d48dc861f2 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/hfpt_tokenizer.py @@ -0,0 +1,99 @@ +from typing import Union, List + +from transformers import AutoTokenizer +import torch + + +class HFPTTokenizer(object): + def __init__(self, pt_name=None): + + self.pt_name = pt_name + self.added_sep_token = 0 + self.added_cls_token = 0 + self.enable_add_tokens = False + self.gpt_special_case = ((not self.enable_add_tokens) and ('gpt' in self.pt_name)) + + if (pt_name is None): + self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') + else: + self.tokenizer = AutoTokenizer.from_pretrained(pt_name) + + # Adding tokens to GPT causing NaN training loss. + # Disable for now until further investigation. + if (self.enable_add_tokens): + if (self.tokenizer.sep_token is None): + self.tokenizer.add_special_tokens({'sep_token': ''}) + self.added_sep_token = 1 + + if (self.tokenizer.cls_token is None): + self.tokenizer.add_special_tokens({'cls_token': ''}) + self.added_cls_token = 1 + + if (self.gpt_special_case): + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.sep_token = self.tokenizer.eos_token + + def get_eot_token(self): + return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)[0] + + def get_sot_token(self): + return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)[0] + + def get_eot_token_list(self): + return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False) + + def get_sot_token_list(self): + return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False) + + def get_tokenizer_obj(self): + return self.tokenizer + + # Language model needs to know if new tokens + # were added to the dictionary. + def check_added_tokens(self): + return self.added_sep_token + self.added_cls_token + + def tokenize(self, texts: Union[str, List[str]], context_length: int = 77): + if isinstance(texts, str): + texts = [texts] + + padding = 'max_length' + + seqstart = [] + seqtok = [] + seqend = [] + + max_length = context_length + + if (self.added_cls_token > 0): + seqstart = self.get_sot_token_list() + max_length = max_length - 1 + + if (self.added_sep_token > 0): + seqend = self.get_eot_token_list() + max_length = max_length - 1 + + tokens = self.tokenizer( + texts, padding=padding, + truncation=True, + max_length=max_length + )['input_ids'] + + for i in range(len(tokens)): + tokens[i] = seqstart + tokens[i] + seqend + + if (self.gpt_special_case): + for i in range(len(tokens)): + tokens[i][-1] = self.get_eot_token() + + # print(str(tokens)) + + result = torch.Tensor(tokens).type(torch.LongTensor) + + return result + + def get_vocab_size(self): + return self.tokenizer.vocab_size + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77): + return self.tokenize(texts, context_length) diff --git a/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py b/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..18d60efcb08675b73bce211e42c0c180ffe8d267 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/rnn_model.py @@ -0,0 +1,115 @@ +from copy import deepcopy +import numpy as np +import torch +from torch import nn + + +class RNNEnoder(nn.Module): + def __init__(self, cfg): + super(RNNEnoder, self).__init__() + self.cfg = cfg + + self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE + self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH + self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE + self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE + self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE + self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL + self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P + self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P + self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS + self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH + self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE + + # language encoder + self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size) + self.input_dropout = nn.Dropout(self.input_dropout_p) + self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU()) + self.rnn = getattr(nn, self.rnn_type.upper())(self.word_vec_size, + self.hidden_size, + self.n_layers, + batch_first=True, + bidirectional=self.bidirectional, + dropout=self.dropout_p) + self.num_dirs = 2 if self.bidirectional else 1 + + def forward(self, input, mask=None): + word_id = input + max_len = (word_id != 0).sum(1).max().item() + word_id = word_id[:, :max_len] # mask zero + # embedding + output, hidden, embedded, final_output = self.RNNEncode(word_id) + return { + 'hidden': hidden, + 'output': output, + 'embedded': embedded, + 'final_output': final_output, + } + + def encode(self, input_labels): + """ + Inputs: + - input_labels: Variable long (batch, seq_len) + Outputs: + - output : Variable float (batch, max_len, hidden_size * num_dirs) + - hidden : Variable float (batch, num_layers * num_dirs * hidden_size) + - embedded: Variable float (batch, max_len, word_vec_size) + """ + device = input_labels.device + if self.variable_length: + input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels) + input_labels = input_labels[sort_idxs] + + embedded = self.embedding(input_labels) # (n, seq_len, word_embedding_size) + embedded = self.input_dropout(embedded) # (n, seq_len, word_embedding_size) + embedded = self.mlp(embedded) # (n, seq_len, word_vec_size) + + if self.variable_length: + if self.variable_length: + embedded = nn.utils.rnn.pack_padded_sequence(embedded, \ + sorted_lengths_list, \ + batch_first=True) + # forward rnn + self.rnn.flatten_parameters() + output, hidden = self.rnn(embedded) + + # recover + if self.variable_length: + # recover embedded + embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded, + batch_first=True) # (batch, max_len, word_vec_size) + embedded = embedded[recover_idxs] + + # recover output + output, _ = nn.utils.rnn.pad_packed_sequence(output, + batch_first=True) # (batch, max_len, hidden_size * num_dir) + output = output[recover_idxs] + + # recover hidden + if self.rnn_type == 'lstm': + hidden = hidden[0] # hidden state + hidden = hidden[:, recover_idxs, :] # (num_layers * num_dirs, batch, hidden_size) + hidden = hidden.transpose(0, 1).contiguous() # (batch, num_layers * num_dirs, hidden_size) + hidden = hidden.view(hidden.size(0), -1) # (batch, num_layers * num_dirs * hidden_size) + + # final output + finnal_output = [] + for ii in range(output.shape[0]): + finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :]) + finnal_output = torch.stack(finnal_output, dim=0) # (batch, number_dirs * hidden_size) + + return output, hidden, embedded, finnal_output + + def sort_inputs(self, input_labels): # sort input labels by descending + device = input_labels.device + input_lengths = (input_labels != 0).sum(1) + input_lengths_list = input_lengths.data.cpu().numpy().tolist() + sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() # list of sorted input_lengths + sort_idxs = np.argsort(input_lengths_list)[::-1].tolist() + s2r = {s: r for r, s in enumerate(sort_idxs)} + recover_idxs = [s2r[s] for s in range(len(input_lengths_list))] + assert max(input_lengths_list) == input_labels.size(1) + # move to long tensor + sort_idxs = input_labels.data.new(sort_idxs).long().to(device) # Variable long + recover_idxs = input_labels.data.new(recover_idxs).long().to(device) # Variable long + return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs diff --git a/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8653b554bce885162452067b67359f07eb022174 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/simple_tokenizer.py @@ -0,0 +1,173 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +from typing import Union, List + +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def get_vocab_size(self): + return 49408 + + def get_eot_token(self): + return self.encoder["<|endoftext|>"] + + def get_sot_token(self): + return self.encoder["<|startoftext|>"] + + def check_added_tokens(self): + return 0 + + def get_tokenizer_obj(self): + return None + + def tokenize(self, texts: Union[str, List[str]], context_length: int = 77): + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] + # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77): + return self.tokenize(texts, context_length) diff --git a/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py b/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c63c73cfba5bd3a580e93852cd9c91fac00b35 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/test_clip_tokenizer.py @@ -0,0 +1,8 @@ +from maskrcnn_benchmark.modeling.language_backbone import build_tokenizer + +if __name__ == '__main__': + + tokenizer2 = build_tokenizer("clip") + tokenized2 = tokenizer2( + ["Detectest : fishid. jellyfishioasod. penguinasd. puffin.asd shark. starfish. round stingray"]) + print(tokenized2) diff --git a/maskrcnn_benchmark/modeling/language_backbone/word_utils.py b/maskrcnn_benchmark/modeling/language_backbone/word_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69f453ba70bf8832f3f4124d82467b3803b09af1 --- /dev/null +++ b/maskrcnn_benchmark/modeling/language_backbone/word_utils.py @@ -0,0 +1,100 @@ +""" +Language-related data loading helper functions and class wrappers. +""" + +import re +import torch +import codecs + +UNK_TOKEN = '' +PAD_TOKEN = '' +END_TOKEN = '' +SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') + + +class Dictionary(object): + def __init__(self): + self.word2idx = {} + self.idx2word = [] + + def add_word(self, word): + if word not in self.word2idx: + self.idx2word.append(word) + self.word2idx[word] = len(self.idx2word) - 1 + return self.word2idx[word] + + def __len__(self): + return len(self.idx2word) + + def __getitem__(self, a): + if isinstance(a, int): + return self.idx2word[a] + elif isinstance(a, list): + return [self.idx2word[x] for x in a] + elif isinstance(a, str): + return self.word2idx[a] + else: + raise TypeError("Query word/index argument must be int or str") + + def __contains__(self, word): + return word in self.word2idx + + +class Corpus(object): + def __init__(self): + self.dictionary = Dictionary() + + def set_max_len(self, value): + self.max_len = value + + def load_file(self, filename): + with codecs.open(filename, 'r', 'utf-8') as f: + for line in f: + line = line.strip() + self.add_to_corpus(line) + self.dictionary.add_word(UNK_TOKEN) + self.dictionary.add_word(PAD_TOKEN) + + def add_to_corpus(self, line): + """Tokenizes a text line.""" + # Add words to the dictionary + words = line.split() + # tokens = len(words) + for word in words: + word = word.lower() + self.dictionary.add_word(word) + + def tokenize(self, line, max_len=20): + # Tokenize line contents + words = SENTENCE_SPLIT_REGEX.split(line.strip()) + # words = [w.lower() for w in words if len(w) > 0] + words = [w.lower() for w in words if (len(w) > 0 and w != ' ')] ## do not include space as a token + + if words[-1] == '.': + words = words[:-1] + + if max_len > 0: + if len(words) > max_len: + words = words[:max_len] + elif len(words) < max_len: + # words = [PAD_TOKEN] * (max_len - len(words)) + words + words = words + [END_TOKEN] + [PAD_TOKEN] * (max_len - len(words) - 1) + + tokens = len(words) ## for end token + ids = torch.LongTensor(tokens) + token = 0 + for word in words: + if word not in self.dictionary: + word = UNK_TOKEN + # print(word, type(word), word.encode('ascii','ignore').decode('ascii'), type(word.encode('ascii','ignore').decode('ascii'))) + if type(word) != type('a'): + print(word, type(word), word.encode('ascii', 'ignore').decode('ascii'), + type(word.encode('ascii', 'ignore').decode('ascii'))) + word = word.encode('ascii', 'ignore').decode('ascii') + ids[token] = self.dictionary[word] + token += 1 + # ids[token] = self.dictionary[END_TOKEN] + return ids + + def __len__(self): + return len(self.dictionary) diff --git a/maskrcnn_benchmark/modeling/make_layers.py b/maskrcnn_benchmark/modeling/make_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2216a952d04a295d0cf474d2f562903081fe0ea6 --- /dev/null +++ b/maskrcnn_benchmark/modeling/make_layers.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Miscellaneous utility functions +""" + +import torch +from torch import nn +from torch.nn import functional as F +from maskrcnn_benchmark.config import cfg +from maskrcnn_benchmark.layers import Conv2d, DYReLU +from maskrcnn_benchmark.modeling.poolers import Pooler + + +def get_group_gn(dim, dim_per_gp, num_groups): + """get number of groups used by GroupNorm, based on number of channels.""" + assert dim_per_gp == -1 or num_groups == -1, \ + "GroupNorm: can only specify G or C/G." + + if dim_per_gp > 0: + assert dim % dim_per_gp == 0, \ + "dim: {}, dim_per_gp: {}".format(dim, dim_per_gp) + group_gn = dim // dim_per_gp + else: + assert dim % num_groups == 0, \ + "dim: {}, num_groups: {}".format(dim, num_groups) + group_gn = num_groups + + return group_gn + + +def group_norm(out_channels, affine=True, divisor=1): + out_channels = out_channels // divisor + dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor + num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor + eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5 + return torch.nn.GroupNorm( + get_group_gn(out_channels, dim_per_gp, num_groups), + out_channels, + eps, + affine + ) + + +def make_conv3x3( + in_channels, + out_channels, + dilation=1, + stride=1, + use_gn=False, + use_relu=False, + kaiming_init=True +): + conv = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False if use_gn else True + ) + if kaiming_init: + nn.init.kaiming_normal_( + conv.weight, mode="fan_out", nonlinearity="relu" + ) + else: + torch.nn.init.normal_(conv.weight, std=0.01) + if not use_gn: + nn.init.constant_(conv.bias, 0) + module = [conv,] + if use_gn: + module.append(group_norm(out_channels)) + if use_relu: + module.append(nn.ReLU(inplace=True)) + if len(module) > 1: + return nn.Sequential(*module) + return conv + + +def make_fc(dim_in, hidden_dim, use_gn=False): + ''' + Caffe2 implementation uses XavierFill, which in fact + corresponds to kaiming_uniform_ in PyTorch + ''' + if use_gn: + fc = nn.Linear(dim_in, hidden_dim, bias=False) + nn.init.kaiming_uniform_(fc.weight, a=1) + return nn.Sequential(fc, group_norm(hidden_dim)) + fc = nn.Linear(dim_in, hidden_dim) + nn.init.kaiming_uniform_(fc.weight, a=1) + nn.init.constant_(fc.bias, 0) + return fc + + +def conv_with_kaiming_uniform(use_gn=False, use_relu=False, use_dyrelu=False): + def make_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1 + ): + conv = Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + bias=False if use_gn else True + ) + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(conv.weight, a=1) + if not use_gn: + nn.init.constant_(conv.bias, 0) + module = [conv,] + if use_gn: + module.append(group_norm(out_channels)) + if use_relu: + module.append(nn.ReLU(inplace=True)) + if use_dyrelu: + module.append(DYReLU(out_channels, out_channels, use_spatial=True)) + if len(module) > 1: + return nn.Sequential(*module) + return conv + + return make_conv diff --git a/maskrcnn_benchmark/modeling/matcher.py b/maskrcnn_benchmark/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..d080b0546b8e3e581ced4fbac89cca4dfde78b1a --- /dev/null +++ b/maskrcnn_benchmark/modeling/matcher.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + + +class Matcher(object): + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be assigned to zero or more predicted elements. + + Matching is based on the MxN match_quality_matrix, that characterizes how well + each (ground-truth, predicted)-pair match. For example, if the elements are + boxes, the matrix may contain box IoU overlap values. + + The matcher returns a tensor of size N containing the index of the ground-truth + element m that matches to prediction n. If there is no match, a negative value + is returned. + """ + + BELOW_LOW_THRESHOLD = -1 + BETWEEN_THRESHOLDS = -2 + + def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): + """ + Args: + high_threshold (float): quality values greater than or equal to + this value are candidate matches. + low_threshold (float): a lower quality threshold used to stratify + matches into three levels: + 1) matches >= high_threshold + 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) + 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) + allow_low_quality_matches (bool): if True, produce additional matches + for predictions that have only low-quality match candidates. See + set_low_quality_matches_ for more details. + """ + assert low_threshold <= high_threshold + self.high_threshold = high_threshold + self.low_threshold = low_threshold + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix): + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted elements. + + Returns: + matches (Tensor[int64]): an N tensor where N[i] is a matched gt in + [0, M - 1] or a negative value indicating that prediction i could not + be matched. + """ + if match_quality_matrix.numel() == 0: + # empty targets or proposals not supported during training + if match_quality_matrix.shape[0] == 0: + # raise ValueError( + # "No ground-truth boxes available for one of the images " + # "during training") + length = match_quality_matrix.size(1) + device = match_quality_matrix.device + return torch.ones(length, dtype=torch.int64, device=device) * -1 + else: + raise ValueError( + "No proposal boxes available for one of the images " + "during training") + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + if self.allow_low_quality_matches: + all_matches = matches.clone() + + # Assign candidate matches with low quality to negative (unassigned) values + below_low_threshold = matched_vals < self.low_threshold + between_thresholds = (matched_vals >= self.low_threshold) & ( + matched_vals < self.high_threshold + ) + matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD + matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS + + if self.allow_low_quality_matches: + self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) + + return matches + + def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth with which it has the highest + quality value. + """ + # For each gt, find the prediction with which it has highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find highest quality match available, even if it is low, including ties + gt_pred_pairs_of_highest_quality = torch.nonzero( + match_quality_matrix == highest_quality_foreach_gt[:, None] + ) + # Example gt_pred_pairs_of_highest_quality: + # tensor([[ 0, 39796], + # [ 1, 32055], + # [ 1, 32070], + # [ 2, 39190], + # [ 2, 40255], + # [ 3, 40390], + # [ 3, 41455], + # [ 4, 45470], + # [ 5, 45325], + # [ 5, 46390]]) + # Each row is a (gt index, prediction index) + # Note how gt items 1, 2, 3, and 5 each have two ties + + pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] + matches[pred_inds_to_update] = all_matches[pred_inds_to_update] diff --git a/maskrcnn_benchmark/modeling/poolers.py b/maskrcnn_benchmark/modeling/poolers.py new file mode 100644 index 0000000000000000000000000000000000000000..ad136731b58a97bbf3d8266ee301d1c930c8fa6e --- /dev/null +++ b/maskrcnn_benchmark/modeling/poolers.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.layers import ROIAlign, ROIAlignV2 + +from .utils import cat + + +class LevelMapper(object): + """Determine which FPN level each RoI in a set of RoIs should map to based + on the heuristic in the FPN paper. + """ + + def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): + """ + Arguments: + k_min (int) + k_max (int) + canonical_scale (int) + canonical_level (int) + eps (float) + """ + self.k_min = k_min + self.k_max = k_max + self.s0 = canonical_scale + self.lvl0 = canonical_level + self.eps = eps + + def __call__(self, boxlists): + """ + Arguments: + boxlists (list[BoxList]) + """ + # Compute level ids + s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists])) + + # Eqn.(1) in FPN paper + target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps)) + target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) + return target_lvls.to(torch.int64) - self.k_min + + +class Pooler(nn.Module): + """ + Pooler for Detection with or without FPN. + It currently hard-code ROIAlign in the implementation, + but that can be made more generic later on. + Also, the requirement of passing the scales is not strictly necessary, as they + can be inferred from the size of the feature map / size of original image, + which is available thanks to the BoxList. + """ + + def __init__(self, output_size, scales, sampling_ratio, use_v2=False): + """ + Arguments: + output_size (list[tuple[int]] or list[int]): output size for the pooled region + scales (list[float]): scales for each Pooler + sampling_ratio (int): sampling ratio for ROIAlign + """ + super(Pooler, self).__init__() + poolers = [] + for scale in scales: + poolers.append( + ROIAlignV2( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + ) + if use_v2 else + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + ) + ) + self.poolers = nn.ModuleList(poolers) + self.output_size = output_size + # get the levels in the feature map by leveraging the fact that the network always + # downsamples by a factor of 2 at each level. + lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() + lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() + self.map_levels = LevelMapper(lvl_min, lvl_max) + + def convert_to_roi_format(self, boxes): + concat_boxes = cat([b.bbox for b in boxes], dim=0) + device, dtype = concat_boxes.device, concat_boxes.dtype + ids = cat( + [ + torch.full((len(b), 1), i, dtype=dtype, device=device) + for i, b in enumerate(boxes) + ], + dim=0, + ) + rois = torch.cat([ids, concat_boxes], dim=1) + return rois + + def forward(self, x, boxes): + """ + Arguments: + x (list[Tensor]): feature maps for each level + boxes (list[BoxList]): boxes to be used to perform the pooling operation. + Returns: + result (Tensor) + """ + num_levels = len(self.poolers) + rois = self.convert_to_roi_format(boxes) + if num_levels == 1: + return self.poolers[0](x[0], rois) + + levels = self.map_levels(boxes) + + num_rois = len(rois) + num_channels = x[0].shape[1] + output_size = self.output_size[0] + + dtype, device = x[0].dtype, x[0].device + result = torch.zeros( + (num_rois, num_channels, output_size, output_size), + dtype=dtype, + device=device, + ) + for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): + idx_in_level = torch.nonzero(levels == level).squeeze(1) + rois_per_level = rois[idx_in_level] + result[idx_in_level] = pooler(per_level_feature, rois_per_level) + + return result diff --git a/maskrcnn_benchmark/modeling/registry.py b/maskrcnn_benchmark/modeling/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3d828cdbb550a242a2b2a944fc1c7efccbe9da90 --- /dev/null +++ b/maskrcnn_benchmark/modeling/registry.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from maskrcnn_benchmark.utils.registry import Registry + +BACKBONES = Registry() + +LANGUAGE_BACKBONES = Registry() + +ROI_BOX_FEATURE_EXTRACTORS = Registry() +RPN_HEADS = Registry() diff --git a/maskrcnn_benchmark/modeling/roi_heads/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6b92b4e6adc2c9b592f4cdee794b36b57a4548 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/__init__.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .box_head.box_head import build_roi_box_head +from .mask_head.mask_head import build_roi_mask_head +from .keypoint_head.keypoint_head import build_roi_keypoint_head + + +class CombinedROIHeads(torch.nn.ModuleDict): + """ + Combines a set of individual heads (for box prediction or masks) into a single + head. + """ + + def __init__(self, cfg, heads): + super(CombinedROIHeads, self).__init__(heads) + self.cfg = cfg.clone() + if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + self.mask.feature_extractor = self.box.feature_extractor + if cfg.MODEL.KEYPOINT_ON and cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + self.keypoint.feature_extractor = self.box.feature_extractor + + def forward(self, features, proposals, targets=None, language_dict_features=None, positive_map_label_to_token=None): + losses = {} + detections = proposals + if self.cfg.MODEL.BOX_ON: + # TODO rename x to roi_box_features, if it doesn't increase memory consumption + x, detections, loss_box = self.box(features, proposals, targets) + losses.update(loss_box) + + if self.cfg.MODEL.MASK_ON: + mask_features = features + # optimization: during training, if we share the feature extractor between + # the box and the mask heads, then we can reuse the features already computed + if ( + self.training + and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR + ): + mask_features = x + # During training, self.box() will return the unaltered proposals as "detections" + # this makes the API consistent during training and testing + x, detections, loss_mask = self.mask( + mask_features, detections, targets, + language_dict_features=language_dict_features, + positive_map_label_to_token=positive_map_label_to_token) + losses.update(loss_mask) + + if self.cfg.MODEL.KEYPOINT_ON: + keypoint_features = features + # optimization: during training, if we share the feature extractor between + # the box and the mask heads, then we can reuse the features already computed + if ( + self.training + and self.cfg.MODEL.ROI_KEYPOINT_HEAD.SHARE_BOX_FEATURE_EXTRACTOR + ): + keypoint_features = x + # During training, self.box() will return the unaltered proposals as "detections" + # this makes the API consistent during training and testing + x, detections, loss_keypoint = self.keypoint(keypoint_features, detections, targets) + losses.update(loss_keypoint) + return x, detections, losses + + +def build_roi_heads(cfg): + # individually create the heads, that will be combined together + # afterwards + # if cfg.MODEL.RPN_ONLY: + # return None + + roi_heads = [] + if cfg.MODEL.BOX_ON and not cfg.MODEL.RPN_ONLY: + roi_heads.append(("box", build_roi_box_head(cfg))) + if cfg.MODEL.MASK_ON: + roi_heads.append(("mask", build_roi_mask_head(cfg))) + if cfg.MODEL.KEYPOINT_ON: + roi_heads.append(("keypoint", build_roi_keypoint_head(cfg))) + + # combine individual heads in a single module + if roi_heads: + roi_heads = CombinedROIHeads(cfg, roi_heads) + else: + roi_heads = None + + return roi_heads \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d509ee6bb0c75d51960c192f782c4b2a8178a96e --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + +from .roi_box_feature_extractors import make_roi_box_feature_extractor +from .roi_box_predictors import make_roi_box_predictor +from .inference import make_roi_box_post_processor +from .loss import make_roi_box_loss_evaluator +from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd + +class ROIBoxHead(torch.nn.Module): + """ + Generic Box Head class. + """ + + def __init__(self, cfg): + super(ROIBoxHead, self).__init__() + self.feature_extractor = make_roi_box_feature_extractor(cfg) + self.predictor = make_roi_box_predictor(cfg) + self.post_processor = make_roi_box_post_processor(cfg) + self.loss_evaluator = make_roi_box_loss_evaluator(cfg) + self.onnx = cfg.MODEL.ONNX + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, features, proposals, targets=None): + """ + Arguments: + features (list[Tensor]): feature-maps from possibly several levels + proposals (list[BoxList]): proposal boxes + targets (list[BoxList], optional): the ground-truth targets. + + Returns: + x (Tensor): the result of the feature extractor + proposals (list[BoxList]): during training, the subsampled proposals + are returned. During testing, the predicted boxlists are returned + losses (dict[Tensor]): During training, returns the losses for the + head. During testing, returns an empty dict. + """ + + if self.training: + # Faster R-CNN subsamples during training the proposals with a fixed + # positive / negative ratio + with torch.no_grad(): + proposals = self.loss_evaluator.subsample(proposals, targets) + + # extract features that will be fed to the final classifier. The + # feature_extractor generally corresponds to the pooler + heads + x = self.feature_extractor(features, proposals) + # final classifier that converts the features into predictions + class_logits, box_regression = self.predictor(x) + + if self.onnx: + return x, (class_logits, box_regression, [box.bbox for box in proposals]), {} + + if not self.training: + result = self.post_processor((class_logits, box_regression), proposals) + return x, result, {} + + loss_classifier, loss_box_reg = self.loss_evaluator( + [class_logits], [box_regression] + ) + return ( + x, + proposals, + dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg), + ) + + +def build_roi_box_head(cfg): + """ + Constructs a new box head. + By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new class + and make it a parameter in the config + """ + return ROIBoxHead(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2f64bb3060e22388caf57c2496c7eb6f7a4cb7f4 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd + +class PostProcessor(nn.Module): + """ + From a set of classification scores, box regression and proposals, + computes the post-processed boxes, and applies NMS to obtain the + final results + """ + + def __init__( + self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None + ): + """ + Arguments: + score_thresh (float) + nms (float) + detections_per_img (int) + box_coder (BoxCoder) + """ + super(PostProcessor, self).__init__() + self.score_thresh = score_thresh + self.nms = nms + self.detections_per_img = detections_per_img + if box_coder is None: + box_coder = BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = box_coder + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, x, boxes): + """ + Arguments: + x (tuple[tensor, tensor]): x contains the class logits + and the box_regression from the model. + boxes (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra fields labels and scores + """ + class_logits, box_regression = x + class_prob = F.softmax(class_logits, -1) + + # TODO think about a representation of batch of boxes + image_shapes = [box.size for box in boxes] + boxes_per_image = [len(box) for box in boxes] + concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) + + extra_fields = [{} for box in boxes] + if boxes[0].has_field("cbox"): + concat_cboxes = torch.cat([a.get_field('cbox').bbox for a in boxes], dim=0) + concat_cscores = torch.cat([a.get_field('cbox').get_field('scores') for a in boxes], dim=0) + for cbox, cscore, extra_field in zip(concat_cboxes.split(boxes_per_image, dim=0), + concat_cscores.split(boxes_per_image, dim=0), + extra_fields): + extra_field["cbox"] = cbox + extra_field["cscore"] = cscore + + proposals = self.box_coder.decode( + box_regression.view(sum(boxes_per_image), -1), concat_boxes + ) + + num_classes = class_prob.shape[1] + + proposals = proposals.split(boxes_per_image, dim=0) + class_prob = class_prob.split(boxes_per_image, dim=0) + + results = [] + for prob, boxes_per_img, image_shape, extra_field in zip( + class_prob, proposals, image_shapes, extra_fields + ): + boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, extra_field) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = self.filter_results(boxlist, num_classes) + results.append(boxlist) + return results + + def prepare_boxlist(self, boxes, scores, image_shape, extra_field={}): + """ + Returns BoxList from `boxes` and adds probability scores information + as an extra field + `boxes` has shape (#detections, 4 * #classes), where each row represents + a list of predicted bounding boxes for each of the object classes in the + dataset (including the background class). The detections in each row + originate from the same object proposal. + `scores` has shape (#detection, #classes), where each row represents a list + of object detection confidence scores for each of the object classes in the + dataset (including the background class). `scores[i, j]`` corresponds to the + box at `boxes[i, j * 4:(j + 1) * 4]`. + """ + boxes = boxes.reshape(-1, 4) + scores = scores.reshape(-1) + boxlist = BoxList(boxes, image_shape, mode="xyxy") + boxlist.add_field("scores", scores) + for key, val in extra_field.items(): + boxlist.add_field(key, val) + return boxlist + + def filter_results(self, boxlist, num_classes): + """Returns bounding-box detection results by thresholding on scores and + applying non-maximum suppression (NMS). + """ + # unwrap the boxlist to avoid additional overhead. + # if we had multi-class NMS, we could perform this directly on the boxlist + boxes = boxlist.bbox.reshape(-1, num_classes * 4) + scores = boxlist.get_field("scores").reshape(-1, num_classes) + if boxlist.has_field('cbox'): + cboxes = boxlist.get_field("cbox").reshape(-1, 4) + cscores = boxlist.get_field("cscore") + else: + cboxes = None + + device = scores.device + result = [] + # Apply threshold on detection probabilities and apply NMS + # Skip j = 0, because it's the background class + inds_all = scores > self.score_thresh + for j in range(1, num_classes): + inds = inds_all[:, j].nonzero().squeeze(1) + scores_j = scores[inds, j] + boxes_j = boxes[inds, j * 4 : (j + 1) * 4] + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + if cboxes is not None: + cboxes_j = cboxes[inds, :] + cscores_j = cscores[inds] + cbox_boxlist = BoxList(cboxes_j, boxlist.size, mode="xyxy") + cbox_boxlist.add_field("scores", cscores_j) + boxlist_for_class.add_field("cbox", cbox_boxlist) + + boxlist_for_class = boxlist_nms( + boxlist_for_class, self.nms, score_field="scores" + ) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device) + ) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.detections_per_img > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), number_of_detections - self.detections_per_img + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + return result + + +def make_roi_box_post_processor(cfg): + use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN + + bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS + box_coder = BoxCoder(weights=bbox_reg_weights) + + score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH + nms_thresh = cfg.MODEL.ROI_HEADS.NMS + detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG + + postprocessor = PostProcessor( + score_thresh, nms_thresh, detections_per_img, box_coder + ) + return postprocessor diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d7592981fdf236c086d70b455967cf12ad0d275e --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch.nn import functional as F + +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( + BalancedPositiveNegativeSampler +) +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd + +class FastRCNNLossComputation(object): + """ + Computes the loss for Faster R-CNN. + Also supports FPN + """ + + def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): + """ + Arguments: + proposal_matcher (Matcher) + fg_bg_sampler (BalancedPositiveNegativeSampler) + box_coder (BoxCoder) + """ + self.proposal_matcher = proposal_matcher + self.fg_bg_sampler = fg_bg_sampler + self.box_coder = box_coder + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Fast RCNN only need "labels" field for selecting the targets + target = target.copy_with_fields("labels") + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + + if len(target): + matched_targets = target[matched_idxs.clamp(min=0)] + else: + device = target.get_field('labels').device + dtype = target.get_field('labels').dtype + labels = torch.zeros_like(matched_idxs, dtype=dtype, device=device) + matched_targets = target + matched_targets.add_field('labels', labels) + + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + labels = [] + regression_targets = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # Label background (below the low threshold) + bg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[bg_inds] = 0 + + # Label ignore proposals (between low and high thresholds) + ignore_inds = matched_idxs == Matcher.BETWEEN_THRESHOLDS + labels_per_image[ignore_inds] = -1 # -1 is ignored by sampler + + # compute regression targets + if not matched_targets.bbox.shape[0]: + zeros = torch.zeros_like(labels_per_image, dtype=torch.float32) + regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1) + else: + regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, proposals_per_image.bbox) + + labels.append(labels_per_image) + regression_targets.append(regression_targets_per_image) + + return labels, regression_targets + + def subsample(self, proposals, targets): + """ + This method performs the positive/negative sampling, and return + the sampled proposals. + Note: this function keeps a state. + + Arguments: + proposals (list[BoxList]) + targets (list[BoxList]) + """ + + labels, regression_targets = self.prepare_targets(proposals, targets) + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + + proposals = list(proposals) + # add corresponding label and regression_targets information to the bounding boxes + for labels_per_image, regression_targets_per_image, proposals_per_image in zip( + labels, regression_targets, proposals + ): + proposals_per_image.add_field("labels", labels_per_image) + proposals_per_image.add_field( + "regression_targets", regression_targets_per_image + ) + + # distributed sampled proposals, that were obtained on all feature maps + # concatenated via the fg_bg_sampler, into individual feature map levels + for img_idx, (pos_inds_img, neg_inds_img) in enumerate( + zip(sampled_pos_inds, sampled_neg_inds) + ): + img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) + proposals_per_image = proposals[img_idx][img_sampled_inds] + proposals[img_idx] = proposals_per_image + + self._proposals = proposals + return proposals + + @custom_fwd(cast_inputs=torch.float32) + def __call__(self, class_logits, box_regression): + """ + Computes the loss for Faster R-CNN. + This requires that the subsample method has been called beforehand. + + Arguments: + class_logits (list[Tensor]) + box_regression (list[Tensor]) + + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + class_logits = cat(class_logits, dim=0) + box_regression = cat(box_regression, dim=0) + device = class_logits.device + + if not hasattr(self, "_proposals"): + raise RuntimeError("subsample needs to be called before") + + proposals = self._proposals + + labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0) + regression_targets = cat( + [proposal.get_field("regression_targets") for proposal in proposals], dim=0 + ) + + classification_loss = F.cross_entropy(class_logits, labels) + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) + labels_pos = labels[sampled_pos_inds_subset] + map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) + + box_loss = smooth_l1_loss( + box_regression[sampled_pos_inds_subset[:, None], map_inds], + regression_targets[sampled_pos_inds_subset], + size_average=False, + beta=1, + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss + + +def make_roi_box_loss_evaluator(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + + bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS + box_coder = BoxCoder(weights=bbox_reg_weights) + + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ) + + loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder) + + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..8614c78d8f7a85874175f82eff042eb793e44c4b --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_feature_extractors.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.nn import functional as F + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.modeling.backbone import resnet +from maskrcnn_benchmark.modeling.poolers import Pooler +from maskrcnn_benchmark.modeling.make_layers import group_norm +from maskrcnn_benchmark.modeling.make_layers import make_fc + + + +@registry.ROI_BOX_FEATURE_EXTRACTORS.register("LightheadFeatureExtractor") +class LightheadFeatureExtractor(nn.Module): + def __init__(self, cfg): + super(LightheadFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = 10 * resolution ** 2 + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN + + C_in, C_mid, C_out = cfg.MODEL.BACKBONE.OUT_CHANNELS, 256, input_size + self.separable_conv_11 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0)) + self.separable_conv_12 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7)) + self.separable_conv_21 = nn.Conv2d(C_in, C_mid, (15, 1), 1, (7, 0)) + self.separable_conv_22 = nn.Conv2d(C_mid, C_out, (1, 15), 1, (0, 7)) + + for module in [self.separable_conv_11, self.separable_conv_12, self.separable_conv_21, self.separable_conv_22]: + # Caffe2 implementation uses XavierFill, which in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(module.weight, a=1) + + self.pooler = pooler + self.fc6 = make_fc(input_size * resolution ** 2, representation_size, use_gn) # wait official repo to support psroi + + + def forward(self, x, proposals): + light = [] + for feat in x: + sc11 = self.separable_conv_11(feat) + sc12 = self.separable_conv_12(sc11) + sc21 = self.separable_conv_21(feat) + sc22 = self.separable_conv_22(sc21) + out = sc12+sc22 + light.append(out) + + x = self.pooler(light, proposals) + x = x.view(x.size(0), -1) + x = F.relu(self.fc6(x)) + + return x + + + + +@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor") +class ResNet50Conv5ROIFeatureExtractor(nn.Module): + def __init__(self, config): + super(ResNet50Conv5ROIFeatureExtractor, self).__init__() + + resolution = config.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = config.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = config.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + + stage = resnet.StageSpec(index=4, block_count=3, return_features=False) + head = resnet.ResNetHead( + block_module=config.MODEL.RESNETS.TRANS_FUNC, + stages=(stage,), + num_groups=config.MODEL.RESNETS.NUM_GROUPS, + width_per_group=config.MODEL.RESNETS.WIDTH_PER_GROUP, + stride_in_1x1=config.MODEL.RESNETS.STRIDE_IN_1X1, + stride_init=None, + res2_out_channels=config.MODEL.RESNETS.RES2_OUT_CHANNELS, + dilation=config.MODEL.RESNETS.RES5_DILATION + ) + + self.pooler = pooler + self.head = head + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + x = self.head(x) + return x + + +@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor") +class FPN2MLPFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + super(FPN2MLPFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2 + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN + self.pooler = pooler + self.fc6 = make_fc(input_size, representation_size, use_gn) + self.fc7 = make_fc(representation_size, representation_size, use_gn) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + x = x.view(x.size(0), -1) + + x = F.relu(self.fc6(x)) + x = F.relu(self.fc7(x)) + + return x + + +@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPNXconv1fcFeatureExtractor") +class FPNXconv1fcFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + super(FPNXconv1fcFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + self.pooler = pooler + + use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM + num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS + dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION + + xconvs = [] + for ix in range(num_stacked_convs): + xconvs.append( + nn.Conv2d( + in_channels, + conv_head_dim, + kernel_size=3, + stride=1, + padding=dilation, + dilation=dilation, + bias=False if use_gn else True + ) + ) + in_channels = conv_head_dim + if use_gn: + xconvs.append(group_norm(in_channels)) + xconvs.append(nn.ReLU(inplace=True)) + + self.add_module("xconvs", nn.Sequential(*xconvs)) + for modules in [self.xconvs,]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + if not use_gn: + torch.nn.init.constant_(l.bias, 0) + + input_size = conv_head_dim * resolution ** 2 + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + self.fc6 = make_fc(input_size, representation_size, use_gn=False) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + x = self.xconvs(x) + x = x.view(x.size(0), -1) + x = F.relu(self.fc6(x)) + return x + + +def make_roi_box_feature_extractor(cfg): + func = registry.ROI_BOX_FEATURE_EXTRACTORS[ + cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR + ] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..ac03cfaece2e47900fc04b58e173f6dea6423caa --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/roi_box_predictors.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch import nn + + +class FastRCNNPredictor(nn.Module): + def __init__(self, config, pretrained=None): + super(FastRCNNPredictor, self).__init__() + + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = config.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES + self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7) + self.cls_score = nn.Linear(num_inputs, num_classes) + self.bbox_pred = nn.Linear(num_inputs, num_classes * 4) + + nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) + nn.init.constant_(self.cls_score.bias, 0) + + nn.init.normal_(self.bbox_pred.weight, mean=0, std=0.001) + nn.init.constant_(self.bbox_pred.bias, 0) + + def forward(self, x): + x = self.avgpool(x) + x = x.view(x.size(0), -1) + cls_logit = self.cls_score(x) + bbox_pred = self.bbox_pred(x) + return cls_logit, bbox_pred + + +class FPNPredictor(nn.Module): + def __init__(self, cfg): + super(FPNPredictor, self).__init__() + num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM + + self.cls_score = nn.Linear(representation_size, num_classes) + self.bbox_pred = nn.Linear(representation_size, num_classes * 4) + + nn.init.normal_(self.cls_score.weight, std=0.01) + nn.init.normal_(self.bbox_pred.weight, std=0.001) + for l in [self.cls_score, self.bbox_pred]: + nn.init.constant_(l.bias, 0) + + def forward(self, x): + scores = self.cls_score(x) + bbox_deltas = self.bbox_pred(x) + + return scores, bbox_deltas + + +_ROI_BOX_PREDICTOR = { + "FastRCNNPredictor": FastRCNNPredictor, + "FPNPredictor": FPNPredictor, +} + + +def make_roi_box_predictor(cfg): + func = _ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ed960ff37cf1c68ac8831fdb87b82c91203ec2 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/inference.py @@ -0,0 +1,121 @@ +import cv2 +import numpy as np +import torch +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.keypoint import PersonKeypoints + + +class KeypointPostProcessor(nn.Module): + def __init__(self, keypointer=None): + super(KeypointPostProcessor, self).__init__() + self.keypointer = keypointer + + def forward(self, x, boxes): + mask_prob = x + + scores = None + if self.keypointer: + mask_prob, scores = self.keypointer(x, boxes) + + assert len(boxes) == 1, "Only non-batched inference supported for now" + boxes_per_image = [box.bbox.size(0) for box in boxes] + mask_prob = mask_prob.split(boxes_per_image, dim=0) + scores = scores.split(boxes_per_image, dim=0) + + results = [] + for prob, box, score in zip(mask_prob, boxes, scores): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + prob = PersonKeypoints(prob, box.size) + prob.add_field("logits", score) + bbox.add_field("keypoints", prob) + results.append(bbox) + + return results + + +def heatmaps_to_keypoints(maps, rois): + """Extract predicted keypoint locations from heatmaps. Output has shape + (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) + for each keypoint. + """ + # This function converts a discrete image coordinate in a HEATMAP_SIZE x + # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain + # consistency with keypoints_to_heatmap_labels by using the conversion from + # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a + # continuous coordinate. + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = rois[:, 2] - rois[:, 0] + heights = rois[:, 3] - rois[:, 1] + widths = np.maximum(widths, 1) + heights = np.maximum(heights, 1) + widths_ceil = np.ceil(widths) + heights_ceil = np.ceil(heights) + + # NCHW to NHWC for use with OpenCV + maps = np.transpose(maps, [0, 2, 3, 1]) + min_size = 0 # cfg.KRCNN.INFERENCE_MIN_SIZE + num_keypoints = maps.shape[3] + xy_preds = np.zeros((len(rois), 3, num_keypoints), dtype=np.float32) + end_scores = np.zeros((len(rois), num_keypoints), dtype=np.float32) + for i in range(len(rois)): + if min_size > 0: + roi_map_width = int(np.maximum(widths_ceil[i], min_size)) + roi_map_height = int(np.maximum(heights_ceil[i], min_size)) + else: + roi_map_width = widths_ceil[i] + roi_map_height = heights_ceil[i] + width_correction = widths[i] / roi_map_width + height_correction = heights[i] / roi_map_height + roi_map = cv2.resize( + maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC + ) + # Bring back to CHW + roi_map = np.transpose(roi_map, [2, 0, 1]) + # roi_map_probs = scores_to_probs(roi_map.copy()) + w = roi_map.shape[2] + pos = roi_map.reshape(num_keypoints, -1).argmax(axis=1) + x_int = pos % w + y_int = (pos - x_int) // w + # assert (roi_map_probs[k, y_int, x_int] == + # roi_map_probs[k, :, :].max()) + x = (x_int + 0.5) * width_correction + y = (y_int + 0.5) * height_correction + xy_preds[i, 0, :] = x + offset_x[i] + xy_preds[i, 1, :] = y + offset_y[i] + xy_preds[i, 2, :] = 1 + end_scores[i, :] = roi_map[np.arange(num_keypoints), y_int, x_int] + + return np.transpose(xy_preds, [0, 2, 1]), end_scores + + +class Keypointer(object): + """ + Projects a set of masks in an image on the locations + specified by the bounding boxes + """ + + def __init__(self, padding=0): + self.padding = padding + + def __call__(self, masks, boxes): + # TODO do this properly + if isinstance(boxes, BoxList): + boxes = [boxes] + assert len(boxes) == 1 + + result, scores = heatmaps_to_keypoints( + masks.detach().cpu().numpy(), boxes[0].bbox.cpu().numpy() + ) + return torch.from_numpy(result).to(masks.device), torch.as_tensor(scores, device=masks.device) + + +def make_roi_keypoint_post_processor(cfg): + keypointer = Keypointer() + keypoint_post_processor = KeypointPostProcessor(keypointer) + return keypoint_post_processor \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1414782ab2a42bd1161c8496c434406df12619d6 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/keypoint_head.py @@ -0,0 +1,50 @@ +import torch + +from .roi_keypoint_feature_extractors import make_roi_keypoint_feature_extractor +from .roi_keypoint_predictors import make_roi_keypoint_predictor +from .inference import make_roi_keypoint_post_processor +from .loss import make_roi_keypoint_loss_evaluator + + +class ROIKeypointHead(torch.nn.Module): + def __init__(self, cfg): + super(ROIKeypointHead, self).__init__() + self.cfg = cfg.clone() + self.feature_extractor = make_roi_keypoint_feature_extractor(cfg) + self.predictor = make_roi_keypoint_predictor(cfg) + self.post_processor = make_roi_keypoint_post_processor(cfg) + self.loss_evaluator = make_roi_keypoint_loss_evaluator(cfg) + + def forward(self, features, proposals, targets=None): + """ + Arguments: + features (list[Tensor]): feature-maps from possibly several levels + proposals (list[BoxList]): proposal boxes + targets (list[BoxList], optional): the ground-truth targets. + + Returns: + x (Tensor): the result of the feature extractor + proposals (list[BoxList]): during training, the original proposals + are returned. During testing, the predicted boxlists are returned + with the `mask` field set + losses (dict[Tensor]): During training, returns the losses for the + head. During testing, returns an empty dict. + """ + if self.training: + with torch.no_grad(): + proposals = self.loss_evaluator.subsample(proposals, targets) + + x = self.feature_extractor(features, proposals) + kp_logits = self.predictor(x) + + if not self.training: + result = self.post_processor(kp_logits, proposals) + return x, result, {} + + loss_kp = self.loss_evaluator(proposals, kp_logits) + + return x, proposals, dict(loss_kp=loss_kp) + + +def build_roi_keypoint_head(cfg): + return ROIKeypointHead(cfg) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..53716c80281f6e8e767552f061d91d486027831e --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/loss.py @@ -0,0 +1,183 @@ +import torch +from torch.nn import functional as F + +from maskrcnn_benchmark.modeling.matcher import Matcher + +from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( + BalancedPositiveNegativeSampler, +) +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist + +from maskrcnn_benchmark.structures.keypoint import keypoints_to_heat_map + + +def project_keypoints_to_heatmap(keypoints, proposals, discretization_size): + proposals = proposals.convert("xyxy") + return keypoints_to_heat_map( + keypoints.keypoints, proposals.bbox, discretization_size + ) + + +def cat_boxlist_with_keypoints(boxlists): + assert all(boxlist.has_field("keypoints") for boxlist in boxlists) + + kp = [boxlist.get_field("keypoints").keypoints for boxlist in boxlists] + kp = cat(kp, 0) + + fields = boxlists[0].get_fields() + fields = [field for field in fields if field != "keypoints"] + + boxlists = [boxlist.copy_with_fields(fields) for boxlist in boxlists] + boxlists = cat_boxlist(boxlists) + boxlists.add_field("keypoints", kp) + return boxlists + + +def _within_box(points, boxes): + """Validate which keypoints are contained inside a given box. + points: NxKx2 + boxes: Nx4 + output: NxK + """ + x_within = (points[..., 0] >= boxes[:, 0, None]) & ( + points[..., 0] <= boxes[:, 2, None] + ) + y_within = (points[..., 1] >= boxes[:, 1, None]) & ( + points[..., 1] <= boxes[:, 3, None] + ) + return x_within & y_within + + +class KeypointRCNNLossComputation(object): + def __init__(self, proposal_matcher, fg_bg_sampler, discretization_size): + """ + Arguments: + proposal_matcher (Matcher) + fg_bg_sampler (BalancedPositiveNegativeSampler) + discretization_size (int) + """ + self.proposal_matcher = proposal_matcher + self.fg_bg_sampler = fg_bg_sampler + self.discretization_size = discretization_size + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Keypoint RCNN needs "labels" and "keypoints "fields for creating the targets + target = target.copy_with_fields(["labels", "keypoints"]) + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + labels = [] + keypoints = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # this can probably be removed, but is left here for clarity + # and completeness + # TODO check if this is the right one, as BELOW_THRESHOLD + neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[neg_inds] = 0 + + keypoints_per_image = matched_targets.get_field("keypoints") + within_box = _within_box( + keypoints_per_image.keypoints, matched_targets.bbox + ) + vis_kp = keypoints_per_image.keypoints[..., 2] > 0 + is_visible = (within_box & vis_kp).sum(1) > 0 + + labels_per_image[~is_visible] = -1 + + labels.append(labels_per_image) + keypoints.append(keypoints_per_image) + + return labels, keypoints + + def subsample(self, proposals, targets): + """ + This method performs the positive/negative sampling, and return + the sampled proposals. + Note: this function keeps a state. + + Arguments: + proposals (list[BoxList]) + targets (list[BoxList]) + """ + + labels, keypoints = self.prepare_targets(proposals, targets) + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + + proposals = list(proposals) + # add corresponding label and regression_targets information to the bounding boxes + for labels_per_image, keypoints_per_image, proposals_per_image in zip( + labels, keypoints, proposals + ): + proposals_per_image.add_field("labels", labels_per_image) + proposals_per_image.add_field("keypoints", keypoints_per_image) + + # distributed sampled proposals, that were obtained on all feature maps + # concatenated via the fg_bg_sampler, into individual feature map levels + for img_idx, (pos_inds_img, neg_inds_img) in enumerate( + zip(sampled_pos_inds, sampled_neg_inds) + ): + img_sampled_inds = torch.nonzero(pos_inds_img).squeeze(1) + proposals_per_image = proposals[img_idx][img_sampled_inds] + proposals[img_idx] = proposals_per_image + + self._proposals = proposals + return proposals + + def __call__(self, proposals, keypoint_logits): + heatmaps = [] + valid = [] + for proposals_per_image in proposals: + kp = proposals_per_image.get_field("keypoints") + heatmaps_per_image, valid_per_image = project_keypoints_to_heatmap( + kp, proposals_per_image, self.discretization_size + ) + heatmaps.append(heatmaps_per_image.view(-1)) + valid.append(valid_per_image.view(-1)) + + keypoint_targets = cat(heatmaps, dim=0) + valid = cat(valid, dim=0).to(dtype=torch.bool) + valid = torch.nonzero(valid).squeeze(1) + + # torch.mean (in binary_cross_entropy_with_logits) does'nt + # accept empty tensors, so handle it sepaartely + if keypoint_targets.numel() == 0 or len(valid) == 0: + return keypoint_logits.sum() * 0 + + N, K, H, W = keypoint_logits.shape + keypoint_logits = keypoint_logits.view(N * K, H * W) + + keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) + return keypoint_loss + + +def make_roi_keypoint_loss_evaluator(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ) + resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION + loss_evaluator = KeypointRCNNLossComputation(matcher, fg_bg_sampler, resolution) + return loss_evaluator \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b4b90be3efebf777871399b7dca821fec60a45 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_feature_extractors.py @@ -0,0 +1,96 @@ +from torch import nn +from torch.nn import functional as F + +from maskrcnn_benchmark.modeling.poolers import Pooler + +from maskrcnn_benchmark.layers import Conv2d +from maskrcnn_benchmark.layers import ConvTranspose2d + + +class KeypointRCNNFeatureExtractor(nn.Module): + def __init__(self, cfg): + super(KeypointRCNNFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + self.pooler = pooler + + input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS + layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS + next_feature = input_features + self.blocks = [] + for layer_idx, layer_features in enumerate(layers, 1): + layer_name = "conv_fcn{}".format(layer_idx) + module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(module.bias, 0) + self.add_module(layer_name, module) + next_feature = layer_features + self.blocks.append(layer_name) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + for layer_name in self.blocks: + x = F.relu(getattr(self, layer_name)(x)) + return x + +class KeypointRCNNFeature2XZoomExtractor(nn.Module): + def __init__(self, cfg): + super(KeypointRCNNFeature2XZoomExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + self.pooler = pooler + + input_features = cfg.MODEL.BACKBONE.OUT_CHANNELS + layers = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS + next_feature = input_features + self.blocks = [] + for layer_idx, layer_features in enumerate(layers, 1): + layer_name = "conv_fcn{}".format(layer_idx) + module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(module.bias, 0) + self.add_module(layer_name, module) + if layer_idx==len(layers)//2: + deconv_kernel = 4 + kps_upsacle = ConvTranspose2d(layer_features, layer_features, deconv_kernel, + stride=2, padding=deconv_kernel//2-1) + nn.init.kaiming_normal_(kps_upsacle.weight, mode="fan_out", nonlinearity="relu") + nn.init.constant_(kps_upsacle.bias, 0) + self.add_module("conv_fcn_upscale", kps_upsacle) + self.blocks.append("conv_fcn_upscale") + + next_feature = layer_features + self.blocks.append(layer_name) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + for layer_name in self.blocks: + x = F.relu(getattr(self, layer_name)(x)) + return x + + +_ROI_KEYPOINT_FEATURE_EXTRACTORS = { + "KeypointRCNNFeatureExtractor": KeypointRCNNFeatureExtractor, + "KeypointRCNNFeature2XZoomExtractor": KeypointRCNNFeature2XZoomExtractor +} + + +def make_roi_keypoint_feature_extractor(cfg): + func = _ROI_KEYPOINT_FEATURE_EXTRACTORS[ + cfg.MODEL.ROI_KEYPOINT_HEAD.FEATURE_EXTRACTOR + ] + return func(cfg) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff8ec3849695737580d5b2da2b411c489216a1a --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/keypoint_head/roi_keypoint_predictors.py @@ -0,0 +1,39 @@ +from torch import nn +from torch.nn import functional as F + +from maskrcnn_benchmark import layers + + +class KeypointRCNNPredictor(nn.Module): + def __init__(self, cfg): + super(KeypointRCNNPredictor, self).__init__() + input_features = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_LAYERS[-1] + num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_CLASSES + deconv_kernel = 4 + self.kps_score_lowres = layers.ConvTranspose2d( + input_features, + num_keypoints, + deconv_kernel, + stride=2, + padding=deconv_kernel // 2 - 1, + ) + nn.init.kaiming_normal_( + self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" + ) + nn.init.constant_(self.kps_score_lowres.bias, 0) + self.up_scale = 2 + + def forward(self, x): + x = self.kps_score_lowres(x) + x = layers.interpolate( + x, scale_factor=self.up_scale, mode="bilinear", align_corners=False + ) + return x + + +_ROI_KEYPOINT_PREDICTOR = {"KeypointRCNNPredictor": KeypointRCNNPredictor} + + +def make_roi_keypoint_predictor(cfg): + func = _ROI_KEYPOINT_PREDICTOR[cfg.MODEL.ROI_KEYPOINT_HEAD.PREDICTOR] + return func(cfg) \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py new file mode 100644 index 0000000000000000000000000000000000000000..82e81b6697536ff23f8b88f7ea1d89da8d8c28e1 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/hourglass.py @@ -0,0 +1,65 @@ +from torch import nn + +from maskrcnn_benchmark.modeling.make_layers import make_conv3x3 + + +class Residual(nn.Module): + def __init__(self, inp_dim, out_dim, use_gn=False): + super(Residual, self).__init__() + self.relu = nn.ReLU() + # self.bn1 = nn.BatchNorm2d(inp_dim) + self.conv1 = make_conv3x3(inp_dim, int(out_dim / 2), 1, use_relu=False, use_gn=use_gn) + # self.bn2 = nn.BatchNorm2d(int(out_dim / 2)) + self.conv2 = make_conv3x3(int(out_dim / 2), int(out_dim / 2), 3, use_relu=False, use_gn=use_gn) + # self.bn3 = nn.BatchNorm2d(int(out_dim / 2)) + self.conv3 = make_conv3x3(int(out_dim / 2), out_dim, 1, use_relu=False, use_gn=use_gn) + if inp_dim == out_dim: + self.need_skip = False + else: + self.need_skip = True + self.skip_layer = make_conv3x3(inp_dim, out_dim, 1, use_relu=False, use_gn=False) + + def forward(self, x): + if self.need_skip: + residual = self.skip_layer(x) + else: + residual = x + out = x + # out = self.bn1(out) + out = self.relu(out) + out = self.conv1(out) + # out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + # out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + out += residual + return out + + +class Hourglass(nn.Module): + def __init__(self, n, f, gn=False, increase=0): + super(Hourglass, self).__init__() + nf = f + increase + self.up1 = Residual(f, f) + # Lower branch + self.pool1 = nn.MaxPool2d(2, 2) + self.low1 = Residual(f, nf) + self.n = n + # Recursive hourglass + if self.n > 1: + self.low2 = Hourglass(n-1, nf, gn=gn) + else: + self.low2 = Residual(nf, nf, gn) + self.low3 = Residual(nf, f, gn) + self.up2 = nn.Upsample(scale_factor=2, mode='nearest') + + def forward(self, x): + up1 = self.up1(x) + pool1 = self.pool1(x) + low1 = self.low1(pool1) + low2 = self.low2(low1) + low3 = self.low3(low2) + up2 = self.up2(low3) + return up1 + up2 \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2c02b78101a6d33db1917377080265206548c7cf --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py @@ -0,0 +1,224 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from maskrcnn_benchmark.structures.bounding_box import BoxList + + +def convert_mask_grounding_to_od_logits(logits, positive_map_label_to_token, num_classes): + od_logits = torch.zeros(logits.shape[0], num_classes + 1, logits.shape[2], logits.shape[3]).to(logits.device) + for label_j in positive_map_label_to_token: + od_logits[:, label_j, :, :] = logits[:, torch.LongTensor(positive_map_label_to_token[label_j]), :, :].mean(1) + mask_prob = od_logits.sigmoid() + return mask_prob + + +# TODO check if want to return a single BoxList or a composite +# object +class MaskPostProcessor(nn.Module): + """ + From the results of the CNN, post process the masks + by taking the mask corresponding to the class with max + probability (which are of fixed size and directly output + by the CNN) and return the masks in the mask field of the BoxList. + + If a masker object is passed, it will additionally + project the masks in the image according to the locations in boxes, + """ + + def __init__(self, masker=None, mdetr_style_aggregate_class_num=None, vl_version=None): + super(MaskPostProcessor, self).__init__() + self.masker = masker + self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num + self.vl_version = vl_version + + def forward(self, x, boxes, positive_map_label_to_token=None): + """ + Arguments: + x (Tensor): the mask logits + boxes (list[BoxList]): bounding boxes that are used as + reference, one for ech image + + Returns: + results (list[BoxList]): one BoxList for each image, containing + the extra field mask + """ + if self.vl_version: + mask_prob = convert_mask_grounding_to_od_logits(x, positive_map_label_to_token, self.mdetr_style_aggregate_class_num) + else: + mask_prob = x.sigmoid() + + # select masks coresponding to the predicted classes + num_masks = x.shape[0] + labels = [bbox.get_field("labels") for bbox in boxes] + labels = torch.cat(labels) + if not self.vl_version: + # TODO: a hack for binary mask head + labels = (labels > 0).to(dtype=torch.int64) + + index = torch.arange(num_masks, device=labels.device) + mask_prob = mask_prob[index, labels][:, None] + + boxes_per_image = [len(box) for box in boxes] + mask_prob = mask_prob.split(boxes_per_image, dim=0) + + if self.masker: + mask_prob = self.masker(mask_prob, boxes) + + results = [] + for prob, box in zip(mask_prob, boxes): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + bbox.add_field("mask", prob) + results.append(bbox) + + return results + + +class MaskPostProcessorCOCOFormat(MaskPostProcessor): + """ + From the results of the CNN, post process the results + so that the masks are pasted in the image, and + additionally convert the results to COCO format. + """ + + def forward(self, x, boxes, positive_map_label_to_token=None, vl_version=None): + import pycocotools.mask as mask_util + import numpy as np + + results = super(MaskPostProcessorCOCOFormat, self).forward(x, boxes) + for result in results: + masks = result.get_field("mask").cpu() + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + result.add_field("mask", rles) + return results + + +# the next two functions should be merged inside Masker +# but are kept here for the moment while we need them +# temporarily gor paste_mask_in_image +def expand_boxes(boxes, scale): + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = torch.zeros_like(boxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +def expand_masks(mask, padding): + N = mask.shape[0] + M = mask.shape[-1] + pad2 = 2 * padding + scale = float(M + pad2) / M + padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) + padded_mask[:, :, padding:-padding, padding:-padding] = mask + return padded_mask, scale + + +def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): + padded_mask, scale = expand_masks(mask[None], padding=padding) + mask = padded_mask[0, 0] + box = expand_boxes(box[None], scale)[0] + box = box.to(dtype=torch.int32) + + TO_REMOVE = 1 + w = int(box[2] - box[0] + TO_REMOVE) + h = int(box[3] - box[1] + TO_REMOVE) + w = max(w, 1) + h = max(h, 1) + + # Set shape to [batchxCxHxW] + mask = mask.expand((1, 1, -1, -1)) + + # Resize mask + mask = mask.to(torch.float32) + mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = mask[0][0] + + if thresh >= 0: + mask = mask > thresh + else: + # for visualization and debugging, we also + # allow it to return an unmodified mask + mask = (mask * 255).to(torch.bool) + + im_mask = torch.zeros((im_h, im_w), dtype=torch.bool) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, im_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, im_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + return im_mask + + +class Masker(object): + """ + Projects a set of masks in an image on the locations + specified by the bounding boxes + """ + + def __init__(self, threshold=0.5, padding=1): + self.threshold = threshold + self.padding = padding + + def forward_single_image(self, masks, boxes): + boxes = boxes.convert("xyxy") + im_w, im_h = boxes.size + res = [ + paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) + for mask, box in zip(masks, boxes.bbox) + ] + if len(res) > 0: + res = torch.stack(res, dim=0)[:, None] + else: + res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) + return res + + def __call__(self, masks, boxes): + if isinstance(boxes, BoxList): + boxes = [boxes] + + # Make some sanity check + assert len(boxes) == len(masks), "Masks and boxes should have the same length." + + # TODO: Is this JIT compatible? + # If not we should make it compatible. + results = [] + for mask, box in zip(masks, boxes): + assert mask.shape[0] == len(box), "Number of objects should be the same." + result = self.forward_single_image(mask, box) + results.append(result) + return results + + +def make_roi_mask_post_processor(cfg): + if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS: + mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD + masker = Masker(threshold=mask_threshold, padding=1) + else: + masker = None + mdetr_style_aggregate_class_num = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM + mask_post_processor = MaskPostProcessor(masker, + mdetr_style_aggregate_class_num, + vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL")) + return mask_post_processor diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..22edb57fd8b67e370f819ba5d8a2a37df8c2a6f7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/loss.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch.nn import functional as F + +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.modeling.utils import cat + + +def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): + """ + Given segmentation masks and the bounding boxes corresponding + to the location of the masks in the image, this function + crops and resizes the masks in the position defined by the + boxes. This prepares the masks for them to be fed to the + loss computation as the targets. + + Arguments: + segmentation_masks: an instance of SegmentationMask + proposals: an instance of BoxList + """ + masks = [] + M = discretization_size + device = proposals.bbox.device + proposals = proposals.convert("xyxy") + assert segmentation_masks.size == proposals.size, "{}, {}".format( + segmentation_masks, proposals + ) + # TODO put the proposals on the CPU, as the representation for the + # masks is not efficient GPU-wise (possibly several small tensors for + # representing a single instance mask) + proposals = proposals.bbox.to(torch.device("cpu")) + for segmentation_mask, proposal in zip(segmentation_masks, proposals): + # crop the masks, resize them to the desired resolution and + # then convert them to the tensor representation, + # instead of the list representation that was used + cropped_mask = segmentation_mask.crop(proposal) + scaled_mask = cropped_mask.resize((M, M)) + mask = scaled_mask.convert(mode="mask") + masks.append(mask) + if len(masks) == 0: + return torch.empty(0, dtype=torch.float32, device=device) + return torch.stack(masks, dim=0).to(device, dtype=torch.float32) + + +class MaskRCNNLossComputation(object): + def __init__(self, proposal_matcher, discretization_size, vl_version=False): + """ + Arguments: + proposal_matcher (Matcher) + discretization_size (int) + """ + self.proposal_matcher = proposal_matcher + self.discretization_size = discretization_size + self.vl_version = vl_version + + def match_targets_to_proposals(self, proposal, target): + match_quality_matrix = boxlist_iou(target, proposal) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # Mask RCNN needs "labels" and "masks "fields for creating the targets + if self.vl_version: + target = target.copy_with_fields(["positive_map", "masks"]) + else: + target = target.copy_with_fields(["labels", "masks"]) + # get the targets corresponding GT for each proposal + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, proposals, targets): + labels = [] + masks = [] + positive_maps = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + matched_targets = self.match_targets_to_proposals( + proposals_per_image, targets_per_image + ) + matched_idxs = matched_targets.get_field("matched_idxs") + + if self.vl_version: + positive_maps_per_image = matched_targets.get_field("positive_map") + + # this can probably be removed, but is left here for clarity + # and completeness + neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + positive_maps_per_image[neg_inds, :] = 0 + + positive_maps.append(positive_maps_per_image) + + # TODO: make sure for the softmax [NoObj] case + labels_per_image = positive_maps_per_image.sum(dim=-1) + labels_per_image = labels_per_image.to(dtype=torch.int64) + else: + labels_per_image = matched_targets.get_field("labels") + labels_per_image = labels_per_image.to(dtype=torch.int64) + + # this can probably be removed, but is left here for clarity + # and completeness + neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[neg_inds] = 0 + + # mask scores are only computed on positive samples + positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) + + segmentation_masks = matched_targets.get_field("masks") + segmentation_masks = segmentation_masks[positive_inds] + + positive_proposals = proposals_per_image[positive_inds] + + masks_per_image = project_masks_on_boxes( + segmentation_masks, positive_proposals, self.discretization_size + ) + + labels.append(labels_per_image) + masks.append(masks_per_image) + + return labels, masks, positive_maps + + def __call__(self, proposals, mask_logits, targets): + """ + Arguments: + proposals (list[BoxList]) + mask_logits (Tensor) + targets (list[BoxList]) + + Return: + mask_loss (Tensor): scalar tensor containing the loss + """ + labels, mask_targets, positive_maps = self.prepare_targets(proposals, targets) + + labels = cat(labels, dim=0) + mask_targets = cat(mask_targets, dim=0) + + positive_inds = torch.nonzero(labels > 0).squeeze(1) + labels_pos = labels[positive_inds] + # TODO: a hack for binary mask head + labels_pos = (labels_pos > 0).to(dtype=torch.int64) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if mask_targets.numel() == 0: + return mask_logits.sum() * 0 + + if self.vl_version: + positive_maps = cat(positive_maps, dim=0) + mask_logits_pos = [] + for positive_ind in positive_inds: + positive_map = positive_maps[positive_ind] + # TODO: make sure for the softmax [NoObj] case + mask_logit_pos = mask_logits[positive_ind][torch.nonzero(positive_map).squeeze(1)].mean(dim=0, keepdim=True) + mask_logits_pos.append(mask_logit_pos) + mask_logits_pos = cat(mask_logits_pos, dim=0) + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits_pos, mask_targets + ) + else: + mask_loss = F.binary_cross_entropy_with_logits( + mask_logits[positive_inds, labels_pos], mask_targets + ) + return mask_loss + + +def make_roi_mask_loss_evaluator(cfg): + matcher = Matcher( + cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, + cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, + allow_low_quality_matches=False, + ) + + loss_evaluator = MaskRCNNLossComputation( + matcher, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION, + vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL") + ) + + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..aff8fb640becf3a455c1eaaf72ad72b5c079a5fe --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList + +from .roi_mask_feature_extractors import make_roi_mask_feature_extractor +from .roi_mask_predictors import make_roi_mask_predictor +from .inference import make_roi_mask_post_processor +from .loss import make_roi_mask_loss_evaluator + + +def keep_only_positive_boxes(boxes): + """ + Given a set of BoxList containing the `labels` field, + return a set of BoxList for which `labels > 0`. + + Arguments: + boxes (list of BoxList) + """ + assert isinstance(boxes, (list, tuple)) + assert isinstance(boxes[0], BoxList) + assert boxes[0].has_field("labels") + positive_boxes = [] + positive_inds = [] + num_boxes = 0 + for boxes_per_image in boxes: + labels = boxes_per_image.get_field("labels") + inds_mask = labels > 0 + inds = inds_mask.nonzero().squeeze(1) + positive_boxes.append(boxes_per_image[inds]) + positive_inds.append(inds_mask) + return positive_boxes, positive_inds + + +class ROIMaskHead(torch.nn.Module): + def __init__(self, cfg): + super(ROIMaskHead, self).__init__() + self.cfg = cfg.clone() + self.feature_extractor = make_roi_mask_feature_extractor(cfg) + self.predictor = make_roi_mask_predictor(cfg) + self.post_processor = make_roi_mask_post_processor(cfg) + self.loss_evaluator = make_roi_mask_loss_evaluator(cfg) + + def forward(self, features, proposals, targets=None, + language_dict_features=None, + positive_map_label_to_token=None + ): + """ + Arguments: + features (list[Tensor]): feature-maps from possibly several levels + proposals (list[BoxList]): proposal boxes + targets (list[BoxList], optional): the ground-truth targets. + language_dict_features: language features: hidden, embedding, mask, ... + + Returns: + x (Tensor): the result of the feature extractor + proposals (list[BoxList]): during training, the original proposals + are returned. During testing, the predicted boxlists are returned + with the `mask` field set + losses (dict[Tensor]): During training, returns the losses for the + head. During testing, returns an empty dict. + """ + if self.training: + # during training, only focus on positive boxes + all_proposals = proposals + proposals, positive_inds = keep_only_positive_boxes(proposals) + if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR: + x = features + x = x[torch.cat(positive_inds, dim=0)] + else: + x = self.feature_extractor(features, proposals) + if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"): + mask_logits = self.predictor(x, language_dict_features) + else: + mask_logits = self.predictor(x) + + if not self.training: + result = self.post_processor(mask_logits, proposals, positive_map_label_to_token) + return x, result, {} + + loss_mask = self.loss_evaluator(proposals, mask_logits, targets) + + return x, all_proposals, dict(loss_mask=loss_mask) + + +def build_roi_mask_head(cfg): + return ROIMaskHead(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..c891feb22703e5d47be37ec20189c4c2bbd7c14c --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_feature_extractors.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from torch import nn +from torch.nn import functional as F + +from .hourglass import Hourglass +from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor +from maskrcnn_benchmark.modeling.poolers import Pooler +from maskrcnn_benchmark.layers import Conv2d +from maskrcnn_benchmark.modeling.make_layers import make_conv3x3 + + + +class MaskRCNNFPNFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + """ + Arguments: + num_classes (int): number of output classes + input_size (int): number of channels of the input once it's flattened + representation_size (int): size of the intermediate representation + """ + super(MaskRCNNFPNFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS + self.pooler = pooler + + use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN + layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS + dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION + + next_feature = input_size + self.blocks = [] + for layer_idx, layer_features in enumerate(layers, 1): + layer_name = "mask_fcn{}".format(layer_idx) + module = make_conv3x3(next_feature, layer_features, + dilation=dilation, stride=1, use_gn=use_gn + ) + self.add_module(layer_name, module) + next_feature = layer_features + self.blocks.append(layer_name) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + + for layer_name in self.blocks: + x = F.relu(getattr(self, layer_name)(x)) + + return x + + +class HourglassFPNFeatureExtractor(nn.Module): + """ + Heads for FPN for classification + """ + + def __init__(self, cfg): + """ + Arguments: + num_classes (int): number of output classes + input_size (int): number of channels of the input once it's flattened + representation_size (int): size of the intermediate representation + """ + super(HourglassFPNFeatureExtractor, self).__init__() + + resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + scales = cfg.MODEL.ROI_MASK_HEAD.POOLER_SCALES + sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS + self.pooler = pooler + + use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN + layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS + scale = cfg.MODEL.ROI_MASK_HEAD.HG_SCALE + + assert input_size==layers[0] + self.blocks = [] + for layer_idx, layer_features in enumerate(layers, 1): + layer_name = "mask_hg{}".format(layer_idx) + module = Hourglass(scale, layer_features, gn=use_gn) + self.add_module(layer_name, module) + self.blocks.append(layer_name) + + def forward(self, x, proposals): + x = self.pooler(x, proposals) + + for layer_name in self.blocks: + x = F.relu(getattr(self, layer_name)(x)) + + return x + + +_ROI_MASK_FEATURE_EXTRACTORS = { + "ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor, + "MaskRCNNFPNFeatureExtractor": MaskRCNNFPNFeatureExtractor, + "HourglassFPNFeatureExtractor": HourglassFPNFeatureExtractor, +} + + +def make_roi_mask_feature_extractor(cfg): + func = _ROI_MASK_FEATURE_EXTRACTORS[cfg.MODEL.ROI_MASK_HEAD.FEATURE_EXTRACTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7e7ff2fb28e3ee39750cbbec39a46539f0c455 --- /dev/null +++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/roi_mask_predictors.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +from torch import nn +from torch.nn import functional as F + +from maskrcnn_benchmark.layers import Conv2d, _NewEmptyTensorOp +from maskrcnn_benchmark.layers import ConvTranspose2d +from ...utils import permute_and_flatten + + +class MaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(MaskRCNNC4Predictor, self).__init__() + # TODO: a hack for binary mask head + # num_classes = cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES + num_classes = 2 + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x): + x = F.relu(self.conv5_mask(x)) + return self.mask_fcn_logits(x) + + +class VLMaskRCNNC4Predictor(nn.Module): + def __init__(self, cfg): + super(VLMaskRCNNC4Predictor, self).__init__() + dim_reduced = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS[-1] + + if cfg.MODEL.ROI_HEADS.USE_FPN: + num_inputs = dim_reduced + else: + stage_index = 4 + stage2_relative_factor = 2 ** (stage_index - 1) + res2_out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + num_inputs = res2_out_channels * stage2_relative_factor + + self.conv5_mask = ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + + # self.mask_fcn_logits = Conv2d(dim_reduced, num_classes, 1, 1, 0) + log_scale = cfg.MODEL.DYHEAD.LOG_SCALE + self.out_dim = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN + self.dot_product_projection_image = nn.Identity() + self.dot_product_projection_text = nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, + dim_reduced, bias=True) + self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) + self.bias_lang = nn.Parameter(torch.zeros(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True) + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + def forward(self, x, language_dict_features): + x = F.relu(self.conv5_mask(x)) + if x.numel() <= 0: + output_shape = [x.shape[0], self.out_dim] + x.shape[-2:] + return _NewEmptyTensorOp.apply(x, output_shape) + + embedding = language_dict_features["hidden"] + # norm + embedding = F.normalize(embedding, p=2, dim=-1) + dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) + dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) + + B, C, H, W = x.shape + # add bias (language) + dot_product_proj_queries = self.dot_product_projection_image(x) + dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W) + A = dot_product_proj_queries.shape[1] + bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1) + + # dot product + dot_product_logit = (torch.matmul(dot_product_proj_queries, + dot_product_proj_tokens.transpose(-1, + -2)) / self.log_scale.exp()) + bias + # clamp for stability + dot_product_logit = torch.clamp(dot_product_logit, max=50000) + dot_product_logit = torch.clamp(dot_product_logit, min=-50000) + dot_product_logit = dot_product_logit.view(B, H, W, self.out_dim).permute(0, 3, 1, 2) + return dot_product_logit + + +_ROI_MASK_PREDICTOR = {"MaskRCNNC4Predictor": MaskRCNNC4Predictor, + "VLMaskRCNNC4Predictor": VLMaskRCNNC4Predictor} + + +def make_roi_mask_predictor(cfg): + func = _ROI_MASK_PREDICTOR[cfg.MODEL.ROI_MASK_HEAD.PREDICTOR] + return func(cfg) diff --git a/maskrcnn_benchmark/modeling/rpn/__init__.py b/maskrcnn_benchmark/modeling/rpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49e6ed6932c3011f357a1c9a97fce632ae9e6eb3 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# from .rpn import build_rpn +from .rpn import RPNModule +from .retina import RetinaNetModule +from .fcos import FCOSModule +from .atss import ATSSModule +from .dyhead import DyHeadModule +from .vldyhead import VLDyHeadModule + +_RPN_META_ARCHITECTURES = {"RPN": RPNModule, + "RETINA": RetinaNetModule, + "FCOS": FCOSModule, + "ATSS": ATSSModule, + "DYHEAD": DyHeadModule, + "VLDYHEAD": VLDyHeadModule + } + + +def build_rpn(cfg): + """ + This gives the gist of it. Not super important because it doesn't change as much + """ + rpn_arch = _RPN_META_ARCHITECTURES[cfg.MODEL.RPN_ARCHITECTURE] + return rpn_arch(cfg) diff --git a/maskrcnn_benchmark/modeling/rpn/anchor_generator.py b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c396730280e0f8f549872ad0403de36a2a626321 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py @@ -0,0 +1,425 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math + +import numpy as np +import torch +from torch import nn + +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.image_list import ImageList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist + +class BufferList(nn.Module): + """ + Similar to nn.ParameterList, but for buffers + """ + + def __init__(self, buffers=None): + super(BufferList, self).__init__() + if buffers is not None: + self.extend(buffers) + + def extend(self, buffers): + offset = len(self) + for i, buffer in enumerate(buffers): + self.register_buffer(str(offset + i), buffer) + return self + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.values()) + + +class AnchorGenerator(nn.Module): + """ + For a set of image sizes and feature maps, computes a set + of anchors + """ + + def __init__( + self, + sizes=(128, 256, 512), + aspect_ratios=(0.5, 1.0, 2.0), + anchor_strides=(8, 16, 32), + straddle_thresh=0, + ): + super(AnchorGenerator, self).__init__() + + if len(anchor_strides) == 1: + anchor_stride = anchor_strides[0] + cell_anchors = [ + generate_anchors(anchor_stride, sizes, aspect_ratios).float() + ] + else: + if len(anchor_strides) != len(sizes): + raise RuntimeError("FPN should have #anchor_strides == #sizes") + cell_anchors = [ + generate_anchors( + anchor_stride, + size if isinstance(size, (tuple, list)) else (size,), + aspect_ratios + ).float() + for anchor_stride, size in zip(anchor_strides, sizes) + ] + self.strides = anchor_strides + self.cell_anchors = BufferList(cell_anchors) + self.straddle_thresh = straddle_thresh + + def num_anchors_per_location(self): + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip( + grid_sizes, self.strides, self.cell_anchors + ): + grid_height, grid_width = size + device = base_anchors.device + shifts_x = torch.arange( + 0, grid_width * stride, step=stride, dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, grid_height * stride, step=stride, dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append( + (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) + ) + + return anchors + + def add_visibility_to(self, boxlist): + image_width, image_height = boxlist.size + anchors = boxlist.bbox + if self.straddle_thresh >= 0: + inds_inside = ( + (anchors[..., 0] >= -self.straddle_thresh) + & (anchors[..., 1] >= -self.straddle_thresh) + & (anchors[..., 2] < image_width + self.straddle_thresh) + & (anchors[..., 3] < image_height + self.straddle_thresh) + ) + else: + device = anchors.device + inds_inside = torch.ones(anchors.shape[0], dtype=torch.bool, device=device) + boxlist.add_field("visibility", inds_inside) + + def forward(self, image_list, feature_maps): + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) + anchors = [] + if isinstance(image_list, ImageList): + for i, (image_height, image_width) in enumerate(image_list.image_sizes): + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + boxlist = BoxList( + anchors_per_feature_map, (image_width, image_height), mode="xyxy" + ) + self.add_visibility_to(boxlist) + anchors_in_image.append(boxlist) + anchors.append(anchors_in_image) + else: + image_height, image_width = [int(x) for x in image_list.size()[-2:]] + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + boxlist = BoxList( + anchors_per_feature_map, (image_width, image_height), mode="xyxy" + ) + self.add_visibility_to(boxlist) + anchors_in_image.append(boxlist) + anchors.append(anchors_in_image) + return anchors + + +def make_anchor_generator(config): + anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES + aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS + anchor_stride = config.MODEL.RPN.ANCHOR_STRIDE + straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH + + if config.MODEL.RPN.USE_FPN: + assert len(anchor_stride) == len( + anchor_sizes + ), "FPN should have len(ANCHOR_STRIDE) == len(ANCHOR_SIZES)" + else: + assert len(anchor_stride) == 1, "Non-FPN should have a single ANCHOR_STRIDE" + anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios, anchor_stride, straddle_thresh + ) + return anchor_generator + + +def make_anchor_generator_complex(config): + anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES + aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS + anchor_strides = config.MODEL.RPN.ANCHOR_STRIDE + straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH + octave = config.MODEL.RPN.OCTAVE + scales_per_octave = config.MODEL.RPN.SCALES_PER_OCTAVE + + if config.MODEL.RPN.USE_FPN: + assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now" + new_anchor_sizes = [] + for size in anchor_sizes: + per_layer_anchor_sizes = [] + for scale_per_octave in range(scales_per_octave): + octave_scale = octave ** (scale_per_octave / float(scales_per_octave)) + per_layer_anchor_sizes.append(octave_scale * size) + new_anchor_sizes.append(tuple(per_layer_anchor_sizes)) + else: + assert len(anchor_strides) == 1, "Non-FPN should have a single ANCHOR_STRIDE" + new_anchor_sizes = anchor_sizes + + anchor_generator = AnchorGenerator( + tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh + ) + return anchor_generator + + +class CenterAnchorGenerator(nn.Module): + """ + For a set of image sizes and feature maps, computes a set + of anchors + """ + + def __init__( + self, + sizes=(128, 256, 512), + aspect_ratios=(0.5, 1.0, 2.0), + anchor_strides=(8, 16, 32), + straddle_thresh=0, + anchor_shift=(0.0, 0.0, 0.0, 0.0), + use_relative=False + ): + super(CenterAnchorGenerator, self).__init__() + + self.sizes = sizes + self.aspect_ratios = aspect_ratios + self.strides = anchor_strides + self.straddle_thresh = straddle_thresh + self.anchor_shift = anchor_shift + self.use_relative = use_relative + + def add_visibility_to(self, boxlist): + image_width, image_height = boxlist.size + anchors = boxlist.bbox + if self.straddle_thresh >= 0: + inds_inside = ( + (anchors[..., 0] >= -self.straddle_thresh) + & (anchors[..., 1] >= -self.straddle_thresh) + & (anchors[..., 2] < image_width + self.straddle_thresh) + & (anchors[..., 3] < image_height + self.straddle_thresh) + ) + else: + device = anchors.device + inds_inside = torch.ones(anchors.shape[0], dtype=torch.uint8, device=device) + boxlist.add_field("visibility", inds_inside) + + def forward(self, centers, image_sizes, feature_maps): + shift_left, shift_top, shift_right, shift_down = self.anchor_shift + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + anchors = [] + for i, ((image_height, image_width), center_bbox) in enumerate(zip(image_sizes, centers)): + center = center_bbox.get_field("centers") + boxlist_per_level = [] + for size, fsize in zip(self.sizes, grid_sizes): + for ratios in self.aspect_ratios: + + size_ratios = size*size / ratios + ws = np.round(np.sqrt(size_ratios)) + hs = np.round(ws * ratios) + + anchors_per_level = torch.cat( + ( + center[:,0,None] - 0.5 * (1 + shift_left) * (ws - 1), + center[:,1,None] - 0.5 * (1 + shift_top) * (hs - 1), + center[:,0,None] + 0.5 * (1 + shift_right) * (ws - 1), + center[:,1,None] + 0.5 * (1 + shift_down) * (hs - 1), + ), + dim=1 + ) + boxlist = BoxList(anchors_per_level, (image_width, image_height), mode="xyxy") + boxlist.add_field('cbox', center_bbox) + self.add_visibility_to(boxlist) + boxlist_per_level.append(boxlist) + if self.use_relative: + area = center_bbox.area() + for ratios in self.aspect_ratios: + + size_ratios = area / ratios + ws = torch.round(torch.sqrt(size_ratios)) + hs = torch.round(ws * ratios) + + anchors_per_level = torch.stack( + ( + center[:,0] - (1 + shift_left) * ws, + center[:,1] - (1 + shift_top) * hs, + center[:,0] + (1 + shift_right) * ws, + center[:,1] + (1 + shift_down) * hs, + ), + dim=1 + ) + boxlist = BoxList(anchors_per_level, (image_width, image_height), mode="xyxy") + boxlist.add_field('cbox', center_bbox) + self.add_visibility_to(boxlist) + boxlist_per_level.append(boxlist) + anchors_in_image = cat_boxlist(boxlist_per_level) + anchors.append(anchors_in_image) + return anchors + + +def make_center_anchor_generator(config): + anchor_sizes = config.MODEL.RPN.ANCHOR_SIZES + aspect_ratios = config.MODEL.RPN.ASPECT_RATIOS + anchor_strides = config.MODEL.RPN.ANCHOR_STRIDE + straddle_thresh = config.MODEL.RPN.STRADDLE_THRESH + octave = config.MODEL.RPN.OCTAVE + scales_per_octave = config.MODEL.RPN.SCALES_PER_OCTAVE + anchor_shift = config.MODEL.RPN.ANCHOR_SHIFT + use_relative = config.MODEL.RPN.USE_RELATIVE_SIZE + + if config.MODEL.RPN.USE_FPN: + assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now" + new_anchor_sizes = [] + for size in anchor_sizes: + per_layer_anchor_sizes = [] + for scale_per_octave in range(scales_per_octave): + octave_scale = octave ** (scale_per_octave / float(scales_per_octave)) + per_layer_anchor_sizes.append(octave_scale * size) + new_anchor_sizes.append(tuple(per_layer_anchor_sizes)) + else: + assert len(anchor_strides) == 1, "Non-FPN should have a single ANCHOR_STRIDE" + new_anchor_sizes = anchor_sizes + + anchor_generator = CenterAnchorGenerator( + tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh, anchor_shift, use_relative + ) + return anchor_generator + +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +# +# Based on: +# -------------------------------------------------------- +# Faster R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick and Sean Bell +# -------------------------------------------------------- + + +# Verify that we compute the same anchors as Shaoqing's matlab implementation: +# +# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat +# >> anchors +# +# anchors = +# +# -83 -39 100 56 +# -175 -87 192 104 +# -359 -183 376 200 +# -55 -55 72 72 +# -119 -119 136 136 +# -247 -247 264 264 +# -35 -79 52 96 +# -79 -167 96 184 +# -167 -343 184 360 + +# array([[ -83., -39., 100., 56.], +# [-175., -87., 192., 104.], +# [-359., -183., 376., 200.], +# [ -55., -55., 72., 72.], +# [-119., -119., 136., 136.], +# [-247., -247., 264., 264.], +# [ -35., -79., 52., 96.], +# [ -79., -167., 96., 184.], +# [-167., -343., 184., 360.]]) + + +def generate_anchors( + stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2) +): + """Generates a matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors + are centered on stride / 2, have (approximate) sqrt areas of the specified + sizes, and aspect ratios as given. + """ + return _generate_anchors( + stride, + np.array(sizes, dtype=np.float) / stride, + np.array(aspect_ratios, dtype=np.float), + ) + + +def _generate_anchors(base_size, scales, aspect_ratios): + """Generate anchor (reference) windows by enumerating aspect ratios X + scales wrt a reference (0, 0, base_size - 1, base_size - 1) window. + """ + anchor = np.array([1, 1, base_size, base_size], dtype=np.float) - 1 + anchors = _ratio_enum(anchor, aspect_ratios) + anchors = np.vstack( + [_scale_enum(anchors[i, :], scales) for i in range(anchors.shape[0])] + ) + return torch.from_numpy(anchors) + + +def _whctrs(anchor): + """Return width, height, x center, and y center for an anchor (window).""" + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr + + +def _mkanchors(ws, hs, x_ctr, y_ctr): + """Given a vector of widths (ws) and heights (hs) around a center + (x_ctr, y_ctr), output a set of anchors (windows). + """ + ws = ws[:, np.newaxis] + hs = hs[:, np.newaxis] + anchors = np.hstack( + ( + x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1), + ) + ) + return anchors + + +def _ratio_enum(anchor, ratios): + """Enumerate a set of anchors for each aspect ratio wrt an anchor.""" + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = np.round(np.sqrt(size_ratios)) + hs = np.round(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + + +def _scale_enum(anchor, scales): + """Enumerate a set of anchors for each scale wrt an anchor.""" + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = w * scales + hs = h * scales + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors diff --git a/maskrcnn_benchmark/modeling/rpn/atss.py b/maskrcnn_benchmark/modeling/rpn/atss.py new file mode 100644 index 0000000000000000000000000000000000000000..f1132522ad4491477c0dd320d43b78351daf69b4 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/atss.py @@ -0,0 +1,233 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn + +from .inference import make_atss_postprocessor +from .loss import make_atss_loss_evaluator + +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.layers import Scale, DFConv2d, DYReLU, SELayer +from .anchor_generator import make_anchor_generator_complex + + +class BoxCoder(object): + + def __init__(self, cfg): + self.cfg = cfg + + def encode(self, gt_boxes, anchors): + + TO_REMOVE = 1 # TODO remove + ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE + gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE + gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2 + gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + + return targets + + def decode(self, preds, anchors): + + anchors = anchors.to(preds.dtype) + + TO_REMOVE = 1 # TODO remove + widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + dx = preds[:, 0::4] / wx + dy = preds[:, 1::4] / wy + dw = preds[:, 2::4] / ww + dh = preds[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=math.log(1000. / 16)) + dh = torch.clamp(dh, max=math.log(1000. / 16)) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(preds) + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1) + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1) + + return pred_boxes + + +class ATSSHead(torch.nn.Module): + def __init__(self, cfg): + super(ATSSHead, self).__init__() + self.cfg = cfg + num_classes = cfg.MODEL.ATSS.NUM_CLASSES - 1 + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + channels = cfg.MODEL.ATSS.CHANNELS + use_gn = cfg.MODEL.ATSS.USE_GN + use_bn = cfg.MODEL.ATSS.USE_BN + use_dcn_in_tower = cfg.MODEL.ATSS.USE_DFCONV + use_dyrelu = cfg.MODEL.ATSS.USE_DYRELU + use_se = cfg.MODEL.ATSS.USE_SE + + cls_tower = [] + bbox_tower = [] + for i in range(cfg.MODEL.ATSS.NUM_CONVS): + if use_dcn_in_tower and \ + i == cfg.MODEL.ATSS.NUM_CONVS - 1: + conv_func = DFConv2d + else: + conv_func = nn.Conv2d + + cls_tower.append( + conv_func( + in_channels if i==0 else channels, + channels, + kernel_size=3, + stride=1, + padding=1, + bias=True + ) + ) + if use_gn: + cls_tower.append(nn.GroupNorm(32, channels)) + if use_bn: + cls_tower.append(nn.BatchNorm2d(channels)) + if use_se: + cls_tower.append(SELayer(channels)) + if use_dyrelu: + cls_tower.append(DYReLU(channels, channels)) + else: + cls_tower.append(nn.ReLU()) + + bbox_tower.append( + conv_func( + in_channels if i == 0 else channels, + channels, + kernel_size=3, + stride=1, + padding=1, + bias=True + ) + ) + if use_gn: + bbox_tower.append(nn.GroupNorm(32, channels)) + if use_bn: + bbox_tower.append(nn.BatchNorm2d(channels)) + if use_se: + bbox_tower.append(SELayer(channels)) + if use_dyrelu: + bbox_tower.append(DYReLU(channels, channels)) + else: + bbox_tower.append(nn.ReLU()) + + self.add_module('cls_tower', nn.Sequential(*cls_tower)) + self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) + self.cls_logits = nn.Conv2d( + channels, num_anchors * num_classes, kernel_size=3, stride=1, + padding=1 + ) + self.bbox_pred = nn.Conv2d( + channels, num_anchors * 4, kernel_size=3, stride=1, + padding=1 + ) + self.centerness = nn.Conv2d( + channels, num_anchors * 1, kernel_size=3, stride=1, + padding=1 + ) + + # initialization + for modules in [self.cls_tower, self.bbox_tower, + self.cls_logits, self.bbox_pred, + self.centerness]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + # initialize the bias for focal loss + prior_prob = cfg.MODEL.ATSS.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) + + def forward(self, x): + logits = [] + bbox_reg = [] + centerness = [] + for l, feature in enumerate(x): + cls_tower = self.cls_tower(feature) + box_tower = self.bbox_tower(feature) + + logits.append(self.cls_logits(cls_tower)) + + bbox_pred = self.scales[l](self.bbox_pred(box_tower)) + bbox_reg.append(bbox_pred) + + centerness.append(self.centerness(box_tower)) + return logits, bbox_reg, centerness + + +class ATSSModule(torch.nn.Module): + + def __init__(self, cfg): + super(ATSSModule, self).__init__() + self.cfg = cfg + self.head = ATSSHead(cfg) + box_coder = BoxCoder(cfg) + self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder) + self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True) + self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False) + self.anchor_generator = make_anchor_generator_complex(cfg) + + def forward(self, images, features, targets=None): + box_cls, box_regression, centerness = self.head(features) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(box_cls, box_regression, centerness, targets, anchors) + else: + return self._forward_test(box_cls, box_regression, centerness, anchors) + + def _forward_train(self, box_cls, box_regression, centerness, targets, anchors): + loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator( + box_cls, box_regression, centerness, targets, anchors + ) + losses = { + "loss_cls": loss_box_cls, + "loss_reg": loss_box_reg, + "loss_centerness": loss_centerness + } + if self.cfg.MODEL.RPN_ONLY: + return None, losses + else: + boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors) + train_boxes = [] + for b, a in zip(boxes, anchors): + a = cat_boxlist(a) + b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device)) + del b.extra_fields['scores'] + del b.extra_fields['labels'] + train_boxes.append(cat_boxlist([b, a])) + return train_boxes, losses + + def _forward_test(self, box_cls, box_regression, centerness, anchors): + boxes = self.box_selector_test(box_cls, box_regression, centerness, anchors) + return boxes, {} diff --git a/maskrcnn_benchmark/modeling/rpn/dyhead.py b/maskrcnn_benchmark/modeling/rpn/dyhead.py new file mode 100644 index 0000000000000000000000000000000000000000..e84cd3a9d28ed337bece0c87689301b865642324 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/dyhead.py @@ -0,0 +1,377 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn + +from .inference import make_atss_postprocessor +from .loss import make_atss_loss_evaluator +from .anchor_generator import make_anchor_generator_complex + +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv +from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d +from maskrcnn_benchmark.modeling.backbone.fbnet import * + + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True, h_max=1): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + self.h_max = h_max + + def forward(self, x): + return self.relu(x + 3) * self.h_max / 6 + + +class BoxCoder(object): + + def __init__(self, cfg): + self.cfg = cfg + + def encode(self, gt_boxes, anchors): + TO_REMOVE = 1 # TODO remove + ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE + gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE + gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2 + gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + + return targets + + def decode(self, preds, anchors): + anchors = anchors.to(preds.dtype) + + TO_REMOVE = 1 # TODO remove + widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + dx = preds[:, 0::4] / wx + dy = preds[:, 1::4] / wy + dw = preds[:, 2::4] / ww + dh = preds[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=math.log(1000. / 16)) + dh = torch.clamp(dh, max=math.log(1000. / 16)) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(preds) + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1) + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1) + + return pred_boxes + + +class Conv3x3Norm(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + stride, + groups=1, + deformable=False, + bn_type=None): + super(Conv3x3Norm, self).__init__() + + if deformable: + self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, + groups=groups) + else: + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) + + if isinstance(bn_type, (list, tuple)): + assert len(bn_type) == 2 + assert bn_type[0] == "gn" + gn_group = bn_type[1] + bn_type = bn_type[0] + + if bn_type == "bn": + bn_op = nn.BatchNorm2d(out_channels) + elif bn_type == "sbn": + bn_op = nn.SyncBatchNorm(out_channels) + elif bn_type == "nsbn": + bn_op = NaiveSyncBatchNorm2d(out_channels) + elif bn_type == "gn": + bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels) + elif bn_type == "af": + bn_op = FrozenBatchNorm2d(out_channels) + if bn_type is not None: + self.bn = bn_op + else: + self.bn = None + + def forward(self, input, **kwargs): + x = self.conv(input, **kwargs) + if self.bn: + x = self.bn(x) + return x + + +class DyConv(torch.nn.Module): + def __init__(self, + in_channels=256, + out_channels=256, + conv_func=nn.Conv2d, + use_dyfuse=True, + use_dyrelu=False, + use_deform=False + ): + super(DyConv, self).__init__() + + self.DyConv = nn.ModuleList() + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 2)) + + if use_dyfuse: + self.AttnConv = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, 1, kernel_size=1), + nn.ReLU(inplace=True)) + self.h_sigmoid = h_sigmoid() + else: + self.AttnConv = None + + if use_dyrelu: + self.relu = DYReLU(in_channels, out_channels) + else: + self.relu = nn.ReLU() + + if use_deform: + self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) + else: + self.offset = None + + self.init_weights() + + def init_weights(self): + for m in self.DyConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + if self.AttnConv is not None: + for m in self.AttnConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + next_x = [] + for level, feature in enumerate(x): + + conv_args = dict() + if self.offset is not None: + offset_mask = self.offset(feature) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, 18:, :, :].sigmoid() + conv_args = dict(offset=offset, mask=mask) + + temp_fea = [self.DyConv[1](feature, **conv_args)] + + if level > 0: + temp_fea.append(self.DyConv[2](x[level - 1], **conv_args)) + if level < len(x) - 1: + temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args), + size=[feature.size(2), feature.size(3)])) + mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) + + if self.AttnConv is not None: + attn_fea = [] + res_fea = [] + for fea in temp_fea: + res_fea.append(fea) + attn_fea.append(self.AttnConv(fea)) + + res_fea = torch.stack(res_fea) + spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) + + mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) + + next_x.append(mean_fea) + + next_x = [self.relu(item) for item in next_x] + return next_x + + +class DyHead(torch.nn.Module): + def __init__(self, cfg): + super(DyHead, self).__init__() + self.cfg = cfg + num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1 + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + channels = cfg.MODEL.DYHEAD.CHANNELS + if cfg.MODEL.DYHEAD.USE_GN: + bn_type = ['gn', cfg.MODEL.GROUP_NORM.NUM_GROUPS] + elif cfg.MODEL.DYHEAD.USE_NSYNCBN: + bn_type = 'nsbn' + elif cfg.MODEL.DYHEAD.USE_SYNCBN: + bn_type = 'sbn' + else: + bn_type = None + + use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU + use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE + use_deform = cfg.MODEL.DYHEAD.USE_DFCONV + + if cfg.MODEL.DYHEAD.CONV_FUNC: + conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type) + else: + conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type) + + dyhead_tower = [] + for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): + dyhead_tower.append( + DyConv( + in_channels if i == 0 else channels, + channels, + conv_func=conv_func, + use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu, + use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse, + use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform, + ) + ) + + self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) + if cfg.MODEL.DYHEAD.COSINE_SCALE <= 0: + self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1) + self.cls_logits_bias = None + else: + self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1, bias=False) + self.cls_logits_bias = nn.Parameter(torch.zeros(num_anchors * num_classes, requires_grad=True)) + self.cosine_scale = nn.Parameter(torch.ones(1) * cfg.MODEL.DYHEAD.COSINE_SCALE) + self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1) + self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1) + + # initialization + for modules in [self.cls_logits, self.bbox_pred, + self.centerness]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + if hasattr(l, 'bias') and l.bias is not None: + torch.nn.init.constant_(l.bias, 0) + + # initialize the bias for focal loss + prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + if self.cls_logits_bias is None: + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + else: + torch.nn.init.constant_(self.cls_logits_bias, bias_value) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) + + def extract_feature(self, x): + output = [] + for i in range(len(self.dyhead_tower)): + x = self.dyhead_tower[i](x) + output.append(x) + return output + + def forward(self, x): + logits = [] + bbox_reg = [] + centerness = [] + + dyhead_tower = self.dyhead_tower(x) + + for l, feature in enumerate(x): + if self.cls_logits_bias is None: + logit = self.cls_logits(dyhead_tower[l]) + else: + # CosineSimOutputLayers: https://github.com/ucbdrive/few-shot-object-detection/blob/master/fsdet/modeling/roi_heads/fast_rcnn.py#L448-L464 + # normalize the input x along the `channel` dimension + x_norm = torch.norm(dyhead_tower[l], p=2, dim=1, keepdim=True).expand_as(dyhead_tower[l]) + x_normalized = dyhead_tower[l].div(x_norm + 1e-5) + # normalize weight + temp_norm = ( + torch.norm(self.cls_logits.weight.data, p=2, dim=1, keepdim=True) + .expand_as(self.cls_logits.weight.data) + ) + self.cls_logits.weight.data = self.cls_logits.weight.data.div( + temp_norm + 1e-5 + ) + cos_dist = self.cls_logits(x_normalized) + logit = self.cosine_scale * cos_dist + self.cls_logits_bias.reshape(1, len(self.cls_logits_bias), 1, 1) + logits.append(logit) + + bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower[l])) + bbox_reg.append(bbox_pred) + + centerness.append(self.centerness(dyhead_tower[l])) + return logits, bbox_reg, centerness + + +class DyHeadModule(torch.nn.Module): + + def __init__(self, cfg): + super(DyHeadModule, self).__init__() + self.cfg = cfg + self.head = DyHead(cfg) + box_coder = BoxCoder(cfg) + self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder) + self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True) + self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False) + self.anchor_generator = make_anchor_generator_complex(cfg) + + def forward(self, images, features, targets=None): + box_cls, box_regression, centerness = self.head(features) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(box_cls, box_regression, centerness, targets, anchors) + else: + return self._forward_test(box_cls, box_regression, centerness, anchors) + + def _forward_train(self, box_cls, box_regression, centerness, targets, anchors): + loss_box_cls, loss_box_reg, loss_centerness, _, _, _, _ = self.loss_evaluator( + box_cls, box_regression, centerness, targets, anchors + ) + losses = { + "loss_cls": loss_box_cls, + "loss_reg": loss_box_reg, + "loss_centerness": loss_centerness + } + if self.cfg.MODEL.RPN_ONLY: + return None, losses + else: + # boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors) + boxes = self.box_selector_train(box_regression, centerness, anchors, box_cls) + train_boxes = [] + # for b, a in zip(boxes, anchors): + # a = cat_boxlist(a) + # b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device)) + # del b.extra_fields['scores'] + # del b.extra_fields['labels'] + # train_boxes.append(cat_boxlist([b, a])) + for b, t in zip(boxes, targets): + tb = t.copy_with_fields(["labels"]) + tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) + train_boxes.append(cat_boxlist([b, tb])) + return train_boxes, losses + + def _forward_test(self, box_cls, box_regression, centerness, anchors): + boxes = self.box_selector_test(box_regression, centerness, anchors, box_cls) + return boxes, {} diff --git a/maskrcnn_benchmark/modeling/rpn/fcos.py b/maskrcnn_benchmark/modeling/rpn/fcos.py new file mode 100644 index 0000000000000000000000000000000000000000..c69dab0fd86d7b891ee001228368294a0fd56ae4 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/fcos.py @@ -0,0 +1,236 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.layers import Scale, DFConv2d +from .loss import make_fcos_loss_evaluator +from .anchor_generator import make_center_anchor_generator +from .inference import make_fcos_postprocessor + + +@registry.RPN_HEADS.register("FCOSHead") +class FCOSHead(torch.nn.Module): + def __init__(self, cfg): + + super(FCOSHead, self).__init__() + # TODO: Implement the sigmoid version first. + num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1 + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + use_gn = cfg.MODEL.FCOS.USE_GN + use_bn = cfg.MODEL.FCOS.USE_BN + use_dcn_in_tower = cfg.MODEL.FCOS.USE_DFCONV + self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES + self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS + self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG + + cls_tower = [] + bbox_tower = [] + for i in range(cfg.MODEL.FCOS.NUM_CONVS): + if use_dcn_in_tower and \ + i == cfg.MODEL.FCOS.NUM_CONVS - 1: + conv_func = DFConv2d + else: + conv_func = nn.Conv2d + + cls_tower.append( + conv_func( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True + ) + ) + if use_gn: + cls_tower.append(nn.GroupNorm(32, in_channels)) + if use_bn: + cls_tower.append(nn.BatchNorm2d(in_channels)) + cls_tower.append(nn.ReLU()) + + bbox_tower.append( + conv_func( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True + ) + ) + if use_gn: + bbox_tower.append(nn.GroupNorm(32, in_channels)) + if use_bn: + bbox_tower.append(nn.BatchNorm2d(in_channels)) + bbox_tower.append(nn.ReLU()) + + self.add_module('cls_tower', nn.Sequential(*cls_tower)) + self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) + self.cls_logits = nn.Conv2d( + in_channels, num_classes, kernel_size=3, stride=1, + padding=1 + ) + self.bbox_pred = nn.Conv2d( + in_channels, 4, kernel_size=3, stride=1, + padding=1 + ) + self.centerness = nn.Conv2d( + in_channels, 1, kernel_size=3, stride=1, + padding=1 + ) + + # initialization + for modules in [self.cls_tower, self.bbox_tower, + self.cls_logits, self.bbox_pred, + self.centerness]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + # initialize the bias for focal loss + prior_prob = cfg.MODEL.FCOS.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) + + def forward(self, x): + logits = [] + bbox_reg = [] + centerness = [] + for l, feature in enumerate(x): + cls_tower = self.cls_tower(feature) + box_tower = self.bbox_tower(feature) + + logits.append(self.cls_logits(cls_tower)) + if self.centerness_on_reg: + centerness.append(self.centerness(box_tower)) + else: + centerness.append(self.centerness(cls_tower)) + + bbox_pred = self.scales[l](self.bbox_pred(box_tower)) + if self.norm_reg_targets: + bbox_pred = F.relu(bbox_pred) + if self.training: + bbox_reg.append(bbox_pred) + else: + bbox_reg.append(bbox_pred * self.fpn_strides[l]) + else: + bbox_reg.append(torch.exp(bbox_pred)) + return logits, bbox_reg, centerness + + +class FCOSModule(torch.nn.Module): + """ + Module for FCOS computation. Takes feature maps from the backbone and + FCOS outputs and losses. Only Test on FPN now. + """ + + def __init__(self, cfg): + super(FCOSModule, self).__init__() + + head = FCOSHead(cfg) + + box_selector_train = make_fcos_postprocessor(cfg, is_train=True) + box_selector_test = make_fcos_postprocessor(cfg, is_train=False) + + loss_evaluator = make_fcos_loss_evaluator(cfg) + + self.cfg = cfg + self.head = head + self.box_selector_train = box_selector_train + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES + if not cfg.MODEL.RPN_ONLY: + self.anchor_generator = make_center_anchor_generator(cfg) + + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (list[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (list[BoxList): ground-truth boxes present in the image (optional) + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + box_cls, box_regression, centerness = self.head(features) + locations = self.compute_locations(features) + if self.training and targets is not None: + return self._forward_train( + locations, box_cls, box_regression, + centerness, targets, images.image_sizes + ) + else: + return self._forward_test( + locations, box_cls, box_regression, + centerness, images.image_sizes + ) + + def _forward_train(self, locations, box_cls, box_regression, centerness, targets, image_sizes=None): + loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator( + locations, box_cls, box_regression, centerness, targets + ) + losses = { + "loss_cls": loss_box_cls, + "loss_reg": loss_box_reg, + "loss_centerness": loss_centerness + } + if self.cfg.MODEL.RPN_ONLY: + return None, losses + else: + boxes = self.box_selector_train( + locations, box_cls, box_regression, + centerness, image_sizes + ) + proposals = self.anchor_generator(boxes, image_sizes, centerness) + return proposals, losses + + def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes): + boxes = self.box_selector_test( + locations, box_cls, box_regression, + centerness, image_sizes + ) + if not self.cfg.MODEL.RPN_ONLY: + boxes = self.anchor_generator(boxes, image_sizes, centerness) + return boxes, {} + + def compute_locations(self, features): + locations = [] + for level, feature in enumerate(features): + h, w = feature.size()[-2:] + locations_per_level = self.compute_locations_per_level( + h, w, self.fpn_strides[level], + feature.device + ) + locations.append(locations_per_level) + return locations + + def compute_locations_per_level(self, h, w, stride, device): + shifts_x = torch.arange( + 0, w * stride, step=stride, + dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, h * stride, step=stride, + dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + return locations + + + + diff --git a/maskrcnn_benchmark/modeling/rpn/inference.py b/maskrcnn_benchmark/modeling/rpn/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4b1430e8add3ca92dace5ebf5bc8ba4261729c --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/inference.py @@ -0,0 +1,850 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging + +import torch + +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.structures.bounding_box import BoxList, _onnx_clip_boxes_to_image +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms +from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes + +from ..utils import permute_and_flatten +import pdb + +class RPNPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RPN boxes, before feeding the + proposals to the heads + """ + + def __init__( + self, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + box_coder=None, + fpn_post_nms_top_n=None, + onnx=False + ): + """ + Arguments: + pre_nms_top_n (int) + post_nms_top_n (int) + nms_thresh (float) + min_size (int) + box_coder (BoxCoder) + fpn_post_nms_top_n (int) + """ + super(RPNPostProcessor, self).__init__() + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + self.onnx = onnx + + if box_coder is None: + box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self.box_coder = box_coder + + if fpn_post_nms_top_n is None: + fpn_post_nms_top_n = post_nms_top_n + self.fpn_post_nms_top_n = fpn_post_nms_top_n + + def add_gt_proposals(self, proposals, targets): + """ + Arguments: + proposals: list[BoxList] + targets: list[BoxList] + """ + # Get the device we're operating on + device = proposals[0].bbox.device + + gt_boxes = [target.copy_with_fields([]) for target in targets] + + # later cat of bbox requires all fields to be present for all bbox + # so we need to add a dummy for objectness that's missing + for gt_box in gt_boxes: + gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) + + proposals = [ + cat_boxlist((proposal, gt_box)) + for proposal, gt_box in zip(proposals, gt_boxes) + ] + + return proposals + + def forward_for_single_feature_map(self, anchors, objectness, box_regression): + """ + Arguments: + anchors: list[BoxList] + objectness: tensor of size N, A, H, W + box_regression: tensor of size N, A * 4, H, W + """ + device = objectness.device + N, A, H, W = objectness.shape + + # put in the same format as anchors + objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1) + objectness = objectness.sigmoid() + box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2) + box_regression = box_regression.reshape(N, -1, 4) + + num_anchors = A * H * W + + pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) + objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True) + + batch_idx = torch.arange(N, device=device)[:, None] + box_regression = box_regression[batch_idx, topk_idx] + + image_shapes = [box.size for box in anchors] + concat_anchors = torch.cat([a.bbox for a in anchors], dim=0) + concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx] + + proposals = self.box_coder.decode( + box_regression.view(-1, 4), concat_anchors.view(-1, 4) + ) + + proposals = proposals.view(N, -1, 4) + + result = [] + for proposal, score, im_shape in zip(proposals, objectness, image_shapes): + if self.onnx: + proposal = _onnx_clip_boxes_to_image(proposal, im_shape) + boxlist = BoxList(proposal, im_shape, mode="xyxy") + else: + boxlist = BoxList(proposal, im_shape, mode="xyxy") + boxlist = boxlist.clip_to_image(remove_empty=False) + + boxlist.add_field("objectness", score) + boxlist = remove_small_boxes(boxlist, self.min_size) + boxlist = boxlist_nms( + boxlist, + self.nms_thresh, + max_proposals=self.post_nms_top_n, + score_field="objectness", + ) + result.append(boxlist) + return result + + def forward(self, anchors, objectness, box_regression, targets=None): + """ + Arguments: + anchors: list[list[BoxList]] + objectness: list[tensor] + box_regression: list[tensor] + + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + num_levels = len(objectness) + anchors = list(zip(*anchors)) + for a, o, b in zip(anchors, objectness, box_regression): + sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + + if num_levels > 1: + boxlists = self.select_over_all_levels(boxlists) + + # append ground-truth bboxes to proposals + if self.training and targets is not None: + boxlists = self.add_gt_proposals(boxlists, targets) + + return boxlists + + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + # different behavior during training and during testing: + # during training, post_nms_top_n is over *all* the proposals combined, while + # during testing, it is over the proposals for each image + # TODO resolve this difference and make it consistent. It should be per image, + # and not per batch + if self.training: + objectness = torch.cat( + [boxlist.get_field("objectness") for boxlist in boxlists], dim=0 + ) + box_sizes = [len(boxlist) for boxlist in boxlists] + post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) + inds_mask = torch.zeros_like(objectness, dtype=torch.bool) + inds_mask[inds_sorted] = 1 + inds_mask = inds_mask.split(box_sizes) + for i in range(num_images): + boxlists[i] = boxlists[i][inds_mask[i]] + else: + for i in range(num_images): + objectness = boxlists[i].get_field("objectness") + post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) + _, inds_sorted = torch.topk( + objectness, post_nms_top_n, dim=0, sorted=True + ) + boxlists[i] = boxlists[i][inds_sorted] + return boxlists + + +def make_rpn_postprocessor(config, rpn_box_coder, is_train): + fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN + if not is_train: + fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST + + pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN + post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN + if not is_train: + pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST + post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST + nms_thresh = config.MODEL.RPN.NMS_THRESH + min_size = config.MODEL.RPN.MIN_SIZE + onnx = config.MODEL.ONNX + box_selector = RPNPostProcessor( + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=nms_thresh, + min_size=min_size, + box_coder=rpn_box_coder, + fpn_post_nms_top_n=fpn_post_nms_top_n, + onnx=onnx + ) + return box_selector + + +class RetinaPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RetinaNet boxes. + This is only used in the testing. + """ + + def __init__( + self, + pre_nms_thresh, + pre_nms_top_n, + nms_thresh, + fpn_post_nms_top_n, + min_size, + num_classes, + box_coder=None, + ): + """ + Arguments: + pre_nms_thresh (float) + pre_nms_top_n (int) + nms_thresh (float) + fpn_post_nms_top_n (int) + min_size (int) + num_classes (int) + box_coder (BoxCoder) + """ + super(RetinaPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + + if box_coder is None: + box_coder = BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = box_coder + + def forward_for_single_feature_map(self, anchors, box_cls, box_regression): + """ + Arguments: + anchors: list[BoxList] + box_cls: tensor of size N, A * C, H, W + box_regression: tensor of size N, A * 4, H, W + """ + device = box_cls.device + N, _, H, W = box_cls.shape + A = box_regression.size(1) // 4 + C = box_cls.size(1) // A + + # put in the same format as anchors + box_cls = permute_and_flatten(box_cls, N, A, C, H, W) + box_cls = box_cls.sigmoid() + + box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) + box_regression = box_regression.reshape(N, -1, 4) + + num_anchors = A * H * W + + candidate_inds = box_cls > self.pre_nms_thresh + + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + results = [] + for per_box_cls, per_box_regression, per_pre_nms_top_n, \ + per_candidate_inds, per_anchors in zip( + box_cls, + box_regression, + pre_nms_top_n, + candidate_inds, + anchors): + # Sort and select TopN + # TODO most of this can be made out of the loop for + # all images. + # TODO:Yang: Not easy to do. Because the numbers of detections are + # different in each image. Therefore, this part needs to be done + # per image. + per_box_cls = per_box_cls[per_candidate_inds] + + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + + per_candidate_nonzeros = \ + per_candidate_inds.nonzero()[top_k_indices, :] + + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + per_class += 1 + + detections = self.box_coder.decode( + per_box_regression[per_box_loc, :].view(-1, 4), + per_anchors.bbox[per_box_loc, :].view(-1, 4) + ) + + boxlist = BoxList(detections, per_anchors.size, mode="xyxy") + boxlist.add_field("labels", per_class) + boxlist.add_field("scores", per_box_cls) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + results.append(boxlist) + + return results + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + scores = boxlists[i].get_field("scores") + labels = boxlists[i].get_field("labels") + boxes = boxlists[i].bbox + boxlist = boxlists[i] + result = [] + # skip the background + for j in range(1, self.num_classes): + inds = (labels == j).nonzero().view(-1) + + scores_j = scores[inds] + boxes_j = boxes[inds, :].view(-1, 4) + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + boxlist_for_class = boxlist_nms( + boxlist_for_class, self.nms_thresh, + score_field="scores" + ) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", torch.full((num_labels,), j, + dtype=torch.int64, + device=scores.device) + ) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.fpn_post_nms_top_n > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), + number_of_detections - self.fpn_post_nms_top_n + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + def forward(self, anchors, objectness, box_regression, targets=None): + """ + Arguments: + anchors: list[list[BoxList]] + objectness: list[tensor] + box_regression: list[tensor] + + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + anchors = list(zip(*anchors)) + for a, o, b in zip(anchors, objectness, box_regression): + sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + +def make_retina_postprocessor(config, rpn_box_coder, is_train): + pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH + pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N + nms_thresh = config.MODEL.RETINANET.NMS_TH + fpn_post_nms_top_n = config.MODEL.RETINANET.DETECTIONS_PER_IMG + min_size = 0 + + box_selector = RetinaPostProcessor( + pre_nms_thresh=pre_nms_thresh, + pre_nms_top_n=pre_nms_top_n, + nms_thresh=nms_thresh, + fpn_post_nms_top_n=fpn_post_nms_top_n, + min_size=min_size, + num_classes=config.MODEL.RETINANET.NUM_CLASSES, + box_coder=rpn_box_coder, + ) + + return box_selector + + +class FCOSPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RetinaNet boxes. + This is only used in the testing. + """ + + def __init__( + self, + pre_nms_thresh, + pre_nms_top_n, + nms_thresh, + fpn_post_nms_top_n, + min_size, + num_classes, + bbox_aug_enabled=False + ): + """ + Arguments: + pre_nms_thresh (float) + pre_nms_top_n (int) + nms_thresh (float) + fpn_post_nms_top_n (int) + min_size (int) + num_classes (int) + box_coder (BoxCoder) + """ + super(FCOSPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + self.bbox_aug_enabled = bbox_aug_enabled + + def forward_for_single_feature_map( + self, locations, box_cls, + box_regression, centerness, + image_sizes): + """ + Arguments: + anchors: list[BoxList] + box_cls: tensor of size N, A * C, H, W + box_regression: tensor of size N, A * 4, H, W + """ + N, C, H, W = box_cls.shape + + # put in the same format as locations + box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) + box_cls = box_cls.reshape(N, -1, C).sigmoid() + box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) + box_regression = box_regression.reshape(N, -1, 4) + centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) + centerness = centerness.reshape(N, -1).sigmoid() + + candidate_inds = box_cls > self.pre_nms_thresh + pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + # multiply the classification scores with centerness scores + box_cls = box_cls * centerness[:, :, None] + + results = [] + for i in range(N): + per_box_cls = box_cls[i] + per_candidate_inds = candidate_inds[i] + per_box_cls = per_box_cls[per_candidate_inds] + + per_candidate_nonzeros = per_candidate_inds.nonzero() + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + 1 + + per_box_regression = box_regression[i] + per_box_regression = per_box_regression[per_box_loc] + per_locations = locations[per_box_loc] + + per_pre_nms_top_n = pre_nms_top_n[i] + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_locations = per_locations[top_k_indices] + + detections = torch.stack([ + per_locations[:, 0] - per_box_regression[:, 0], + per_locations[:, 1] - per_box_regression[:, 1], + per_locations[:, 0] + per_box_regression[:, 2], + per_locations[:, 1] + per_box_regression[:, 3], + ], dim=1) + + h, w = image_sizes[i] + boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") + boxlist.add_field('centers', per_locations) + boxlist.add_field("labels", per_class) + boxlist.add_field("scores", torch.sqrt(per_box_cls)) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + results.append(boxlist) + + return results + + def forward(self, locations, box_cls, box_regression, centerness, image_sizes): + """ + Arguments: + anchors: list[list[BoxList]] + box_cls: list[tensor] + box_regression: list[tensor] + image_sizes: list[(h, w)] + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): + sampled_boxes.append( + self.forward_for_single_feature_map( + l, o, b, c, image_sizes + ) + ) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + if not self.bbox_aug_enabled: + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + # multiclass nms + result = boxlist_ml_nms(boxlists[i], self.nms_thresh) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.fpn_post_nms_top_n > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), + number_of_detections - self.fpn_post_nms_top_n + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + +def make_fcos_postprocessor(config, is_train=False): + pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH + if is_train: + pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH_TRAIN + pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N + fpn_post_nms_top_n = config.MODEL.FCOS.DETECTIONS_PER_IMG + if is_train: + pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN + fpn_post_nms_top_n = config.MODEL.FCOS.POST_NMS_TOP_N_TRAIN + nms_thresh = config.MODEL.FCOS.NMS_TH + + box_selector = FCOSPostProcessor( + pre_nms_thresh=pre_nms_thresh, + pre_nms_top_n=pre_nms_top_n, + nms_thresh=nms_thresh, + fpn_post_nms_top_n=fpn_post_nms_top_n, + min_size=0, + num_classes=config.MODEL.FCOS.NUM_CLASSES, + ) + + return box_selector + + +class ATSSPostProcessor(torch.nn.Module): + def __init__( + self, + pre_nms_thresh, + pre_nms_top_n, + nms_thresh, + fpn_post_nms_top_n, + min_size, + num_classes, + box_coder, + bbox_aug_enabled=False, + bbox_aug_vote=False, + score_agg='MEAN', + mdetr_style_aggregate_class_num=-1 + ): + super(ATSSPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + self.bbox_aug_enabled = bbox_aug_enabled + self.box_coder = box_coder + self.bbox_aug_vote = bbox_aug_vote + self.score_agg = score_agg + self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num + + def forward_for_single_feature_map(self, box_regression, centerness, anchors, + box_cls=None, + token_logits=None, + dot_product_logits=None, + positive_map=None, + ): + + N, _, H, W = box_regression.shape + + A = box_regression.size(1) // 4 + + if box_cls is not None: + C = box_cls.size(1) // A + + if token_logits is not None: + T = token_logits.size(1) // A + + # put in the same format as anchors + if box_cls is not None: + #print('Classification.') + box_cls = permute_and_flatten(box_cls, N, A, C, H, W) + box_cls = box_cls.sigmoid() + + # binary focal loss version + if token_logits is not None: + #print('Token.') + token_logits = permute_and_flatten(token_logits, N, A, T, H, W) + token_logits = token_logits.sigmoid() + # turn back to original classes + scores = convert_grounding_to_od_logits(logits=token_logits, box_cls=box_cls, positive_map=positive_map, + score_agg=self.score_agg) + box_cls = scores + + # binary dot product focal version + if dot_product_logits is not None: + #print('Dot Product.') + dot_product_logits = dot_product_logits.sigmoid() + if self.mdetr_style_aggregate_class_num != -1: + scores = convert_grounding_to_od_logits_v2( + logits=dot_product_logits, + num_class=self.mdetr_style_aggregate_class_num, + positive_map=positive_map, + score_agg=self.score_agg, + disable_minus_one=False) + else: + scores = convert_grounding_to_od_logits(logits=dot_product_logits, box_cls=box_cls, + positive_map=positive_map, + score_agg=self.score_agg) + box_cls = scores + + box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) + box_regression = box_regression.reshape(N, -1, 4) + + candidate_inds = box_cls > self.pre_nms_thresh + pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + centerness = permute_and_flatten(centerness, N, A, 1, H, W) + centerness = centerness.reshape(N, -1).sigmoid() + + # multiply the classification scores with centerness scores + + box_cls = box_cls * centerness[:, :, None] + + results = [] + + for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors \ + in zip(box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors): + per_box_cls = per_box_cls[per_candidate_inds] + + per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) + + per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :] + + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + 1 + + # print(per_class) + + detections = self.box_coder.decode( + per_box_regression[per_box_loc, :].view(-1, 4), + per_anchors.bbox[per_box_loc, :].view(-1, 4) + ) + + boxlist = BoxList(detections, per_anchors.size, mode="xyxy") + boxlist.add_field("labels", per_class) + boxlist.add_field("scores", torch.sqrt(per_box_cls)) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + results.append(boxlist) + + return results + + def forward(self, box_regression, centerness, anchors, + box_cls=None, + token_logits=None, + dot_product_logits=None, + positive_map=None, + ): + sampled_boxes = [] + anchors = list(zip(*anchors)) + for idx, (b, c, a) in enumerate(zip(box_regression, centerness, anchors)): + o = None + t = None + d = None + if box_cls is not None: + o = box_cls[idx] + if token_logits is not None: + t = token_logits[idx] + if dot_product_logits is not None: + d = dot_product_logits[idx] + + sampled_boxes.append( + self.forward_for_single_feature_map(b, c, a, o, t, d, positive_map) + ) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + if not (self.bbox_aug_enabled and not self.bbox_aug_vote): + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + # multiclass nms + result = boxlist_ml_nms(boxlists[i], self.nms_thresh) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.fpn_post_nms_top_n > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + # TODO: confirm with Pengchuan and Xiyang, torch.kthvalue is not implemented for 'Half' + # cls_scores.cpu(), + cls_scores.cpu().float(), + number_of_detections - self.fpn_post_nms_top_n + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + +def convert_grounding_to_od_logits(logits, box_cls, positive_map, score_agg=None): + scores = torch.zeros(logits.shape[0], logits.shape[1], box_cls.shape[2]).to(logits.device) + # 256 -> 80, average for each class + if positive_map is not None: + # score aggregation method + if score_agg == "MEAN": + for label_j in positive_map: + scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].mean(-1) + elif score_agg == "MAX": + # torch.max() returns (values, indices) + for label_j in positive_map: + scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[ + 0] + elif score_agg == "ONEHOT": + # one hot + scores = logits[:, :, :len(positive_map)] + else: + raise NotImplementedError + return scores + + +def convert_grounding_to_od_logits_v2(logits, num_class, positive_map, score_agg=None, disable_minus_one = True): + + scores = torch.zeros(logits.shape[0], logits.shape[1], num_class).to(logits.device) + # 256 -> 80, average for each class + if positive_map is not None: + # score aggregation method + if score_agg == "MEAN": + for label_j in positive_map: + locations_label_j = positive_map[label_j] + if isinstance(locations_label_j, int): + locations_label_j = [locations_label_j] + scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(locations_label_j)].mean(-1) + elif score_agg == "POWER": + for label_j in positive_map: + locations_label_j = positive_map[label_j] + if isinstance(locations_label_j, int): + locations_label_j = [locations_label_j] + + probability = torch.prod(logits[:, :, torch.LongTensor(locations_label_j)], dim=-1).squeeze(-1) + probability = torch.pow(probability, 1/len(locations_label_j)) + scores[:, :, label_j if disable_minus_one else label_j - 1] = probability + elif score_agg == "MAX": + # torch.max() returns (values, indices) + for label_j in positive_map: + scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[ + 0] + elif score_agg == "ONEHOT": + # one hot + scores = logits[:, :, :len(positive_map)] + else: + raise NotImplementedError + return scores + +def make_atss_postprocessor(config, box_coder, is_train=False): + pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH + if is_train: + pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH_TRAIN + pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N + fpn_post_nms_top_n = config.MODEL.ATSS.DETECTIONS_PER_IMG + if is_train: + pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN + fpn_post_nms_top_n = config.MODEL.ATSS.POST_NMS_TOP_N_TRAIN + nms_thresh = config.MODEL.ATSS.NMS_TH + score_agg = config.MODEL.DYHEAD.SCORE_AGG + + box_selector = ATSSPostProcessor( + pre_nms_thresh=pre_nms_thresh, + pre_nms_top_n=pre_nms_top_n, + nms_thresh=nms_thresh, + fpn_post_nms_top_n=fpn_post_nms_top_n, + min_size=0, + num_classes=config.MODEL.ATSS.NUM_CLASSES, + box_coder=box_coder, + bbox_aug_enabled=config.TEST.USE_MULTISCALE, + score_agg=score_agg, + mdetr_style_aggregate_class_num=config.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM + ) + + return box_selector diff --git a/maskrcnn_benchmark/modeling/rpn/loss.py b/maskrcnn_benchmark/modeling/rpn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..097fdc5ffd0a864110a26f6cbb0ee54b6af38f2b --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/loss.py @@ -0,0 +1,1251 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +This file contains specific functions for computing losses on the RPN +file +""" + +import torch +from torch import nn +from torch.nn import functional as F + +from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler +from ..utils import cat, concat_box_prediction_layers + +from maskrcnn_benchmark.layers import smooth_l1_loss +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.layers import SigmoidFocalLoss, IOULoss, TokenSigmoidFocalLoss +from maskrcnn_benchmark.utils.comm import get_world_size, reduce_sum +from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd +from maskrcnn_benchmark.utils.shallow_contrastive_loss_helper import * + +from transformers import AutoTokenizer + +INF = 1e8 + + +class RPNLossComputation(object): + """ + This class computes the RPN loss. + """ + + def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): + """ + Arguments: + proposal_matcher (Matcher) + fg_bg_sampler (BalancedPositiveNegativeSampler) + box_coder (BoxCoder) + """ + # self.target_preparator = target_preparator + self.proposal_matcher = proposal_matcher + self.fg_bg_sampler = fg_bg_sampler + self.box_coder = box_coder + + def match_targets_to_anchors(self, anchor, target): + match_quality_matrix = boxlist_iou(target, anchor) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # RPN doesn't need any fields from target + # for creating the labels, so clear them all + target = target.copy_with_fields([]) + # get the targets corresponding GT for each anchor + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + + if len(target): + matched_targets = target[matched_idxs.clamp(min=0)] + else: + matched_targets = target + + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, anchors, targets): + labels = [] + regression_targets = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + matched_targets = self.match_targets_to_anchors( + anchors_per_image, targets_per_image + ) + + matched_idxs = matched_targets.get_field("matched_idxs") + labels_per_image = matched_idxs >= 0 + labels_per_image = labels_per_image.to(dtype=torch.float32) + # discard anchors that go out of the boundaries of the image + labels_per_image[~anchors_per_image.get_field("visibility")] = -1 + + # discard indices that are between thresholds + inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS + labels_per_image[inds_to_discard] = -1 + + # compute regression targets + if not matched_targets.bbox.shape[0]: + zeros = torch.zeros_like(labels_per_image) + regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1) + else: + regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, anchors_per_image.bbox) + + labels.append(labels_per_image) + regression_targets.append(regression_targets_per_image) + + return labels, regression_targets + + @custom_fwd(cast_inputs=torch.float32) + def __call__(self, anchors, objectness, box_regression, targets): + """ + Arguments: + anchors (list[BoxList]) + objectness (list[Tensor]) + box_regression (list[Tensor]) + targets (list[BoxList]) + + Returns: + objectness_loss (Tensor) + box_loss (Tensor + """ + anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] + labels, regression_targets = self.prepare_targets(anchors, targets) + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) + sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) + sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) + + sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) + + objectness_flattened = [] + box_regression_flattened = [] + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for objectness_per_level, box_regression_per_level in zip( + objectness, box_regression + ): + N, A, H, W = objectness_per_level.shape + objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape( + N, -1 + ) + box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W) + box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2) + box_regression_per_level = box_regression_per_level.reshape(N, -1, 4) + objectness_flattened.append(objectness_per_level) + box_regression_flattened.append(box_regression_per_level) + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + objectness = cat(objectness_flattened, dim=1).reshape(-1) + box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + box_loss = smooth_l1_loss( + box_regression[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1.0 / 9, + size_average=False, + ) / (sampled_inds.numel()) + + objectness_loss = F.binary_cross_entropy_with_logits( + objectness[sampled_inds], labels[sampled_inds] + ) + + return objectness_loss, box_loss + + +class FocalLossComputation(object): + """ + This class computes the RetinaNet loss. + """ + + def __init__(self, proposal_matcher, box_coder, + generate_labels_func, + sigmoid_focal_loss, + bbox_reg_beta=0.11, + regress_norm=1.0): + """ + Arguments: + proposal_matcher (Matcher) + box_coder (BoxCoder) + """ + self.proposal_matcher = proposal_matcher + self.box_coder = box_coder + self.box_cls_loss_func = sigmoid_focal_loss + self.bbox_reg_beta = bbox_reg_beta + self.copied_fields = ['labels'] + self.generate_labels_func = generate_labels_func + self.discard_cases = ['between_thresholds'] + self.regress_norm = regress_norm + + def match_targets_to_anchors(self, anchor, target, copied_fields=[]): + match_quality_matrix = boxlist_iou(target, anchor) + matched_idxs = self.proposal_matcher(match_quality_matrix) + # RPN doesn't need any fields from target + # for creating the labels, so clear them all + target = target.copy_with_fields(copied_fields) + # get the targets corresponding GT for each anchor + # NB: need to clamp the indices because we can have a single + # GT in the image, and matched_idxs can be -2, which goes + # out of bounds + matched_targets = target[matched_idxs.clamp(min=0)] + matched_targets.add_field("matched_idxs", matched_idxs) + return matched_targets + + def prepare_targets(self, anchors, targets): + labels = [] + regression_targets = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + matched_targets = self.match_targets_to_anchors( + anchors_per_image, targets_per_image, self.copied_fields + ) + + matched_idxs = matched_targets.get_field("matched_idxs") + labels_per_image = self.generate_labels_func(matched_targets) + labels_per_image = labels_per_image.to(dtype=torch.float32) + + # Background (negative examples) + bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD + labels_per_image[bg_indices] = 0 + + # discard anchors that go out of the boundaries of the image + if "not_visibility" in self.discard_cases: + labels_per_image[~anchors_per_image.get_field("visibility")] = -1 + + # discard indices that are between thresholds + if "between_thresholds" in self.discard_cases: + inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS + labels_per_image[inds_to_discard] = -1 + + # compute regression targets + regression_targets_per_image = self.box_coder.encode( + matched_targets.bbox, anchors_per_image.bbox + ) + + labels.append(labels_per_image) + regression_targets.append(regression_targets_per_image) + + return labels, regression_targets + + @custom_fwd(cast_inputs=torch.float32) + def __call__(self, anchors, box_cls, box_regression, targets): + """ + Arguments: + anchors (list[BoxList]) + box_cls (list[Tensor]) + box_regression (list[Tensor]) + targets (list[BoxList]) + + Returns: + retinanet_cls_loss (Tensor) + retinanet_regression_loss (Tensor + """ + anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] + labels, regression_targets = self.prepare_targets(anchors, targets) + + N = len(labels) + box_cls, box_regression = \ + concat_box_prediction_layers(box_cls, box_regression) + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + pos_inds = torch.nonzero(labels > 0).squeeze(1) + + retinanet_regression_loss = smooth_l1_loss( + box_regression[pos_inds], + regression_targets[pos_inds], + beta=self.bbox_reg_beta, + size_average=False, + ) / (max(1, pos_inds.numel() * self.regress_norm)) + + labels = labels.int() + + retinanet_cls_loss = self.box_cls_loss_func( + box_cls, + labels + ) / (pos_inds.numel() + N) + + return retinanet_cls_loss, retinanet_regression_loss + + +class FCOSLossComputation(object): + """ + This class computes the FCOS losses. + """ + + def __init__(self, cfg): + self.cls_loss_func = SigmoidFocalLoss( + cfg.MODEL.FOCAL.LOSS_GAMMA, + cfg.MODEL.FOCAL.LOSS_ALPHA + ) + self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES + self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS + self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE + self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS + self.use_gt_center = cfg.MODEL.FCOS.USE_GT_CENTER + + # we make use of IOU Loss for bounding boxes regression, + # but we found that L1 in log scale can yield a similar performance + self.box_reg_loss_func = IOULoss(self.iou_loss_type) + self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum") + + def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0): + ''' + This code is from + https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/ + maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42 + ''' + num_gts = gt.shape[0] + K = len(gt_xs) + gt = gt[None].expand(K, num_gts, 4) + center_x = (gt[..., 0] + gt[..., 2]) / 2 + center_y = (gt[..., 1] + gt[..., 3]) / 2 + center_gt = gt.new_zeros(gt.shape) + # no gt + if center_x[..., 0].sum() == 0: + return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8) + beg = 0 + for level, n_p in enumerate(num_points_per): + end = beg + n_p + stride = strides[level] * radius + xmin = center_x[beg:end] - stride + ymin = center_y[beg:end] - stride + xmax = center_x[beg:end] + stride + ymax = center_y[beg:end] + stride + # limit sample region in gt + center_gt[beg:end, :, 0] = torch.where( + xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0] + ) + center_gt[beg:end, :, 1] = torch.where( + ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1] + ) + center_gt[beg:end, :, 2] = torch.where( + xmax > gt[beg:end, :, 2], + gt[beg:end, :, 2], xmax + ) + center_gt[beg:end, :, 3] = torch.where( + ymax > gt[beg:end, :, 3], + gt[beg:end, :, 3], ymax + ) + beg = end + left = gt_xs[:, None] - center_gt[..., 0] + right = center_gt[..., 2] - gt_xs[:, None] + top = gt_ys[:, None] - center_gt[..., 1] + bottom = center_gt[..., 3] - gt_ys[:, None] + center_bbox = torch.stack((left, top, right, bottom), -1) + inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 + return inside_gt_bbox_mask + + def prepare_targets(self, points, targets): + object_sizes_of_interest = [ + [-1, 64], + [64, 128], + [128, 256], + [256, 512], + [512, INF], + ] + expanded_object_sizes_of_interest = [] + for l, points_per_level in enumerate(points): + object_sizes_of_interest_per_level = \ + points_per_level.new_tensor(object_sizes_of_interest[l]) + expanded_object_sizes_of_interest.append( + object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1) + ) + + expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0) + num_points_per_level = [len(points_per_level) for points_per_level in points] + self.num_points_per_level = num_points_per_level + points_all_level = torch.cat(points, dim=0) + labels, reg_targets = self.compute_targets_for_locations( + points_all_level, targets, expanded_object_sizes_of_interest + ) + + for i in range(len(labels)): + labels[i] = torch.split(labels[i], num_points_per_level, dim=0) + reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0) + + labels_level_first = [] + reg_targets_level_first = [] + for level in range(len(points)): + labels_level_first.append( + torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0) + ) + + reg_targets_per_level = torch.cat([ + reg_targets_per_im[level] + for reg_targets_per_im in reg_targets + ], dim=0) + + if self.norm_reg_targets: + reg_targets_per_level = reg_targets_per_level / self.fpn_strides[level] + reg_targets_level_first.append(reg_targets_per_level) + + return labels_level_first, reg_targets_level_first + + def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest): + labels = [] + reg_targets = [] + xs, ys = locations[:, 0], locations[:, 1] + + for im_i in range(len(targets)): + targets_per_im = targets[im_i] + assert targets_per_im.mode == "xyxy" + + if self.use_gt_center: + center = targets_per_im.get_field("cbox") + bboxes = center.bbox + area = center.area() + else: + bboxes = targets_per_im.bbox + area = targets_per_im.area() + labels_per_im = targets_per_im.get_field("labels") + + l = xs[:, None] - bboxes[:, 0][None] + t = ys[:, None] - bboxes[:, 1][None] + r = bboxes[:, 2][None] - xs[:, None] + b = bboxes[:, 3][None] - ys[:, None] + reg_targets_per_im = torch.stack([l, t, r, b], dim=2) + + if self.center_sampling_radius > 0: + is_in_boxes = self.get_sample_region( + bboxes, + self.fpn_strides, + self.num_points_per_level, + xs, ys, + radius=self.center_sampling_radius + ) + else: + # no center sampling, it will use all the locations within a ground-truth box + is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 + + max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] + # limit the regression range for each location + is_cared_in_the_level = \ + (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \ + (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]]) + + locations_to_gt_area = area[None].repeat(len(locations), 1) + locations_to_gt_area[is_in_boxes == 0] = INF + locations_to_gt_area[is_cared_in_the_level == 0] = INF + + # if there are still more than one objects for a location, + # we choose the one with minimal area + locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1) + + reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds] + labels_per_im = labels_per_im[locations_to_gt_inds] + labels_per_im[locations_to_min_area == INF] = 0 + + labels.append(labels_per_im) + reg_targets.append(reg_targets_per_im) + + return labels, reg_targets + + def compute_centerness_targets(self, reg_targets): + left_right = reg_targets[:, [0, 2]] + top_bottom = reg_targets[:, [1, 3]] + centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \ + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(centerness) + + @custom_fwd(cast_inputs=torch.float32) + def __call__(self, locations, box_cls, box_regression, centerness, targets): + """ + Arguments: + locations (list[BoxList]) + box_cls (list[Tensor]) + box_regression (list[Tensor]) + centerness (list[Tensor]) + targets (list[BoxList]) + + Returns: + cls_loss (Tensor) + reg_loss (Tensor) + centerness_loss (Tensor) + """ + N = box_cls[0].size(0) + num_classes = box_cls[0].size(1) + labels, reg_targets = self.prepare_targets(locations, targets) + + box_cls_flatten = [] + box_regression_flatten = [] + centerness_flatten = [] + labels_flatten = [] + reg_targets_flatten = [] + for l in range(len(labels)): + box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) + box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) + labels_flatten.append(labels[l].reshape(-1)) + reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) + centerness_flatten.append(centerness[l].reshape(-1)) + + box_cls_flatten = torch.cat(box_cls_flatten, dim=0) + box_regression_flatten = torch.cat(box_regression_flatten, dim=0) + centerness_flatten = torch.cat(centerness_flatten, dim=0) + labels_flatten = torch.cat(labels_flatten, dim=0) + reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) + + pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) + + box_regression_flatten = box_regression_flatten[pos_inds] + reg_targets_flatten = reg_targets_flatten[pos_inds] + centerness_flatten = centerness_flatten[pos_inds] + + cls_loss = self.cls_loss_func( + box_cls_flatten, + labels_flatten.int() + ) / max(pos_inds.numel(), 1.0) + + if pos_inds.numel() > 0: + centerness_targets = self.compute_centerness_targets(reg_targets_flatten) + + reg_loss = self.box_reg_loss_func( + box_regression_flatten, + reg_targets_flatten, + centerness_targets + ) / centerness_targets.sum() + centerness_loss = self.centerness_loss_func( + centerness_flatten, + centerness_targets + ) / max(pos_inds.numel(), 1.0) + else: + reg_loss = box_regression_flatten.sum() + centerness_loss = centerness_flatten.sum() + + return cls_loss, reg_loss, centerness_loss + + +# class ATSSLossComputation(object): +class ATSSLossComputation(torch.nn.Module): + + def __init__(self, cfg, box_coder): + super(ATSSLossComputation, self).__init__() + + self.cfg = cfg + self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA) + self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum") + self.matcher = Matcher(cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, True) + self.box_coder = box_coder + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + self.token_loss_func = TokenSigmoidFocalLoss(cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA, + cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA) + + self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE + + # self.tokenizer = AutoTokenizer.from_pretrained(self.lang) + if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": + from transformers import CLIPTokenizerFast + # self.tokenizer = build_tokenizer(self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) + if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + print("Reuse token 'ðŁĴij' (token_id = 49404) for mask token!") + self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True, mask_token='ðŁĴij') + else: + self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", + from_slow=True) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.lang) + + # if use shallow contrastive loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \ + or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: + assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS == False + channels = cfg.MODEL.DYHEAD.CHANNELS + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE + shallow_input_dim = channels * num_anchors + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS == False + shallow_input_dim = cfg.MODEL.SWINT.OUT_CHANNELS[-2] + + shallow_log_scale = self.cfg.MODEL.DYHEAD.SHALLOW_LOG_SCALE + shallow_contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM + # self.shallow_contrastive_projection_image = nn.Conv2d(channels, num_anchors * shallow_contrastive_hdim, + # kernel_size=1) + self.shallow_contrastive_projection_image = nn.Linear(shallow_input_dim, shallow_contrastive_hdim, + bias=True) + self.shallow_contrastive_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, + shallow_contrastive_hdim, bias=True) + self.shallow_log_scale = nn.Parameter(torch.Tensor([shallow_log_scale]), requires_grad=True) + + # (initialization) if use shallow contrastive loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: + for modules in [self.shallow_contrastive_projection_image, self.shallow_contrastive_projection_text]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + if isinstance(l, nn.Linear): + torch.nn.init.xavier_uniform_(l.weight) + l.bias.data.fill_(0) + + def NllSoftMaxLoss(self, logits, target): + loss_ce = -target * logits.log_softmax( + -1) # basically, only the those positives with positive target_sim will have losses + return loss_ce + + def ContrastiveAlignLoss(self, logits, positive_map): + positive_logits = -logits.masked_fill(~positive_map, 0) + negative_logits = logits # .masked_fill(positive_map, -1000000) + + boxes_with_pos = positive_map.any(2) + pos_term = positive_logits.sum(2) + neg_term = negative_logits.logsumexp(2) + + nb_pos = positive_map.sum(2) + 1e-6 + + box_to_token_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~boxes_with_pos, 0).sum() + + tokens_with_pos = positive_map.any(1) + pos_term = positive_logits.sum(1) + neg_term = negative_logits.logsumexp(1) + + nb_pos = positive_map.sum(1) + 1e-6 + + tokens_to_boxes_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~tokens_with_pos, 0).sum() + tot_loss = (box_to_token_loss + tokens_to_boxes_loss) / 2 + + return tot_loss + + def GIoULoss(self, pred, target, anchor, weight=None): + pred_boxes = self.box_coder.decode(pred.view(-1, 4), anchor.view(-1, 4)) + pred_x1 = pred_boxes[:, 0] + pred_y1 = pred_boxes[:, 1] + pred_x2 = pred_boxes[:, 2] + pred_y2 = pred_boxes[:, 3] + pred_x2 = torch.max(pred_x1, pred_x2) + pred_y2 = torch.max(pred_y1, pred_y2) + pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) + + gt_boxes = self.box_coder.decode(target.view(-1, 4), anchor.view(-1, 4)) + target_x1 = gt_boxes[:, 0] + target_y1 = gt_boxes[:, 1] + target_x2 = gt_boxes[:, 2] + target_y2 = gt_boxes[:, 3] + target_area = (target_x2 - target_x1) * (target_y2 - target_y1) + + x1_intersect = torch.max(pred_x1, target_x1) + y1_intersect = torch.max(pred_y1, target_y1) + x2_intersect = torch.min(pred_x2, target_x2) + y2_intersect = torch.min(pred_y2, target_y2) + area_intersect = torch.zeros(pred_x1.size()).to(pred) + mask = (y2_intersect > y1_intersect) * (x2_intersect > x1_intersect) + area_intersect[mask] = (x2_intersect[mask] - x1_intersect[mask]) * (y2_intersect[mask] - y1_intersect[mask]) + + x1_enclosing = torch.min(pred_x1, target_x1) + y1_enclosing = torch.min(pred_y1, target_y1) + x2_enclosing = torch.max(pred_x2, target_x2) + y2_enclosing = torch.max(pred_y2, target_y2) + area_enclosing = (x2_enclosing - x1_enclosing) * (y2_enclosing - y1_enclosing) + 1e-7 + + area_union = pred_area + target_area - area_intersect + 1e-7 + ious = area_intersect / area_union + gious = ious - (area_enclosing - area_union) / area_enclosing + + losses = 1 - gious + + if weight is not None and weight.sum() > 0: + return (losses * weight).sum() + else: + assert losses.numel() != 0 + return losses.sum() + + def prepare_targets(self, targets, anchors, tokenized=None, positive_map=None, proj_tokens=None): + cls_labels = [] + reg_targets = [] + token_labels = [] + map_labels = [] + + gold_box_od_labels = [] + od_label_of_tokens_labels = [] + positive_indices = [] + + offset = 0 + + for im_i in range(len(targets)): + targets_per_im = targets[im_i] + assert targets_per_im.mode == "xyxy" + # bboxes_per_im = targets_per_im.get_field("boxes") + bboxes_per_im = targets_per_im.bbox + labels_per_im = targets_per_im.get_field("labels") + num_gt = len(bboxes_per_im) + + if positive_map is not None: + token_per_im = positive_map[offset:offset + num_gt, :] + offset += num_gt + + # Recheck if the label matches with the positive map + # print(labels_per_im) + # print(token_per_im.nonzero()) + + # shallow contrastive + if "original_od_label" in targets_per_im.fields(): + gold_box_od_label = targets_per_im.get_field("original_od_label") + if "positive_map_for_od_labels" in targets_per_im.fields(): + od_label_of_token_per_im = targets_per_im.get_field("positive_map_for_od_labels") + + # print(gold_box_od_label) + # print(od_label_of_token_per_im) + + if positive_map is not None and proj_tokens is not None: + if "tokens_positive" in targets_per_im.fields(): + cur_tokens = targets_per_im.get_field("tokens_positive") + else: + cur_tokens = targets_per_im.get_field("tokens") + map = torch.zeros((len(cur_tokens), proj_tokens.shape[1]), dtype=torch.bool) + for j, tok_list in enumerate(cur_tokens): + for (beg, end) in tok_list: + beg_pos = tokenized.char_to_token(im_i, beg) + end_pos = tokenized.char_to_token(im_i, end - 1) + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(im_i, beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(im_i, beg + 2) + except: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(im_i, end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(im_i, end - 3) + except: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + map[j, beg_pos: end_pos + 1].fill_(True) + + anchors_per_im = cat_boxlist(anchors[im_i]) + + num_anchors_per_loc = len(self.cfg.MODEL.RPN.ASPECT_RATIOS) * self.cfg.MODEL.RPN.SCALES_PER_OCTAVE + num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]] + ious = boxlist_iou(anchors_per_im, targets_per_im) + + gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0 + gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0 + gt_points = torch.stack((gt_cx, gt_cy), dim=1) + + anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0 + anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0 + anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1) + + distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() + + # Selecting candidates based on the center distance between anchor box and object + candidate_idxs = [] + star_idx = 0 + for level, anchors_per_level in enumerate(anchors[im_i]): + end_idx = star_idx + num_anchors_per_level[level] + distances_per_level = distances[star_idx:end_idx, :] + topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level]) + _, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False) + candidate_idxs.append(topk_idxs_per_level + star_idx) + star_idx = end_idx + candidate_idxs = torch.cat(candidate_idxs, dim=0) + + # Using the sum of mean and standard deviation as the IoU threshold to select final positive samples + candidate_ious = ious[candidate_idxs, torch.arange(num_gt)] + iou_mean_per_gt = candidate_ious.mean(0) + iou_std_per_gt = candidate_ious.std(0) + iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt + is_pos = candidate_ious >= iou_thresh_per_gt[None, :] + + # Limiting the final positive samples’ center to object + anchor_num = anchors_cx_per_im.shape[0] + for ng in range(num_gt): + candidate_idxs[:, ng] += ng * anchor_num + e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1) + e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0] + t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1] + r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt) + b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected. + ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + ious_inf[index] = ious.t().contiguous().view(-1)[index] + ious_inf = ious_inf.view(num_gt, -1).t() + + anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1) + # get positive anchors index from ATSS + positive_index = [i[0].item() for i in torch.nonzero(anchors_to_gt_indexs)] + cls_labels_per_im = labels_per_im[anchors_to_gt_indexs] + cls_labels_per_im[anchors_to_gt_values == -INF] = 0 + + if positive_map is not None: + token_labels_per_im = token_per_im[anchors_to_gt_indexs] + unmatched_labels = torch.zeros(token_labels_per_im.shape[1], device=token_labels_per_im.device) + # TODO: temporarially disable the [NoObj] token logic, and only restrict to binary loss + unmatched_labels[-1] = 1 # token: none object - > 256 + token_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels + # move from cpu to gpu + token_labels_per_im = token_labels_per_im.to(cls_labels_per_im.device) + + # print(token_labels_per_im[anchors_to_gt_values == -INF].shape) + # print(cls_labels_per_im[anchors_to_gt_values != -INF][0]) + # print(token_labels_per_im[anchors_to_gt_values != -INF][0].nonzero()) + + if positive_map is not None and proj_tokens is not None: + map_labels_per_im = map[anchors_to_gt_indexs] + unmatched_labels = torch.zeros(map_labels_per_im.shape[1], dtype=torch.bool, + device=map_labels_per_im.device) # map: none False + map_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels + # move from cpu to gpu + map_labels_per_im = map_labels_per_im.to(cls_labels_per_im.device) + + # print(map_labels_per_im[anchors_to_gt_values == -INF].shape) + # print(map_labels_per_im[anchors_to_gt_values != -INF][0]) + + if positive_map is not None and proj_tokens is not None: + gold_box_od_label_per_im = gold_box_od_label[anchors_to_gt_indexs] + gold_box_od_label_per_im[anchors_to_gt_values == -INF] = -100 + # move from cpu to gpu + gold_box_od_label_per_im = gold_box_od_label_per_im.to(cls_labels_per_im.device) + + # print(gold_box_od_label_per_im[anchors_to_gt_values != -INF]) + + matched_gts = bboxes_per_im[anchors_to_gt_indexs] + + reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox) + cls_labels.append(cls_labels_per_im) + reg_targets.append(reg_targets_per_im) + + if positive_map is not None: + token_labels.append(token_labels_per_im) + + if positive_map is not None and proj_tokens is not None: + map_labels.append(map_labels_per_im) + gold_box_od_labels.append(gold_box_od_label_per_im) + od_label_of_tokens_labels.append(od_label_of_token_per_im) + positive_indices.append(positive_index) + + # print([len(x) for x in positive_indices]) + + return cls_labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices + + def compute_centerness_targets(self, reg_targets, anchors): + gts = self.box_coder.decode(reg_targets, anchors) + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + l = anchors_cx - gts[:, 0] + t = anchors_cy - gts[:, 1] + r = gts[:, 2] - anchors_cx + b = gts[:, 3] - anchors_cy + left_right = torch.stack([l, r], dim=1) + top_bottom = torch.stack([t, b], dim=1) + centerness = torch.sqrt((left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \ + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) + assert not torch.isnan(centerness).any() + return centerness + + @custom_fwd(cast_inputs=torch.float32) + def __call__(self, box_cls, box_regression, centerness, targets, anchors, + captions=None, + positive_map=None, + token_logits=None, + proj_tokens=None, + contrastive_logits=None, + dot_product_logits=None, + text_masks=None, + shallow_img_emb_feats=None + ): + + tokenized = None + if captions is not None: + # tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt") + if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": + tokenized = self.tokenizer.batch_encode_plus(captions, + max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, + padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", + return_tensors='pt', + truncation=True) + else: + tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt") + + labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices = self.prepare_targets(targets, anchors, + tokenized, + positive_map, + proj_tokens + ) + + N = len(labels) + + box_regression_flatten, box_cls_flatten, token_logits_stacked = concat_box_prediction_layers( + box_regression, + box_cls, + token_logits, + ) + + # contrastive logits + if positive_map is not None and contrastive_logits is not None: + contrastive_logits = torch.cat(contrastive_logits, dim=1) + + # dot product soft token logits + if dot_product_logits is not None: + dot_product_logits = torch.cat(dot_product_logits, dim=1) + + centerness_flatten = [ct.permute(0, 2, 3, 1).reshape(N, -1, 1) for ct in centerness] + centerness_flatten = torch.cat(centerness_flatten, dim=1).reshape(-1) + + labels_flatten = torch.cat(labels, dim=0) + reg_targets_flatten = torch.cat(reg_targets, dim=0) + anchors_flatten = torch.cat([cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors], dim=0) + + if positive_map is not None: + token_labels_stacked = torch.stack(token_labels, dim=0) + + if positive_map is not None and proj_tokens is not None: + positive_map_box_to_self_text = None + shallow_positive_map = None + bs = proj_tokens.shape[0] + device = proj_tokens.device + + # NOTE: 0. setup env + if dist.is_dist_avail_and_initialized(): + world_size = dist.get_world_size() + rank = torch.distributed.get_rank() + else: + world_size = 1 + rank = 0 + + if contrastive_logits is not None: + positive_map_box_to_self_text = torch.stack(map_labels, dim=0) + + if shallow_img_emb_feats is not None: + ''' + Ultimate: + N*B*(max_anchor_num) x N*B*T + Final Goal: + F = B x (max_anchor_num) x N*B*T + X: B x (max_anchor_num) od_labels : [0, 20, 30, ..] + Y: N*B*T: which denotes the od_label of every token + F[i,j] = A[i] == B[j] + ''' + with torch.no_grad(): + # NOTE: 1. get X (predicted_box_od_label), which the detection label of every predicted boxes + # predicted_box_od_label: B x A + + # check memory limitation: prevent # of positive >= # of max_positive + new_positive_indices = [] + # print([len(positive_index) for positive_index in positive_indices]) + for positive_index in positive_indices: + if len(positive_index) >= self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS: + import random + positive_index = sorted(random.sample(positive_index, + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS)) + new_positive_indices.append(positive_index) + # print([len(positive_index) for positive_index in positive_indices]) + + max_len = max([len(positive_index) for positive_index in new_positive_indices]) + max_anchor_num = max_len + + if world_size > 1: + num_anchors = torch.tensor(max_len, device=positive_map.device) + num_anchors_full = [torch.zeros_like(num_anchors) for _ in range(world_size)] + torch.distributed.all_gather(num_anchors_full, num_anchors) + max_anchor_num = max([anchor.item() for anchor in num_anchors_full]) + + new_negative_pad_indices = [] + # if not PAD_ZEROS, select random negative paddings + if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS: + for (positive_index, old_positive_index) in zip(new_positive_indices, positive_indices): + negative_index = [i for i in range(len(cat_boxlist(anchors[0]))) if i not in old_positive_index] + import random + negative_pad_index = sorted(random.sample(negative_index, + max_anchor_num - len(positive_index))) + new_negative_pad_indices.append(negative_pad_index) + + predicted_box_od_label = [] + for i in range(bs): + predicted_box_od_label.append( + pad_tensor_given_dim_length(gold_box_od_labels[i][new_positive_indices[i]], + dim=0, + length=max_anchor_num, + padding_value=-100, + batch_first=False + )) + predicted_box_od_label = torch.stack(predicted_box_od_label, dim=0) + + # if padding, need to create image masks to filter out the paddings + image_masks = None + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS: + image_masks = torch.zeros((bs, max_anchor_num), dtype=torch.long).to(text_masks.device) + for i in range(bs): + image_masks[i, :len(new_positive_indices[i])] = 1 + + # NOTE: 2. Get Y (od_label_of_tokens) + # od_label_of_tokens: N x B x T + od_label_of_tokens = torch.stack(od_label_of_tokens_labels, dim=0).long() + od_label_of_tokens = gather_tensors(od_label_of_tokens) + + # NOTE: 3. get F + # F: B*A x N*B*T + mapping_predicted_box_to_all_text = predicted_box_od_label.view(-1).unsqueeze( + 1) == od_label_of_tokens.view(-1).unsqueeze(0) + + # NOTE: 4. we still need to calculate the mapping between predicted box to its corresponding text's mapping + # positive_map_box_to_self_text: B x A x T, leave this for vanilla contrastive alignment loss + positive_map_box_to_self_text = [] + for i in range(bs): + positive_map_box_to_self_text.append( + pad_tensor_given_dim_length(map_labels[i][new_positive_indices[i]], + dim=0, + length=max_anchor_num, + padding_value=False, + batch_first=False + )) + positive_map_box_to_self_text = torch.stack(positive_map_box_to_self_text, dim=0) + + # change the corresponding place in our batch + for i in range(bs): + mapping_predicted_box_to_all_text[i * max_anchor_num: (i + 1) * max_anchor_num, + (rank * bs + i) * 256: (rank * bs + i + 1) * 256] = positive_map_box_to_self_text[i] + + # NOTE: 5. communicate and get positive map + # mapping_predicted_box_to_all_text: N*B*A x N*B*T + mapping_predicted_box_to_all_text = gather_tensors(mapping_predicted_box_to_all_text).view(-1, + mapping_predicted_box_to_all_text.size( + -1)) + shallow_positive_map = mapping_predicted_box_to_all_text # This is the true positive map + shallow_positive_map = shallow_positive_map.unsqueeze(0) + + # Get text attention masks + text_attention_mask = torch.zeros((bs, 256), dtype=torch.long) # B x 256 + for i in range(bs): + text_attention_mask[i, :len(text_masks[i])] = text_masks[i] + text_attention_mask = gather_tensors( + text_attention_mask.bool().to(device)) # N x B x 256 + + # if PAD_ZEROS, get image masks + if image_masks is not None: + image_attention_mask = torch.zeros((bs, max_anchor_num), dtype=torch.long) # B x max_anchor + for i in range(bs): + image_attention_mask[i, :len(image_masks[i])] = image_masks[i] + image_attention_mask = gather_tensors( + image_attention_mask.bool().to(device)) # N x B x max_anchor + + # NOTE: 6. calculate shallow contrastive logits + shallow_proj_tokens = F.normalize(self.shallow_contrastive_projection_text(proj_tokens), p=2, dim=-1) + + shallow_normalized_img_embs = [] + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + # choice 1:use features from SWINT backbone layer (c4) before vl fusion + from maskrcnn_benchmark.layers.roi_align import ROIAlignV2 + pooler = ROIAlignV2((1, 1), 1./16, 0) + # get positive features + for i in range(bs): + rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_positive_indices[i]]) + roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), rois) + roi_feature = roi_feature.squeeze(-1).squeeze(-1) + shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(roi_feature) + shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1) + if image_masks is not None: + # pad zeros + shallow_normalized_img_embs.append( + pad_tensor_given_dim_length(shallow_normalized_img_emb, + dim=0, + length=max_anchor_num, + padding_value=0.0, + batch_first=False + )) + else: + # pad negatives + negative_rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_negative_pad_indices[i]]) + negative_roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), negative_rois) + negative_roi_feature = negative_roi_feature.squeeze(-1).squeeze(-1) + negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(negative_roi_feature) + negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries, + p=2, dim=-1) + shallow_normalized_img_embs.append( + pad_random_negative_tensor_given_length(shallow_normalized_img_emb, + negative_shallow_normalized_img_emb, + length=max_anchor_num + ) + ) + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: + # choice 2:use features after FPN + shallow_img_embs = torch.cat(shallow_img_emb_feats, dim=1) + # get positive features + for i in range(bs): + shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_positive_indices[i], :]) + shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1) + if image_masks is not None: + # pad zeros + shallow_normalized_img_embs.append( + pad_tensor_given_dim_length(shallow_normalized_img_emb, + dim=0, + length=max_anchor_num, + padding_value=0.0, + batch_first=False + )) + else: + # pad negatives + negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_negative_pad_indices[i], :]) + negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries, + p=2, dim=-1) + shallow_normalized_img_embs.append( + pad_random_negative_tensor_given_length(shallow_normalized_img_emb, + negative_shallow_normalized_img_emb, + length=max_anchor_num + ) + ) + + shallow_normalized_img_embs = torch.stack(shallow_normalized_img_embs, dim=0) + shallow_normalized_text_emb = shallow_proj_tokens + shallow_normalized_text_emb = pad_tensor_given_dim_length(shallow_normalized_text_emb, + dim=1, + length=256, + padding_value=0.0) + + gathered_shallow_normalized_img_emb = gather_tensors(shallow_normalized_img_embs) + gathered_shallow_normalized_text_emb = gather_tensors(shallow_normalized_text_emb) + gathered_shallow_normalized_img_emb = gathered_shallow_normalized_img_emb.view(-1, + gathered_shallow_normalized_img_emb.size( + -1)) + gathered_shallow_normalized_text_emb = gathered_shallow_normalized_text_emb.view(-1, + gathered_shallow_normalized_text_emb.size( + -1)) + shallow_contrastive_logits = ( + torch.matmul(gathered_shallow_normalized_img_emb, + gathered_shallow_normalized_text_emb.transpose(-1, + -2)) / self.shallow_log_scale.exp()) + shallow_contrastive_logits = shallow_contrastive_logits.unsqueeze(0) + + # apply text mask + text_attention_mask = text_attention_mask.view(-1).unsqueeze(0).unsqueeze(0) + text_attention_mask = text_attention_mask.repeat(1, shallow_contrastive_logits.size(1), + 1) # copy along the image feature dimension + shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~text_attention_mask, -1000000) + + # if PAD ZEROS, apply image mask + if image_masks is not None: + image_attention_mask = image_attention_mask.view(-1).unsqueeze(0).unsqueeze(-1) + image_attention_mask = image_attention_mask.repeat(1, 1, shallow_contrastive_logits.size( + 2)) # copy along the text feature dimension + shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~image_attention_mask, -1000000) + + # Note: 7. calculate image and text logits and maps + shallow_image_logits = shallow_contrastive_logits[:, + (rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :] + shallow_image_positive_map = normalized_positive_map( + shallow_positive_map[:, (rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :]) + + shallow_text_logits = shallow_contrastive_logits[:, :, + (rank * bs) * 256: (rank * bs + bs) * 256].transpose(1, + 2) + shallow_text_positive_map = normalized_positive_map( + shallow_positive_map[:, :, (rank * bs) * 256: (rank * bs + bs) * 256].transpose(1, 2)) + + pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) + + num_gpus = get_world_size() + total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()])).item() + num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) + + cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu + + token_logits_loss = None + contrastive_align_loss = None + dot_product_token_loss = None + shallow_contrastive_loss = None + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + token_logits_loss = self.token_loss_func(token_logits_stacked, + token_labels_stacked, text_masks=text_masks, + version="binary") / num_pos_avg_per_gpu + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + contrastive_align_loss = self.ContrastiveAlignLoss(contrastive_logits, positive_map_box_to_self_text) / num_pos_avg_per_gpu + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + dot_product_token_loss = self.token_loss_func(dot_product_logits, + token_labels_stacked, text_masks=text_masks, + version="binary") / num_pos_avg_per_gpu + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \ + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + box_to_token_loss = self.NllSoftMaxLoss(shallow_image_logits, shallow_image_positive_map).sum() + token_to_box_loss = self.NllSoftMaxLoss(shallow_text_logits, shallow_text_positive_map).sum() + tot_loss = (box_to_token_loss + token_to_box_loss) / 2 + shallow_contrastive_loss = tot_loss / num_pos_avg_per_gpu + + box_regression_flatten = box_regression_flatten[pos_inds] + reg_targets_flatten = reg_targets_flatten[pos_inds] + anchors_flatten = anchors_flatten[pos_inds] + centerness_flatten = centerness_flatten[pos_inds] + + if pos_inds.numel() > 0: + centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten) + + sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus) + reg_loss = self.GIoULoss(box_regression_flatten, reg_targets_flatten, anchors_flatten, + weight=centerness_targets) / sum_centerness_targets_avg_per_gpu + centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / num_pos_avg_per_gpu + else: + reg_loss = box_regression_flatten.sum() + reduce_sum(centerness_flatten.new_tensor([0.0])) + centerness_loss = centerness_flatten.sum() + + return cls_loss, reg_loss * self.cfg.MODEL.ATSS.REG_LOSS_WEIGHT, centerness_loss, \ + token_logits_loss, \ + contrastive_align_loss, \ + dot_product_token_loss, \ + shallow_contrastive_loss + + +def generate_anchor_labels(matched_targets): + labels_per_image = matched_targets.get_field("labels") + return labels_per_image + + +def make_focal_loss_evaluator(cfg, box_coder): + matcher = Matcher( + cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, + cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, + allow_low_quality_matches=True, + ) + sigmoid_focal_loss = SigmoidFocalLoss( + cfg.MODEL.FOCAL.LOSS_GAMMA, + cfg.MODEL.FOCAL.LOSS_ALPHA + ) + + loss_evaluator = FocalLossComputation( + matcher, + box_coder, + generate_anchor_labels, + sigmoid_focal_loss, + bbox_reg_beta=cfg.MODEL.FOCAL.BBOX_REG_BETA, + regress_norm=cfg.MODEL.FOCAL.BBOX_REG_WEIGHT, + ) + return loss_evaluator + + +def make_rpn_loss_evaluator(cfg, box_coder): + matcher = Matcher( + cfg.MODEL.RPN.FG_IOU_THRESHOLD, + cfg.MODEL.RPN.BG_IOU_THRESHOLD, + allow_low_quality_matches=True, + ) + + fg_bg_sampler = BalancedPositiveNegativeSampler( + cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION + ) + + loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder) + return loss_evaluator + + +def make_fcos_loss_evaluator(cfg): + loss_evaluator = FCOSLossComputation(cfg) + return loss_evaluator + + +def make_atss_loss_evaluator(cfg, box_coder): + loss_evaluator = ATSSLossComputation(cfg, box_coder) + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/rpn/modeling_bert.py b/maskrcnn_benchmark/modeling/rpn/modeling_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..f7eda26e6e13262cb281d7a53acd2f5e515fd391 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/modeling_bert.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +import pdb +from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer + + +def clamp_values(vector, min_val = -50000, max_val = 50000): + vector = torch.clamp(vector, min = min_val, max = max_val) + return vector + + +class BertSelfAttention(nn.Module): + def __init__(self, config, clamp_min_for_underflow=False, clamp_max_for_overflow=False): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.clamp_min_for_underflow = clamp_min_for_underflow + self.clamp_max_for_overflow = clamp_max_for_overflow + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if self.clamp_min_for_underflow: + attention_scores = torch.clamp(attention_scores, min=-50000) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attention_scores = torch.clamp(attention_scores, max=50000) # Do not increase 50000, data type half has quite limited range + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # if math.isnan(attention_probs.sum().item()): + # for i in range(attention_probs.size(1)): + # for j in range(attention_probs.size(2)): + # if math.isnan(attention_probs[0, i, j].sum().item()): + # print(i, j) + # pdb.set_trace() + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, clamp_min_for_underflow=False, clamp_max_for_overflow=False): + super().__init__() + self.self = BertSelfAttention(config, clamp_min_for_underflow, clamp_max_for_overflow) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = clamp_values(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = clamp_values(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = clamp_values(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = clamp_values(hidden_states) + return hidden_states + diff --git a/maskrcnn_benchmark/modeling/rpn/retina.py b/maskrcnn_benchmark/modeling/rpn/retina.py new file mode 100644 index 0000000000000000000000000000000000000000..146449c7cc930bef93d89471d021979bdea7546e --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/retina.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import math +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from .loss import make_focal_loss_evaluator +from .anchor_generator import make_anchor_generator_complex +from .inference import make_retina_postprocessor + + +@registry.RPN_HEADS.register("RetinaNetHead") +class RetinaNetHead(torch.nn.Module): + """ + Adds a RetinNet head with classification and regression heads + """ + + def __init__(self, cfg): + """ + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + super(RetinaNetHead, self).__init__() + # TODO: Implement the sigmoid version first. + num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1 + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + if cfg.MODEL.RPN.USE_FPN: + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE + else: + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * len(cfg.MODEL.RPN.ANCHOR_SIZES) + + cls_tower = [] + bbox_tower = [] + for i in range(cfg.MODEL.RETINANET.NUM_CONVS): + cls_tower.append( + nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + ) + cls_tower.append(nn.ReLU()) + bbox_tower.append( + nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + ) + bbox_tower.append(nn.ReLU()) + + self.add_module('cls_tower', nn.Sequential(*cls_tower)) + self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) + self.cls_logits = nn.Conv2d( + in_channels, num_anchors * num_classes, kernel_size=3, stride=1, + padding=1 + ) + self.bbox_pred = nn.Conv2d( + in_channels, num_anchors * 4, kernel_size=3, stride=1, + padding=1 + ) + + # Initialization + for modules in [self.cls_tower, self.bbox_tower, self.cls_logits, + self.bbox_pred]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + + # retinanet_bias_init + prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + def forward(self, x): + logits = [] + bbox_reg = [] + for feature in x: + logits.append(self.cls_logits(self.cls_tower(feature))) + bbox_reg.append(self.bbox_pred(self.bbox_tower(feature))) + return logits, bbox_reg + + +class RetinaNetModule(torch.nn.Module): + """ + Module for RetinaNet computation. Takes feature maps from the backbone and + RetinaNet outputs and losses. Only Test on FPN now. + """ + + def __init__(self, cfg): + super(RetinaNetModule, self).__init__() + + self.cfg = cfg.clone() + + anchor_generator = make_anchor_generator_complex(cfg) + head = RetinaNetHead(cfg) + + box_coder = BoxCoder(weights=(10., 10., 5., 5.)) + + box_selector_test = make_retina_postprocessor(cfg, box_coder, is_train=False) + + loss_evaluator = make_focal_loss_evaluator(cfg, box_coder) + + self.anchor_generator = anchor_generator + self.head = head + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (list[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (list[BoxList): ground-truth boxes present in the image (optional) + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + box_cls, box_regression = self.head(features) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(anchors, box_cls, box_regression, targets) + else: + return self._forward_test(anchors, box_cls, box_regression) + + def _forward_train(self, anchors, box_cls, box_regression, targets): + + loss_box_cls, loss_box_reg = self.loss_evaluator( + anchors, box_cls, box_regression, targets + ) + losses = { + "loss_retina_cls": loss_box_cls, + "loss_retina_reg": loss_box_reg, + } + return anchors, losses + + def _forward_test(self, anchors, box_cls, box_regression): + boxes = self.box_selector_test(anchors, box_cls, box_regression) + return boxes, {} + + diff --git a/maskrcnn_benchmark/modeling/rpn/rpn.py b/maskrcnn_benchmark/modeling/rpn/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..7f300f67773358a11d890999a556a0dbea3bfdeb --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/rpn.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import torch.nn.functional as F +from torch import nn + +from maskrcnn_benchmark.modeling import registry +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from .loss import make_rpn_loss_evaluator +from .anchor_generator import make_anchor_generator +from .inference import make_rpn_postprocessor + + +@registry.RPN_HEADS.register("SimpleRPNHead") +class mRPNHead(nn.Module): + """ + Adds a simple RPN Head with classification and regression heads + """ + + def __init__(self, cfg, in_channels, num_anchors): + """ + Arguments: + cfg : config + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + super(mRPNHead, self).__init__() + self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) + self.bbox_pred = nn.Conv2d( + in_channels, num_anchors * 4, kernel_size=1, stride=1 + ) + + for l in [self.cls_logits, self.bbox_pred]: + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def forward(self, x): + logits = [] + bbox_reg = [] + for feature in x: + t = F.relu(feature) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +@registry.RPN_HEADS.register("SingleConvRPNHead") +class RPNHead(nn.Module): + """ + Adds a simple RPN Head with classification and regression heads + """ + + def __init__(self, cfg, in_channels, num_anchors): + """ + Arguments: + cfg : config + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + super(RPNHead, self).__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) + self.bbox_pred = nn.Conv2d( + in_channels, num_anchors * 4, kernel_size=1, stride=1 + ) + + for l in [self.conv, self.cls_logits, self.bbox_pred]: + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def forward(self, x): + logits = [] + bbox_reg = [] + for feature in x: + t = F.relu(self.conv(feature)) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +class RPNModule(torch.nn.Module): + """ + Module for RPN computation. Takes feature maps from the backbone and RPN + proposals and losses. Works for both FPN and non-FPN. + """ + + def __init__(self, cfg): + super(RPNModule, self).__init__() + + self.cfg = cfg.clone() + + anchor_generator = make_anchor_generator(cfg) + + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD] + head = rpn_head( + cfg, in_channels, anchor_generator.num_anchors_per_location()[0] + ) + + rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + + box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) + box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) + + loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) + + self.anchor_generator = anchor_generator + self.head = head + self.box_selector_train = box_selector_train + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (list[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (list[BoxList): ground-truth boxes present in the image (optional) + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + objectness, rpn_box_regression = self.head(features) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(anchors, objectness, rpn_box_regression, targets) + else: + return self._forward_test(anchors, objectness, rpn_box_regression) + + def _forward_train(self, anchors, objectness, rpn_box_regression, targets): + if self.cfg.MODEL.RPN_ONLY: + # When training an RPN-only model, the loss is determined by the + # predicted objectness and rpn_box_regression values and there is + # no need to transform the anchors into predicted boxes; this is an + # optimization that avoids the unnecessary transformation. + boxes = anchors + else: + # For end-to-end models, anchors must be transformed into boxes and + # sampled into a training batch. + with torch.no_grad(): + boxes = self.box_selector_train( + anchors, objectness, rpn_box_regression, targets + ) + loss_objectness, loss_rpn_box_reg = self.loss_evaluator( + anchors, objectness, rpn_box_regression, targets + ) + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } + return boxes, losses + + def _forward_test(self, anchors, objectness, rpn_box_regression): + boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) + if self.cfg.MODEL.RPN_ONLY: + # For end-to-end models, the RPN proposals are an intermediate state + # and don't bother to sort them in decreasing score order. For RPN-only + # models, the proposals are the final output and we return them in + # high-to-low confidence order. + inds = [ + box.get_field("objectness").sort(descending=True)[1] for box in boxes + ] + boxes = [box[ind] for box, ind in zip(boxes, inds)] + return boxes, {} \ No newline at end of file diff --git a/maskrcnn_benchmark/modeling/rpn/transformer.py b/maskrcnn_benchmark/modeling/rpn/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f0cd1efc216113cb3ef78896356cc3c35c6354 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/transformer.py @@ -0,0 +1,52 @@ +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import copy +from typing import Optional, List + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None): + src2 = self.self_attn(src, src, src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src diff --git a/maskrcnn_benchmark/modeling/rpn/vldyhead.py b/maskrcnn_benchmark/modeling/rpn/vldyhead.py new file mode 100644 index 0000000000000000000000000000000000000000..2edbb5d477c80e9abe760320fb7311fcc3efdcbe --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/vldyhead.py @@ -0,0 +1,1036 @@ +import torch +import torch.nn.functional as F +from torch import nn +from collections import defaultdict + +from .inference import make_atss_postprocessor +from .loss import make_atss_loss_evaluator +from .anchor_generator import make_anchor_generator_complex + +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv +from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d +from maskrcnn_benchmark.modeling.backbone.fbnet import * +from maskrcnn_benchmark.engine.inference import create_positive_map_label_to_token_from_positive_map +from ..utils import cat, concat_box_prediction_layers, permute_and_flatten + +from maskrcnn_benchmark.utils.fuse_helper import FeatureResizer, func_attention, _make_mlp, _make_conv, _make_coord, \ + BiAttentionBlock, AttentionT2I, BiAttentionBlockForCheckpoint, BertLMPredictionHead +from transformers.models.bert.modeling_bert import BertConfig, BertAttention, BertIntermediate, BertOutput, \ + BertPreTrainedModel +from transformers.modeling_utils import apply_chunking_to_forward +import torch.utils.checkpoint as checkpoint +import pdb + +from maskrcnn_benchmark.modeling.language_backbone.clip_model import QuickGELU, LayerNorm, DropPath +from timm.models.layers import DropPath, trunc_normal_ + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True, h_max=1): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + self.h_max = h_max + + def forward(self, x): + return self.relu(x + 3) * self.h_max / 6 + + +class BoxCoder(object): + + def __init__(self, cfg): + self.cfg = cfg + + def encode(self, gt_boxes, anchors): + TO_REMOVE = 1 # TODO remove + ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE + gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE + gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2 + gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * torch.log(gt_widths / ex_widths) + targets_dh = wh * torch.log(gt_heights / ex_heights) + targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + + return targets + + def decode(self, preds, anchors): + anchors = anchors.to(preds.dtype) + + TO_REMOVE = 1 # TODO remove + widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE + heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE + ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 + ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 + + wx, wy, ww, wh = (10., 10., 5., 5.) + dx = preds[:, 0::4] / wx + dy = preds[:, 1::4] / wy + dw = preds[:, 2::4] / ww + dh = preds[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=math.log(1000. / 16)) + dh = torch.clamp(dh, max=math.log(1000. / 16)) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(preds) + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1) + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1) + + return pred_boxes + + +class Conv3x3Norm(torch.nn.Module): + def __init__(self, + in_channels, + out_channels, + stride, + groups=1, + deformable=False, + bn_type=None): + super(Conv3x3Norm, self).__init__() + + if deformable: + self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, + groups=groups) + else: + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) + + if isinstance(bn_type, (list, tuple)): + assert len(bn_type) == 2 + assert bn_type[0] == "gn" + gn_group = bn_type[1] + bn_type = bn_type[0] + + if bn_type == "bn": + bn_op = nn.BatchNorm2d(out_channels) + elif bn_type == "sbn": + bn_op = nn.SyncBatchNorm(out_channels) + elif bn_type == "nsbn": + bn_op = NaiveSyncBatchNorm2d(out_channels) + elif bn_type == "gn": + bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels) + elif bn_type == "af": + bn_op = FrozenBatchNorm2d(out_channels) + if bn_type is not None: + self.bn = bn_op + else: + self.bn = None + + def forward(self, input, **kwargs): + x = self.conv(input, **kwargs) + if self.bn: + x = self.bn(x) + return x + + +class DyConv(torch.nn.Module): + def __init__(self, + in_channels=256, + out_channels=256, + conv_func=nn.Conv2d, + use_dyfuse=True, + use_dyrelu=False, + use_deform=False + ): + super(DyConv, self).__init__() + + self.DyConv = nn.ModuleList() + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 1)) + self.DyConv.append(conv_func(in_channels, out_channels, 2)) + + if use_dyfuse: + self.AttnConv = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, 1, kernel_size=1), + nn.ReLU(inplace=True)) + self.h_sigmoid = h_sigmoid() + else: + self.AttnConv = None + + if use_dyrelu: + self.relu = DYReLU(in_channels, out_channels) + else: + self.relu = nn.ReLU() + + if use_deform: + self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) + else: + self.offset = None + + self.init_weights() + + def init_weights(self): + for m in self.DyConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + if self.AttnConv is not None: + for m in self.AttnConv.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight.data, 0, 0.01) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, inputs): + visual_feats = inputs["visual"] + language_dict_features = inputs["lang"] + + next_x = [] + for level, feature in enumerate(visual_feats): + + conv_args = dict() + if self.offset is not None: + offset_mask = self.offset(feature) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, 18:, :, :].sigmoid() + conv_args = dict(offset=offset, mask=mask) + + temp_fea = [self.DyConv[1](feature, **conv_args)] + + if level > 0: + temp_fea.append(self.DyConv[2](visual_feats[level - 1], **conv_args)) + if level < len(visual_feats) - 1: + temp_fea.append(F.upsample_bilinear(self.DyConv[0](visual_feats[level + 1], **conv_args), + size=[feature.size(2), feature.size(3)])) + mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) + + if self.AttnConv is not None: + attn_fea = [] + res_fea = [] + for fea in temp_fea: + res_fea.append(fea) + attn_fea.append(self.AttnConv(fea)) + + res_fea = torch.stack(res_fea) + spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) + + mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) + + next_x.append(mean_fea) + + next_x = [self.relu(item) for item in next_x] + + features_dict = {"visual": next_x, + "lang": language_dict_features} + + return features_dict + + +class BertEncoderLayer(BertPreTrainedModel): + def __init__(self, config, clamp_min_for_underflow = False, clamp_max_for_overflow = False): + super().__init__(config) + self.config = config + + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + from maskrcnn_benchmark.modeling.rpn.modeling_bert import BertAttention, BertIntermediate, BertOutput + + self.attention = BertAttention(config, clamp_min_for_underflow, clamp_max_for_overflow) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, inputs): + language_dict_features = inputs["lang"] + hidden_states = language_dict_features["hidden"] + attention_mask = language_dict_features["masks"] + + device = hidden_states.device + input_shape = hidden_states.size()[:-1] + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + self_attention_outputs = self.attention( + hidden_states, + extended_attention_mask, + None, + output_attentions=False, + past_key_value=None, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + hidden_states = outputs[0] + + language_dict_features["hidden"] = hidden_states + + features_dict = {"visual": inputs["visual"], + "lang": language_dict_features + } + + return features_dict + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class CLIPTransformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + d_model = self.config.MODEL.CLIP.WIDTH + n_head = self.config.MODEL.CLIP.HEADS + drop_path = self.config.MODEL.CLIP.DROP_PATH + self.context_length = self.config.MODEL.CLIP.CONTEXT_LENGTH + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + nn.init.constant_(m.bias, 0) + + def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ + if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] + + def forward(self, inputs): + language_dict_features = inputs["lang"] + x = language_dict_features["hidden"] + mask = language_dict_features["masks"] + # get extended attention mask for nn.MultiHeadAttention + key_padding_mask = (1.0 - mask).to(torch.bool) + + x = x.permute(1, 0, 2) + x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + x = x.permute(1, 0, 2) + + language_dict_features["hidden"] = x + features_dict = {"visual": inputs["visual"], + "lang": language_dict_features + } + return features_dict + + +class DummyLayer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + return inputs + + +class VLFuse(torch.nn.Module): + """ + Early Fusion Module + """ + + def __init__(self, cfg): + super(VLFuse, self).__init__() + self.init_configs(cfg) + self.cfg = cfg + + self.use_checkpoint = False + if hasattr(cfg.MODEL.DYHEAD, 'USE_CHECKPOINT'): + self.use_checkpoint = cfg.MODEL.DYHEAD.USE_CHECKPOINT + self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True) + + # early fusion module + print("EARLY FUSION ON, USING {}".format(cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE)) + if cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S": + # single-direction (text->image) + # text -> image + self.t2i_attn = AttentionT2I(q_dim=self.joint_embedding_size, + k_dim=self.lang_dim, + embed_dim=self.embed_dim, + num_heads=self.n_head, + hidden_dim=self.t2i_hidden_dim, + dropout=0.1, + drop_path=.0, + init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS, + mode="t2i", + use_layer_scale=cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE, + clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW, + clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW + ) + + elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B": + # bi-direction (text->image, image->text) + self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.joint_embedding_size, + l_dim=self.lang_dim, + embed_dim=self.embed_dim, + num_heads=self.n_head, + hidden_dim=self.i2t_hidden_dim, + dropout=0.1, + drop_path=.0, + init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS, + cfg=cfg + ) + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: + self.shrink_lang = FeatureResizer(self.lang_dim * 5, + self.lang_dim, 0.1) + + + elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN": + # single-direction (text->image) + self.mapping_lang = _make_mlp(self.lang_dim, + self.joint_embedding_size, + self.joint_embedding_dropout) + self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \ + for _ in range(5)]) + + elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM": + # single-direction (text->image) + self.mapping_lang = _make_mlp(self.lang_dim, + self.joint_embedding_size, + self.joint_embedding_dropout) + self.gamma = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5)) + self.beta = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5)) + + self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \ + for _ in range(5)]) + + else: + print("NO FUSION INVOLVED.") + + def init_configs(self, cfg): + # common params + self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE + self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE + self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT + self.joint_mlp_layers = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS + + self.max_query_len = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN + self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS + self.coord_dim = 8 + self.joint_inp_dim = self.coord_dim + self.joint_embedding_size + self.joint_out_dim = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE + + # mha params + self.n_head = 8 + self.embed_dim = 2048 + self.t2i_hidden_dim = 1024 # 256 * 4 + self.i2t_hidden_dim = 3072 # 768 * 4 + + if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]: + self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM + else: + self.lang_dim = 1024 + + def forward(self, x): + visual_features = x["visual"] + language_dict_features = x["lang"] + + batch_size = visual_features[0].shape[0] + device = visual_features[0].device + + fused_visual_features = None + fused_language_dict_features = None + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S": + language_feature = language_dict_features['hidden'] + mask = language_dict_features['masks'] + # text -> image + if self.use_checkpoint: + q0, q1, q2, q3, q4 = checkpoint.checkpoint( + self.t2i_attn, + visual_features[0], visual_features[1], + visual_features[2], visual_features[3], + visual_features[4], + language_feature, language_feature, + mask, + self.dummy_tensor + ) + else: + q0, q1, q2, q3, q4 = self.t2i_attn( + visual_features[0], visual_features[1], + visual_features[2], visual_features[3], + visual_features[4], + language_feature, language_feature, + attention_mask=mask + ) + + fused_visual_features = [q0, q1, q2, q3, q4] + fused_language_dict_features = language_dict_features + + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B": + if self.use_checkpoint: + q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = checkpoint.checkpoint(self.b_attn, + visual_features[0], visual_features[1], + visual_features[2], visual_features[3], + visual_features[4], + language_dict_features['hidden'], + language_dict_features['masks'], + self.dummy_tensor + ) + else: + q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = self.b_attn( + visual_features[0], visual_features[1], + visual_features[2], visual_features[3], + visual_features[4], + language_dict_features['hidden'], + language_dict_features['masks'], + self.dummy_tensor + ) + + fused_visual_features = [q0, q1, q2, q3, q4] + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: + language_features = self.shrink_lang(torch.cat([l0, l1, l2, l3, l4], dim = -1)) + else: + language_features = l0 + + language_dict_features['hidden'] = language_features + fused_language_dict_features = language_dict_features + + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN": + # text -> image + language_feature = language_dict_features['aggregate'] + language_feature = self.mapping_lang(language_feature) + visu_feat = [] + for ii, feat in enumerate(visual_features): + attn_feat = func_attention(feat, language_feature, smooth=1, raw_feature_norm="softmax") + visu_feat.append(attn_feat) + + fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)] + fused_language_dict_features = language_dict_features + + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM": + # text -> image + # relative position embedding + coord_feats = [_make_coord(batch_size, x.shape[2], x.shape[3]) for x in visual_features] + # I only use a global representation of language + # you can also use more complex modeling using word-level representations + # Usage: lang_feat = lang_feat['words'] shape [seq_len, dim] + language_feature = language_dict_features['aggregate'] + language_feature = self.mapping_lang(language_feature) + + # attention mechanism for fusion + gamma = [F.tanh(gamma(language_feature)) for gamma in self.gamma] + beta = [F.tanh(beta(language_feature)) for beta in self.beta] + + visu_feat = [] + for ii, feat in enumerate(visual_features): + coord_feat = coord_feats[ii].to(device) + feat = torch.cat([feat, coord_feat], dim=1) + b = beta[ii].view(batch_size, -1, 1, 1).expand_as(feat) + g = gamma[ii].view(batch_size, -1, 1, 1).expand_as(feat) + feat = F.relu(g * feat + b) + visu_feat.append(feat) + + fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)] + fused_language_dict_features = language_dict_features + + else: + fused_visual_features = visual_features + fused_language_dict_features = language_dict_features + + features_dict = {"visual": fused_visual_features, + "lang": fused_language_dict_features} + + return features_dict + + +class VLDyHead(torch.nn.Module): + def __init__(self, cfg): + super(VLDyHead, self).__init__() + self.cfg = cfg + # bert_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE) + if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": + lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE) + elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": + lang_cfg = cfg + else: + lang_cfg = None + raise NotImplementedError + + num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1 + num_tokens = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN + num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE + in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS + channels = cfg.MODEL.DYHEAD.CHANNELS + + if cfg.MODEL.DYHEAD.USE_GN: + bn_type = ['gn', cfg.MODEL.GROUP_NORM.NUM_GROUPS] + elif cfg.MODEL.DYHEAD.USE_NSYNCBN: + bn_type = 'nsbn' + elif cfg.MODEL.DYHEAD.USE_SYNCBN: + bn_type = 'sbn' + else: + bn_type = None + + use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU + use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE + use_deform = cfg.MODEL.DYHEAD.USE_DFCONV + + if cfg.MODEL.DYHEAD.CONV_FUNC: + conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type) + else: + conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type) + + dyhead_tower = [] + for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): + if cfg.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON: + # cross-modality fusion + dyhead_tower.append( + VLFuse(cfg) + ) + # self language path + if i < cfg.MODEL.DYHEAD.NUM_CONVS - 1 or cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT: + # dyhead_tower.append( + # BertEncoderLayer( + # bert_cfg, + # clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW, + # clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW) + # ) + if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": + dyhead_tower.append( + BertEncoderLayer( + lang_cfg, + clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW, + clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW) + ) + elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": + dyhead_tower.append( + CLIPTransformerLayer(lang_cfg) + ) + else: + raise NotImplementedError + + else: + dyhead_tower.append( + DummyLayer() + ) + + # self vision path + dyhead_tower.append( + DyConv( + in_channels if i == 0 else channels, + channels, + conv_func=conv_func, + use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu, + use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse, + use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform, + ) + ) + + self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) + + self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1) + self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1) + self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1) + + # initialize the bias for focal loss + prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + + log_scale = self.cfg.MODEL.DYHEAD.LOG_SCALE + + # soft token head + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1) + # ABLATION + # self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1, bias=False) + # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True) + # self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True) + + # contrastive alignment head + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS == False + contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM + self.contrastive_align_projection_image = nn.Conv2d(channels, num_anchors * contrastive_hdim, kernel_size=1) + self.contrastive_align_projection_text = nn.Linear(channels, contrastive_hdim, bias=True) + self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) + + # dot product soft token head + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS == False + self.dot_product_projection_image = nn.Identity() + self.dot_product_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, + num_anchors * channels, bias=True) + self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) + # DEBUG + # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True) + self.bias_lang = nn.Parameter(torch.zeros(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True) + self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True) + + # initialization + for modules in [self.cls_logits, self.bbox_pred, + self.centerness]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) + + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + # if use soft token loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + for modules in [self.token_logits]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + torch.nn.init.constant_(self.token_logits.bias, bias_value) + # print(torch.norm(self.token_logits.weight)) + + # if use contrastive loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + for modules in [self.contrastive_align_projection_image]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + # if use dot product token loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + for modules in [self.dot_product_projection_image]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, bias_value) + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": + lang_cfg = BertConfig.from_pretrained("bert-base-uncased") + lang_cfg.hidden_size = cfg.MODEL.CLIP.WIDTH + lang_cfg.vocab_size = cfg.MODEL.CLIP.VOCAB_SIZE + self.mlm_head = BertLMPredictionHead( + lang_cfg + ) #nn.Linear(hidden_size, config.vocab_size, bias=False) + + def forward(self, x, language_dict_features=None, embedding=None, swint_feature_c4=None): + logits = [] + bbox_reg = [] + centerness = [] + + feat_inputs = {"visual": x, + "lang": language_dict_features} + + dyhead_tower = self.dyhead_tower(feat_inputs) + + # soft token + t_logits = None + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + t_logits = [] + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT: + embedding = dyhead_tower["lang"]["hidden"] + + # MLM loss + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: + mlm_logits = self.mlm_head(embedding) + else: + mlm_logits = None + + # contrastive + contrastive_logits = None + proj_tokens = None + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + contrastive_logits = [] + # follow MDETR's way + proj_tokens = F.normalize( + self.contrastive_align_projection_text(embedding), p=2, dim=-1 + ) + + # dot product soft token + dot_product_logits = None + dot_product_proj_tokens = None + dot_product_proj_tokens_bias = None + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + dot_product_logits = [] + # norm + embedding = F.normalize(embedding, p=2, dim=-1) + dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) + # w/o norm + # dot_product_proj_tokens = self.dot_product_projection_text(embedding / 28.0) + + dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) + self.bias0 + + # shallow contrastive (original feature from image & text encoder) + shallow_img_emb_feats = None + shallow_text_emb = None + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \ + or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + shallow_img_emb_feats = [] + shallow_text_emb = embedding + + # print([v.shape for v in x]) + # shallow contrastive: use the feature from swint backbone + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + for b, feature in enumerate(swint_feature_c4): + # BF, CF, HF, WF = feat.shape + # shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF) + shallow_img_emb_feats.append(feature) + + fused_visual_features = None + if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: + fused_visual_features = [] + + # use the feature from FPN + for l, feature in enumerate(x): + logits.append(self.cls_logits(dyhead_tower["visual"][l])) + + bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower["visual"][l])) + bbox_reg.append(bbox_pred) + + centerness.append(self.centerness(dyhead_tower["visual"][l])) + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + t_logits.append(self.token_logits(dyhead_tower["visual"][l])) + + # ABLATION + # b = self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + # x = dyhead_tower["visual"][l] + # B, C, H, W = x.shape + # bias = b.repeat(B, 1, H, W) + # t_logits.append(self.token_logits(dyhead_tower["visual"][l] + bias) + self.bias0) + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + x = dyhead_tower["visual"][l] + B, _, H, W = x.shape + C = proj_tokens.shape[2] + proj_queries = self.contrastive_align_projection_image(dyhead_tower["visual"][l]) + proj_queries = permute_and_flatten(proj_queries, B, -1, C, H, W) + normalized_img_emb = F.normalize(proj_queries, p=2, dim=-1) + normalized_text_emb = proj_tokens + contrastive_logit = ( + torch.matmul(normalized_img_emb, normalized_text_emb.transpose(-1, -2)) / self.log_scale.exp()) + contrastive_logits.append(contrastive_logit) + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + x = dyhead_tower["visual"][l] + if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: + fused_visual_features.append(x) + B, C, H, W = x.shape + + # add bias (language) + dot_product_proj_queries = self.dot_product_projection_image(x) + dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W) + + A = dot_product_proj_queries.shape[1] + bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1) + + dot_product_logit = (torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp()) + bias + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT: + dot_product_logit = torch.clamp(dot_product_logit, max=50000) + dot_product_logit = torch.clamp(dot_product_logit, min=-50000) + dot_product_logits.append(dot_product_logit) + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: + feat = feature + BF, CF, HF, WF = feat.shape + shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF) + shallow_img_emb_feats.append(shallow_img_emb) + + # no matter the feature is from backboone or from fpn, we use shallow_img_embs all the time + if shallow_img_emb_feats is not None and shallow_text_emb is not None: + # shallow_img_embs = torch.cat(shallow_img_embs, dim=1) + proj_tokens = shallow_text_emb + return logits, bbox_reg, centerness, t_logits, proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features + + +class VLDyHeadModule(torch.nn.Module): + + def __init__(self, cfg): + super(VLDyHeadModule, self).__init__() + self.cfg = cfg + self.head = VLDyHead(cfg) + box_coder = BoxCoder(cfg) + self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder) + self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True) + self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False) + self.anchor_generator = make_anchor_generator_complex(cfg) + + self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE + self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE + self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT + if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]: + self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM + else: + self.lang_dim = 1024 + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + self.resizer = FeatureResizer( + input_feat_size=self.lang_dim, + output_feat_size=self.joint_embedding_size, + dropout=self.joint_embedding_dropout + ) + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER: + self.tunable_linear = torch.nn.Linear(self.lang_dim, 1000, bias=False) + self.tunable_linear.weight.data.fill_(0.0) + + def forward(self, images, features, targets=None, + language_dict_features=None, + positive_map=None, + captions=None, + swint_feature_c4=None + ): + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + # resizer needed + embedding = language_dict_features['embedded'] + embedding = self.resizer(embedding) + elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + # no resizer needed + embedding = language_dict_features['embedded'] + else: + embedding = None + + if "masks" in language_dict_features: + text_masks = language_dict_features["masks"] + else: + text_masks = None + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER: + embedding = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + embedding + language_dict_features['embedded'] = embedding + language_dict_features['hidden'] = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + language_dict_features['hidden'] + + box_cls, box_regression, centerness, token_logits, \ + proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features = self.head(features, + language_dict_features, + embedding, + swint_feature_c4 + ) + anchors = self.anchor_generator(images, features) + + if self.training: + return self._forward_train(box_cls, box_regression, centerness, targets, anchors, + captions, + positive_map, + token_logits, + proj_tokens, + contrastive_logits, + dot_product_logits, + text_masks, + mlm_logits = mlm_logits, + mlm_labels = language_dict_features["mlm_labels"], + shallow_img_emb_feats=shallow_img_emb_feats, + fused_visual_features=fused_visual_features + ) + else: + return self._forward_test(box_regression, centerness, anchors, + box_cls, + token_logits, + dot_product_logits, + positive_map, + fused_visual_features=fused_visual_features + ) + + def _forward_train(self, box_cls, box_regression, centerness, targets, anchors, + captions=None, + positive_map=None, + token_logits=None, + proj_tokens=None, + contrastive_logits=None, + dot_product_logits=None, + text_masks=None, + mlm_logits=None, + mlm_labels=None, + shallow_img_emb_feats=None, + fused_visual_features=None + ): + + loss_box_cls, loss_box_reg, loss_centerness, loss_token, loss_contrastive_align, loss_dot_product_token, loss_shallow_contrastive = self.loss_evaluator( + box_cls, box_regression, centerness, targets, anchors, + captions, + positive_map, + token_logits, + proj_tokens, + contrastive_logits, + dot_product_logits, + text_masks, + shallow_img_emb_feats + ) + + losses = { + # "loss_cls": loss_box_cls, + "loss_reg": loss_box_reg, + "loss_centerness": loss_centerness + } + + if mlm_labels is not None and mlm_logits is not None: + losses["mlm_loss"] = nn.CrossEntropyLoss(ignore_index = -100)(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1)) * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS: + losses["loss_cls"] = loss_box_cls + else: + losses["loss_cls"] = 0.0 * loss_box_cls + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: + losses["loss_token"] = loss_token * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: + losses["loss_contrastive_align"] = loss_contrastive_align * \ + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: + losses["loss_dot_product_token"] = loss_dot_product_token * \ + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \ + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: + losses["loss_shallow_contrastive"] = loss_shallow_contrastive * \ + self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT + + if self.cfg.MODEL.RPN_ONLY: + return None, losses, None + else: + # Let's just use one image per batch + assert (box_regression[0].shape[0]) == 1 + positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=1) + boxes = self.box_selector_train(box_regression, centerness, anchors, + box_cls, + token_logits, + dot_product_logits, + positive_map=positive_map_label_to_token + ) + train_boxes = [] + for b, t in zip(boxes, targets): + tb = t.copy_with_fields(["labels"]) + tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) + train_boxes.append(cat_boxlist([b, tb])) + return train_boxes, losses, fused_visual_features + + def _forward_test(self, box_regression, centerness, anchors, + box_cls=None, + token_logits=None, + dot_product_logits=None, + positive_map=None, + fused_visual_features=None + ): + + boxes = self.box_selector_test(box_regression, centerness, anchors, + box_cls, + token_logits, + dot_product_logits, + positive_map, + ) + return boxes, {}, fused_visual_features diff --git a/maskrcnn_benchmark/modeling/utils.py b/maskrcnn_benchmark/modeling/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2834b105c8d171438a4534eb17fc0da65154d610 --- /dev/null +++ b/maskrcnn_benchmark/modeling/utils.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" +Miscellaneous utility functions +""" + +import torch + + +def cat(tensors, dim=0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def permute_and_flatten(layer, N, A, C, H, W): + layer = layer.view(N, -1, C, H, W) + layer = layer.permute(0, 3, 4, 1, 2) + layer = layer.reshape(N, -1, C) + return layer + + +def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None): + box_regression_flattened = [] + box_cls_flattened = [] + token_logit_flattened = [] + + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for box_cls_per_level, box_regression_per_level in zip( + box_cls, box_regression + ): + N, AxC, H, W = box_cls_per_level.shape + Ax4 = box_regression_per_level.shape[1] + A = Ax4 // 4 + C = AxC // A + box_cls_per_level = permute_and_flatten( + box_cls_per_level, N, A, C, H, W + ) + box_cls_flattened.append(box_cls_per_level) + + box_regression_per_level = permute_and_flatten( + box_regression_per_level, N, A, 4, H, W + ) + box_regression_flattened.append(box_regression_per_level) + + if token_logits is not None: + for token_logit_per_level in token_logits: + N, AXT, H, W = token_logit_per_level.shape + T = AXT // A + token_logit_per_level = permute_and_flatten( + token_logit_per_level, N, A, T, H, W + ) + token_logit_flattened.append(token_logit_per_level) + + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C) + box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) + + token_logits_stacked = None + if token_logits is not None: + # stacked + token_logits_stacked = cat(token_logit_flattened, dim=1) + + return box_regression, box_cls, token_logits_stacked + + +def round_channels(channels, divisor=8): + rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) + if float(rounded_channels) < 0.9 * channels: + rounded_channels += divisor + return rounded_channels diff --git a/maskrcnn_benchmark/solver/__init__.py b/maskrcnn_benchmark/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..927668ea6f35aedcff25f779e85a8b8c27a8c797 --- /dev/null +++ b/maskrcnn_benchmark/solver/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from .build import make_optimizer +from .build import make_lr_scheduler +from .lr_scheduler import WarmupMultiStepLR diff --git a/maskrcnn_benchmark/solver/build.py b/maskrcnn_benchmark/solver/build.py new file mode 100644 index 0000000000000000000000000000000000000000..4456f914f2349d3a86642871161e95e4cd26af7d --- /dev/null +++ b/maskrcnn_benchmark/solver/build.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch +import itertools + +from .lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR, WarmupReduceLROnPlateau + + +def make_optimizer(cfg, model): + def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class + # detectron2 doesn't have full model gradient clipping now + clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE + enable = ( + cfg.SOLVER.CLIP_GRADIENTS.ENABLED + and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" + and clip_norm_val > 0.0 + ) + + class FullModelGradientClippingOptimizer(optim): + def step(self, closure=None): + all_params = itertools.chain(*[x["params"] for x in self.param_groups]) + torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) + super().step(closure=closure) + + return FullModelGradientClippingOptimizer if enable else optim + + params = [] + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + + # different lr schedule + if "language_backbone" in key: + lr = cfg.SOLVER.LANG_LR + + if "backbone.body" in key and "language_backbone.body" not in key: + lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BACKBONE_BODY_LR_FACTOR + + if "bias" in key: + lr *= cfg.SOLVER.BIAS_LR_FACTOR + weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS + + if 'norm' in key or 'Norm' in key: + weight_decay *= cfg.SOLVER.WEIGHT_DECAY_NORM_FACTOR + print("Setting weight decay of {} to {}".format(key, weight_decay)) + + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + + if cfg.SOLVER.OPTIMIZER == "SGD": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(params, lr, momentum=cfg.SOLVER.MOMENTUM) + elif cfg.SOLVER.OPTIMIZER == "ADAMW": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(params, lr) + + return optimizer + + +def make_lr_scheduler(cfg, optimizer): + if cfg.SOLVER.MULTI_MAX_EPOCH: + assert len(cfg.SOLVER.MULTI_MAX_EPOCH) == len(cfg.SOLVER.STEPS) + lr_scheduler = [] + + for stage_step, stage_max_epoch in zip(cfg.SOLVER.STEPS, cfg.SOLVER.MULTI_MAX_ITER): + milestones = [] + for step in stage_step: + milestones.append(round(step * stage_max_epoch)) + lr_scheduler.append(WarmupMultiStepLR(optimizer, + milestones, + cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, ) + ) + return lr_scheduler + + elif cfg.SOLVER.USE_COSINE: + max_iters = cfg.SOLVER.MAX_ITER + return WarmupCosineAnnealingLR( + optimizer, + max_iters, + cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + eta_min=cfg.SOLVER.MIN_LR + ) + + elif cfg.SOLVER.USE_AUTOSTEP: + max_iters = cfg.SOLVER.MAX_ITER + return WarmupReduceLROnPlateau( + optimizer, + max_iters, + cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + eta_min=cfg.SOLVER.MIN_LR, + patience=cfg.SOLVER.STEP_PATIENCE, + verbose=True + ) + + else: + milestones = [] + for step in cfg.SOLVER.STEPS: + if step < 1: + milestones.append(round(step * cfg.SOLVER.MAX_ITER)) + else: + milestones.append(step) + return WarmupMultiStepLR( + optimizer, + milestones, + cfg.SOLVER.GAMMA, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + ) diff --git a/maskrcnn_benchmark/solver/lr_scheduler.py b/maskrcnn_benchmark/solver/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a06a52d57a1da5433c06555a551753dfe38a0fa8 --- /dev/null +++ b/maskrcnn_benchmark/solver/lr_scheduler.py @@ -0,0 +1,164 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from bisect import bisect_right + +import math +import torch + + +# FIXME ideally this would be achieved with a CombinedLRScheduler, +# separating MultiStepLR with WarmupLR +# but the current LRScheduler design doesn't allow it +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=500, + warmup_method="linear", + last_epoch=-1, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = float(self.last_epoch) / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + +class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + max_iters, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=500, + warmup_method="linear", + eta_min = 0, + last_epoch=-1, + ): + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.max_iters = max_iters + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.eta_min = eta_min + super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = float(self.last_epoch) / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [ + base_lr + * warmup_factor + for base_lr in self.base_lrs + ] + else: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_iters) / self.max_iters)) / 2 + for base_lr in self.base_lrs + ] + +class WarmupReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): + def __init__( + self, + optimizer, + max_iters, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=500, + warmup_method="linear", + eta_min = 0, + last_epoch=-1, + patience = 5, + verbose = False, + ): + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.eta_min = eta_min + + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + super(WarmupReduceLROnPlateau, self).__init__(optimizer, factor=gamma, patience=patience, mode='max', min_lr=eta_min, verbose = verbose) + + def step(self, metrics=None): + warmup_factor = 1 + + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = float(self.last_epoch) / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + + if self.last_epoch >= self.warmup_iters-1: + warmup_factor = 1.0 + + warmup_lrs = [ + base_lr + * warmup_factor + for base_lr in self.base_lrs + ] + + for param_group, lr in zip(self.optimizer.param_groups, warmup_lrs): + param_group['lr'] = lr + + self.last_epoch += 1 + elif metrics: + super().step(metrics) \ No newline at end of file diff --git a/maskrcnn_benchmark/structures/__init__.py b/maskrcnn_benchmark/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/structures/bounding_box.py b/maskrcnn_benchmark/structures/bounding_box.py new file mode 100644 index 0000000000000000000000000000000000000000..4b04683086ffad2345aed97b08d0c11ac385ba85 --- /dev/null +++ b/maskrcnn_benchmark/structures/bounding_box.py @@ -0,0 +1,321 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +class BoxList(object): + """ + This class represents a set of bounding boxes. + The bounding boxes are represented as a Nx4 Tensor. + In order to uniquely determine the bounding boxes with respect + to an image, we also store the corresponding image dimensions. + They can contain extra information that is specific to each bounding box, such as + labels. + """ + + def __init__(self, bbox, image_size, mode="xyxy"): + device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") + # only do as_tensor if isn't a "no-op", because it hurts JIT tracing + if (not isinstance(bbox, torch.Tensor) + or bbox.dtype != torch.float32 or bbox.device != device): + bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) + if bbox.ndimension() != 2: + raise ValueError( + "bbox should have 2 dimensions, got {}".format(bbox.ndimension()) + ) + if bbox.size(-1) != 4: + raise ValueError( + "last dimenion of bbox should have a " + "size of 4, got {}".format(bbox.size(-1)) + ) + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") + + self.bbox = bbox + self.size = image_size # (image_width, image_height) + self.mode = mode + self.extra_fields = {} + + # note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between + def _jit_unwrap(self): + return (self.bbox,) + tuple(f for f in (self.get_field(field) + for field in sorted(self.fields())) + if isinstance(f, torch.Tensor)) + + def _jit_wrap(self, input_stream): + self.bbox = input_stream[0] + num_consumed = 1 + for f in sorted(self.fields()): + if isinstance(self.extra_fields[f], torch.Tensor): + self.extra_fields[f] = input_stream[num_consumed] + num_consumed += 1 + return self, input_stream[num_consumed:] + + def add_field(self, field, field_data): + self.extra_fields[field] = field_data + + def get_field(self, field): + return self.extra_fields[field] + + def has_field(self, field): + return field in self.extra_fields + + def fields(self): + return list(self.extra_fields.keys()) + + def _copy_extra_fields(self, bbox): + for k, v in bbox.extra_fields.items(): + self.extra_fields[k] = v + + def convert(self, mode): + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") + if mode == self.mode: + return self + # we only have two modes, so don't need to check + # self.mode + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if mode == "xyxy": + bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) + bbox = BoxList(bbox, self.size, mode=mode) + else: + TO_REMOVE = 1 + # NOTE: explicitly specify dim to avoid tracing error in GPU + bbox = torch.cat( + (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1 + ) + bbox = BoxList(bbox, self.size, mode=mode) + bbox._copy_extra_fields(self) + return bbox + + def _split_into_xyxy(self): + if self.mode == "xyxy": + xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) + return xmin, ymin, xmax, ymax + elif self.mode == "xywh": + TO_REMOVE = 1 + xmin, ymin, w, h = self.bbox.split(1, dim=-1) + return ( + xmin, + ymin, + xmin + (w - TO_REMOVE).clamp(min=0), + ymin + (h - TO_REMOVE).clamp(min=0), + ) + else: + raise RuntimeError("Should not be here") + + def resize(self, size, *args, **kwargs): + """ + Returns a resized copy of this bounding box + + :param size: The requested size in pixels, as a 2-tuple: + (width, height). + """ + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_box = self.bbox * ratio + bbox = BoxList(scaled_box, size, mode=self.mode) + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + return bbox + + ratio_width, ratio_height = ratios + xmin, ymin, xmax, ymax = self._split_into_xyxy() + scaled_xmin = xmin * ratio_width + scaled_xmax = xmax * ratio_width + scaled_ymin = ymin * ratio_height + scaled_ymax = ymax * ratio_height + scaled_box = torch.cat( + (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 + ) + bbox = BoxList(scaled_box, size, mode="xyxy") + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + + return bbox.convert(self.mode) + + def transpose(self, method): + """ + Transpose bounding box (flip or rotate in 90 degree steps) + :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, + :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`, + :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`, + :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`. + """ + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + image_width, image_height = self.size + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if method == FLIP_LEFT_RIGHT: + TO_REMOVE = 1 + transposed_xmin = image_width - xmax - TO_REMOVE + transposed_xmax = image_width - xmin - TO_REMOVE + transposed_ymin = ymin + transposed_ymax = ymax + elif method == FLIP_TOP_BOTTOM: + transposed_xmin = xmin + transposed_xmax = xmax + transposed_ymin = image_height - ymax + transposed_ymax = image_height - ymin + + transposed_boxes = torch.cat( + (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 + ) + bbox = BoxList(transposed_boxes, self.size, mode="xyxy") + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.transpose(method) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + def crop(self, box): + """ + Cropss a rectangular region from this bounding box. The box is a + 4-tuple defining the left, upper, right, and lower pixel + coordinate. + """ + xmin, ymin, xmax, ymax = self._split_into_xyxy() + w, h = box[2] - box[0], box[3] - box[1] + cropped_xmin = (xmin - box[0]).clamp(min=0, max=w) + cropped_ymin = (ymin - box[1]).clamp(min=0, max=h) + cropped_xmax = (xmax - box[0]).clamp(min=0, max=w) + cropped_ymax = (ymax - box[1]).clamp(min=0, max=h) + + # TODO should I filter empty boxes here? + cropped_box = torch.cat( + (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 + ) + bbox = BoxList(cropped_box, (w, h), mode="xyxy") + # bbox._copy_extra_fields(self) + for k, v in self.extra_fields.items(): + if not isinstance(v, torch.Tensor): + v = v.crop(box) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + # Tensor-like methods + + def to(self, device): + bbox = BoxList(self.bbox.to(device), self.size, self.mode) + for k, v in self.extra_fields.items(): + if hasattr(v, "to"): + v = v.to(device) + bbox.add_field(k, v) + return bbox + + def __getitem__(self, item): + bbox = BoxList(self.bbox[item], self.size, self.mode) + for k, v in self.extra_fields.items(): + bbox.add_field(k, v[item]) + return bbox + + def __len__(self): + return self.bbox.shape[0] + + def clip_to_image(self, remove_empty=True): + TO_REMOVE = 1 + x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE) + y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE) + x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE) + y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE) + self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1) + if remove_empty: + box = self.bbox + keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) + return self[keep] + return self + + def area(self): + if self.mode == 'xyxy': + TO_REMOVE = 1 + box = self.bbox + area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) + elif self.mode == 'xywh': + box = self.bbox + area = box[:, 2] * box[:, 3] + else: + raise RuntimeError("Should not be here") + + return area + + def copy_with_fields(self, fields): + bbox = BoxList(self.bbox, self.size, self.mode) + if not isinstance(fields, (list, tuple)): + fields = [fields] + for field in fields: + bbox.add_field(field, self.get_field(field)) + return bbox + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_boxes={}, ".format(len(self)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + @staticmethod + def concate_box_list(list_of_boxes): + boxes = torch.cat([i.bbox for i in list_of_boxes], dim = 0) + extra_fields_keys = list(list_of_boxes[0].extra_fields.keys()) + extra_fields = {} + for key in extra_fields_keys: + extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim = 0) + + final = list_of_boxes[0].copy_with_fields(extra_fields_keys) + + final.bbox = boxes + final.extra_fields = extra_fields + return final + +@torch.jit.unused +def _onnx_clip_boxes_to_image(boxes, size): + # type: (Tensor, Tuple[int, int]) + """ + Clip boxes so that they lie inside an image of size `size`. + Clip's min max are traced as constants. Use torch.min/max to WAR this issue + Arguments: + boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format + size (Tuple[height, width]): size of the image + Returns: + clipped_boxes (Tensor[N, 4]) + """ + TO_REMOVE = 1 + device = boxes.device + dim = boxes.dim() + boxes_x = boxes[..., 0::2] + boxes_y = boxes[..., 1::2] + + boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device)) + boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device)) + boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device)) + boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device)) + + clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) + return clipped_boxes.reshape(boxes.shape) + + +if __name__ == "__main__": + bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) + s_bbox = bbox.resize((5, 5)) + print(s_bbox) + print(s_bbox.bbox) + + t_bbox = bbox.transpose(0) + print(t_bbox) + print(t_bbox.bbox) diff --git a/maskrcnn_benchmark/structures/boxlist_ops.py b/maskrcnn_benchmark/structures/boxlist_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..85eb081ca64bf4464d6e523759b82882498bf4da --- /dev/null +++ b/maskrcnn_benchmark/structures/boxlist_ops.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +from .bounding_box import BoxList + +from maskrcnn_benchmark.layers import nms as _box_nms +from maskrcnn_benchmark.layers import ml_nms as _box_ml_nms + + +def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"): + """ + Performs non-maximum suppression on a boxlist, with scores specified + in a boxlist field via score_field. + + Arguments: + boxlist(BoxList) + nms_thresh (float) + max_proposals (int): if > 0, then only the top max_proposals are kept + after non-maxium suppression + score_field (str) + """ + if nms_thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + score = boxlist.get_field(score_field) + keep = _box_nms(boxes, score, nms_thresh) + if max_proposals > 0: + keep = keep[: max_proposals] + boxlist = boxlist[keep] + return boxlist.convert(mode) + + +def boxlist_ml_nms(boxlist, nms_thresh, max_proposals=-1, + score_field="scores", label_field="labels"): + """ + Performs non-maximum suppression on a boxlist, with scores specified + in a boxlist field via score_field. + + Arguments: + boxlist(BoxList) + nms_thresh (float) + max_proposals (int): if > 0, then only the top max_proposals are kept + after non-maximum suppression + score_field (str) + """ + if nms_thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + scores = boxlist.get_field(score_field) + labels = boxlist.get_field(label_field) + + if boxes.device==torch.device("cpu"): + keep = [] + unique_labels = torch.unique(labels) + print(unique_labels) + for j in unique_labels: + inds = (labels == j).nonzero().view(-1) + + scores_j = scores[inds] + boxes_j = boxes[inds, :].view(-1, 4) + keep_j = _box_nms(boxes_j, scores_j, nms_thresh) + + keep += keep_j + else: + keep = _box_ml_nms(boxes, scores, labels.float(), nms_thresh) + + if max_proposals > 0: + keep = keep[: max_proposals] + boxlist = boxlist[keep] + + return boxlist.convert(mode) + + +def remove_small_boxes(boxlist, min_size): + """ + Only keep boxes with both sides >= min_size + + Arguments: + boxlist (Boxlist) + min_size (int) + """ + # WORK AROUND: work around unbind using split + squeeze. + xywh_boxes = boxlist.convert("xywh").bbox + _, _, ws, hs = xywh_boxes.split(1, dim=1) + ws = ws.squeeze(1) + hs = hs.squeeze(1) + keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1) + return boxlist[keep] + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def boxlist_iou(boxlist1, boxlist2): + """Compute the intersection over union of two set of boxes. + The box order must be (xmin, ymin, xmax, ymax). + + Arguments: + box1: (BoxList) bounding boxes, sized [N,4]. + box2: (BoxList) bounding boxes, sized [M,4]. + + Returns: + (tensor) iou, sized [N,M]. + + Reference: + https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py + """ + if boxlist1.size != boxlist2.size: + raise RuntimeError( + "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) + + N = len(boxlist1) + M = len(boxlist2) + + area1 = boxlist1.area() + area2 = boxlist2.area() + + box1, box2 = boxlist1.bbox, boxlist2.bbox + + lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] + rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] + + TO_REMOVE = 1 + + wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + iou = inter / (area1[:, None] + area2 - inter) + return iou + + +# TODO redundant, remove +def _cat(tensors, dim=0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + if isinstance(tensors[0], torch.Tensor): + return torch.cat(tensors, dim) + else: + return cat_boxlist(tensors) + +def cat_boxlist(bboxes): + """ + Concatenates a list of BoxList (having the same image size) into a + single BoxList + + Arguments: + bboxes (list[BoxList]) + """ + assert isinstance(bboxes, (list, tuple)) + assert all(isinstance(bbox, BoxList) for bbox in bboxes) + + size = bboxes[0].size + assert all(bbox.size == size for bbox in bboxes) + + mode = bboxes[0].mode + assert all(bbox.mode == mode for bbox in bboxes) + + fields = set(bboxes[0].fields()) + assert all(set(bbox.fields()) == fields for bbox in bboxes) + + cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) + + for field in fields: + data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0) + cat_boxes.add_field(field, data) + + return cat_boxes + + +def getUnionBBox(aBB, bBB, margin = 10): + assert aBB.size==bBB.size + assert aBB.mode==bBB.mode + ih, iw = aBB.size + union_boxes = torch.cat([(torch.min(aBB.bbox[:,[0,1]], bBB.bbox[:,[0,1]]) - margin).clamp(min=0), \ + (torch.max(aBB.bbox[:,[2]], bBB.bbox[:,[2]]) + margin).clamp(max=iw), \ + (torch.max(aBB.bbox[:,[3]], bBB.bbox[:,[3]]) + margin).clamp(max=ih)], dim=1) + return BoxList(union_boxes, aBB.size, mode=aBB.mode) diff --git a/maskrcnn_benchmark/structures/image_list.py b/maskrcnn_benchmark/structures/image_list.py new file mode 100644 index 0000000000000000000000000000000000000000..e24df46e95ba39476fdce9f748c0e0f4fb94be98 --- /dev/null +++ b/maskrcnn_benchmark/structures/image_list.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from __future__ import division + +import torch + + +class ImageList(object): + """ + Structure that holds a list of images (of possibly + varying sizes) as a single tensor. + This works by padding the images to the same size, + and storing in a field the original sizes of each image + """ + + def __init__(self, tensors, image_sizes): + """ + Arguments: + tensors (tensor) + image_sizes (list[tuple[int, int]]) + """ + self.tensors = tensors + self.image_sizes = image_sizes + + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) + + +def to_image_list(tensors, size_divisible=0): + """ + tensors can be an ImageList, a torch.Tensor or + an iterable of Tensors. It can't be a numpy array. + When tensors is an iterable of Tensors, it pads + the Tensors with zeros so that they have the same + shape + """ + if isinstance(tensors, torch.Tensor) and size_divisible > 0: + tensors = [tensors] + + if isinstance(tensors, ImageList): + return tensors + elif isinstance(tensors, torch.Tensor): + # single tensor shape can be inferred + assert tensors.dim() == 4 + image_sizes = [tensor.shape[-2:] for tensor in tensors] + return ImageList(tensors, image_sizes) + elif isinstance(tensors, (tuple, list)): + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) + + # TODO Ideally, just remove this and let me model handle arbitrary + # input sizs + if size_divisible > 0: + import math + + stride = size_divisible + max_size = list(max_size) + max_size[1] = int(math.ceil(max_size[1] / stride) * stride) + max_size[2] = int(math.ceil(max_size[2] / stride) * stride) + max_size = tuple(max_size) + + batch_shape = (len(tensors),) + max_size + batched_imgs = tensors[0].new(*batch_shape).zero_() + for img, pad_img in zip(tensors, batched_imgs): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + + image_sizes = [im.shape[-2:] for im in tensors] + + return ImageList(batched_imgs, image_sizes) + else: + raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) diff --git a/maskrcnn_benchmark/structures/keypoint.py b/maskrcnn_benchmark/structures/keypoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d74c536f94bb8ba3f07f435769450c8971244a --- /dev/null +++ b/maskrcnn_benchmark/structures/keypoint.py @@ -0,0 +1,212 @@ +import torch +from maskrcnn_benchmark.config import cfg + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +class Keypoints(object): + def __init__(self, keypoints, size, mode=None): + # FIXME remove check once we have better integration with device + # in my version this would consistently return a CPU tensor + device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device('cpu') + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + # TODO should I split them? + # self.visibility = keypoints[..., 2] + self.keypoints = keypoints # [..., :2] + + self.size = size + self.mode = mode + self.extra_fields = {} + + def crop(self, box): + raise NotImplementedError() + + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + ratio_w, ratio_h = ratios + resized_data = self.keypoints.clone() + resized_data[..., 0] *= ratio_w + resized_data[..., 1] *= ratio_h + keypoints = type(self)(resized_data, size, self.mode) + for k, v in self.extra_fields.items(): + keypoints.add_field(k, v) + return keypoints + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT,): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT implemented") + + flip_inds = self.FLIP_INDS + flipped_data = self.keypoints[:, flip_inds] + width = self.size[0] + TO_REMOVE = 1 + # Flip x coordinates + flipped_data[..., 0] = width - flipped_data[..., 0] - TO_REMOVE + + # Maintain COCO convention that if visibility == 0, then x, y = 0 + inds = flipped_data[..., 2] == 0 + flipped_data[inds] = 0 + + keypoints = type(self)(flipped_data, self.size, self.mode) + for k, v in self.extra_fields.items(): + keypoints.add_field(k, v) + return keypoints + + def to(self, *args, **kwargs): + keypoints = type(self)(self.keypoints.to(*args, **kwargs), self.size, self.mode) + for k, v in self.extra_fields.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + keypoints.add_field(k, v) + return keypoints + + def __getitem__(self, item): + keypoints = type(self)(self.keypoints[item], self.size, self.mode) + for k, v in self.extra_fields.items(): + keypoints.add_field(k, v[item]) + return keypoints + + def add_field(self, field, field_data): + self.extra_fields[field] = field_data + + def get_field(self, field): + return self.extra_fields[field] + + def __repr__(self): + s = self.__class__.__name__ + '(' + s += 'num_instances={}, '.format(len(self.keypoints)) + s += 'image_width={}, '.format(self.size[0]) + s += 'image_height={})'.format(self.size[1]) + return s + + +class PersonKeypoints(Keypoints): + _NAMES = [ + 'nose', + 'left_eye', + 'right_eye', + 'left_ear', + 'right_ear', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_hip', + 'right_hip', + 'left_knee', + 'right_knee', + 'left_ankle', + 'right_ankle' + ] + _FLIP_MAP = { + 'left_eye': 'right_eye', + 'left_ear': 'right_ear', + 'left_shoulder': 'right_shoulder', + 'left_elbow': 'right_elbow', + 'left_wrist': 'right_wrist', + 'left_hip': 'right_hip', + 'left_knee': 'right_knee', + 'left_ankle': 'right_ankle' + } + + def __init__(self, *args, **kwargs): + super(PersonKeypoints, self).__init__(*args, **kwargs) + if len(cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME)>0: + self.NAMES = cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME + self.FLIP_MAP = {l:r for l,r in PersonKeypoints._FLIP_MAP.items() if l in cfg.MODEL.ROI_KEYPOINT_HEAD.KEYPOINT_NAME} + else: + self.NAMES = PersonKeypoints._NAMES + self.FLIP_MAP = PersonKeypoints._FLIP_MAP + + self.FLIP_INDS = self._create_flip_indices(self.NAMES, self.FLIP_MAP) + self.CONNECTIONS = self._kp_connections(self.NAMES) + + def to_coco_format(self): + coco_result = [] + for i in range(self.keypoints.shape[0]): + coco_kps = [0]*len(PersonKeypoints._NAMES)*3 + for ki, name in enumerate(self.NAMES): + coco_kps[3*PersonKeypoints._NAMES.index(name)] = self.keypoints[i,ki,0].item() + coco_kps[3*PersonKeypoints._NAMES.index(name)+1] = self.keypoints[i,ki,1].item() + coco_kps[3*PersonKeypoints._NAMES.index(name)+2] = self.keypoints[i,ki,2].item() + coco_result.append(coco_kps) + return coco_result + + def _create_flip_indices(self, names, flip_map): + full_flip_map = flip_map.copy() + full_flip_map.update({v: k for k, v in flip_map.items()}) + flipped_names = [i if i not in full_flip_map else full_flip_map[i] for i in names] + flip_indices = [names.index(i) for i in flipped_names] + return torch.tensor(flip_indices) + + + def _kp_connections(self, keypoints): + CONNECTIONS = [ + ['left_eye', 'right_eye'], + ['left_eye', 'nose'], + ['right_eye', 'nose'], + ['right_eye', 'right_ear'], + ['left_eye', 'left_ear'], + ['right_shoulder', 'right_elbow'], + ['right_elbow', 'right_wrist'], + ['left_shoulder', 'left_elbow'], + ['left_elbow', 'left_wrist'], + ['right_hip', 'right_knee'], + ['right_knee', 'right_ankle'], + ['left_hip', 'left_knee'], + ['left_knee', 'left_ankle'], + ['right_shoulder', 'left_shoulder'], + ['right_hip', 'left_hip'], + ] + + kp_lines = [[keypoints.index(conn[0]), keypoints.index(conn[1])] for conn in CONNECTIONS + if conn[0] in self.NAMES and conn[1] in self.NAMES] + return kp_lines + + + +# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop) +def keypoints_to_heat_map(keypoints, rois, heatmap_size): + if rois.numel() == 0: + return rois.new().long(), rois.new().long() + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) + scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) + + offset_x = offset_x[:, None] + offset_y = offset_y[:, None] + scale_x = scale_x[:, None] + scale_y = scale_y[:, None] + + x = keypoints[..., 0] + y = keypoints[..., 1] + + x_boundary_inds = x == rois[:, 2][:, None] + y_boundary_inds = y == rois[:, 3][:, None] + + x = (x - offset_x) * scale_x + x = x.floor().long() + y = (y - offset_y) * scale_y + y = y.floor().long() + + x[x_boundary_inds] = heatmap_size - 1 + y[y_boundary_inds] = heatmap_size - 1 + + valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) + vis = keypoints[..., 2] > 0 + valid = (valid_loc & vis).long() + + lin_ind = y * heatmap_size + x + heatmaps = lin_ind * valid + + return heatmaps, valid \ No newline at end of file diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..3a05e8ff93a352100e8463e074ee888d76e5b451 --- /dev/null +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -0,0 +1,214 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +import pycocotools.mask as mask_utils + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +class Mask(object): + """ + This class is unfinished and not meant for use yet + It is supposed to contain the mask for an object as + a 2d tensor + """ + + def __init__(self, masks, size, mode): + self.masks = masks + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + idx = 2 + elif method == FLIP_TOP_BOTTOM: + dim = height + idx = 1 + + flip_idx = list(range(dim)[::-1]) + flipped_masks = self.masks.index_select(dim, flip_idx) + return Mask(flipped_masks, self.size, self.mode) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + + cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]] + return Mask(cropped_masks, size=(w, h), mode=self.mode) + + def resize(self, size, *args, **kwargs): + pass + + +class Polygons(object): + """ + This class holds a set of polygons that represents a single instance + of an object mask. The object can be represented as a set of + polygons + """ + + def __init__(self, polygons, size, mode): + # assert isinstance(polygons, list), '{}'.format(polygons) + if isinstance(polygons, list): + polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] + elif isinstance(polygons, Polygons): + polygons = polygons.polygons + + self.polygons = polygons + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped_polygons = [] + width, height = self.size + if method == FLIP_LEFT_RIGHT: + dim = width + idx = 0 + elif method == FLIP_TOP_BOTTOM: + dim = height + idx = 1 + + for poly in self.polygons: + p = poly.clone() + TO_REMOVE = 1 + p[idx::2] = dim - poly[idx::2] - TO_REMOVE + flipped_polygons.append(p) + + return Polygons(flipped_polygons, size=self.size, mode=self.mode) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + + # TODO chck if necessary + w = max(w, 1) + h = max(h, 1) + + cropped_polygons = [] + for poly in self.polygons: + p = poly.clone() + p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) + p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) + cropped_polygons.append(p) + + return Polygons(cropped_polygons, size=(w, h), mode=self.mode) + + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_polys = [p * ratio for p in self.polygons] + return Polygons(scaled_polys, size, mode=self.mode) + + ratio_w, ratio_h = ratios + scaled_polygons = [] + for poly in self.polygons: + p = poly.clone() + p[0::2] *= ratio_w + p[1::2] *= ratio_h + scaled_polygons.append(p) + + return Polygons(scaled_polygons, size=size, mode=self.mode) + + def convert(self, mode): + width, height = self.size + if mode == "mask": + rles = mask_utils.frPyObjects( + [p.detach().numpy() for p in self.polygons], height, width + ) + rle = mask_utils.merge(rles) + mask = mask_utils.decode(rle) + mask = torch.from_numpy(mask) + # TODO add squeeze? + return mask + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_polygons={}, ".format(len(self.polygons)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + +class SegmentationMask(object): + """ + This class stores the segmentations for all objects in the image + """ + + def __init__(self, polygons, size, mode=None): + """ + Arguments: + polygons: a list of list of lists of numbers. The first + level of the list correspond to individual instances, + the second level to all the polygons that compose the + object, and the third level to the polygon coordinates. + """ + assert isinstance(polygons, list) + + self.polygons = [Polygons(p, size, mode) for p in polygons] + self.size = size + self.mode = mode + + def transpose(self, method): + if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): + raise NotImplementedError( + "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" + ) + + flipped = [] + for polygon in self.polygons: + flipped.append(polygon.transpose(method)) + return SegmentationMask(flipped, size=self.size, mode=self.mode) + + def crop(self, box): + w, h = box[2] - box[0], box[3] - box[1] + cropped = [] + for polygon in self.polygons: + cropped.append(polygon.crop(box)) + return SegmentationMask(cropped, size=(w, h), mode=self.mode) + + def resize(self, size, *args, **kwargs): + scaled = [] + for polygon in self.polygons: + scaled.append(polygon.resize(size, *args, **kwargs)) + return SegmentationMask(scaled, size=size, mode=self.mode) + + def to(self, *args, **kwargs): + return self + + def __getitem__(self, item): + if isinstance(item, (int, slice)): + selected_polygons = [self.polygons[item]] + else: + # advanced indexing on a single dimension + selected_polygons = [] + if isinstance(item, torch.Tensor) and item.dtype == torch.bool: + item = item.nonzero() + item = item.squeeze(1) if item.numel() > 0 else item + item = item.tolist() + for i in item: + selected_polygons.append(self.polygons[i]) + return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) + + def __iter__(self): + return iter(self.polygons) + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_instances={}, ".format(len(self.polygons)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={})".format(self.size[1]) + return s diff --git a/maskrcnn_benchmark/utils/README.md b/maskrcnn_benchmark/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3c35e560d1b3e3fb6cfc5e5a5653a283b1c603e3 --- /dev/null +++ b/maskrcnn_benchmark/utils/README.md @@ -0,0 +1,5 @@ +# Utility functions + +This folder contain utility functions that are not used in the +core library, but are useful for building models or training +code using the config system. diff --git a/maskrcnn_benchmark/utils/__init__.py b/maskrcnn_benchmark/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/maskrcnn_benchmark/utils/amp.py b/maskrcnn_benchmark/utils/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b1a4f5bd5baaf888829aca231af445b4600650 --- /dev/null +++ b/maskrcnn_benchmark/utils/amp.py @@ -0,0 +1,14 @@ +from contextlib import contextmanager + +@contextmanager +def nullcontext(enter_result=None, **kwargs): + yield enter_result + +try: + from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd +except: + print('[Warning] Library for automatic mixed precision is not found, AMP is disabled!!') + GradScaler = nullcontext + autocast = nullcontext + custom_fwd = nullcontext + custom_bwd = nullcontext \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/big_model_loading.py b/maskrcnn_benchmark/utils/big_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..25dc5429f2b771a96edd402c569bf140dac7fc33 --- /dev/null +++ b/maskrcnn_benchmark/utils/big_model_loading.py @@ -0,0 +1,80 @@ +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict + + +def tf2th(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + +def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg): + import re + layer_keys = sorted(state_dict.keys()) + for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1): + if not stage_with_dcn: + continue + for old_key in layer_keys: + pattern = ".*block{}.*conv2.*".format(ix) + r = re.match(pattern, old_key) + if r is None: + continue + for param in ["weight", "bias"]: + if old_key.find(param) is -1: + continue + if 'unit01' in old_key: + continue + new_key = old_key.replace( + "conv2.{}".format(param), "conv2.conv.{}".format(param) + ) + print("pattern: {}, old_key: {}, new_key: {}".format( + pattern, old_key, new_key + )) + # Calculate SD conv weight + w = state_dict[old_key] + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-10) + + state_dict[new_key] = w + del state_dict[old_key] + return state_dict + + +def load_big_format(cfg, f): + model = OrderedDict() + weights = np.load(f) + + cmap = {'a':1, 'b':2, 'c':3} + for key, val in weights.items(): + old_key = key.replace('resnet/', '') + if 'root_block' in old_key: + new_key = 'root.conv.weight' + elif '/proj/standardized_conv2d/kernel' in old_key: + key_pattern = old_key.replace('/proj/standardized_conv2d/kernel', '').replace('resnet/', '') + bname, uname, cidx = key_pattern.split('/') + new_key = '{}.downsample.{}.conv{}.weight'.format(bname,uname,cmap[cidx]) + elif '/standardized_conv2d/kernel' in old_key: + key_pattern = old_key.replace('/standardized_conv2d/kernel', '').replace('resnet/', '') + bname, uname, cidx = key_pattern.split('/') + new_key = '{}.{}.conv{}.weight'.format(bname,uname,cmap[cidx]) + elif '/group_norm/gamma' in old_key: + key_pattern = old_key.replace('/group_norm/gamma', '').replace('resnet/', '') + bname, uname, cidx = key_pattern.split('/') + new_key = '{}.{}.gn{}.weight'.format(bname,uname,cmap[cidx]) + elif '/group_norm/beta' in old_key: + key_pattern = old_key.replace('/group_norm/beta', '').replace('resnet/', '') + bname, uname, cidx = key_pattern.split('/') + new_key = '{}.{}.gn{}.bias'.format(bname,uname,cmap[cidx]) + else: + print('Unknown key {}'.format(old_key)) + continue + print('Map {} -> {}'.format(key, new_key)) + model[new_key] = tf2th(val) + + model = _rename_conv_weights_for_deformable_conv_layers(model, cfg) + + return dict(model=model) diff --git a/maskrcnn_benchmark/utils/c2_model_loading.py b/maskrcnn_benchmark/utils/c2_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..e51eea3a16aba9d1f392ac10a1602b1023938c30 --- /dev/null +++ b/maskrcnn_benchmark/utils/c2_model_loading.py @@ -0,0 +1,207 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import pickle +from collections import OrderedDict + +import torch + +from maskrcnn_benchmark.utils.model_serialization import load_state_dict +from maskrcnn_benchmark.utils.registry import Registry + + +def _rename_basic_resnet_weights(layer_keys): + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [k.replace(".w", ".weight") for k in layer_keys] + layer_keys = [k.replace(".bn", "_bn") for k in layer_keys] + layer_keys = [k.replace(".b", ".bias") for k in layer_keys] + layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys] + layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys] + layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys] + layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys] + layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys] + + # RPN / Faster RCNN + layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys] + layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys] + + # Affine-Channel -> BatchNorm enaming + layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys] + + # Make torchvision-compatible + layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys] + + layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys] + layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys] + layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys] + layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys] + + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys] + + layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] + layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] + + # GroupNorm + layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys] + layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys] + layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys] + layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys] + layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys] + layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys] + layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \ + for k in layer_keys] + layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \ + for k in layer_keys] + + return layer_keys + +def _rename_fpn_weights(layer_keys, stage_names): + for mapped_idx, stage_name in enumerate(stage_names, 1): + suffix = "" + if mapped_idx < 4: + suffix = ".lateral" + layer_keys = [ + k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys + ] + layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys] + + + layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys] + layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys] + layer_keys = [ + k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys + ] + + return layer_keys + + +def _rename_weights_for_resnet(weights, stage_names): + original_keys = sorted(weights.keys()) + layer_keys = sorted(weights.keys()) + + # for X-101, rename output to fc1000 to avoid conflicts afterwards + layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys] + layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys] + + # performs basic renaming: _ -> . , etc + layer_keys = _rename_basic_resnet_weights(layer_keys) + + # FPN + layer_keys = _rename_fpn_weights(layer_keys, stage_names) + + # Mask R-CNN + layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys] + layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys] + layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys] + + # Keypoint R-CNN + layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys] + layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys] + layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys] + + # Rename for our RPN structure + layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys] + + key_map = {k: v for k, v in zip(original_keys, layer_keys)} + + logger = logging.getLogger(__name__) + logger.info("Remapping C2 weights") + max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k]) + + new_weights = OrderedDict() + for k in original_keys: + v = weights[k] + if "_momentum" in k: + continue + if 'weight_order' in k: + continue + # if 'fc1000' in k: + # continue + w = torch.from_numpy(v) + # if "bn" in k: + # w = w.view(1, -1, 1, 1) + logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k])) + new_weights[key_map[k]] = w + + return new_weights + + +def _load_c2_pickled_weights(file_path): + with open(file_path, "rb") as f: + if torch._six.PY3: + data = pickle.load(f, encoding="latin1") + else: + data = pickle.load(f) + if "blobs" in data: + weights = data["blobs"] + else: + weights = data + return weights + + +def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg): + import re + logger = logging.getLogger(__name__) + logger.info("Remapping conv weights for deformable conv weights") + layer_keys = sorted(state_dict.keys()) + for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1): + if not stage_with_dcn: + continue + for old_key in layer_keys: + pattern = ".*layer{}.*conv2.*".format(ix) + r = re.match(pattern, old_key) + if r is None: + continue + for param in ["weight", "bias"]: + if old_key.find(param) is -1: + continue + new_key = old_key.replace( + "conv2.{}".format(param), "conv2.conv.{}".format(param) + ) + logger.info("pattern: {}, old_key: {}, new_key: {}".format( + pattern, old_key, new_key + )) + state_dict[new_key] = state_dict[old_key] + del state_dict[old_key] + return state_dict + + +_C2_STAGE_NAMES = { + "R-50": ["1.2", "2.3", "3.5", "4.2"], + "R-101": ["1.2", "2.3", "3.22", "4.2"], +} + +C2_FORMAT_LOADER = Registry() + + +@C2_FORMAT_LOADER.register("R-50-C4") +@C2_FORMAT_LOADER.register("R-50-C5") +@C2_FORMAT_LOADER.register("R-101-C4") +@C2_FORMAT_LOADER.register("R-101-C5") +@C2_FORMAT_LOADER.register("R-50-FPN") +@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET") +@C2_FORMAT_LOADER.register("R-50-FPN-FCOS") +@C2_FORMAT_LOADER.register("R-101-FPN") +@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET") +@C2_FORMAT_LOADER.register("R-101-FPN-FCOS") +def load_resnet_c2_format(cfg, f): + state_dict = _load_c2_pickled_weights(f) + conv_body = cfg.MODEL.BACKBONE.CONV_BODY + arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "").replace("-RETINANET", "").replace("-FCOS", "") + stages = _C2_STAGE_NAMES[arch] + state_dict = _rename_weights_for_resnet(state_dict, stages) + # *********************************** + # for deformable convolutional layer + state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg) + # *********************************** + return dict(model=state_dict) + + +def load_c2_format(cfg, f): + return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f) diff --git a/maskrcnn_benchmark/utils/checkpoint.py b/maskrcnn_benchmark/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c35a8f6e7b6da8478d0100f8c240e5ee1d50ccba --- /dev/null +++ b/maskrcnn_benchmark/utils/checkpoint.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os + +import torch + +from maskrcnn_benchmark.utils.model_serialization import load_state_dict +from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format +from maskrcnn_benchmark.utils.big_model_loading import load_big_format +from maskrcnn_benchmark.utils.pretrain_model_loading import load_pretrain_format +from maskrcnn_benchmark.utils.imports import import_file +from maskrcnn_benchmark.utils.model_zoo import cache_url + + +class Checkpointer(object): + def __init__( + self, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_dir = save_dir + self.save_to_disk = save_to_disk + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + + def save(self, name, **kwargs): + if not self.save_dir: + return + + if not self.save_to_disk: + return + + data = {} + data["model"] = self.model.state_dict() + if self.optimizer is not None: + data["optimizer"] = self.optimizer.state_dict() + if self.scheduler is not None: + if isinstance(self.scheduler, list): + data["scheduler"] = [scheduler.state_dict() for scheduler in self.scheduler] + else: + data["scheduler"] = self.scheduler.state_dict() + data.update(kwargs) + + save_file = os.path.join(self.save_dir, "{}.pth".format(name)) + self.logger.info("Saving checkpoint to {}".format(save_file)) + torch.save(data, save_file) + # self.tag_last_checkpoint(save_file) + # use relative path name to save the checkpoint + self.tag_last_checkpoint("{}.pth".format(name)) + + def load(self, f=None, force=False, keyword="model", skip_optimizer =False): + resume = False + if self.has_checkpoint() and not force: + # override argument with existing checkpoint + f = self.get_checkpoint_file() + # get the absolute path + f = os.path.join(self.save_dir, f) + resume = True + if not f: + # no checkpoint could be found + self.logger.info("No checkpoint found. Initializing model from scratch") + return {} + self.logger.info("Loading checkpoint from {}".format(f)) + checkpoint = self._load_file(f) + self._load_model(checkpoint, keyword=keyword) + # if resume training, load optimizer and scheduler, + # otherwise use the specified LR in config yaml for fine-tuning + if resume and not skip_optimizer: + if "optimizer" in checkpoint and self.optimizer: + self.logger.info("Loading optimizer from {}".format(f)) + self.optimizer.load_state_dict(checkpoint.pop("optimizer")) + if "scheduler" in checkpoint and self.scheduler: + self.logger.info("Loading scheduler from {}".format(f)) + if isinstance(self.scheduler, list): + for scheduler, state_dict in zip(self.scheduler, checkpoint.pop("scheduler")): + scheduler.load_state_dict(state_dict) + else: + self.scheduler.load_state_dict(checkpoint.pop("scheduler")) + + # return any further checkpoint data + return checkpoint + else: + return {} + + def has_checkpoint(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + return os.path.exists(save_file) + + def get_checkpoint_file(self): + save_file = os.path.join(self.save_dir, "last_checkpoint") + try: + with open(save_file, "r") as f: + last_saved = f.read() + last_saved = last_saved.strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + last_saved = "" + return last_saved + + def tag_last_checkpoint(self, last_filename): + save_file = os.path.join(self.save_dir, "last_checkpoint") + with open(save_file, "w") as f: + f.write(last_filename) + + def _load_file(self, f): + return torch.load(f, map_location=torch.device("cpu")) + + def _load_model(self, checkpoint, keyword="model"): + load_state_dict(self.model, checkpoint.pop(keyword)) + + +class DetectronCheckpointer(Checkpointer): + def __init__( + self, + cfg, + model, + optimizer=None, + scheduler=None, + save_dir="", + save_to_disk=None, + logger=None, + ): + super(DetectronCheckpointer, self).__init__( + model, optimizer, scheduler, save_dir, save_to_disk, logger + ) + self.cfg = cfg.clone() + + def _load_file(self, f): + # catalog lookup + if f.startswith("catalog://"): + paths_catalog = import_file( + "maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True + ) + catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :]) + self.logger.info("{} points to {}".format(f, catalog_f)) + f = catalog_f + # download url files + if f.startswith("http"): + # if the file is a url path, download it and cache it + cached_f = cache_url(f) + self.logger.info("url {} cached in {}".format(f, cached_f)) + f = cached_f + # convert Caffe2 checkpoint from pkl + if f.endswith(".pkl"): + return load_c2_format(self.cfg, f) + if f.endswith(".big"): + return load_big_format(self.cfg, f) + if f.endswith(".pretrain"): + return load_pretrain_format(self.cfg, f) + # load native detectron.pytorch checkpoint + loaded = super(DetectronCheckpointer, self)._load_file(f) + if "model" not in loaded: + loaded = dict(model=loaded) + return loaded diff --git a/maskrcnn_benchmark/utils/collect_env.py b/maskrcnn_benchmark/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d93d6164aed31b783c58581cc85c183e1f1805be --- /dev/null +++ b/maskrcnn_benchmark/utils/collect_env.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import PIL + +from torch.utils.collect_env import get_pretty_env_info + + +def get_pil_version(): + return "\n Pillow ({})".format(PIL.__version__) + + +def collect_env_info(): + env_str = get_pretty_env_info() + env_str += get_pil_version() + return env_str diff --git a/maskrcnn_benchmark/utils/comm.py b/maskrcnn_benchmark/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..f1222d2b36d83edb659973cf2253e4d5201d823c --- /dev/null +++ b/maskrcnn_benchmark/utils/comm.py @@ -0,0 +1,157 @@ +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import pickle +import time +import functools +import logging +import torch +import torch.distributed as dist +import numpy as np + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to("cuda") + size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def broadcast_data(data): + if not torch.distributed.is_initialized(): + return data + rank = dist.get_rank() + if rank == 0: + data_tensor = torch.tensor(data + [0], device="cuda") + else: + data_tensor = torch.tensor(data + [1], device="cuda") + torch.distributed.broadcast(data_tensor, 0) + while data_tensor.cpu().numpy()[-1] == 1: + time.sleep(1) + + return data_tensor.cpu().numpy().tolist()[:-1] + + +def reduce_sum(tensor): + if get_world_size() <= 1: + return tensor + + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/cv2_util.py b/maskrcnn_benchmark/utils/cv2_util.py new file mode 100644 index 0000000000000000000000000000000000000000..268db9e5be1dc8094c39a1fddc1bfd7a89a7ca47 --- /dev/null +++ b/maskrcnn_benchmark/utils/cv2_util.py @@ -0,0 +1,24 @@ +""" +Module for cv2 utility functions and maintaining version compatibility +between 3.x and 4.x +""" +import cv2 + + +def findContours(*args, **kwargs): + """ + Wraps cv2.findContours to maintain compatiblity between versions + 3 and 4 + + Returns: + contours, hierarchy + """ + if cv2.__version__.startswith('4'): + contours, hierarchy = cv2.findContours(*args, **kwargs) + elif cv2.__version__.startswith('3'): + _, contours, hierarchy = cv2.findContours(*args, **kwargs) + else: + raise AssertionError( + 'cv2 must be either version 3 or 4 to call this method') + + return contours, hierarchy diff --git a/maskrcnn_benchmark/utils/dist.py b/maskrcnn_benchmark/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..de7ac00c0eed1acc723df95f79367af82f79ddb0 --- /dev/null +++ b/maskrcnn_benchmark/utils/dist.py @@ -0,0 +1,228 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities related to distributed mode. + +By default, the reduce of metrics and such are done on GPU, since it's more straightforward (we reuse the NCCL backend) +If you want to reduce on CPU instead (required for big datasets like GQA), use the env variable MDETR_CPU_REDUCE=1 +""" +import functools +import io +import os + +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + + return dist.group.WORLD + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + cpu_group = None + if os.getenv("MDETR_CPU_REDUCE") == "1": + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer) + data_list.append(obj) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + """ + Returns: + True if distributed training is enabled + """ + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + """ + Returns: + The number of processes in the process group + """ + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + """ + Returns: + The rank of the current process within the global process group. + """ + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process(): + """Return true if the current process is the main one""" + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + """Utility function to save only from the main process""" + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + """Initialize distributed training, if appropriate""" + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + dist.barrier() + setup_for_distributed(args.rank == 0) diff --git a/maskrcnn_benchmark/utils/ema.py b/maskrcnn_benchmark/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..1da65bc07a0365bf950aac5232ccbe666ae85741 --- /dev/null +++ b/maskrcnn_benchmark/utils/ema.py @@ -0,0 +1,46 @@ +from copy import deepcopy +from collections import OrderedDict +import torch + + +class ModelEma: + def __init__(self, model, decay=0.9999, device=''): + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device + if device: + self.ema.to(device=device) + self.ema_is_dp = hasattr(self.ema, 'module') + for p in self.ema.parameters(): + p.requires_grad_(False) + + def load_checkpoint(self, checkpoint): + if isinstance(checkpoint, str): + checkpoint = torch.load(checkpoint) + + assert isinstance(checkpoint, dict) + if 'model_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['model_ema'].items(): + if self.ema_is_dp: + name = k if k.startswith('module') else 'module.' + k + else: + name = k.replace('module.', '') if k.startswith('module') else k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + + def state_dict(self): + return self.ema.state_dict() + + def update(self, model): + pre_module = hasattr(model, 'module') and not self.ema_is_dp + with torch.no_grad(): + curr_msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + k = 'module.' + k if pre_module else k + model_v = curr_msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + diff --git a/maskrcnn_benchmark/utils/env.py b/maskrcnn_benchmark/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e19c760c076c3dfdb89cf2bf34b7ed8866a019 --- /dev/null +++ b/maskrcnn_benchmark/utils/env.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os + +from maskrcnn_benchmark.utils.imports import import_file + + +def setup_environment(): + """Perform environment setup work. The default setup is a no-op, but this + function allows the user to specify a Python source file that performs + custom setup work that may be necessary to their computing environment. + """ + custom_module_path = os.environ.get("TORCH_DETECTRON_ENV_MODULE") + if custom_module_path: + setup_custom_environment(custom_module_path) + else: + # The default setup is a no-op + pass + + +def setup_custom_environment(custom_module_path): + """Load custom environment setup from a Python source file and run the setup + function. + """ + module = import_file("maskrcnn_benchmark.utils.env.custom_module", custom_module_path) + assert hasattr(module, "setup_environment") and callable( + module.setup_environment + ), ( + "Custom environment module defined in {} does not have the " + "required callable attribute 'setup_environment'." + ).format( + custom_module_path + ) + module.setup_environment() + + +# Force environment setup when this module is imported +setup_environment() diff --git a/maskrcnn_benchmark/utils/flops.py b/maskrcnn_benchmark/utils/flops.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e3d72ff32c3d2824099517356067ad55c722a2 --- /dev/null +++ b/maskrcnn_benchmark/utils/flops.py @@ -0,0 +1,249 @@ +import argparse +import logging +import torch +import torch.nn as nn +import timeit + +from maskrcnn_benchmark.layers import * +from maskrcnn_benchmark.modeling.backbone.resnet_big import StdConv2d +from maskrcnn_benchmark.modeling.backbone.fpn import * +from maskrcnn_benchmark.modeling.rpn.inference import * +from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import PostProcessor +from maskrcnn_benchmark.modeling.rpn.anchor_generator import BufferList + + +def profile(model, input_size, custom_ops={}, device="cpu", verbose=False, extra_args={}, return_time=False): + handler_collection = [] + + def add_hooks(m): + if len(list(m.children())) > 0: + return + + m.register_buffer('total_ops', torch.zeros(1)) + m.register_buffer('total_params', torch.zeros(1)) + + for p in m.parameters(): + m.total_params += torch.Tensor([p.numel()]) + + m_type = type(m) + fn = None + + if m_type in custom_ops: + fn = custom_ops[m_type] + elif m_type in register_hooks: + fn = register_hooks[m_type] + else: + print("Not implemented for ", m) + + if fn is not None: + if verbose: + print("Register FLOP counter for module %s" % str(m)) + handler = m.register_forward_hook(fn) + handler_collection.append(handler) + + original_device = model.parameters().__next__().device + training = model.training + + model.eval().to(device) + model.apply(add_hooks) + + x = torch.zeros(input_size).to(device) + with torch.no_grad(): + tic = timeit.time.perf_counter() + model(x, **extra_args) + toc = timeit.time.perf_counter() + total_time = toc-tic + + total_ops = 0 + total_params = 0 + for m in model.modules(): + if len(list(m.children())) > 0: # skip for non-leaf module + continue + total_ops += m.total_ops + total_params += m.total_params + + total_ops = total_ops.item() + total_params = total_params.item() + + model.train(training).to(original_device) + for handler in handler_collection: + handler.remove() + + if return_time: + return total_ops, total_params, total_time + else: + return total_ops, total_params + + +multiply_adds = 1 +def count_conv2d(m, x, y): + x = x[0] + cin = m.in_channels + cout = m.out_channels + kh, kw = m.kernel_size + batch_size = x.size()[0] + out_h = y.size(2) + out_w = y.size(3) + # ops per output element + # kernel_mul = kh * kw * cin + # kernel_add = kh * kw * cin - 1 + kernel_ops = multiply_adds * kh * kw * cin // m.groups + bias_ops = 1 if m.bias is not None else 0 + ops_per_element = kernel_ops + bias_ops + # total ops + # num_out_elements = y.numel() + output_elements = batch_size * out_w * out_h * cout + total_ops = output_elements * ops_per_element + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_convtranspose2d(m, x, y): + x = x[0] + cin = m.in_channels + cout = m.out_channels + kh, kw = m.kernel_size + batch_size = x.size()[0] + out_h = y.size(2) + out_w = y.size(3) + # ops per output element + # kernel_mul = kh * kw * cin + # kernel_add = kh * kw * cin - 1 + kernel_ops = multiply_adds * kh * kw * cin // m.groups + bias_ops = 1 if m.bias is not None else 0 + ops_per_element = kernel_ops + bias_ops + # total ops + # num_out_elements = y.numel() + # output_elements = batch_size * out_w * out_h * cout + ops_per_element = m.weight.nelement() + output_elements = y.nelement() + total_ops = output_elements * ops_per_element + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_bn(m, x, y): + x = x[0] + nelements = x.numel() + # subtract, divide, gamma, beta + total_ops = 4*nelements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_relu(m, x, y): + x = x[0] + nelements = x.numel() + total_ops = nelements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_softmax(m, x, y): + x = x[0] + batch_size, nfeatures = x.size() + total_exp = nfeatures + total_add = nfeatures - 1 + total_div = nfeatures + total_ops = batch_size * (total_exp + total_add + total_div) + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_maxpool(m, x, y): + kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) + num_elements = y.numel() + total_ops = kernel_ops * num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_adap_maxpool(m, x, y): + kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze() + kernel_ops = torch.prod(kernel) + num_elements = y.numel() + total_ops = kernel_ops * num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_avgpool(m, x, y): + total_add = torch.prod(torch.Tensor([m.kernel_size])) + total_div = 1 + kernel_ops = total_add + total_div + num_elements = y.numel() + total_ops = kernel_ops * num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_adap_avgpool(m, x, y): + kernel = torch.Tensor([*(x[0].shape[2:])])//torch.Tensor(list((m.output_size,))).squeeze() + total_add = torch.prod(kernel) + total_div = 1 + kernel_ops = total_add + total_div + num_elements = y.numel() + total_ops = kernel_ops * num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_linear(m, x, y): + # per output element + total_mul = m.in_features + total_add = m.in_features - 1 + num_elements = y.numel() + total_ops = (total_mul + total_add) * num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_LastLevelMaxPool(m, x, y): + num_elements = y[-1].numel() + total_ops = num_elements + m.total_ops = torch.Tensor([int(total_ops)]) + + +def count_ROIAlign(m, x, y): + num_elements = y.numel() + total_ops = num_elements*4 + m.total_ops = torch.Tensor([int(total_ops)]) + + +register_hooks = { + Scale: None, + Conv2d: count_conv2d, + nn.Conv2d: count_conv2d, + ModulatedDeformConv: count_conv2d, + StdConv2d: count_conv2d, + + nn.BatchNorm1d: count_bn, + nn.BatchNorm2d: count_bn, + nn.BatchNorm3d: count_bn, + FrozenBatchNorm2d: count_bn, + nn.GroupNorm: count_bn, + NaiveSyncBatchNorm2d: count_bn, + + nn.ReLU: count_relu, + nn.ReLU6: count_relu, + swish: None, + + nn.ConstantPad2d: None, + SPPLayer: count_LastLevelMaxPool, + LastLevelMaxPool: count_LastLevelMaxPool, + nn.MaxPool1d: count_maxpool, + nn.MaxPool2d: count_maxpool, + nn.MaxPool3d: count_maxpool, + nn.AdaptiveMaxPool1d: count_adap_maxpool, + nn.AdaptiveMaxPool2d: count_adap_maxpool, + nn.AdaptiveMaxPool3d: count_adap_maxpool, + nn.AvgPool1d: count_avgpool, + nn.AvgPool2d: count_avgpool, + nn.AvgPool3d: count_avgpool, + nn.AdaptiveAvgPool1d: count_adap_avgpool, + nn.AdaptiveAvgPool2d: count_adap_avgpool, + nn.AdaptiveAvgPool3d: count_adap_avgpool, + nn.Linear: count_linear, + nn.Upsample: None, + nn.Dropout: None, + nn.Sigmoid: None, + DropBlock2D: None, + + ROIAlign: count_ROIAlign, + RPNPostProcessor: None, + PostProcessor: None, + BufferList: None, + RetinaPostProcessor: None, + FCOSPostProcessor: None, + ATSSPostProcessor: None, +} \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/fuse_helper.py b/maskrcnn_benchmark/utils/fuse_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9ea03f9f69c9d1a4f9a49c90436d540dc612e5 --- /dev/null +++ b/maskrcnn_benchmark/utils/fuse_helper.py @@ -0,0 +1,608 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import pdb +import math +from maskrcnn_benchmark.modeling.utils import cat, concat_box_prediction_layers, permute_and_flatten +from timm.models.layers import DropPath + +from transformers.activations import ACT2FN +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + +class FeatureResizer(nn.Module): + """ + This class takes as input a set of embeddings of dimension C1 and outputs a set of + embedding of dimension C2, after a linear transformation, dropout and normalization (LN). + """ + + def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): + super().__init__() + self.do_ln = do_ln + # Object feature encoding + self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) + self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, encoder_features): + x = self.fc(encoder_features) + if self.do_ln: + x = self.layer_norm(x) + output = self.dropout(x) + return output + + +def _make_conv(input_dim, output_dim, k, stride=1): + pad = (k - 1) // 2 + return nn.Sequential( + nn.Conv2d(input_dim, output_dim, (k, k), padding=(pad, pad), stride=(stride, stride)), + nn.BatchNorm2d(output_dim), + nn.ReLU(inplace=True) + ) + + +def _make_mlp(input_dim, output_dim, drop): + return nn.Sequential(nn.Linear(input_dim, output_dim), + nn.BatchNorm1d(output_dim), + nn.ReLU(inplace=True), + nn.Dropout(drop), + nn.Linear(output_dim, output_dim), + nn.BatchNorm1d(output_dim), + nn.ReLU(inplace=True)) + + +def _make_coord(batch, height, width): + # relative position encoding + xv, yv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) + xv_min = (xv.float() * 2 - width) / width + yv_min = (yv.float() * 2 - height) / height + xv_max = ((xv + 1).float() * 2 - width) / width + yv_max = ((yv + 1).float() * 2 - height) / height + xv_ctr = (xv_min + xv_max) / 2 + yv_ctr = (yv_min + yv_max) / 2 + hmap = torch.ones(height, width) * (1. / height) + wmap = torch.ones(height, width) * (1. / width) + coord = torch.autograd.Variable(torch.cat([xv_min.unsqueeze(0), yv_min.unsqueeze(0), \ + xv_max.unsqueeze(0), yv_max.unsqueeze(0), \ + xv_ctr.unsqueeze(0), yv_ctr.unsqueeze(0), \ + hmap.unsqueeze(0), wmap.unsqueeze(0)], dim=0)) + coord = coord.unsqueeze(0).repeat(batch, 1, 1, 1) + return coord + + +def l1norm(X, dim, eps=1e-8): + """L1-normalize columns of X + """ + norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps + X = torch.div(X, norm) + return X + + +def l2norm(X, dim, eps=1e-8): + """L2-normalize columns of X + """ + norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps + X = torch.div(X, norm) + return X + + +def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8): + """ + query: (n_context, queryL, d) + context: (n_context, sourceL, d) + """ + batch_size_q, queryL = query.size(0), query.size(1) + batch_size, sourceL = context.size(0), context.size(1) + + # Get attention + # --> (batch, d, queryL) + queryT = torch.transpose(query, 1, 2) + + # (batch, sourceL, d)(batch, d, queryL) + # --> (batch, sourceL, queryL) + attn = torch.bmm(context, queryT) + if raw_feature_norm == "softmax": + # --> (batch*sourceL, queryL) + attn = attn.view(batch_size * sourceL, queryL) + attn = nn.Softmax()(attn) + # --> (batch, sourceL, queryL) + attn = attn.view(batch_size, sourceL, queryL) + elif raw_feature_norm == "l2norm": + attn = l2norm(attn, 2) + elif raw_feature_norm == "clipped_l2norm": + attn = nn.LeakyReLU(0.1)(attn) + attn = l2norm(attn, 2) + else: + raise ValueError("unknown first norm type:", raw_feature_norm) + # --> (batch, queryL, sourceL) + attn = torch.transpose(attn, 1, 2).contiguous() + # --> (batch*queryL, sourceL) + attn = attn.view(batch_size * queryL, sourceL) + attn = nn.Softmax()(attn * smooth) + # --> (batch, queryL, sourceL) + attn = attn.view(batch_size, queryL, sourceL) + # --> (batch, sourceL, queryL) + attnT = torch.transpose(attn, 1, 2).contiguous() + + # --> (batch, d, sourceL) + contextT = torch.transpose(context, 1, 2) + # (batch x d x sourceL)(batch x sourceL x queryL) + # --> (batch, d, queryL) + weightedContext = torch.bmm(contextT, attnT) + # --> (batch, queryL, d) + weightedContext = torch.transpose(weightedContext, 1, 2) + + return weightedContext, attnT + + +class BiMultiHeadAttention(nn.Module): + def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): + super(BiMultiHeadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.v_dim = v_dim + self.l_dim = l_dim + + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + self.scale = self.head_dim ** (-0.5) + self.dropout = dropout + + self.v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.l_proj = nn.Linear(self.l_dim, self.embed_dim) + self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) + self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) + + self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) + self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) + + self.stable_softmax_2d = cfg.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D + self.clamp_min_for_underflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW + self.clamp_max_for_overflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW + + self._reset_parameters() + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _reset_parameters(self): + nn.init.xavier_uniform_(self.v_proj.weight) + self.v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.l_proj.weight) + self.l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_v_proj.weight) + self.values_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.values_l_proj.weight) + self.values_l_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_v_proj.weight) + self.out_v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_l_proj.weight) + self.out_l_proj.bias.data.fill_(0) + + def forward(self, v, l, attention_mask_l=None): + bsz, tgt_len, embed_dim = v.size() + + query_states = self.v_proj(v) * self.scale + key_states = self._shape(self.l_proj(l), -1, bsz) + value_v_states = self._shape(self.values_v_proj(v), -1, bsz) + value_l_states = self._shape(self.values_l_proj(l), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_v_states = value_v_states.view(*proj_shape) + value_l_states = value_l_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + # attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1) + + if self.stable_softmax_2d: + attn_weights = attn_weights - attn_weights.max() + + if self.clamp_min_for_underflow: + attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range + + attn_weights_T = attn_weights.transpose(1, 2) + attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[ + 0]) + if self.clamp_min_for_underflow: + attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range + + attn_weights_l = attn_weights_l.softmax(dim=-1) + + if attention_mask_l is not None: + assert (attention_mask_l.dim() == 2) + attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) + attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) + attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15) + + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) + + attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) + attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) + + attn_output_v = torch.bmm(attn_probs_v, value_l_states) + attn_output_l = torch.bmm(attn_probs_l, value_v_states) + + + if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" + ) + + if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): + raise ValueError( + f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" + ) + + attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output_v = attn_output_v.transpose(1, 2) + attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) + + attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) + attn_output_l = attn_output_l.transpose(1, 2) + attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) + + attn_output_v = self.out_v_proj(attn_output_v) + attn_output_l = self.out_l_proj(attn_output_l) + + return attn_output_v, attn_output_l + + +# Bi-Direction MHA (text->image, image->text) +class BiAttentionBlock(nn.Module): + def __init__(self, v_dim, l_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1, + drop_path=.0, init_values=1e-4, cfg=None): + """ + Inputs: + embed_dim - Dimensionality of input and attention feature vectors + hidden_dim - Dimensionality of hidden layer in feed-forward network + (usually 2-4x larger than embed_dim) + num_heads - Number of heads to use in the Multi-Head Attention block + dropout - Amount of dropout to apply in the feed-forward network + """ + super(BiAttentionBlock, self).__init__() + + # pre layer norm + self.layer_norm_v = nn.LayerNorm(v_dim) + self.layer_norm_l = nn.LayerNorm(l_dim) + self.attn = BiMultiHeadAttention(v_dim=v_dim, + l_dim=l_dim, + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + cfg=cfg) + + # add layer scale for training stability + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True) + self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True) + + def forward(self, v, l, attention_mask_l=None, dummy_tensor=None): + v = self.layer_norm_v(v) + l = self.layer_norm_l(l) + delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l) + # v, l = v + delta_v, l + delta_l + v = v + self.drop_path(self.gamma_v * delta_v) + l = l + self.drop_path(self.gamma_l * delta_l) + return v, l + +class BiAttentionBlockForCheckpoint(nn.Module): + def __init__(self, v_dim, l_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1, + drop_path=.0, init_values=1e-4, cfg=None): + """ + Inputs: + embed_dim - Dimensionality of input and attention feature vectors + hidden_dim - Dimensionality of hidden layer in feed-forward network + (usually 2-4x larger than embed_dim) + num_heads - Number of heads to use in the Multi-Head Attention block + dropout - Amount of dropout to apply in the feed-forward network + """ + super(BiAttentionBlockForCheckpoint, self).__init__() + + # pre layer norm + self.layer_norm_v = nn.LayerNorm(v_dim) + self.layer_norm_l = nn.LayerNorm(l_dim) + self.attn = BiMultiHeadAttention(v_dim=v_dim, + l_dim=l_dim, + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + cfg=cfg) + + # add layer scale for training stability + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True) + self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True) + + self.cfg = cfg + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL: + if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: + self.shrink_lang = FeatureResizer(l_dim * 5, l_dim, 0.1) + + def forward(self, q0, q1, q2, q3, q4, l, attention_mask_l=None, dummy_tensor=None): + + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL: + visu_feat = [] + lang_feat = [] + for ii, feat in enumerate([q0, q1, q2, q3, q4]): + bs, _, h, w = feat.shape + q = feat.flatten(2).transpose(1, 2) + + new_v, new_l = self.single_attention_call(q, l, attention_mask_l=attention_mask_l) + new_v = new_v.transpose(1, 2).contiguous().view(bs, -1, h, w) + lang_feat.append(new_l) + visu_feat.append(new_v) + if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: + pass + else: + lang_feat = self.shrink_lang(torch.cat(lang_feat, dim = -1)) # From multiple dimensions + lang_feat = [lang_feat, None, None, None, None] + else: + visu_feat = [] + size_per_level, visual_features_flatten = [], [] + for ii, feat_per_level in enumerate([q0, q1, q2, q3, q4]): + bs, c, h, w = feat_per_level.shape + size_per_level.append([h, w]) + feat = permute_and_flatten(feat_per_level, bs, 1, c, h, w) + visual_features_flatten.append(feat) + visual_features_flatten = cat(visual_features_flatten, dim=1) + new_v, new_l = self.single_attention_call(visual_features_flatten, l, attention_mask_l=attention_mask_l) + # [bs, N, C] -> [bs, C, N] + new_v = new_v.transpose(1, 2).contiguous() + + start = 0 + for (h, w) in size_per_level: + new_v_per_level = new_v[:, :, start:start + h * w].view(bs, -1, h, w).contiguous() + visu_feat.append(new_v_per_level) + start += h * w + + lang_feat = [new_l, None, None, None, None] + + return visu_feat[0], visu_feat[1], visu_feat[2], visu_feat[3], visu_feat[4], lang_feat[0], lang_feat[1], lang_feat[2], lang_feat[3], lang_feat[4] + + + def single_attention_call(self, v, l, attention_mask_l=None, dummy_tensor=None): + v = self.layer_norm_v(v) + l = self.layer_norm_l(l) + delta_v, delta_l = self.attn(v, l, attention_mask_l=attention_mask_l) + # v, l = v + delta_v, l + delta_l + v = v + self.drop_path(self.gamma_v * delta_v) + l = l + self.drop_path(self.gamma_l * delta_l) + return v, l + + +# Single Direction MHA +class MultiHeadAttention(nn.Module): + """ + Multi-head attention module for both image and text + """ + + def __init__(self, q_dim, k_dim, embed_dim, num_heads, dropout=0.1, + clamp_min_for_underflow = False, clamp_max_for_overflow = False): + super(MultiHeadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_dim = q_dim + self.k_dim = k_dim + + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + self.scale = self.head_dim ** (-0.5) + self.dropout = dropout + + self.q_proj = nn.Linear(self.q_dim, self.embed_dim) + self.k_proj = nn.Linear(self.k_dim, self.embed_dim) + self.v_proj = nn.Linear(self.k_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.q_dim) + self.clamp_min_for_underflow = clamp_min_for_underflow + self.clamp_max_for_overflow = clamp_max_for_overflow + + self._reset_parameters() + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight) + self.q_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.k_proj.weight) + self.k_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.v_proj.weight) + self.v_proj.bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_proj.weight) + self.out_proj.bias.data.fill_(0) + + def forward(self, q, k, v, attention_mask=None, return_attention=False): + bsz, tgt_len, embed_dim = q.size() + + query_states = self.q_proj(q) * self.scale + key_states = self._shape(self.k_proj(k), -1, bsz) + value_states = self._shape(self.v_proj(v), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if self.clamp_min_for_underflow: + attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range + if self.clamp_max_for_overflow: + attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range + + if attention_mask is not None: + # [bsz, src_len] + assert (attention_mask.dim() == 2) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) + attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15) + + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if return_attention: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + + return attn_output, attn_weights + + +class AttentionMLP(nn.Module): + def __init__(self, q_dim, hidden_dim, dropout=0.1): + super(AttentionMLP, self).__init__() + self.hidden_dim = hidden_dim + self.activation_fn = nn.GELU() + self.fc1 = nn.Linear(q_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, q_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class AttentionT2I(nn.Module): + def __init__(self, q_dim, k_dim, embed_dim, num_heads, hidden_dim=None, dropout=0.1, + drop_path=.0, init_values=1e-4, mode="i2t", use_layer_scale = False, + clamp_min_for_underflow = False, clamp_max_for_overflow = False): + """ + Inputs: + embed_dim - Dimensionality of input and attention feature vectors + hidden_dim - Dimensionality of hidden layer in feed-forward network + (usually 2-4x larger than embed_dim) + num_heads - Number of heads to use in the Multi-Head Attention block + dropout - Amount of dropout to apply in the feed-forward network + """ + super(AttentionT2I, self).__init__() + + # pre_layer norm + self.layer_norm_q_1 = nn.LayerNorm(q_dim) + self.layer_norm_k_1 = nn.LayerNorm(k_dim) + self.attn = MultiHeadAttention(q_dim=q_dim, + k_dim=k_dim, + embed_dim=embed_dim, + num_heads=num_heads, + clamp_min_for_underflow=clamp_min_for_underflow, + clamp_max_for_overflow=clamp_max_for_overflow) + self.mode = mode + + # add layer scale for training stability + self.use_layer_scale = use_layer_scale + if self.use_layer_scale: + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.gamma = nn.Parameter(init_values * torch.ones((q_dim)), requires_grad=True) + + + def forward(self, q0, q1, q2, q3, q4, k, v, attention_mask, dummy_arg=None): + qs = [] + for q_index, q in enumerate([q0, q1, q2, q3, q4]): + bs, _, h, w = q.shape + # (batch, seq_len, embed_size) + q = q.flatten(2).transpose(1, 2) + q = self.layer_norm_q_1(q) + k, v = self.layer_norm_k_1(k), self.layer_norm_k_1(v) + delta_q = self.attn(q, k, v, attention_mask=attention_mask)[0] + if self.use_layer_scale: + q = q + self.drop_path(self.gamma * delta_q) + else: + q = q + delta_q + q = q.transpose(1, 2).contiguous().view(bs, -1, h, w) + qs.append(q) + + + return qs[0], qs[1], qs[2], qs[3], qs[4] diff --git a/maskrcnn_benchmark/utils/imports.py b/maskrcnn_benchmark/utils/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..081e5556f74f0068957f4514593ca7446652d546 --- /dev/null +++ b/maskrcnn_benchmark/utils/imports.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import torch + +if torch._six.PY37: + import importlib + import importlib.util + import sys + + + # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa + def import_file(module_name, file_path, make_importable=False): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if make_importable: + sys.modules[module_name] = module + return module +else: + import imp + + def import_file(module_name, file_path, make_importable=None): + module = imp.load_source(module_name, file_path) + return module diff --git a/maskrcnn_benchmark/utils/logger.py b/maskrcnn_benchmark/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a30fa8d49a67111cb7d8d47e7db1ece98134aa8e --- /dev/null +++ b/maskrcnn_benchmark/utils/logger.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys + + +def setup_logger(name, save_dir, distributed_rank): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = logging.FileHandler(os.path.join(save_dir, "log.txt")) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/maskrcnn_benchmark/utils/mdetr_dist.py b/maskrcnn_benchmark/utils/mdetr_dist.py new file mode 100644 index 0000000000000000000000000000000000000000..af8f19fd511db7b871e78abf0e64d1225994406d --- /dev/null +++ b/maskrcnn_benchmark/utils/mdetr_dist.py @@ -0,0 +1,229 @@ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities related to distributed mode. + +By default, the reduce of metrics and such are done on GPU, since it's more straightforward (we reuse the NCCL backend) +If you want to reduce on CPU instead (required for big datasets like GQA), use the env variable MDETR_CPU_REDUCE=1 +""" +import functools +import io +import os +import datetime + +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + + return dist.group.WORLD + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + cpu_group = None + if os.getenv("MDETR_CPU_REDUCE") == "1": + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer) + data_list.append(obj) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + """ + Returns: + True if distributed training is enabled + """ + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + """ + Returns: + The number of processes in the process group + """ + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + """ + Returns: + The rank of the current process within the global process group. + """ + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process(): + """Return true if the current process is the main one""" + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + """Utility function to save only from the main process""" + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + """Initialize distributed training, if appropriate""" + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, + timeout=datetime.timedelta(0, 7200) + ) + dist.barrier() + setup_for_distributed(args.debug or args.rank == 0) diff --git a/maskrcnn_benchmark/utils/metric_logger.py b/maskrcnn_benchmark/utils/metric_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b506f40f1d7e912389c9c27dafd4d340552a6a9c --- /dev/null +++ b/maskrcnn_benchmark/utils/metric_logger.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import defaultdict +from collections import deque + +import torch +import time +from datetime import datetime +from .comm import is_main_process + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + # self.series = [] + self.total = 0.0 + self.count = 0 + + def update(self, value): + self.deque.append(value) + # self.series.append(value) + self.count += 1 + if value != value: + value = 0 + self.total += value + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque)) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + +# haotian added tensorboard support +class TensorboardLogger(MetricLogger): + def __init__(self, + log_dir, + start_iter=0, + delimiter='\t' + ): + super(TensorboardLogger, self).__init__(delimiter) + self.iteration = start_iter + self.writer = self._get_tensorboard_writer(log_dir) + + @staticmethod + def _get_tensorboard_writer(log_dir): + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError( + 'To use tensorboard please install tensorboardX ' + '[ pip install tensorflow tensorboardX ].' + ) + + if is_main_process(): + # timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M') + tb_logger = SummaryWriter('{}'.format(log_dir)) + return tb_logger + else: + return None + + def update(self, **kwargs): + super(TensorboardLogger, self).update(**kwargs) + if self.writer: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.writer.add_scalar(k, v, self.iteration) + + self.iteration += 1 diff --git a/maskrcnn_benchmark/utils/miscellaneous.py b/maskrcnn_benchmark/utils/miscellaneous.py new file mode 100644 index 0000000000000000000000000000000000000000..0169648926c729c217520442cd59a9214975a3bb --- /dev/null +++ b/maskrcnn_benchmark/utils/miscellaneous.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import errno +import os +from .comm import is_main_process + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def save_config(cfg, path): + if is_main_process(): + with open(path, 'w') as f: + f.write(cfg.dump()) diff --git a/maskrcnn_benchmark/utils/model_serialization.py b/maskrcnn_benchmark/utils/model_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..b8707ceb92f045d78e756c0a00df7e0192c39f1e --- /dev/null +++ b/maskrcnn_benchmark/utils/model_serialization.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from collections import OrderedDict, defaultdict +import logging +import math +import torch + +from maskrcnn_benchmark.utils.imports import import_file + +def resize_2d(posemb, shape_new): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = shape_new[0] + gs_old = int(math.sqrt(len(posemb))) # 2 * w - 1 + gs_new = int(math.sqrt(ntok_new)) # 2 * w - 1 + posemb_grid = posemb.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(gs_new * gs_new, -1) + return posemb_grid + +def align_and_update_state_dicts(model_state_dict, loaded_state_dict, reshape_keys=['pos_bias_table'], use_weightmap=False): + """ + Strategy: suppose that the models that we will create will have prefixes appended + to each of its keys, for example due to an extra level of nesting that the original + pre-trained weights from ImageNet won't contain. For example, model.state_dict() + might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains + res2.conv1.weight. We thus want to match both parameters together. + For that, we look for each model weight, look among all loaded keys if there is one + that is a suffix of the current weight name, and use it if that's the case. + If multiple matches exist, take the one with longest size + of the corresponding name. For example, for the same model as before, the pretrained + weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, + we want to match backbone[0].body.conv1.weight to conv1.weight, and + backbone[0].body.res2.conv1.weight to res2.conv1.weight. + """ + current_keys = sorted(list(model_state_dict.keys())) + loaded_keys = sorted(list(loaded_state_dict.keys())) + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # loaded_key string, if it matches + match_matrix = [ + len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys + ] + match_matrix = torch.as_tensor(match_matrix).view( + len(current_keys), len(loaded_keys) + ) + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + matched_keys = [] + # used for logging + max_size = max([len(key) for key in current_keys]) if current_keys else 1 + max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 + log_str_template = "{: <{}} loaded from {: <{}} of shape {}" + logger = logging.getLogger(__name__) + for idx_new, idx_old in enumerate(idxs.tolist()): + if idx_old == -1: + continue + key = current_keys[idx_new] + key_old = loaded_keys[idx_old] + if model_state_dict[key].shape != loaded_state_dict[key_old].shape: + if any([k in key_old for k in reshape_keys]): + new_shape = model_state_dict[key].shape + logger.warning('Reshaping {} -> {}. \n'.format(key_old, key)) + model_state_dict[key] = resize_2d(loaded_state_dict[key_old], new_shape) + elif use_weightmap and 'cls_logits' in key: + coco_in_objects365_inds = [ + 227, 26, 55, 202, 2, 44, 338, 346, 32, 336, 118, 299, 218, + 25, 361, 59, 95, 161, 278, 82, 110, 22, 364, 134, 9, 350, + 152, 323, 304, 130, 285, 289, 16, 172, 17, 18, 283, 305, + 321, 35, 362, 88, 127, 174, 292, 37, 11, 6, 267, 212, 41, + 58, 162, 237, 98, 48, 63, 81, 247, 23, 94, 326, 349, 178, + 203, 259, 171, 60, 198, 213, 325, 282, 258, 33, 71, 353, + 273, 318, 148, 330 + ] + logger.info("Use coco_in_objects365_inds labelmap for COCO detection because of size mis-match, " + "Reshaping {} -> {}. \n".format(key_old, key)) + new_shape = model_state_dict[key].shape + assert new_shape[0] == len(coco_in_objects365_inds) + weight_inds_old = torch.as_tensor(coco_in_objects365_inds).to(loaded_state_dict[key_old].device) + model_state_dict[key] = loaded_state_dict[key_old][weight_inds_old].to(model_state_dict[key].device) + else: + logger.info('Skip due to size mismatch: {} -> {}. \n'.format(key_old, key)) + continue + else: + model_state_dict[key] = loaded_state_dict[key_old] + matched_keys.append(key) + logger.info( + log_str_template.format( + key, + max_size, + key_old, + max_size_loaded, + tuple(loaded_state_dict[key_old].shape), + ) + ) + missing_keys = set(current_keys)-set(matched_keys) + if len(missing_keys): + groups = _group_checkpoint_keys(missing_keys) + msg_per_group = sorted(k + _group_to_str(v) for k, v in groups.items()) + msg = '\n'.join(sorted(msg_per_group)) + logger.warning('Some layers unloaded with pre-trained weight: \n' + msg) + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "", 1)] = value + return stripped_state_dict + +def load_state_dict(model, loaded_state_dict): + model_state_dict = model.state_dict() + # if the state_dict comes from a model that was wrapped in a + # DataParallel or DistributedDataParallel during serialization, + # remove the "module" prefix before performing the matching + loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") + align_and_update_state_dicts(model_state_dict, loaded_state_dict) + + # use strict loading + model.load_state_dict(model_state_dict) + +def _group_checkpoint_keys(keys): + """ + Group keys based on common prefixes. A prefix is the string up to the final + "." in each key. + Args: + keys (list[str]): list of parameter names, i.e. keys in the model + checkpoint dict. + Returns: + dict[list]: keys with common prefixes are grouped into lists. + """ + groups = defaultdict(list) + for key in keys: + pos = key.rfind(".") + if pos >= 0: + head, tail = key[:pos], [key[pos + 1 :]] + else: + head, tail = key, [] + groups[head].extend(tail) + return groups + +def _group_to_str(group): + """ + Format a group of parameter name suffixes into a loggable string. + Args: + group (list[str]): list of parameter name suffixes. + Returns: + str: formated string. + """ + if len(group) == 0: + return "" + + if len(group) == 1: + return "." + group[0] + + return ".{" + ", ".join(sorted(group)) + "}" \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/model_zoo.py b/maskrcnn_benchmark/utils/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..96aef6fda6cec21c074c33fd1e3934cf52088e60 --- /dev/null +++ b/maskrcnn_benchmark/utils/model_zoo.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import os +import sys + +try: + from torch.hub import _download_url_to_file + from torch.hub import urlparse + from torch.hub import HASH_REGEX +except ImportError: + from torch.utils.model_zoo import _download_url_to_file + from torch.utils.model_zoo import urlparse + from torch.utils.model_zoo import HASH_REGEX + +from maskrcnn_benchmark.utils.comm import is_main_process +from maskrcnn_benchmark.utils.comm import synchronize + + +# very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py +# but with a few improvements and modifications +def cache_url(url, model_dir='model', progress=True): + r"""Loads the Torch serialized object at the given URL. + If the object is already present in `model_dir`, it's deserialized and + returned. The filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + The default value of `model_dir` is ``$TORCH_HOME/models`` where + ``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be + overridden with the ``$TORCH_MODEL_ZOO`` environment variable. + Args: + url (string): URL of the object to download + model_dir (string, optional): directory in which to save the object + progress (bool, optional): whether or not to display a progress bar to stderr + Example: + >>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') + """ + if model_dir is None: + torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) + model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) + if not os.path.exists(model_dir): + os.makedirs(model_dir, exist_ok=True) + parts = urlparse(url) + filename = os.path.basename(parts.path) + if filename == "model_final.pkl": + # workaround as pre-trained Caffe2 models from Detectron have all the same filename + # so make the full path the filename by replacing / with _ + filename = parts.path.replace("/", "_") + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = HASH_REGEX.search(filename) + if hash_prefix is not None: + hash_prefix = hash_prefix.group(1) + # workaround: Caffe2 models don't have a hash, but follow the R-50 convention, + # which matches the hash PyTorch uses. So we skip the hash matching + # if the hash_prefix is less than 6 characters + if len(hash_prefix) < 6: + hash_prefix = None + _download_url_to_file(url, cached_file, hash_prefix, progress=progress) + synchronize() + return cached_file diff --git a/maskrcnn_benchmark/utils/pretrain_model_loading.py b/maskrcnn_benchmark/utils/pretrain_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..a45f05a5e3baa225a9e628ddcc3bfc0a0eefc238 --- /dev/null +++ b/maskrcnn_benchmark/utils/pretrain_model_loading.py @@ -0,0 +1,49 @@ +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict + +def _remove_bn_statics(state_dict): + layer_keys = sorted(state_dict.keys()) + remove_list = [] + for key in layer_keys: + if 'running_mean' in key or 'running_var' in key or 'num_batches_tracked' in key: + remove_list.append(key) + for key in remove_list: + del state_dict[key] + return state_dict + +def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg): + import re + layer_keys = sorted(state_dict.keys()) + for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1): + if not stage_with_dcn: + continue + for old_key in layer_keys: + pattern = ".*layer{}.*conv2.*".format(ix) + r = re.match(pattern, old_key) + if r is None: + continue + for param in ["weight", "bias"]: + if old_key.find(param) is -1: + continue + if 'unit01' in old_key: + continue + new_key = old_key.replace( + "conv2.{}".format(param), "conv2.conv.{}".format(param) + ) + print("pattern: {}, old_key: {}, new_key: {}".format( + pattern, old_key, new_key + )) + state_dict[new_key] = state_dict[old_key] + del state_dict[old_key] + return state_dict + + +def load_pretrain_format(cfg, f): + model = torch.load(f) + model = _remove_bn_statics(model) + model = _rename_conv_weights_for_deformable_conv_layers(model, cfg) + + return dict(model=model) diff --git a/maskrcnn_benchmark/utils/registry.py b/maskrcnn_benchmark/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ae82dfb879f19ca0c3d9056abdb440b0863cb912 --- /dev/null +++ b/maskrcnn_benchmark/utils/registry.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +def _register_generic(module_dict, module_name, module): + assert module_name not in module_dict + module_dict[module_name] = module + + +class Registry(dict): + ''' + A helper class for managing registering modules, it extends a dictionary + and provides a register functions. + + Eg. creeting a registry: + some_registry = Registry({"default": default_module}) + + There're two ways of registering new modules: + 1): normal way is just calling register function: + def foo(): + ... + some_registry.register("foo_module", foo) + 2): used as decorator when declaring the module: + @some_registry.register("foo_module") + @some_registry.register("foo_modeul_nickname") + def foo(): + ... + + Access of module is just like using a dictionary, eg: + f = some_registry["foo_modeul"] + ''' + def __init__(self, *args, **kwargs): + super(Registry, self).__init__(*args, **kwargs) + + def register(self, module_name, module=None): + # used as function call + if module is not None: + _register_generic(self, module_name, module) + return + + # used as decorator + def register_fn(fn): + _register_generic(self, module_name, fn) + return fn + + return register_fn diff --git a/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py b/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..027fb4598529c0072f670a4776f2c825968f5caf --- /dev/null +++ b/maskrcnn_benchmark/utils/shallow_contrastive_loss_helper.py @@ -0,0 +1,58 @@ +import torch +import maskrcnn_benchmark.utils.dist as dist + + +def normalized_positive_map(positive_map): + positive_map = positive_map.float() + positive_map_num_pos = positive_map.sum(2) + positive_map_num_pos[positive_map_num_pos == 0] = 1e-6 + positive_map = positive_map / positive_map_num_pos.unsqueeze(-1) + return positive_map + + +def pad_tensor_given_dim_length(tensor, dim, length, padding_value=0, batch_first=True): + new_size = list(tensor.size()[:dim]) + [length] + list(tensor.size()[dim + 1:]) + out_tensor = tensor.data.new(*new_size).fill_(padding_value) + if batch_first: + out_tensor[:, :tensor.size(1), ...] = tensor + else: + out_tensor[:tensor.size(0), ...] = tensor + return out_tensor + + +def pad_random_negative_tensor_given_length(positive_tensor, negative_padding_tensor, length=None): + assert positive_tensor.shape[0] + negative_padding_tensor.shape[0] == length + return torch.cat((positive_tensor, negative_padding_tensor), dim=0) + + +def gather_tensors(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + if not dist.is_dist_avail_and_initialized(): + return torch.stack([tensor], dim=0) + + total = dist.get_world_size() + rank = torch.distributed.get_rank() + # gathered_normalized_img_emb = [torch.zeros_like(normalized_img_emb) for _ in range(total)] + # torch.distributed.all_gather(gathered_normalized_img_emb, normalized_img_emb) + + tensors_gather = [ + torch.zeros_like(tensor) + for _ in range(total) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + # need to do this to restore propagation of the gradients + tensors_gather[rank] = tensor + output = torch.stack(tensors_gather, dim=0) + return output + + +def convert_to_roi_format(boxes): + concat_boxes = boxes.bbox + device, dtype = concat_boxes.device, concat_boxes.dtype + ids = torch.full((len(boxes), 1), 0, dtype=dtype, device=device) + rois = torch.cat([ids, concat_boxes], dim=1) + return rois \ No newline at end of file diff --git a/maskrcnn_benchmark/utils/stats.py b/maskrcnn_benchmark/utils/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..ae04f1e20b44d2774f73b698ea91a00e7b0ce690 --- /dev/null +++ b/maskrcnn_benchmark/utils/stats.py @@ -0,0 +1,510 @@ +''' +Copyright (C) 2019 Sovrasov V. - All Rights Reserved + * You may use, distribute and modify this code under the + * terms of the MIT license. + * You should have received a copy of the MIT license with + * this file. If not visit https://opensource.org/licenses/MIT +''' + +import sys +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from maskrcnn_benchmark.layers import * + +def get_model_complexity_info(model, input_res, + print_per_layer_stat=True, + as_strings=True, + input_constructor=None, ost=sys.stdout, + verbose=False, ignore_modules=[], + custom_modules_hooks={}): + assert type(input_res) is tuple + assert len(input_res) >= 1 + assert isinstance(model, nn.Module) + global CUSTOM_MODULES_MAPPING + CUSTOM_MODULES_MAPPING = custom_modules_hooks + flops_model = add_flops_counting_methods(model) + flops_model.eval() + flops_model.start_flops_count(ost=ost, verbose=verbose, + ignore_list=ignore_modules) + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + try: + batch = torch.ones(()).new_empty((1, *input_res), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + except StopIteration: + batch = torch.ones(()).new_empty((1, *input_res)) + + _ = flops_model(batch) + + flops_count, params_count = flops_model.compute_average_flops_cost() + if print_per_layer_stat: + print_model_with_flops(flops_model, flops_count, params_count, ost=ost) + flops_model.stop_flops_count() + CUSTOM_MODULES_MAPPING = {} + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GMac', precision=2): + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GMac' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MMac' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KMac' + else: + return str(flops) + ' Mac' + else: + if units == 'GMac': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MMac': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KMac': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' Mac' + + +def params_to_string(params_num, units=None, precision=2): + if units is None: + if params_num // 10 ** 6 > 0: + return str(round(params_num / 10 ** 6, 2)) + ' M' + elif params_num // 10 ** 3: + return str(round(params_num / 10 ** 3, 2)) + ' k' + else: + return str(params_num) + else: + if units == 'M': + return str(round(params_num / 10.**6, precision)) + ' ' + units + elif units == 'K': + return str(round(params_num / 10.**3, precision)) + ' ' + units + else: + return str(params_num) + + +def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + +def print_model_with_flops(model, total_flops, total_params, units='GMac', + precision=3, ost=sys.stdout): + + def accumulate_params(self): + if is_supported_instance(self): + return self.__params__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_params() + return sum + + def flops_repr(self): + accumulated_params_num = self.accumulate_params() + accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__ + return ', '.join([params_to_string(accumulated_params_num, + units='M', precision=precision), + '{:.3%} Params'.format(accumulated_params_num / total_params), + flops_to_string(accumulated_flops_cost, + units=units, precision=precision), + '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr()]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + m.accumulate_params = accumulate_params.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(repr(model), file=ost) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params_num + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( + net_main_module) + + net_main_module.reset_flops_count() + + return net_main_module + + +def compute_average_flops_cost(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Returns current mean flops consumption per image. + + """ + + for m in self.modules(): + m.accumulate_flops = accumulate_flops.__get__(m) + + flops_sum = self.accumulate_flops() + + for m in self.modules(): + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + params_sum = get_model_parameters_number(self) + return flops_sum / self.__batch_counter__, params_sum + + +def start_flops_count(self, **kwargs): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Activates the computation of mean flops consumption per image. + Call it before you run the network. + + """ + add_batch_counter_hook_function(self) + + seen_types = set() + + def add_flops_counter_hook_function(module, ost, verbose, ignore_list): + if type(module) in ignore_list: + seen_types.add(type(module)) + if is_supported_instance(module): + module.__params__ = 0 + elif is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + if type(module) in CUSTOM_MODULES_MAPPING: + handle = module.register_forward_hook( + CUSTOM_MODULES_MAPPING[type(module)]) + elif getattr(module, 'compute_macs', False): + handle = module.register_forward_hook( + module.compute_macs + ) + else: + handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) + module.__flops_handle__ = handle + seen_types.add(type(module)) + else: + if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \ + not type(module) in seen_types: + print('Warning: module ' + type(module).__name__ + + ' is treated as a zero-op.', file=ost) + seen_types.add(type(module)) + + self.apply(partial(add_flops_counter_hook_function, **kwargs)) + + +def stop_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Stops computing the mean flops consumption per image. + Call whenever you want to pause the computation. + + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_counter_variable_or_reset) + + +# ---- Internal functions +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + # pytorch checks dimensions, so here we don't care much + output_last_dim = output.shape[-1] + bias_flops = output_last_dim if module.bias is not None else 0 + module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def bn_flops_counter_hook(module, input, output): + input = input[0] + + batch_flops = np.prod(input.shape) + if module.affine: + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = int(np.prod(kernel_dims)) * \ + in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + + bias_flops = 0 + + if conv_module.bias is not None: + + bias_flops = out_channels * active_elements_count + + overall_flops = overall_conv_flops + bias_flops + + conv_module.__flops__ += int(overall_flops) + + +def batch_counter_hook(module, input, output): + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = len(input) + else: + pass + print('Warning! No positional inputs found for a module,' + ' assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): + # matrix matrix mult ih state and internal state + flops += w_ih.shape[0]*w_ih.shape[1] + # matrix matrix mult hh state and internal state + flops += w_hh.shape[0]*w_hh.shape[1] + if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): + # add both operations + flops += rnn_module.hidden_size + elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): + # hadamard of r + flops += rnn_module.hidden_size + # adding operations from both states + flops += rnn_module.hidden_size*3 + # last two hadamard product and add + flops += rnn_module.hidden_size*3 + elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): + # adding operations from both states + flops += rnn_module.hidden_size*4 + # two hadamard product and add for C state + flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + # final hadamard + flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + return flops + + +def rnn_flops_counter_hook(rnn_module, input, output): + """ + Takes into account batch goes at first position, contrary + to pytorch common rule (but actually it doesn't matter). + IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate + """ + flops = 0 + # input is a tuple containing a sequence to process and (optionally) hidden state + inp = input[0] + batch_size = inp.shape[0] + seq_length = inp.shape[1] + num_layers = rnn_module.num_layers + + for i in range(num_layers): + w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) + w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) + if i == 0: + input_size = rnn_module.input_size + else: + input_size = rnn_module.hidden_size + flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) + if rnn_module.bias: + b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) + b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + flops *= seq_length + if rnn_module.bidirectional: + flops *= 2 + rnn_module.__flops__ += int(flops) + + +def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): + flops = 0 + inp = input[0] + batch_size = inp.shape[0] + w_ih = rnn_cell_module.__getattr__('weight_ih') + w_hh = rnn_cell_module.__getattr__('weight_hh') + input_size = inp.shape[1] + flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) + if rnn_cell_module.bias: + b_ih = rnn_cell_module.__getattr__('bias_ih') + b_hh = rnn_cell_module.__getattr__('bias_hh') + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + rnn_cell_module.__flops__ += int(flops) + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + if hasattr(module, '__flops__') or hasattr(module, '__params__'): + print('Warning: variables __flops__ or __params__ are already ' + 'defined for the module' + type(module).__name__ + + ' ptflops can affect your code!') + module.__flops__ = 0 + module.__params__ = get_model_parameters_number(module) + + +CUSTOM_MODULES_MAPPING = {} + +MODULES_MAPPING = { + # convolutions + nn.Conv1d: conv_flops_counter_hook, + nn.Conv2d: conv_flops_counter_hook, + nn.Conv3d: conv_flops_counter_hook, + Conv2d: conv_flops_counter_hook, + ModulatedDeformConv: conv_flops_counter_hook, + # activations + nn.ReLU: relu_flops_counter_hook, + nn.PReLU: relu_flops_counter_hook, + nn.ELU: relu_flops_counter_hook, + nn.LeakyReLU: relu_flops_counter_hook, + nn.ReLU6: relu_flops_counter_hook, + # poolings + nn.MaxPool1d: pool_flops_counter_hook, + nn.AvgPool1d: pool_flops_counter_hook, + nn.AvgPool2d: pool_flops_counter_hook, + nn.MaxPool2d: pool_flops_counter_hook, + nn.MaxPool3d: pool_flops_counter_hook, + nn.AvgPool3d: pool_flops_counter_hook, + nn.AdaptiveMaxPool1d: pool_flops_counter_hook, + nn.AdaptiveAvgPool1d: pool_flops_counter_hook, + nn.AdaptiveMaxPool2d: pool_flops_counter_hook, + nn.AdaptiveAvgPool2d: pool_flops_counter_hook, + nn.AdaptiveMaxPool3d: pool_flops_counter_hook, + nn.AdaptiveAvgPool3d: pool_flops_counter_hook, + # BNs + nn.BatchNorm1d: bn_flops_counter_hook, + nn.BatchNorm2d: bn_flops_counter_hook, + nn.BatchNorm3d: bn_flops_counter_hook, + nn.GroupNorm : bn_flops_counter_hook, + # FC + nn.Linear: linear_flops_counter_hook, + # Upscale + nn.Upsample: upsample_flops_counter_hook, + # Deconvolution + nn.ConvTranspose1d: conv_flops_counter_hook, + nn.ConvTranspose2d: conv_flops_counter_hook, + nn.ConvTranspose3d: conv_flops_counter_hook, + ConvTranspose2d: conv_flops_counter_hook, + # RNN + nn.RNN: rnn_flops_counter_hook, + nn.GRU: rnn_flops_counter_hook, + nn.LSTM: rnn_flops_counter_hook, + nn.RNNCell: rnn_cell_flops_counter_hook, + nn.LSTMCell: rnn_cell_flops_counter_hook, + nn.GRUCell: rnn_cell_flops_counter_hook +} + + +def is_supported_instance(module): + if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING \ + or getattr(module, 'compute_macs', False): + return True + return False + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/blip.py b/models/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..38678f65ea2c276b351c2c97d429ebc2525ddcf7 --- /dev/null +++ b/models/blip.py @@ -0,0 +1,238 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import warnings +warnings.filterwarnings("ignore") + +from models.vit import VisionTransformer, interpolate_pos_embed +from models.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +import os +from urllib.parse import urlparse +from timm.models.hub import download_cached_file + +class BLIP_Base(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + + def forward(self, image, caption, mode): + + assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" + text = self.tokenizer(caption, return_tensors="pt").to(image.device) + + if mode=='image': + # return image features + image_embeds = self.visual_encoder(image) + return image_embeds + + elif mode=='text': + # return text features + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + return text_output.last_hidden_state + + elif mode=='multimodal': + # return multimodel features + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text.input_ids[:,0] = self.tokenizer.enc_token_id + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + return output.last_hidden_state + + + +class BLIP_Decoder(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + prompt = 'a picture of ', + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel(config=med_config) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 + + + def forward(self, image, caption): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) + + text.input_ids[:,0] = self.tokenizer.bos_token_id + + decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) + decoder_targets[:,:self.prompt_length] = -100 + + decoder_output = self.text_decoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + loss_lm = decoder_output.loss + + return loss_lm + + def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): + image_embeds = self.visual_encoder(image) + + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} + + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) + input_ids[:,0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + #nucleus sampling + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + #beam search + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + return captions + + +def blip_decoder(pretrained='',**kwargs): + model = BLIP_Decoder(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def blip_feature_extractor(pretrained='',**kwargs): + model = BLIP_Base(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token':'[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit=='base': + vision_width = 768 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, + num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate + ) + elif vit=='large': + vision_width = 1024 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, + num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate + ) + return visual_encoder, vision_width + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], + model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape!=model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + diff --git a/models/blip_itm.py b/models/blip_itm.py new file mode 100644 index 0000000000000000000000000000000000000000..cf354c829564bf5a1f56089a2d745093d51e0fa2 --- /dev/null +++ b/models/blip_itm.py @@ -0,0 +1,76 @@ +from models.med import BertConfig, BertModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +from models.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_ITM(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + + def forward(self, image, caption, match_head='itm'): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + + + if match_head=='itm': + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + itm_output = self.itm_head(output.last_hidden_state[:,0,:]) + return itm_output + + elif match_head=='itc': + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + sim = image_feat @ text_feat.t() + return sim + + +def blip_itm(pretrained='',**kwargs): + model = BLIP_ITM(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + \ No newline at end of file diff --git a/models/blip_nlvr.py b/models/blip_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..84837167bfa6874d3c3e41fb9b37271113910b7f --- /dev/null +++ b/models/blip_nlvr.py @@ -0,0 +1,103 @@ +from models.med import BertConfig +from models.nlvr_encoder import BertModel +from models.vit import interpolate_pos_embed +from models.blip import create_vit, init_tokenizer, is_url + +from timm.models.hub import download_cached_file + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import BertTokenizer +import numpy as np + +class BLIP_NLVR(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 480, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + self.cls_head = nn.Sequential( + nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), + nn.ReLU(), + nn.Linear(self.text_encoder.config.hidden_size, 2) + ) + + def forward(self, image, text, targets, train=True): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) + + text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) + text.input_ids[:,0] = self.tokenizer.enc_token_id + + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = [image0_embeds,image1_embeds], + encoder_attention_mask = [image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):]], + return_dict = True, + ) + hidden_state = output.last_hidden_state[:,0,:] + prediction = self.cls_head(hidden_state) + + if train: + loss = F.cross_entropy(prediction, targets) + return loss + else: + return prediction + +def blip_nlvr(pretrained='',**kwargs): + model = BLIP_NLVR(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + print("missing keys:") + print(msg.missing_keys) + return model + + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + + for key in list(state_dict.keys()): + if 'crossattention.self.' in key: + new_key0 = key.replace('self','self0') + new_key1 = key.replace('self','self1') + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + elif 'crossattention.output.dense.' in key: + new_key0 = key.replace('dense','dense0') + new_key1 = key.replace('dense','dense1') + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + \ No newline at end of file diff --git a/models/blip_pretrain.py b/models/blip_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..e42ce5f998b0a51e6f731ee6b5c8bae6d02a8664 --- /dev/null +++ b/models/blip_pretrain.py @@ -0,0 +1,339 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +from models.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer +import transformers +transformers.logging.set_verbosity_error() + +import torch +from torch import nn +import torch.nn.functional as F + +from models.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_Pretrain(nn.Module): + def __init__(self, + med_config = 'configs/bert_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + queue_size = 57600, + momentum = 0.995, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) + + if vit=='base': + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", + map_location="cpu", check_hash=True) + state_dict = checkpoint["model"] + msg = self.visual_encoder.load_state_dict(state_dict,strict=False) + elif vit=='large': + from timm.models.helpers import load_custom_pretrained + from timm.models.vision_transformer import default_cfgs + load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k']) + + self.tokenizer = init_tokenizer() + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False) + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create momentum encoders + self.visual_encoder_m, vision_width = create_vit(vit,image_size) + self.vision_proj_m = nn.Linear(vision_width, embed_dim) + self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False) + self.text_proj_m = nn.Linear(text_width, embed_dim) + + self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], + [self.vision_proj,self.vision_proj_m], + [self.text_encoder,self.text_encoder_m], + [self.text_proj,self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07*torch.ones([])) + + # create the decoder + decoder_config = BertConfig.from_json_file(med_config) + decoder_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config) + self.text_decoder.resize_token_embeddings(len(self.tokenizer)) + tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention') + + + def forward(self, image, caption, alpha): + with torch.no_grad(): + self.temp.clamp_(0.001,0.5) + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, + return_tensors="pt").to(image.device) + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) + image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) + + text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) + text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) + + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) + sim_targets.fill_diagonal_(1) + + sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() + loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() + + loss_ita = (loss_i2t+loss_t2i)/2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m) + + ###============== Image-text Matching ===================### + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:,0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder(encoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 + weights_t2i.fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 + weights_i2t.fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg,dim=0) + text_atts_neg = torch.stack(text_atts_neg,dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) + + image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) + image_atts_all = torch.cat([image_atts,image_atts],dim=0) + + output_neg = self.text_encoder(text_ids_all, + attention_mask = text_atts_all, + encoder_hidden_states = image_embeds_all, + encoder_attention_mask = image_atts_all, + return_dict = True, + ) + + vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) + vl_output = self.itm_head(vl_embeddings) + + itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + ##================= LM ========================## + decoder_input_ids = text.input_ids.clone() + decoder_input_ids[:,0] = self.tokenizer.bos_token_id + decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) + + decoder_output = self.text_decoder(decoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + + loss_lm = decoder_output.loss + return loss_ita, loss_itm, loss_lm + + + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) + + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.queue_ptr[0] = ptr + + +def blip_pretrain(**kwargs): + model = BLIP_Pretrain(**kwargs) + return model + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +from typing import List +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" + if hasattr(decoder_pointer, "weight") and skip_key not in module_name: + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + print(module_name+' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key) diff --git a/models/blip_retrieval.py b/models/blip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..1debe7e2e664f8dd603f8d4c537e3599c68638d7 --- /dev/null +++ b/models/blip_retrieval.py @@ -0,0 +1,319 @@ +from models.med import BertConfig, BertModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +from models.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_Retrieval(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + queue_size = 57600, + momentum = 0.995, + negative_all_rank = False, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create momentum encoders + self.visual_encoder_m, vision_width = create_vit(vit,image_size) + self.vision_proj_m = nn.Linear(vision_width, embed_dim) + self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) + self.text_proj_m = nn.Linear(text_width, embed_dim) + + self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], + [self.vision_proj,self.vision_proj_m], + [self.text_encoder,self.text_encoder_m], + [self.text_proj,self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) + self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07*torch.ones([])) + + self.negative_all_rank = negative_all_rank + + + def forward(self, image, caption, alpha, idx): + with torch.no_grad(): + self.temp.clamp_(0.001,0.5) + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + ###============== Image-text Contrastive Learning ===================### + idx = idx.view(-1,1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) + image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) + + text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) + text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) + + sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp + + sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_m_all / self.temp + sim_t2i = text_feat @ image_feat_m_all / self.temp + + loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() + loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() + + loss_ita = (loss_i2t+loss_t2i)/2 + + idxs = concat_all_gather(idx) + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) + + ###============== Image-text Matching ===================### + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:,0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder(encoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + + + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()) + + image_feat_world = concat_all_gather(image_feat) + text_feat_world = concat_all_gather(text_feat) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t,dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i,dim=1) + weights_t2i.masked_fill_(mask, 0) + + image_embeds_world = all_gather_with_grad(image_embeds) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = concat_all_gather(encoder_input_ids) + att_mask_world = concat_all_gather(text.attention_mask) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + else: + with torch.no_grad(): + mask = torch.eq(idx, idx.t()) + + sim_i2t = image_feat @ text_feat.t() / self.temp + sim_t2i = text_feat @ image_feat.t() / self.temp + + weights_i2t = F.softmax(sim_i2t,dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i,dim=1) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image (from same rank) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text (from same rank) for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg,dim=0) + text_atts_neg = torch.stack(text_atts_neg,dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) + + image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) + image_atts_all = torch.cat([image_atts,image_atts],dim=0) + + output_neg = self.text_encoder(text_ids_all, + attention_mask = text_atts_all, + encoder_hidden_states = image_embeds_all, + encoder_attention_mask = image_atts_all, + return_dict = True, + ) + + + vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) + vl_output = self.itm_head(vl_embeddings) + + itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + return loss_ita, loss_itm + + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) + + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + + batch_size = image_feats.shape[0] + + ptr = int(self.ptr_queue) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.ptr_queue[0] = ptr + + +def blip_retrieval(pretrained='',**kwargs): + model = BLIP_Retrieval(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + print("missing keys:") + print(msg.missing_keys) + return model + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) diff --git a/models/blip_vqa.py b/models/blip_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef8641cab1badd32e00abea352d764f6165faae --- /dev/null +++ b/models/blip_vqa.py @@ -0,0 +1,223 @@ +from models.med import BertConfig, BertModel, BertLMHeadModel +from models.blip import create_vit, init_tokenizer, load_checkpoint + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import BertTokenizer +import numpy as np + +class BLIP_VQA(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 480, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) + self.tokenizer = init_tokenizer() + + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) + + decoder_config = BertConfig.from_json_file(med_config) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + self.itm_head = nn.Linear(768, 2) + + def forward(self, image, question, answer=None, n=None, weights=None, mode='inference', inference='rank', k_test=128): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + question.input_ids[:,0] = self.tokenizer.enc_token_id + + if mode == 'train': + ''' + n: number of answers for each question + weights: weight for each answer + ''' + answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) + answer.input_ids[:,0] = self.tokenizer.bos_token_id + answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) + + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + question_states = [] + question_atts = [] + for b, n in enumerate(n): + question_states += [question_output.last_hidden_state[b]]*n + question_atts += [question.attention_mask[b]]*n + question_states = torch.stack(question_states,0) + question_atts = torch.stack(question_atts,0) + + answer_output = self.text_decoder(answer.input_ids, + attention_mask = answer.attention_mask, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = answer_targets, + return_dict = True, + reduction = 'none', + ) + + loss = weights * answer_output.loss + loss = loss.sum()/image.size(0) + + return loss + + elif mode == 'gradcam': + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + vl_embeddings = question_output.last_hidden_state[:,0,:] + vl_output = self.itm_head(vl_embeddings) + + if inference=='generate': + num_beams = 3 + question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) + question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) + model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} + + bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) + + outputs = self.text_decoder.generate(input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs) + + answers = [] + for output in outputs: + answer = self.tokenizer.decode(output, skip_special_tokens=True) + answers.append(answer) + return answers, vl_output, question + + elif inference=='rank': + max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, + answer.input_ids, answer.attention_mask, k_test) + return max_ids, vl_output, question + + else: + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + if inference=='generate': + num_beams = 3 + question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) + question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) + model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} + + bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) + + outputs = self.text_decoder.generate(input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs) + + answers = [] + for output in outputs: + answer = self.tokenizer.decode(output, skip_special_tokens=True) + answers.append(answer) + return answers + + elif inference=='rank': + max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, + answer.input_ids, answer.attention_mask, k_test) + return max_ids + + + + def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): + + num_ques = question_states.size(0) + start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token + + start_output = self.text_decoder(start_ids, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + return_dict = True, + reduction = 'none') + logits = start_output.logits[:,0,:] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:,1] + prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k,dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids,dim=0) + input_atts = torch.cat(input_atts,dim=0) + + targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, k) + question_atts = tile(question_atts, 0, k) + + output = self.text_decoder(input_ids, + attention_mask = input_atts, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = targets_ids, + return_dict = True, + reduction = 'none') + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques,k) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] + + return max_ids + + +def blip_vqa(pretrained='',**kwargs): + model = BLIP_VQA(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) +# assert(len(msg.missing_keys)==0) + return model + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + \ No newline at end of file diff --git a/models/med.py b/models/med.py new file mode 100644 index 0000000000000000000000000000000000000000..99b0abab574a850320cc784aef4cc016f2b174c1 --- /dev/null +++ b/models/med.py @@ -0,0 +1,955 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/models/nlvr_encoder.py b/models/nlvr_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1946bb4a300f75afa4848f6622839445903c34a9 --- /dev/null +++ b/models/nlvr_encoder.py @@ -0,0 +1,843 @@ +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1))) + hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1)) + else: + hidden_states = (hidden_states0+hidden_states1)/2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + if is_cross_attention: + self.self0 = BertSelfAttention(config, is_cross_attention) + self.self1 = BertSelfAttention(config, is_cross_attention) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6)) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states)==list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states) + + outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + diff --git a/models/vit.py b/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899 --- /dev/null +++ b/models/vit.py @@ -0,0 +1,305 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/models/xbert.py b/models/xbert.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1f7774524bacc0c91a15ec66a8063de3f332a2 --- /dev/null +++ b/models/xbert.py @@ -0,0 +1,1916 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model. """ + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +import transformers +transformers.logging.set_verbosity_error() + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = (layer_num >= config.fusion_layer) + if self.has_cross_attention: + self.layer_num = layer_num + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)], + encoder_attention_mask[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multi_modal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + + if mode=='text': + start_layer = 0 + output_layer = self.config.fusion_layer + + elif mode=='fusion': + start_layer = self.config.fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode=='multi_modal': + start_layer = 0 + output_layer = self.config.num_hidden_layers + + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.BertForPreTraining`. + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +BERT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForPreTraining + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForPreTraining.from_pretrained('bert-base-uncased') + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING +) +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + mode='multi_modal', + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=-1)*soft_labels,dim=-1) + loss_distill = (loss_distill * (labels!=-100)).sum(1) + lm_loss = (1-alpha)*lm_loss + alpha*loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multi_modal', + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if soft_labels is not None: + loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=-1)*soft_labels,dim=-1) + loss_distill = loss_distill[labels!=-100].mean() + masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + Returns: + Example:: + >>> from transformers import BertTokenizer, BertForNextSentencePrediction + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased') + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="bert-base-uncased", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d47697c44a93aa9ee6f091e0be56aa73a38d8d77 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,152 @@ +absl-py==1.2.0 +aiohttp==3.8.3 +aiosignal==1.2.0 +anyio==3.6.1 +asttokens==2.0.8 +async-timeout==4.0.2 +attrs==22.1.0 +av==9.2.0 +backcall==0.2.0 +bcrypt==4.0.0 +cachetools==5.2.0 +certifi==2022.9.14 +cffi==1.15.1 +charset-normalizer==2.1.1 +click==8.1.3 +cloudpickle==2.2.0 +configobj==5.0.6 +contourpy==1.0.5 +cryptography==38.0.1 +cycler==0.11.0 +cytoolz==0.12.0 +debugpy==1.6.3 +decorator==5.1.1 +decord==0.6.0 +easydict==1.10 +einops==0.4.1 +entrypoints==0.4 +executing==1.1.0 +fairscale==0.4.12 +fastapi==0.85.0 +ffmpy==0.3.0 +filelock==3.8.0 +fonttools==4.37.3 +frozenlist==1.3.1 +fsspec==2022.8.2 +ftfy==6.1.1 +google-auth==2.12.0 +google-auth-oauthlib==0.4.6 +gradio==3.4.0 +grpcio==1.49.1 +h11==0.12.0 +httpcore==0.15.0 +httpx==0.23.0 +huggingface-hub==0.9.1 +idna==3.4 +imageio==2.22.1 +importlib-metadata==5.0.0 +inflect==6.0.0 +ipdb==0.13.9 +ipykernel==6.16.0 +ipython==8.5.0 +jedi==0.18.1 +Jinja2==3.1.2 +joblib==1.2.0 +jupyter-core==4.11.1 +jupyter_client==7.3.5 +kiwisolver==1.4.4 +linkify-it-py==1.0.3 +lmdb==1.3.0 +lz4==4.0.2 +Markdown==3.4.1 +markdown-it-py==2.1.0 +MarkupSafe==2.1.1 +matplotlib==3.6.0 +matplotlib-inline==0.1.6 +mdit-py-plugins==0.3.1 +mdurl==0.1.2 +msgpack==1.0.4 +msgpack-numpy==0.4.8 +multidict==6.0.2 +nest-asyncio==1.5.6 +networkx==2.8.7 +nltk==3.7 +numpy==1.23.3 +oauthlib==3.2.1 +opencv-python==4.6.0.66 +orjson==3.8.0 +packaging==21.3 +pandas==1.5.0 +paramiko==2.11.0 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.2.0 +Pillow-SIMD==9.0.0.post1 +prettytable==3.4.1 +prompt-toolkit==3.0.31 +protobuf==3.19.6 +psutil==5.9.2 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycocotools==2.0.5 +pycparser==2.21 +pycryptodome==3.15.0 +pydantic==1.10.2 +pydub==0.25.1 +Pygments==2.13.0 +pymongo==4.2.0 +PyNaCl==1.5.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +python-multipart==0.0.5 +pytz==2022.4 +PyWavelets==1.4.1 +PyYAML==6.0 +pyzmq==24.0.1 +regex==2022.9.13 +requests==2.28.1 +requests-oauthlib==1.3.1 +rfc3986==1.5.0 +rsa==4.9 +ruamel.yaml==0.17.21 +ruamel.yaml.base==0.3.0 +ruamel.yaml.clib==0.2.6 +ruamel.yaml.cmd==0.6.3 +ruamel.yaml.convert==0.3.2 +sacremoses==0.0.53 +scikit-image==0.19.3 +scipy==1.9.1 +Shapely==1.8.4 +six==1.16.0 +sniffio==1.3.0 +stack-data==0.5.1 +starlette==0.20.4 +tensorboard==2.10.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +tensorboardX==2.5.1 +tifffile==2022.8.12 +timm==0.6.7 +tokenizers==0.10.3 +toml==0.10.2 +toolz==0.12.0 +torch==1.10.0+cu113 +torchvision==0.11.0+cu113 +tornado==6.2 +tqdm==4.64.1 +traitlets==5.4.0 +transformers==4.11.3 +typing_extensions==4.3.0 +uc-micro-py==1.0.1 +ujson==5.5.0 +urllib3==1.26.12 +uvicorn==0.18.3 +wcwidth==0.2.5 +websockets==10.3 +Werkzeug==2.2.2 +yacs==0.1.8 +yarl==1.8.1 +zipp==3.9.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..837c2cd15f4624f630540ef6993dcb9123adb39b --- /dev/null +++ b/setup.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#!/usr/bin/env python + +import glob +import os + +import torch +from setuptools import find_packages +from setuptools import setup +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "maskrcnn_benchmark", "csrc") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + + sources = [os.path.join(extensions_dir, s) for s in sources] + + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "maskrcnn_benchmark._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="maskrcnn_benchmark", + version="0.1", + author="fmassa", + url="https://github.com/facebookresearch/maskrcnn-benchmark", + description="object detection in pytorch", + packages=find_packages(exclude=("configs", "tests",)), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/vqa.py b/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..f90b1e5469705a89755fb2bebe93ea966f36dcea --- /dev/null +++ b/vqa.py @@ -0,0 +1,127 @@ +import sys +from PIL import Image +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from models.blip_vqa import blip_vqa +import cv2 +import numpy as np +import matplotlib.image as mpimg + +from skimage import transform as skimage_transform +from scipy.ndimage import filters +from matplotlib import pyplot as plt + + +import torch +from torch import nn +from torchvision import transforms + +import json +import traceback + +class VQA: + def __init__(self, model_path, image_size=480): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base') + self.block_num = 9 + self.model.eval() + self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.save_attention = True + + self.model = self.model.to(self.device) + def getAttMap(self, img, attMap, blur = True, overlap = True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant') + if blur: + attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap('jet') + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV + return attMap + + def gradcam(self, text_input, image_path, image): + mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1) + grads = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attn_gradients() + cams = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attention_map() + cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 30, 30) * mask + grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 30, 30) * mask + gradcam = cams * grads + gradcam = gradcam[0].mean(0).cpu().detach() + + num_image = len(text_input.input_ids[0]) + num_image -= 1 + fig, ax = plt.subplots(num_image, 1, figsize=(15,15*num_image)) + + rgb_image = cv2.imread(image_path)[:, :, ::-1] + rgb_image = np.float32(rgb_image) / 255 + ax[0].imshow(rgb_image) + ax[0].set_yticks([]) + ax[0].set_xticks([]) + ax[0].set_xlabel("Image") + + for i,token_id in enumerate(text_input.input_ids[0][1:-1]): + word = self.model.tokenizer.decode([token_id]) + gradcam_image = self.getAttMap(rgb_image, gradcam[i+1]) + ax[i+1].imshow(gradcam_image) + ax[i+1].set_yticks([]) + ax[i+1].set_xticks([]) + ax[i+1].set_xlabel(word) + + plt.show() + + + def load_demo_image(self, image_size, img_path, device): + raw_image = Image.open(img_path).convert('RGB') + w,h = raw_image.size + transform = transforms.Compose([ + transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + image = transform(raw_image).unsqueeze(0).to(device) + return raw_image, image + + def vqa(self, img_path, question): + raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device) + answer, vl_output, que = self.model(image, question, mode='gradcam', inference='generate') + loss = vl_output[:,1].sum() + self.model.zero_grad() + loss.backward() + + with torch.no_grad(): + self.gradcam(que, img_path, image) + + return answer[0] + + def vqa_demo(self, image, question): + image_size = 480 + transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + image = transform(image).unsqueeze(0).to(self.device) + answer = self.model(image, question, mode='inference', inference='generate') + + return answer[0] + + +if __name__=="__main__": + if not len(sys.argv) == 3: + print('Format: python3 vqa.py ') + print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"') + + else: + model_path = 'checkpoints/model_base_vqa_capfilt_large.pth' + vqa_object = VQA(model_path=model_path) + img_path = sys.argv[1] + question = sys.argv[2] + answer = vqa_object.vqa(img_path, question) + print('Question: {} | Answer: {}'.format(question, answer)) \ No newline at end of file