diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..20377d5ced56af1e710d41c6f2f32d5af8ae9234 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.ipynb_checkpoints/* diff --git a/datasets/common_30k.model b/datasets/common_30k.model new file mode 100644 index 0000000000000000000000000000000000000000..7989b7f953ff46b3ac48d3c735506edca6d499f4 Binary files /dev/null and b/datasets/common_30k.model differ diff --git a/virtex/CHANGELOG.md b/virtex/CHANGELOG.md new file mode 100644 index 0000000000000000000000000000000000000000..9e54814cdf13b6404c9da2c41300455be981b9a1 --- /dev/null +++ b/virtex/CHANGELOG.md @@ -0,0 +1,41 @@ +ArXiv v1 -> v2 CHANGELOG +========================= + +[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is out CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0). + +While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models! + +We have updated the experiment config files for all changes described below. + +Experiment Changes +------------------ + +### New Feature: + +Add a new pretraining task for BERT-style _Masked Language Modeling_. Pre-trained model released in Model Zoo. + +### Pre-training: + +- The only change during pre-training is that we do not apply weight decay to LayerNorm and biases in input embedding and transformer layers. We apply weight decay to the biases in output linear layer (before softmax). + +- Other factors that could affect results: + - Use official [albumentations.ColorJitter transform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) that mimics torchvision ColorJitter transform. Earlier I implemented [my own ColorJitter](https://github.com/kdexd/virtex/blob/c19e7fc9b98e98af82286ed1537b6f588eaeac44/virtex/data/transforms.py#L156) because albumentations didn't have one. + - Use PyTorch Native AMP (Automatic Mixed Precision) instead of NVIDIA Apex. + +### Downstream Evaluations: + +1. **PASCAL VOC 2007 Linear Classification:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-b4405dd4879a48ef1e5b1e2801035909584a5f1f32f63d5e793fb50dee077b97) + - Instead of training linear SVMs on 8192-dimensional average pooled features from ResNet-50 (7x7x2048 —> 2x2x2048), like [(Misra et al. 2019)](https://arxiv.org/abs/1905.01235), we directly train SVMs on 2048-dimensional global average pooled features, following recent works like [SwAV (Caron et al. 2020)](https://arxiv.org/abs/2006.09882). + - We change the pre-processing: resize shortest edge to 256 pixels, and take center crop of 224 pixels. + - These improve VOC mAP by 1-2 points everywhere, and makes SVM training faster. Since we select best checkpoint based on this metric, all results on other downstream tasks also change in `ArXiv v2` (But the trends remain same.) + +2. **ImageNet Linear Evaluation:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-d3dea1e7bf97d0cfca4b59a47c0a9bb81e78b8827654fe0258df9ce2c3f5f41c) + - Changed random resized crop scale from (20-100%) to (8-100%) for consistency with evaluations in SSL works like MoCo and SwAV. + - Use cosine LR decay instead of step decay, following SwAV. Improves accuracy by up to 1%. + +3. **iNaturalist Fine-tuning:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-09096da78cfcde3a604ce22d80313f0800225d928cce5ef7334b89a382adfe4d) + - This evaluation is left unchanged across ArXiv versions, but we fixd a typo in image pre-processing step, present in publicly released config. + +4. **Detectron2 tasks (COCO and LVIS Instance Segmentation, VOC Detection):** + - Heavily simplified the script. Updated Detectron2 uses a more memory-efficient SyncBatchNorm and supports AMP. + diff --git a/virtex/LICENSE b/virtex/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e909de7a0a9528ffc9a95e854842315713a971a3 --- /dev/null +++ b/virtex/LICENSE @@ -0,0 +1,16 @@ +Copyright (c) 2020, Karan Desai. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial +portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES +OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/virtex/README.md b/virtex/README.md new file mode 100644 index 0000000000000000000000000000000000000000..720ce5e0559be640430fe8c783b4c7bbf17c1da3 --- /dev/null +++ b/virtex/README.md @@ -0,0 +1,92 @@ +VirTex: Learning Visual Representations from Textual Annotations +================================================================ + +

+Karan Desai and Justin Johnson +
+ +University of Michigan + +

+
+ +**CVPR 2021** [arxiv.org/abs/2006.06666][1] + +**Model Zoo, Usage Instructions and API docs:** [kdexd.github.io/virtex](https://kdexd.github.io/virtex) + +VirTex is a pretraining approach which uses semantically dense captions to +learn visual representations. We train CNN + Transformers from scratch on +COCO Captions, and transfer the CNN to downstream vision tasks including +image classification, object detection, and instance segmentation. +VirTex matches or outperforms models which use ImageNet for pretraining -- +both supervised or unsupervised -- despite using up to 10x fewer images. + +![virtex-model](docs/_static/system_figure.jpg) + + +Get the pretrained ResNet-50 visual backbone from our best performing VirTex +model in one line *without any installation*! + +```python +import torch + +# That's it, this one line only requires PyTorch. +model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) +``` + +### Note (For returning users before January 2021): + +The pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards. +They are slightly better tuned than older models, and reproduce the results in our +CVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)). +Some training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9). +Please refer [`CHANGELOG.md`](https://github.com/kdexd/virtex/blob/master/CHANGELOG.md) + + +Usage Instructions +------------------ + +1. [How to setup this codebase?][2] +2. [VirTex Model Zoo][3] +3. [How to train your VirTex model?][4] +4. [How to evaluate on downstream tasks?][5] + +Full documentation is available at [kdexd.github.io/virtex](https://kdexd.github.io/virtex). + + +Citation +-------- + +If you find this code useful, please consider citing: + +```text +@inproceedings{desai2021virtex, + title={{VirTex: Learning Visual Representations from Textual Annotations}}, + author={Karan Desai and Justin Johnson}, + booktitle={CVPR}, + year={2021} +} +``` + +Acknowledgments +--------------- + +We thank Harsh Agrawal, Mohamed El Banani, Richard Higgins, Nilesh Kulkarni +and Chris Rockwell for helpful discussions and feedback on the paper. We thank +Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for +discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and +Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting +the Instance Segmentation pretraining task ablation; and Stefan Lee for +suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs +during project development; and UMich ARC-TS team for support with GPU cluster +management. Finally, we thank all the Starbucks outlets in Ann Arbor for many +hours of free WiFi. This work was partially supported by the Toyota Research +Institute (TRI). However, note that this article solely reflects the opinions +and conclusions of its authors and not TRI or any other Toyota entity. + + +[1]: https://arxiv.org/abs/2006.06666 +[2]: https://kdexd.github.io/virtex/virtex/usage/setup_dependencies.html +[3]: https://kdexd.github.io/virtex/virtex/usage/model_zoo.html +[4]: https://kdexd.github.io/virtex/virtex/usage/pretrain.html +[5]: https://kdexd.github.io/virtex/virtex/usage/downstream.html diff --git a/virtex/configs/_base_bicaptioning_R_50_L1_H1024.yaml b/virtex/configs/_base_bicaptioning_R_50_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab40b92b6560f88547d3e12952c79a4fa71448f8 --- /dev/null +++ b/virtex/configs/_base_bicaptioning_R_50_L1_H1024.yaml @@ -0,0 +1,66 @@ +# ----------------------------------------------------------------------------- +# Base config: VirTex pretraining for our "base" bicaptioning model: +# ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations. +# ----------------------------------------------------------------------------- +RANDOM_SEED: 0 +AMP: true +CUDNN_BENCHMARK: true +CUDNN_DETERMINISTIC: false + +DATA: + ROOT: "datasets/coco" + TOKENIZER_MODEL: "datasets/vocab/coco_10k.model" + VOCAB_SIZE: 10000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + IMAGE_CROP_SIZE: 224 + MAX_CAPTION_LENGTH: 30 + + IMAGE_TRANSFORM_TRAIN: + - "random_resized_crop" + - "horizontal_flip" + - "color_jitter" + - "normalize" + + IMAGE_TRANSFORM_VAL: + - "smallest_resize" + - "center_crop" + - "normalize" + + USE_PERCENTAGE: 100.0 + USE_SINGLE_CAPTION: false + +MODEL: + NAME: "virtex" + VISUAL: + NAME: "torchvision::resnet50" + PRETRAINED: false + FROZEN: false + TEXTUAL: + NAME: "transdec_postnorm::L1_H1024_A16_F4096" + DROPOUT: 0.1 + +OPTIM: + OPTIMIZER_NAME: "sgd" + SGD_MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0001 + + LOOKAHEAD: + USE: true + ALPHA: 0.5 + STEPS: 5 + + BATCH_SIZE: 256 + CNN_LR: 0.2 + LR: 0.001 + NUM_ITERATIONS: 500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" + + NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)" + CLIP_GRAD_NORM: 10.0 + diff --git a/virtex/configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml b/virtex/configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3db670cd5b9e4bfa0f1da6c668b7cf90cf80d23d --- /dev/null +++ b/virtex/configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + VISUAL: + NAME: "torchvision::resnet101" diff --git a/virtex/configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml b/virtex/configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e89bb9e3cdb3ceacbc94ad10829b0b5d4c409d34 --- /dev/null +++ b/virtex/configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + VISUAL: + NAME: "torchvision::wide_resnet50_2" diff --git a/virtex/configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml b/virtex/configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d004bb1a991185d067b68d361e854273cb2738a --- /dev/null +++ b/virtex/configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml @@ -0,0 +1 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" diff --git a/virtex/configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml b/virtex/configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d004bb1a991185d067b68d361e854273cb2738a --- /dev/null +++ b/virtex/configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml @@ -0,0 +1 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" diff --git a/virtex/configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml b/virtex/configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c391a26bbce5217484cd41bbadc37ce9a6b0309 --- /dev/null +++ b/virtex/configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L2_H1024_A16_F4096" diff --git a/virtex/configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml b/virtex/configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aeb89ca98d97cdff802f1800eb32531357781177 --- /dev/null +++ b/virtex/configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L3_H1024_A16_F4096" diff --git a/virtex/configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml b/virtex/configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bde4aca414e76c89243311916bf00bdacbafac2 --- /dev/null +++ b/virtex/configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L4_H1024_A16_F4096" diff --git a/virtex/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml b/virtex/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml new file mode 100644 index 0000000000000000000000000000000000000000..639cb01322588d5f7329d792964847413259e60f --- /dev/null +++ b/virtex/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------------- +# Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows +# Detectron2 format; and is unrelated with our VirTex configs. Params here +# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). +# ---------------------------------------------------------------------------- + +INPUT: + # Input format will always be RGB, consistent with torchvision. + FORMAT: "RGB" + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 + +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + + # Train all layers end-to-end by default. + BACKBONE: + NAME: build_resnet_backbone + FREEZE_AT: 0 + + # Fine-tune with SyncBN. + # STRIDE_IN_1X1 is False for torchvision-like models. + RESNETS: + DEPTH: 50 + NORM: SyncBN + STRIDE_IN_1X1: False + + RPN: + PRE_NMS_TOPK_TEST: 6000 + POST_NMS_TOPK_TEST: 1000 + + # ROI head with extra BN layer after res5 stage. + ROI_HEADS: + NAME: "Res5ROIHeadsExtraNorm" + + # ImageNet color mean for torchvision-like models (RGB order). + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + +SOLVER: + # This is for 8 GPUs, apply linear scaling for 4 GPUs. + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + +TEST: + PRECISE_BN: + ENABLED: True + +VERSION: 2 diff --git a/virtex/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml b/virtex/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml new file mode 100644 index 0000000000000000000000000000000000000000..efb1f40f6c5c13ea95f4b3cb758bc20ef42983c1 --- /dev/null +++ b/virtex/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml @@ -0,0 +1,75 @@ +# ---------------------------------------------------------------------------- +# Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows +# Detectron2 format; and is unrelated with our VirTex configs. Params here +# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). +# ---------------------------------------------------------------------------- + +INPUT: + # Input format will always be RGB, consistent with torchvision. + FORMAT: "RGB" + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 + +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + + # Train all layers end-to-end by default. + BACKBONE: + NAME: "build_resnet_fpn_backbone" + FREEZE_AT: 0 + + # Fine-tune with SyncBN. + # STRIDE_IN_1X1 is False for torchvision-like models. + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + + ANCHOR_GENERATOR: + # One size for each in feature map + SIZES: [[32], [64], [128], [256], [512]] + # Three aspect ratios (same for all in feature maps) + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] + + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 + PRE_NMS_TOPK_TEST: 1000 + + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 + + # ImageNet color mean for torchvision-like models (RGB order). + # These are in [0-255] range as expected by Detectron2. Rest of our codebase + # uses [0-1] range; but both are equivalent and consistent. + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + +SOLVER: + # This is for 8 GPUs, apply linear scaling for 4 GPUs. + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + +TEST: + PRECISE_BN: + ENABLED: True + +VERSION: 2 diff --git a/virtex/configs/detectron2/coco_segm_default_init_2x.yaml b/virtex/configs/detectron2/coco_segm_default_init_2x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d897fa532405d753ee3e9396616831326c89404 --- /dev/null +++ b/virtex/configs/detectron2/coco_segm_default_init_2x.yaml @@ -0,0 +1,24 @@ +# ----------------------------------------------------------------------------- +# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of +# these weight init: random, imagenet (torchvision), virtex or MoCo. +# ----------------------------------------------------------------------------- +_BASE_: "_base_mask_rcnn_R_50_FPN.yaml" + +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) + +MODEL: + MASK_ON: True + # FPN also has SyncBN, as opposed to no norm (usually). + FPN: + NORM: "SyncBN" + + # This will be ignored, weights will be loaded manually in the script. + WEIGHTS: "" + +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 + +VERSION: 2 diff --git a/virtex/configs/detectron2/lvis_segm_default_init_2x.yaml b/virtex/configs/detectron2/lvis_segm_default_init_2x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1df4dc373e03aff0ae3fbcbd783329ec485d605 --- /dev/null +++ b/virtex/configs/detectron2/lvis_segm_default_init_2x.yaml @@ -0,0 +1,36 @@ +# ----------------------------------------------------------------------------- +# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of +# these weight init: random, virtex or MoCo. (ImageNet init config is separate) +# ----------------------------------------------------------------------------- +_BASE_: "_base_mask_rcnn_R_50_FPN.yaml" + +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) + +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300. + +MODEL: + MASK_ON: True + # FPN also has SyncBN, as opposed to no norm (usually). + FPN: + NORM: "SyncBN" + + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.0001 + + # This will be ignored, weights will be loaded manually in the script. + WEIGHTS: "" + +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 + +VERSION: 2 + diff --git a/virtex/configs/detectron2/lvis_segm_imagenet_init_2x.yaml b/virtex/configs/detectron2/lvis_segm_imagenet_init_2x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5751f83ddf67bb0bb7b9bc1b6e992b72676fceea --- /dev/null +++ b/virtex/configs/detectron2/lvis_segm_imagenet_init_2x.yaml @@ -0,0 +1,38 @@ +# ----------------------------------------------------------------------------- +# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation +# with weights initialized from supervised ImageNet pretraining (torchvision). +# Key difference is that fine-tuning here happens with BN frozen. +# ----------------------------------------------------------------------------- +_BASE_: "_base_mask_rcnn_R_50_FPN.yaml" + +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) + +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300. + +MODEL: + MASK_ON: True + RESNETS: + NORM: "FrozenBN" + + # Do not tune with SyncBN for ImageNet init from LVIS. + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.0001 + + # This will be ignored, weights will be loaded manually in the script. + WEIGHTS: "" + +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 + +VERSION: 2 + + diff --git a/virtex/configs/detectron2/voc_det_default_init_24k.yaml b/virtex/configs/detectron2/voc_det_default_init_24k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97b9fdad1305dec93504f4b288adab44f2239fb9 --- /dev/null +++ b/virtex/configs/detectron2/voc_det_default_init_24k.yaml @@ -0,0 +1,28 @@ +# ----------------------------------------------------------------------------- +# Train a Faster R-CNN with R50-C4 backbone on VOC07+12 detection with any of +# these weight init: random, imagenet (torchvision), virtex or MoCo. +# ----------------------------------------------------------------------------- +_BASE_: "_base_faster_rcnn_R_50_C4_BN.yaml" + +DATASETS: + TRAIN: ("voc_2007_trainval", "voc_2012_trainval") + TEST: ("voc_2007_test",) + +INPUT: + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + MIN_SIZE_TEST: 800 + +MODEL: + MASK_ON: False + ROI_HEADS: + NUM_CLASSES: 20 + + # This will be ignored, weights will be loaded manually in the script. + WEIGHTS: "" + +SOLVER: + STEPS: (18000, 22000) + MAX_ITER: 24000 + WARMUP_ITERS: 100 + +VERSION: 2 diff --git a/virtex/configs/downstream/imagenet_clf.yaml b/virtex/configs/downstream/imagenet_clf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..895de3f251ea76945fa483f3b75f69a2303b47c2 --- /dev/null +++ b/virtex/configs/downstream/imagenet_clf.yaml @@ -0,0 +1,33 @@ +RANDOM_SEED: 0 +# Don't need AMP to train a tiny linear layer. +AMP: false +CUDNN_BENCHMARK: true +CUDNN_DETERMINISTIC: false + +DATA: + ROOT: "datasets/imagenet" + IMAGE_TRANSFORM_TRAIN: + - "random_resized_crop::{'scale': (0.08, 1.0)}" + - "horizontal_flip" + - "normalize" + IMAGE_TRANSFORM_VAL: + - "smallest_resize" + - "center_crop" + - "normalize" + +MODEL: + VISUAL: + FROZEN: true + +OPTIM: + BATCH_SIZE: 256 + SGD_MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0 + NO_DECAY: "none" + LOOKAHEAD: + USE: false + + LR: 0.3 + WARMUP_STEPS: 0 + LR_DECAY_NAME: "cosine" + NUM_ITERATIONS: 500500 # 100 epochs diff --git a/virtex/configs/downstream/inaturalist_clf.yaml b/virtex/configs/downstream/inaturalist_clf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eeb5b13ce31e4ba79918a881fcd23528b0f0c905 --- /dev/null +++ b/virtex/configs/downstream/inaturalist_clf.yaml @@ -0,0 +1,36 @@ +RANDOM_SEED: 0 +AMP: true +CUDNN_BENCHMARK: true +CUDNN_DETERMINISTIC: false + +DATA: + ROOT: "datasets/inaturalist" + IMAGE_TRANSFORM_TRAIN: + - "random_resized_crop::{'scale': (0.08, 1.0)}" + - "horizontal_flip" + - "normalize" + IMAGE_TRANSFORM_VAL: + - "smallest_resize" + - "center_crop" + - "normalize" + +MODEL: + VISUAL: + FROZEN: false + +OPTIM: + BATCH_SIZE: 256 + SGD_MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0001 + NO_DECAY: "none" + LOOKAHEAD: + USE: false + + LR: 0.025 + WARMUP_STEPS: 0 + LR_DECAY_NAME: multistep + LR_GAMMA: 0.1 + LR_STEPS: + - 119700 # 70 epochs + - 153900 # 90 epochs + NUM_ITERATIONS: 171000 # 100 epochs diff --git a/virtex/configs/downstream/voc07_clf.yaml b/virtex/configs/downstream/voc07_clf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac3b029e3969662ca0811b87a2224fcfb6bd7ac0 --- /dev/null +++ b/virtex/configs/downstream/voc07_clf.yaml @@ -0,0 +1,15 @@ +RANDOM_SEED: 0 +DATA: + ROOT: datasets/VOC2007 + IMAGE_TRANSFORM_TRAIN: + - smallest_resize + - center_crop + - normalize + IMAGE_TRANSFORM_VAL: + - smallest_resize + - center_crop + - normalize + +OPTIM: + # Only used for feature extraction, doesn't mean much. + BATCH_SIZE: 128 diff --git a/virtex/configs/redcaps/gcc_R_50_L6_H512.yaml b/virtex/configs/redcaps/gcc_R_50_L6_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b5d9ae7621dad07390d889ef685d102e387ad3f --- /dev/null +++ b/virtex/configs/redcaps/gcc_R_50_L6_H512.yaml @@ -0,0 +1,35 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +AMP: True + +DATA: + ROOT: "datasets/gcc/tarfiles/*.tar" + TOKENIZER_MODEL: "datasets/vocab/common_30k.model" + VOCAB_SIZE: 30000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + MAX_CAPTION_LENGTH: 50 + +MODEL: + NAME: "virtex_web" + TEXTUAL: + NAME: "transdec_prenorm::L6_H512_A8_F2048" + + LABEL_SMOOTHING: 0.1 + +OPTIM: + OPTIMIZER_NAME: "adamw" + WEIGHT_DECAY: 0.01 + LOOKAHEAD: + USE: false + + BATCH_SIZE: 256 + CNN_LR: 0.0005 + LR: 0.0005 + NUM_ITERATIONS: 1500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" diff --git a/virtex/configs/redcaps/miniclip_sbu_R_50_L12_H512.yaml b/virtex/configs/redcaps/miniclip_sbu_R_50_L12_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ebc42885bc1348a9883947076b8db95d7ed4677 --- /dev/null +++ b/virtex/configs/redcaps/miniclip_sbu_R_50_L12_H512.yaml @@ -0,0 +1,35 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +AMP: True + +DATA: + ROOT: "datasets/sbu/tarfiles/*.tar" + TOKENIZER_MODEL: "datasets/vocab/common_30k.model" + VOCAB_SIZE: 30000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + MAX_CAPTION_LENGTH: 50 + +MODEL: + NAME: "miniclip_web" + TEXTUAL: + NAME: "transenc_prenorm::L12_H512_A8_F2048" + LABEL_SMOOTHING: 0.1 + +OPTIM: + OPTIMIZER_NAME: "adamw" + WEIGHT_DECAY: 0.01 + + LOOKAHEAD: + USE: false + + BATCH_SIZE: 256 + CNN_LR: 0.0005 + LR: 0.0005 + NUM_ITERATIONS: 1500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" diff --git a/virtex/configs/redcaps/redcaps_2020_R_50_L6_H512.yaml b/virtex/configs/redcaps/redcaps_2020_R_50_L6_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c95e9f507df982f448c4c47898c9d2bb70bfb6f --- /dev/null +++ b/virtex/configs/redcaps/redcaps_2020_R_50_L6_H512.yaml @@ -0,0 +1,35 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +AMP: True + +DATA: + ROOT: "datasets/redcaps/tarfiles/*_2020_*.tar" + TOKENIZER_MODEL: "datasets/vocab/common_30k.model" + VOCAB_SIZE: 30000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + MAX_CAPTION_LENGTH: 50 + +MODEL: + NAME: "virtex_web" + TEXTUAL: + NAME: "transdec_prenorm::L6_H512_A8_F2048" + LABEL_SMOOTHING: 0.1 + +OPTIM: + OPTIMIZER_NAME: "adamw" + WEIGHT_DECAY: 0.01 + + LOOKAHEAD: + USE: false + + BATCH_SIZE: 256 + CNN_LR: 0.0005 + LR: 0.0005 + NUM_ITERATIONS: 1500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" diff --git a/virtex/configs/redcaps/redcaps_all_R_50_L6_H512.yaml b/virtex/configs/redcaps/redcaps_all_R_50_L6_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5249782d69c8ce79d68f15e9a7fd06e83b06ae9 --- /dev/null +++ b/virtex/configs/redcaps/redcaps_all_R_50_L6_H512.yaml @@ -0,0 +1,35 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +AMP: True + +DATA: + ROOT: "datasets/redcaps/tarfiles/*.tar" + TOKENIZER_MODEL: "datasets/vocab/common_30k.model" + VOCAB_SIZE: 30000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + MAX_CAPTION_LENGTH: 50 + +MODEL: + NAME: "virtex_web" + TEXTUAL: + NAME: "transdec_prenorm::L6_H512_A8_F2048" + LABEL_SMOOTHING: 0.1 + +OPTIM: + OPTIMIZER_NAME: "adamw" + WEIGHT_DECAY: 0.01 + + LOOKAHEAD: + USE: false + + BATCH_SIZE: 256 + CNN_LR: 0.0005 + LR: 0.0005 + NUM_ITERATIONS: 1500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" diff --git a/virtex/configs/redcaps/sbu_R_50_L6_H512.yaml b/virtex/configs/redcaps/sbu_R_50_L6_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1834f85ebfcfcf7a11361d95e99efc651577f9dd --- /dev/null +++ b/virtex/configs/redcaps/sbu_R_50_L6_H512.yaml @@ -0,0 +1,35 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +AMP: True + +DATA: + ROOT: "datasets/sbu/tarfiles/*.tar" + TOKENIZER_MODEL: "datasets/vocab/common_30k.model" + VOCAB_SIZE: 30000 + UNK_INDEX: 0 + SOS_INDEX: 1 + EOS_INDEX: 2 + MASK_INDEX: 3 + + MAX_CAPTION_LENGTH: 50 + +MODEL: + NAME: "virtex_web" + TEXTUAL: + NAME: "transdec_prenorm::L6_H512_A8_F2048" + LABEL_SMOOTHING: 0.1 + +OPTIM: + OPTIMIZER_NAME: "adamw" + WEIGHT_DECAY: 0.01 + + LOOKAHEAD: + USE: false + + BATCH_SIZE: 256 + CNN_LR: 0.0005 + LR: 0.0005 + NUM_ITERATIONS: 1500000 + + WARMUP_STEPS: 10000 + LR_DECAY_NAME: "cosine" diff --git a/virtex/configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml b/virtex/configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a132a630675820261f09afcf3128b9684034c630 --- /dev/null +++ b/virtex/configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L1_H2048_A32_F8192" diff --git a/virtex/configs/task_ablations/captioning_R_50_L1_H2048.yaml b/virtex/configs/task_ablations/captioning_R_50_L1_H2048.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82f159f9203cbd323c71e183b44630f1b44c558d --- /dev/null +++ b/virtex/configs/task_ablations/captioning_R_50_L1_H2048.yaml @@ -0,0 +1,6 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + NAME: "captioning" + TEXTUAL: + NAME: "transdec_postnorm::L1_H2048_A32_F8192" diff --git a/virtex/configs/task_ablations/masked_lm_R_50_L1_H2048.yaml b/virtex/configs/task_ablations/masked_lm_R_50_L1_H2048.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14a11155ae1d2aadfcc41f56bbf580a890a8e83b --- /dev/null +++ b/virtex/configs/task_ablations/masked_lm_R_50_L1_H2048.yaml @@ -0,0 +1,6 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + NAME: "masked_lm" + TEXTUAL: + NAME: "transdec_postnorm::L1_H2048_A32_F8192" diff --git a/virtex/configs/task_ablations/multilabel_classification_R_50.yaml b/virtex/configs/task_ablations/multilabel_classification_R_50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8768b9c34bcc6c8078b7324062246f73eb346b0 --- /dev/null +++ b/virtex/configs/task_ablations/multilabel_classification_R_50.yaml @@ -0,0 +1,12 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +DATA: + VOCAB_SIZE: 81 + +MODEL: + NAME: "multilabel_classification" + TEXTUAL: + NAME: "none" + +OPTIM: + NO_DECAY: "none" diff --git a/virtex/configs/task_ablations/token_classification_R_50.yaml b/virtex/configs/task_ablations/token_classification_R_50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c31ee4d08300c4033231003c0940bb4940276073 --- /dev/null +++ b/virtex/configs/task_ablations/token_classification_R_50.yaml @@ -0,0 +1,9 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + NAME: "token_classification" + TEXTUAL: + NAME: "none" + +OPTIM: + NO_DECAY: "none" diff --git a/virtex/configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d004bb1a991185d067b68d361e854273cb2738a --- /dev/null +++ b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml @@ -0,0 +1 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" diff --git a/virtex/configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a132a630675820261f09afcf3128b9684034c630 --- /dev/null +++ b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L1_H2048_A32_F8192" diff --git a/virtex/configs/width_ablations/bicaptioning_R_50_L1_H512.yaml b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b23d0c5ebcc2aae31e599ebd7bd49c923e4fe23 --- /dev/null +++ b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H512.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L1_H512_A8_F2048" diff --git a/virtex/configs/width_ablations/bicaptioning_R_50_L1_H768.yaml b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H768.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7882e204fba9f05febc204d7b053d8cb4dfe344f --- /dev/null +++ b/virtex/configs/width_ablations/bicaptioning_R_50_L1_H768.yaml @@ -0,0 +1,5 @@ +_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" + +MODEL: + TEXTUAL: + NAME: "transdec_postnorm::L1_H768_A12_F3072" diff --git a/virtex/docs/Makefile b/virtex/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..a33ba2ab28931acc202130e69db1104b883fb578 --- /dev/null +++ b/virtex/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = ../../virtex-sphinx + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/virtex/docs/_static/custom.css b/virtex/docs/_static/custom.css new file mode 100644 index 0000000000000000000000000000000000000000..02df8f2a58a2a283273083b58a1cb9d0ac10f7f7 --- /dev/null +++ b/virtex/docs/_static/custom.css @@ -0,0 +1,115 @@ +body { + padding: 40px 0 0 0; + font-size: 12pt; + font-family: Inconsolata !important; +} + +/* Monospace everywhere */ +h1, h2, h3, h4, div.sphinxsidebar h1, div.sphinxsidebar h2, +div.sphinxsidebar h3, div.sphinxsidebar h4, div.body h1, +div.body h2, div.body h3, div.body h4, .admonition-title { + font-family: monospace !important; +} + +/* Make main content wider */ +div.document { + margin: auto; + width: 65%; +} + +/* Make sidebar slightly wider. */ +div.sphinxsidebar { + width: 250px; +} + +div.bodywrapper { + margin: 0 0 0 250px; +} + +div.body { + color: black; + max-width: 100% +} + +/* Darker headings */ +h1, h2, h3, h4, div.sphinxsidebar h1, div.sphinxsidebar h2, +div.sphinxsidebar h3, div.sphinxsidebar h4, div.body h1, +div.body h2, div.body h3, div.body h4 { + color: black; +} + +@media screen and (max-width: 875px) { + div.sphinxsidebar { + background-color: white; + } +} + +/* Darker bold words */ +strong { + color: #252525; +} + +/* TOC tree tag, view source link & permalink anchor styling. */ +div.sphinxsidebar a, .viewcode-link, a.reference { + color: darkgreen; + text-decoration: none; + border-bottom: 1px dashed green; + text-underline-position: under; +} +a.headerlink { + color: black; +} + +/* TOC tree tag, view source link & permalink anchor styling. */ +div.sphinxsidebar a:hover, .viewcode-link:hover, a.reference:hover, +a.headerlink:hover { + font-weight: 700; + border-bottom: 1px solid green; +} + +/* Add a light background to class signatures. */ +dl.class > dt:first-of-type, dl.function > dt:first-of-type, +dl.method > dt:first-of-type, dl.classmethod > dt:first-of-type, +dl.attribute > dt:first-of-type, dl.data > dt:first-of-type { + font-size: 14pt; + background-color: #d8f6e9; + padding: 10px 20px 10px 10px; + border: 1px solid #1b5e20; +} + +/* Add lightgrey background to code snippets. */ +pre { + background-color: #eeeeee !important; + border: 1pt solid #999999; + border-radius: 5px; +} + +/* Dark orange-red comments in code snippets. */ +.highlight .c1 { + color: #dd4533; +} + +.admonition, .note { + background-color: #fed8b1 !important; + border: 1pt solid #ff7700; + border-radius: 5px; +} + +/* Make "Parameters" subsection wider - display heading and content vertically. */ +dl.field-list { + display: block; +} + +/* Increase font size of subsection headings ("Parameters", "Examples" etc.) */ +.rubric, dl.field-list > dt.field-odd, dl.field-list > dt.field-even { + color: black; + font-size: 18pt; + font-weight: bold; + padding: 0px; + margin: 20px 0px 20px 0px; +} + +/* Add margins around methods and properties. */ +.py { + margin: 20px 0px 20px 0px; +} diff --git a/virtex/docs/_static/system_figure.jpg b/virtex/docs/_static/system_figure.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ac5d31c263df121e1a005d141390dee0a6b344a Binary files /dev/null and b/virtex/docs/_static/system_figure.jpg differ diff --git a/virtex/docs/_templates/layout.html b/virtex/docs/_templates/layout.html new file mode 100644 index 0000000000000000000000000000000000000000..66497fed98d5eb668d0781cbcc28e23147bd72bb --- /dev/null +++ b/virtex/docs/_templates/layout.html @@ -0,0 +1,19 @@ +{% extends "!layout.html" %} + +{% block htmltitle %} + + + + + + + + +{{ super() }} +{% endblock %} diff --git a/virtex/docs/conf.py b/virtex/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd9cafe341546765f5aa38074d79832655da75a --- /dev/null +++ b/virtex/docs/conf.py @@ -0,0 +1,173 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import inspect +import os +import sys + +sys.path.insert(0, os.path.abspath("../")) + + +# -- Project information ----------------------------------------------------- + +project = "virtex" +copyright = "2021, Karan Desai and Justin Johnson" +author = "Karan Desai" + +# The full version, including alpha/beta/rc tags +release = "1.1" + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.linkcode", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx_copybutton", + "numpydoc", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = ".rst" + +# The master toctree document. +master_doc = "index" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# This version is used underneath the title on the index page. +version = "1.1" +# The following is used if you need to also include a more detailed version. +release = "1.1" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = "en" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +numpydoc_show_class_members = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "alabaster" + +# html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {"collapse_navigation": False, "display_version": True} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + + +# -- Autodoc configuration ------------------------------------------------ + +autodoc_default_options = { + "members": True, + "member-order": "bysource", + "private-members": True, + "show-inheritance": True, +} + + +# -- Intersphinx configuration -------------------------------------------- + +intersphinx_mapping = { + "torch": ("https://pytorch.org/docs/stable/", None), + "albumentations": ("https://albumentations.readthedocs.io/en/latest/", None), +} + +# -- Miscellaneous Extra Tweaks ------------------------------------------- + +# make github links resolve +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + This code is from + https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L290 + and https://github.com/Lasagne/Lasagne/pull/262 + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except: # noqa: E722 + return None + + try: + fn = inspect.getsourcefile(obj) + except: # noqa: E722 + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except: # noqa: E722 + lineno = None + + if lineno: + linespec = "#L%d-L%d" % (lineno, lineno + len(source) - 1) + else: + linespec = "" + + filename = info["module"].replace(".", "/") + return f"https://github.com/kdexd/virtex/blob/master/{filename}.py{linespec}" diff --git a/virtex/docs/index.rst b/virtex/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..f0866e98c0b79ebdbd962fb3f1ac2e595ddc2397 --- /dev/null +++ b/virtex/docs/index.rst @@ -0,0 +1,122 @@ +.. raw:: html + +

+ VirTex: Learning Visual Representations from Textual Annotations +

+

+ Karan Desai and Justin Johnson +
+ + University of Michigan + +

+
+ +

+ Abstract +

+ +

+ The de-facto approach to many vision tasks is to start from pretrained + visual representations, typically learned via supervised training on + ImageNet. Recent methods have explored unsupervised pretraining to scale to + vast quantities of unlabeled images. In contrast, we aim to learn + high-quality visual representations from fewer images. To this end we + revisit supervised pretraining, and seek data-efficient alternatives to + classification-based pretraining. We propose VirTex -- a pretraining + approach using semantically dense captions to learn visual representations. + We train convolutional networks from scratch on COCO Captions, and transfer + them to downstream recognition tasks including image classification, object + detection, and instance segmentation. On all tasks, VirTex yields features + that match or exceed those learned on ImageNet -- supervised or unsupervised + -- despite using up to ten times fewer images. +

+ +**CVPR 2021. Paper available at:** `arxiv.org/abs/2006.06666 `_. + +**Code available at:** `github.com/kdexd/virtex `_. + +.. image:: _static/system_figure.jpg + + +Get the pretrained ResNet-50 visual backbone from our best performing VirTex +model in one line *without any installation*! + +.. code-block:: python + + import torch + + # That's it, this one line only requires PyTorch. + model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) + + +More details in :doc:`virtex/usage/model_zoo`. Next, dive deeper into our +code with User Guide and API References! + + +User Guide +---------- + +.. toctree:: + :maxdepth: 2 + + virtex/usage/setup_dependencies + virtex/usage/model_zoo + virtex/usage/pretrain + virtex/usage/downstream + + +API Reference +------------- + +.. toctree:: + :maxdepth: 2 + + virtex/config + virtex/factories + virtex/data + virtex/models + virtex/modules + virtex/optim + virtex/utils + virtex/model_zoo + + +Citation +-------- + +If you find this code useful, please consider citing: + +.. code-block:: text + + @inproceedings{desai2021virtex, + title={{VirTex: Learning Visual Representations from Textual Annotations}}, + author={Karan Desai and Justin Johnson}, + booktitle={CVPR}, + year={2021} + } + + +Acknowledgments +--------------- + +We thank Harsh Agrawal, Mohamed El Banani, Richard Higgins, Nilesh Kulkarni +and Chris Rockwell for helpful discussions and feedback on the paper. We thank +Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for +discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and +Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting +the Instance Segmentation pretraining task ablation; and Stefan Lee for +suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs +during project development; and UMich ARC-TS team for support with GPU cluster +management. Finally, we thank all the Starbucks outlets in Ann Arbor for many +hours of free WiFi. This work was partially supported by the Toyota Research +Institute (TRI). However, note that this article solely reflects the opinions +and conclusions of its authors and not TRI or any other Toyota entity. + + +Indices and Tables +------------------ + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/virtex/docs/virtex/config.rst b/virtex/docs/virtex/config.rst new file mode 100644 index 0000000000000000000000000000000000000000..585a5042aa961964722e4184f56718b7682f16fd --- /dev/null +++ b/virtex/docs/virtex/config.rst @@ -0,0 +1,18 @@ +virtex.config +============= + +.. raw:: html + +
+ +.. automodule:: virtex.config + + +Config References +----------------- + +.. literalinclude:: ../../virtex/config.py + :language: python + :linenos: + :lines: 46-206 + :dedent: 8 diff --git a/virtex/docs/virtex/data.datasets.rst b/virtex/docs/virtex/data.datasets.rst new file mode 100644 index 0000000000000000000000000000000000000000..686a974d2d5c4db3a937270719053de6df0ade67 --- /dev/null +++ b/virtex/docs/virtex/data.datasets.rst @@ -0,0 +1,20 @@ +virtex.data.datasets +==================== + +.. raw:: html + +
+ +Pretraining Datasets +-------------------- + +.. automodule:: virtex.data.datasets.captioning + +.. automodule:: virtex.data.datasets.classification + +------------------------------------------------------------------------------ + +Downstream Datasets +------------------- + +.. automodule:: virtex.data.datasets.downstream diff --git a/virtex/docs/virtex/data.readers.rst b/virtex/docs/virtex/data.readers.rst new file mode 100644 index 0000000000000000000000000000000000000000..f65a8327103cc6d0a203838b17960201735f6885 --- /dev/null +++ b/virtex/docs/virtex/data.readers.rst @@ -0,0 +1,8 @@ +virtex.data.readers +=================== + +.. raw:: html + +
+ +.. automodule:: virtex.data.readers diff --git a/virtex/docs/virtex/data.rst b/virtex/docs/virtex/data.rst new file mode 100644 index 0000000000000000000000000000000000000000..882d69accc4a25e7275b933788a4b07ca0d964fd --- /dev/null +++ b/virtex/docs/virtex/data.rst @@ -0,0 +1,14 @@ +virtex.data +=========== + +.. raw:: html + +
+ + +.. toctree:: + + data.readers + data.datasets + data.tokenizers + data.transforms diff --git a/virtex/docs/virtex/data.tokenizers.rst b/virtex/docs/virtex/data.tokenizers.rst new file mode 100644 index 0000000000000000000000000000000000000000..59594dd805010eff4a8201a1797baf3488c0e33d --- /dev/null +++ b/virtex/docs/virtex/data.tokenizers.rst @@ -0,0 +1,8 @@ +virtex.data.tokenizers +====================== + +.. raw:: html + +
+ +.. automodule:: virtex.data.tokenizers diff --git a/virtex/docs/virtex/data.transforms.rst b/virtex/docs/virtex/data.transforms.rst new file mode 100644 index 0000000000000000000000000000000000000000..7d9b0299f2187112e5ea51b54d55eac18f0717c4 --- /dev/null +++ b/virtex/docs/virtex/data.transforms.rst @@ -0,0 +1,8 @@ +virtex.data.transforms +====================== + +.. raw:: html + +
+ +.. automodule:: virtex.data.transforms diff --git a/virtex/docs/virtex/factories.rst b/virtex/docs/virtex/factories.rst new file mode 100644 index 0000000000000000000000000000000000000000..078afc5bacd486b8b449d76b82624de09679f916 --- /dev/null +++ b/virtex/docs/virtex/factories.rst @@ -0,0 +1,56 @@ +virtex.factories +================ + +.. raw:: html + +
+ +.. First only include the top-level module, and base class docstrings. + +.. automodule:: virtex.factories + :no-members: + +.. autoclass:: virtex.factories.Factory + + +------------------------------------------------------------------------------ + +Dataloading-related Factories +----------------------------- + +.. autoclass:: virtex.factories.TokenizerFactory + :members: from_config + +.. autoclass:: virtex.factories.ImageTransformsFactory + :members: from_config + +.. autoclass:: virtex.factories.PretrainingDatasetFactory + :members: from_config + +.. autoclass:: virtex.factories.DownstreamDatasetFactory + :members: from_config + +------------------------------------------------------------------------------ + +Modeling-related Factories +-------------------------- + +.. autoclass:: virtex.factories.VisualBackboneFactory + :members: from_config + +.. autoclass:: virtex.factories.TextualHeadFactory + :members: from_config + +.. autoclass:: virtex.factories.PretrainingModelFactory + :members: from_config + +------------------------------------------------------------------------------ + +Optimization-related Factories +------------------------------ + +.. autoclass:: virtex.factories.OptimizerFactory + :members: from_config + +.. autoclass:: virtex.factories.LRSchedulerFactory + :members: from_config diff --git a/virtex/docs/virtex/model_zoo.rst b/virtex/docs/virtex/model_zoo.rst new file mode 100644 index 0000000000000000000000000000000000000000..ebdb81863704d6d4c85d5c1b580240ea317d45c7 --- /dev/null +++ b/virtex/docs/virtex/model_zoo.rst @@ -0,0 +1,8 @@ +virtex.model_zoo +================ + +.. raw:: html + +
+ +.. automodule:: virtex.model_zoo.model_zoo diff --git a/virtex/docs/virtex/models.rst b/virtex/docs/virtex/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..83ab5751e65294b5071a4eebc543b6f70f9566d9 --- /dev/null +++ b/virtex/docs/virtex/models.rst @@ -0,0 +1,16 @@ +virtex.models +============= + +.. raw:: html + +
+ +.. automodule:: virtex.models.classification + +------------------------------------------------------------------------------- + +.. automodule:: virtex.models.captioning + +------------------------------------------------------------------------------- + +.. automodule:: virtex.models.masked_lm diff --git a/virtex/docs/virtex/modules.embedding.rst b/virtex/docs/virtex/modules.embedding.rst new file mode 100644 index 0000000000000000000000000000000000000000..6125716a3d2c964f1c8e402a475d24750ee42fa3 --- /dev/null +++ b/virtex/docs/virtex/modules.embedding.rst @@ -0,0 +1,8 @@ +virtex.modules.embedding +======================== + +.. raw:: html + +
+ +.. automodule:: virtex.modules.embedding diff --git a/virtex/docs/virtex/modules.rst b/virtex/docs/virtex/modules.rst new file mode 100644 index 0000000000000000000000000000000000000000..f623cfd865184240057f249c4cced5b8d11793c2 --- /dev/null +++ b/virtex/docs/virtex/modules.rst @@ -0,0 +1,12 @@ +virtex.modules +============== + +.. raw:: html + +
+ +.. toctree:: + + modules.embedding + modules.visual_backbones + modules.textual_heads diff --git a/virtex/docs/virtex/modules.textual_heads.rst b/virtex/docs/virtex/modules.textual_heads.rst new file mode 100644 index 0000000000000000000000000000000000000000..ddbc68d1c0bd8d1c6b8a997030b48050aec09ea9 --- /dev/null +++ b/virtex/docs/virtex/modules.textual_heads.rst @@ -0,0 +1,8 @@ +virtex.modules.textual_heads +============================ + +.. raw:: html + +
+ +.. automodule:: virtex.modules.textual_heads diff --git a/virtex/docs/virtex/modules.visual_backbones.rst b/virtex/docs/virtex/modules.visual_backbones.rst new file mode 100644 index 0000000000000000000000000000000000000000..8aff72132cf9ddddc5de04d6d68975cc5086e262 --- /dev/null +++ b/virtex/docs/virtex/modules.visual_backbones.rst @@ -0,0 +1,8 @@ +virtex.modules.visual_backbones +=============================== + +.. raw:: html + +
+ +.. automodule:: virtex.modules.visual_backbones diff --git a/virtex/docs/virtex/optim.lookahead.rst b/virtex/docs/virtex/optim.lookahead.rst new file mode 100644 index 0000000000000000000000000000000000000000..63030fd060386bec339f6cc88c7edde6b523ffb5 --- /dev/null +++ b/virtex/docs/virtex/optim.lookahead.rst @@ -0,0 +1,8 @@ +virtex.optim.lookahead +====================== + +.. raw:: html + +
+ +.. automodule:: virtex.optim.lookahead diff --git a/virtex/docs/virtex/optim.lr_scheduler.rst b/virtex/docs/virtex/optim.lr_scheduler.rst new file mode 100644 index 0000000000000000000000000000000000000000..62a0596e86ca1ef624ba0314c4f673f0c7829a66 --- /dev/null +++ b/virtex/docs/virtex/optim.lr_scheduler.rst @@ -0,0 +1,8 @@ +virtex.optim.lr_scheduler +========================= + +.. raw:: html + +
+ +.. automodule:: virtex.optim.lr_scheduler diff --git a/virtex/docs/virtex/optim.rst b/virtex/docs/virtex/optim.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf31a85cc8d92ef62e8686d2209e6b1e6f18c172 --- /dev/null +++ b/virtex/docs/virtex/optim.rst @@ -0,0 +1,11 @@ +virtex.optim +============ + +.. raw:: html + +
+ +.. toctree:: + + optim.lookahead + optim.lr_scheduler diff --git a/virtex/docs/virtex/usage/downstream.rst b/virtex/docs/virtex/usage/downstream.rst new file mode 100644 index 0000000000000000000000000000000000000000..c7278b0015cd46904c4e02d0d7e4e15cb6ef00f7 --- /dev/null +++ b/virtex/docs/virtex/usage/downstream.rst @@ -0,0 +1,216 @@ +How to evaluate on downstream tasks? +==================================== + +In our paper, we evaluate our pretrained VirTex models on seven different +downstream tasks. Our codebase supports all of these evaluations. Throughout +this documentation, we consider a specific example of our VirTex pretrained +model being evaluated for ensuring filepath uniformity in the following example +command snippets. Paths can be trivially adjusted for any other VirTex model; +evaluating the baselines (MoCo, ImageNet-supervised, Random Init) require +additional changes in commands, explained in the last sub-section. + +As an example, consider a pretraining job for our best performing VirTex model +(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``). The serialization +directory might look something like this: + +.. code-block:: text + + /tmp/bicaptioning_R_50_L1_H2048 + pretrain_config.yaml + log-rank0.txt # stdout/stderr per GPU process + log-rank1.txt + ... + log-rank7.txt + checkpoint_2000.pth + checkpoint_4000.pth + ... + checkpoint_498000.pth + checkpoint_500000.pth # serialized checkpoints + train_captioning_forward/ + events.out.* ... # tensorboard logs + ... + +We evaluate all checkpoints on **PASCAL VOC 2007 Linear Classification**, and +then evaluate the best checkpoint (here, it was iteration 500000) on all other +downstream tasks. + + +PASCAL VOC 2007 Linear Classification +------------------------------------- + +Evaluate a single VirTex pretrained checkpoint on VOC 2007 ``trainval`` split: + +.. code-block:: shell + + python scripts/clf_voc07.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --down-config configs/downstream/voc07_clf.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 1 \ + --cpu-workers 4 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048 + +To evaluate recent 100 checkpoints in the sub-directory, this command can be +looped over as follows: + +.. code-block:: shell + + for ((iter = 300000; iter <= 500000; iter+=2000)); do + # add command with `checkpoint_$iter.pth` + done + +This script write metric to tensorboard logs in the same pretraining directory, +all VOC07 mAP curves appear together with pretraining loss curves. + +------------------------------------------------------------------------------- + +ImageNet Linear Classification +------------------------------ + +We train a linear classifier on 2048-dimensional global average pooled features +extracted from a frozen visual backbone. Evaluate a checkpoint (for example, +iteration 500000) on this task as: + +.. code-block:: shell + + python scripts/clf_linear.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --down-config configs/downstream/imagenet_clf.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 8 \ + --cpu-workers 4 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/imagenet_500000 \ + --checkpoint-every 5005 # 1 epoch of ImageNet + +------------------------------------------------------------------------------- + +Instance Segmentation (and Object Detection) on COCO +---------------------------------------------------- + +Train a Mask R-CNN with FPN backbone for COCO Instance Segmentation (and Object +Detection, because it also has a box head) by initializing the backbone from +VirTex pretrained weights: + +.. code-block:: shell + + python scripts/eval_detectron2.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --d2-config configs/detectron2/coco_segm_default_init_2x.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 8 \ + --cpu-workers 2 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/coco_segm_500000 \ + --checkpoint-every 5000 + +.. note:: + + 1. This script periodically serializes checkpoints but skips validation + step during training for saving time; to evaluate a serialized checkpoint + and write results to tensorboard, provide it as ``--checkpoint-path`` and + additional flags ``--resume --eval-only``. + + 2. Note that ``--d2-config`` here is in Detectron2 format, and not our + package :class:`~virtex.config.Config`. + + These points are applicable for all tasks described below. + +------------------------------------------------------------------------------- + +Instance Segmentation on LVIS +----------------------------- + +Train a Mask R-CNN with FPN backbone for LVIS Instance Segmentation by +initializing the backbone from VirTex pretrained weights: + +.. code-block:: shell + + python scripts/eval_detectron2.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --d2-config configs/detectron2/lvis_segm_default_init_2x.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 8 \ + --cpu-workers 2 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/lvis_segm_500000 \ + --checkpoint-every 5000 + +------------------------------------------------------------------------------- + +Object Detection on PASCAL VOC 2007+12 +-------------------------------------- + +Train a Faster R-CNN with C4 backbone for PASCAL VOC 2007+12 Object Detection +by initializing the backbone from VirTex pretrained weights: + +.. code-block:: shell + + python scripts/eval_detectron2.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --d2-config configs/detectron2/voc_det_default_init_24k.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 8 \ + --cpu-workers 2 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/voc_det_500000 \ + --checkpoint-every 2500 + +------------------------------------------------------------------------------- + +iNaturalist 2018 Fine-Grained Classification +-------------------------------------------- + +Fine-tune the VirTex pretrained visual backbone end-to-end on iNaturalist 2018 +dataset: + +.. code-block:: shell + + python scripts/clf_linear.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --down-config configs/downstream/inaturalist_clf.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --weight-init virtex \ + --num-gpus-per-machine 8 \ + --cpu-workers 4 \ + --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/inaturalist_500000 \ + --checkpoint-every 1710 # 1 epoch of iNaturalist + +------------------------------------------------------------------------------- + +Image Captioning on COCO Captions val2017 +----------------------------------------- + +Evaluate a pretrained VirTex model on image captioning for COCO Captions val2017 +split (reporting CIDEr and SPICE metics): + +.. code-block:: shell + + python scripts/eval_captioning.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --calc-metrics \ + --num-gpus-per-machine 1 \ + --cpu-workers 4 + +------------------------------------------------------------------------------- + +Running Image Captioning Inference on Arbitrary Images +------------------------------------------------------ + +The above script can be used for generating captions for any images in a directory. +Replace certain commands as follows: + +.. code-block:: shell + + python scripts/eval_captioning.py \ + --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ + --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ + --data-root /path/to/images_dir \ + --output /path/to/save/predictions.json \ + --num-gpus-per-machine 1 \ + --cpu-workers 4 + +This script will save predictions in JSON format. Since our goal is to not +improve image captioning, these models may not generate the best captions. diff --git a/virtex/docs/virtex/usage/model_zoo.rst b/virtex/docs/virtex/usage/model_zoo.rst new file mode 100644 index 0000000000000000000000000000000000000000..daee9ee44c56ae08f5e2444b552f8233d28b66cb --- /dev/null +++ b/virtex/docs/virtex/usage/model_zoo.rst @@ -0,0 +1,234 @@ +VirTex Model Zoo +================ + +We provide a collection of pretrained model weights and corresponding config +names in this model zoo. Tables contain partial paths to config files for each +model, download link for pretrained weights and for reference -- VOC07 mAP and +ImageNet top-1 accuracy. + +The simplest way to download and use a *full* pretrained model (including both, +the visual backbone and the textual head) is through :doc:`../model_zoo` API as +follows. This code snippet works from anywhere, and does not require to be +executed from project root. + +.. code-block:: python + + # Get our full best performing VirTex model: + import virtex.model_zoo as mz + model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True) + + # Optionally extract the torchvision-like visual backbone (with ``avgpool`` + # and ``fc`` layers replaced with ``nn.Identity`` module). + cnn = model.visual.cnn + +Alternatively, weights can be manually downloaded from links below, and this +can be executed from the project root: + +.. code-block:: python + + from virtex.config import Config + from virtex.factories import PretrainingModelFactory + from virtex.utils.checkpointing import CheckpointManager + + # Get the best performing VirTex model: + _C = Config("configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml") + model = PretrainingModelFactory.from_config(_C) + + CheckpointManager(model=model).load("/path/to/downloaded/weights.pth") + + # Optionally extract the torchvision-like visual backbone (with ``avgpool`` + # and ``fc`` layers replaced with ``nn.Identity`` module). + cnn = model.visual.cnn + + +The pretrained ResNet-50 visual backbone of our best performing model +(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) can be loaded in a single +line, *without following any installation steps* (only requires PyTorch v1.5): + +.. code-block:: python + + import torch + + model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) + + # This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc`` + # layers replaced with ``nn.Identity`` module. + image_batch = torch.randn(1, 3, 224, 224) # batch tensor of one image. + features_batch = model(image_batch) # shape: (1, 2048, 7, 7) + +------------------------------------------------------------------------------- + +Pretraining Task Ablations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. raw:: html + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
task_ablations/bicaptioning_R_50_L1_H2048.yaml88.753.8model
task_ablations/captioning_R_50_L1_H2048.yaml88.650.8model
task_ablations/token_classification_R_50.yaml88.848.6model
task_ablations/multilabel_classification_R_50.yaml86.246.2model
task_ablations/masked_lm_R_50_L1_H2048.yaml86.446.7model
+ + +Width Ablations +^^^^^^^^^^^^^^^ + +.. raw:: html + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
width_ablations/bicaptioning_R_50_L1_H512.yaml88.451.8model
width_ablations/bicaptioning_R_50_L1_H768.yaml88.352.3model
width_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
width_ablations/bicaptioning_R_50_L1_H2048.yaml88.753.8model
+ + +Depth Ablations +^^^^^^^^^^^^^^^ + +.. raw:: html + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
depth_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
depth_ablations/bicaptioning_R_50_L2_H1024.yaml88.853.8model
depth_ablations/bicaptioning_R_50_L3_H1024.yaml88.753.9model
depth_ablations/bicaptioning_R_50_L4_H1024.yaml88.753.9model
+ + +Backbone Ablations +^^^^^^^^^^^^^^^^^^ + +.. raw:: html + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + +
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
backbone_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml88.552.9model
backbone_ablations/bicaptioning_R_101_L1_H1024.yaml88.752.1model
diff --git a/virtex/docs/virtex/usage/pretrain.rst b/virtex/docs/virtex/usage/pretrain.rst new file mode 100644 index 0000000000000000000000000000000000000000..2f14305f2152afdede708d45cbe5b2d165e9246a --- /dev/null +++ b/virtex/docs/virtex/usage/pretrain.rst @@ -0,0 +1,100 @@ +How to train your VirTex model? +=============================== + +We provide training scripts for all type of VirTex models from the paper; +including our best-performing model and other ablations. +Our training jobs are specified by config files (YAML). +Execute all commands from project root to use the provided config files. + + +Training the base VirTex model +------------------------------ + +Train the base VirTex model with ResNet-50 visual backbone; and a textual head +with ``L = 1, H = 1024`` using all default optimization hyperparameters. + +.. code-block:: + + python scripts/pretrain_virtex.py \ + --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \ + --num-gpus-per-machine 8 \ + --cpu-workers 4 \ + --serialization-dir /tmp/VIRTEX_R_50_L1_H1024 + # Default: --checkpoint-every 2000 --log-every 20 + +Training job will save checkpoints, tensorboard logs (loss curves and metrics), +and back up the config in ``--serialization-dir``. Use ``tensorboard --logdir +`` to view training curves, validation metrics etc. directly +on tensorboard. + +We recommend training with 8 GPUs on the same machine, although training with +multiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``), +single GPU (``--num-gpus-per-machine 1``) as well as CPU +(``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for +interactive debugging with PDB is not supported, as PDB and ``multiprocessing`` +module do not play nice. + +------------------------------------------------------------------------------- + +Reproducing all VirTex ablations +-------------------------------- + +To reproduce all ablations from the `paper `_, +replace the ``--config`` argument in above command with the following (all +assumed to be relative to project root): + +Pretraining Task Ablations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml +2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml +3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml +4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml +5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml + +Transformer Size Ablations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml +2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml +3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml +4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml +5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml +6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml +7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml +8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml + +Backbone Ablations +^^^^^^^^^^^^^^^^^^ + +1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml +2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml +3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml + +.. note:: + + **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5) + and **Backbone Ablations** (1) are all the same exact model. + +Data Efficiency Experiments +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +These are VirTex models trained on a subset of COCO Captions dataset. For example, +train a base VirTex model on randomly selected ``50%`` of COCO Captions: + +.. code-block:: + + python scripts/pretrain_virtex.py \ + --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \ + --config-override DATA.USE_PERCENTAGE 50.0 \ + --num-gpus-per-machine 8 \ + --cpu-workers 4 \ + --serialization-dir /tmp/VIRTEX_R_50_L1_H1024_PERCENT_50 + # Default: --checkpoint-every 2000 --log-every 20 + +COCO Captions provides five captions per image. To train with one fixed caption +per image, add ``DATA.USE_SINGLE_CAPTION True`` in ``--config-override``. + +The randomly selected subset is deterministic across runs based on random seed +(``RANDOM_SEED`` in config). When training on less than ``50%`` dataset size, we +recommend using multiple random seeds (results will have a variance of ``±1%``). diff --git a/virtex/docs/virtex/usage/setup_dependencies.rst b/virtex/docs/virtex/usage/setup_dependencies.rst new file mode 100644 index 0000000000000000000000000000000000000000..b4ece3964148f977154c367de2dfb84c57a86053 --- /dev/null +++ b/virtex/docs/virtex/usage/setup_dependencies.rst @@ -0,0 +1,153 @@ +How to setup this codebase? +=========================== + +.. raw:: html + +
+ +This codebase requires Python 3.6+ or higher. We recommend using Anaconda or +Miniconda. We walk through installation and data preprocessing here. + + +Install Dependencies +-------------------- + +For these steps to install through Anaconda (or Miniconda). + +1. Install Anaconda or Miniconda distribution based on Python 3+ from their + `downloads site `_. + + +2. Clone the repository first. + + .. code-block:: shell + + git clone https://www.github.com/kdexd/virtex + + +3. Create a conda environment and install all the dependencies. + + .. code-block:: shell + + cd virtex + conda create -n virtex python=3.6 + conda activate virtex + pip install -r requirements.txt + + +4. Install this codebase as a package in development version. + + .. code-block:: shell + + python setup.py develop + +Now you can ``import virtex`` from anywhere as long as you have this conda +environment activated. + +------------------------------------------------------------------------------- + + +Setup Datasets +-------------- + +Datasets are assumed to exist in ``./datasets`` directory (relative to the +project root) following the structure specified below. COCO is used for +pretraining, and rest of the datasets (including COCO) are used for downstream +tasks. This structure is compatible when using +`Detectron2 `_ for downstream +tasks. + +COCO +^^^^ +.. code-block:: + + datasets/coco/ + annotations/ + captions_{train,val}2017.json + instances_{train,val}2017.json + train2017/ + # images in train2017 split + val2017/ + # images in val2017 split + +LVIS +^^^^ +.. code-block:: + + datasets/coco/ + train2017/ + val2017/ + datasets/lvis/ + lvis_v1.0_{train,val}.json + +PASCAL VOC +^^^^^^^^^^ +.. code-block:: + + datasets/VOC2007/ + Annotations/ + ImageSets/ + Main/ + trainval.txt + test.txt + JPEGImages/ + + datasets/VOC2012/ + # Same as VOC2007 above + +ImageNet +^^^^^^^^ +.. code-block:: + + datasets/imagenet/ + train/ + # One directory per category with images in it + val/ + # One directory per category with images in it + ILSVRC2012_devkit_t12.tar.gz + +iNaturalist 2018 +^^^^^^^^^^^^^^^^ +.. code-block:: + + datasets/inaturalist/ + train_val2018/ + annotations/ + train2018.json + val2018.json + +------------------------------------------------------------------------------- + + +Preprocess Data +--------------- + +1. Build a vocabulary out of COCO Captions ``train2017`` split. + + .. code-block:: shell + + python scripts/preprocess/build_vocabulary.py \ + --captions datasets/coco/annotations/captions_train2017.json \ + --vocab-size 10000 \ + --output-prefix datasets/vocab/coco_10k \ + --do-lower-case + + +2. Serialize COCO Captions (``train2017`` and ``val2017`` splits) into LMDB + files. These are faster for data reading during pretraining. + + .. code-block:: shell + + python scripts/preprocess/preprocess_coco.py \ + --data-root datasets/coco \ + --split train \ + --output datasets/coco/serialized_train.lmdb + + .. code-block:: shell + + python scripts/preprocess/preprocess_coco.py \ + --data-root datasets/coco \ + --split val \ + --output datasets/coco/serialized_val.lmdb + +That's it! You are all set to use this codebase. diff --git a/virtex/docs/virtex/utils.beam_search.rst b/virtex/docs/virtex/utils.beam_search.rst new file mode 100644 index 0000000000000000000000000000000000000000..a04811e9c89a0c093e1ffb373467eb6ba9b81b87 --- /dev/null +++ b/virtex/docs/virtex/utils.beam_search.rst @@ -0,0 +1,8 @@ +virtex.utils.beam_search +======================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.beam_search diff --git a/virtex/docs/virtex/utils.checkpointing.rst b/virtex/docs/virtex/utils.checkpointing.rst new file mode 100644 index 0000000000000000000000000000000000000000..1b3719bf7e330c13835dc57457a3bef238c29b0e --- /dev/null +++ b/virtex/docs/virtex/utils.checkpointing.rst @@ -0,0 +1,8 @@ +virtex.utils.checkpointing +========================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.checkpointing diff --git a/virtex/docs/virtex/utils.common.rst b/virtex/docs/virtex/utils.common.rst new file mode 100644 index 0000000000000000000000000000000000000000..cadd36d26a01f03b4457f1caed1c0c03dc58a9ef --- /dev/null +++ b/virtex/docs/virtex/utils.common.rst @@ -0,0 +1,8 @@ +virtex.utils.common +=================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.common diff --git a/virtex/docs/virtex/utils.distributed.rst b/virtex/docs/virtex/utils.distributed.rst new file mode 100644 index 0000000000000000000000000000000000000000..e6a44d674ecb8a96d2568b1cd4072dd1e38f2a9d --- /dev/null +++ b/virtex/docs/virtex/utils.distributed.rst @@ -0,0 +1,8 @@ +virtex.utils.distributed +======================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.distributed diff --git a/virtex/docs/virtex/utils.metrics.rst b/virtex/docs/virtex/utils.metrics.rst new file mode 100644 index 0000000000000000000000000000000000000000..75234d5e4d230adf20192af77849b1a9c3f059d1 --- /dev/null +++ b/virtex/docs/virtex/utils.metrics.rst @@ -0,0 +1,8 @@ +virtex.utils.metrics +==================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.metrics diff --git a/virtex/docs/virtex/utils.rst b/virtex/docs/virtex/utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..9d021d9c4e1e255554130264d12abad06cc53911 --- /dev/null +++ b/virtex/docs/virtex/utils.rst @@ -0,0 +1,15 @@ +virtex.utils +============ + +.. raw:: html + +
+ +.. toctree:: + + utils.common + utils.distributed + utils.timer + utils.checkpointing + utils.beam_search + utils.metrics diff --git a/virtex/docs/virtex/utils.timer.rst b/virtex/docs/virtex/utils.timer.rst new file mode 100644 index 0000000000000000000000000000000000000000..c2ddcdb3459f519d9a98766a6ddbef2adefa072d --- /dev/null +++ b/virtex/docs/virtex/utils.timer.rst @@ -0,0 +1,8 @@ +virtex.utils.timer +================== + +.. raw:: html + +
+ +.. automodule:: virtex.utils.timer diff --git a/virtex/hubconf.py b/virtex/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..f85d01d371151f0716680397b1c955d3c4dd42d7 --- /dev/null +++ b/virtex/hubconf.py @@ -0,0 +1,35 @@ +dependencies = ["torch"] + +import torch +import torchvision + + +def resnet50(pretrained: bool = False, **kwargs): + r""" + ResNet-50 visual backbone from the best performing VirTex model: pretrained + for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``. + + This is a torchvision-like model, with the last ``avgpool`` and `fc`` + modules replaced with ``nn.Identity()`` modules. Given a batch of image + tensors with size ``(B, 3, 224, 224)``, this model computes spatial image + features of size ``(B, 7, 7, 2048)``, where B = batch size. + + pretrained (bool): Whether to load model with pretrained weights. + """ + + # Create a torchvision resnet50 with randomly initialized weights. + model = torchvision.models.resnet50(pretrained=False, **kwargs) + + # Replace global average pooling and fully connected layers with identity + # modules. + model.avgpool = torch.nn.Identity() + model.fc = torch.nn.Identity() + + if pretrained: + model.load_state_dict( + torch.hub.load_state_dict_from_url( + "https://umich.box.com/shared/static/gsjqm4i4fm1wpzi947h27wweljd8gcpy.pth", + progress=False, + )["model"] + ) + return model diff --git a/virtex/requirements.txt b/virtex/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa85aa188b9d1cd1340aefe2eb27007672981740 --- /dev/null +++ b/virtex/requirements.txt @@ -0,0 +1,18 @@ +albumentations>=0.5.0 +Cython>=0.25 +ftfy==5.8 +future==0.18.0 +lmdb==0.97 +loguru==0.3.2 +mypy_extensions==0.4.1 +lvis==0.5.3 +numpy>=1.17 +opencv-python==4.1.2.30 +scikit-learn==0.21.3 +sentencepiece>=0.1.90 +torch==1.7.0 +torchvision==0.8 +tqdm>=4.50.0 +wordsegment==1.3.1 +git+git://github.com/facebookresearch/fvcore.git#egg=fvcore +git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI diff --git a/virtex/scripts/clf_linear.py b/virtex/scripts/clf_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..52ab5f22d974cf4e523f174aab09143d7d19b005 --- /dev/null +++ b/virtex/scripts/clf_linear.py @@ -0,0 +1,302 @@ +import argparse +import os + +from loguru import logger +import torch +from torch import nn +from torch.cuda import amp +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from virtex.config import Config +from virtex.factories import ( + DownstreamDatasetFactory, + PretrainingModelFactory, + OptimizerFactory, + LRSchedulerFactory, +) +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser, common_setup, cycle +import virtex.utils.distributed as dist +from virtex.utils.metrics import TopkAccuracy +from virtex.utils.timer import Timer + + +# fmt: off +parser = common_parser( + description="""Do image classification with linear models and frozen + feature extractor, or fine-tune the feature extractor end-to-end.""" +) +group = parser.add_argument_group("Downstream config arguments.") +group.add_argument( + "--down-config", metavar="FILE", help="Path to a downstream config file." +) +group.add_argument( + "--down-config-override", nargs="*", default=[], + help="A list of key-value pairs to modify downstream config params.", +) + +parser.add_argument_group("Checkpointing and Logging") +parser.add_argument( + "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], + default="virtex", help="""How to initialize weights: + 1. 'random' initializes all weights randomly + 2. 'imagenet' initializes backbone weights from torchvision model zoo + 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path + - with 'torchvision', state dict would be from PyTorch's training + script. + - with 'virtex' it should be for our full pretrained model.""" +) +parser.add_argument( + "--log-every", type=int, default=50, + help="""Log training curves to tensorboard after every these many iterations + only master process logs averaged loss values across processes.""", +) +parser.add_argument( + "--checkpoint-path", + help="""Path to load checkpoint and run downstream task evaluation. The + name of checkpoint file is required to be `model_*.pth`, where * is + iteration number from which the checkpoint was serialized.""" +) +parser.add_argument( + "--checkpoint-every", type=int, default=5000, + help="""Serialize model to a checkpoint after every these many iterations. + For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations + = 1 epoch).""", +) +# fmt: on + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device = torch.device("cpu") + else: + # Get the current device as set for current distributed process. + # Check `launch` function in `virtex.utils.distributed` module. + device = torch.cuda.current_device() + + # Create a downstream config object (this will be immutable) and perform + # common setup such as logging and setting up serialization directory. + _DOWNC = Config(_A.down_config, _A.down_config_override) + common_setup(_DOWNC, _A, job_type="downstream") + + # Create a (pretraining) config object and backup in serializaion directory. + _C = Config(_A.config, _A.config_override) + _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) + + # Get dataset name for tensorboard logging. + DATASET = _DOWNC.DATA.ROOT.split("/")[-1] + + # Set number of output classes according to dataset: + NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142} + NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET] + + # ------------------------------------------------------------------------- + # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER + # ------------------------------------------------------------------------- + train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train") + train_dataloader = DataLoader( + train_dataset, + batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), + num_workers=_A.cpu_workers, + sampler=DistributedSampler( + train_dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True, + ), + drop_last=False, + pin_memory=True, + collate_fn=train_dataset.collate_fn, + ) + val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val") + val_dataloader = DataLoader( + val_dataset, + batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), + num_workers=_A.cpu_workers, + sampler=DistributedSampler( + val_dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False, + ), + pin_memory=True, + drop_last=False, + collate_fn=val_dataset.collate_fn, + ) + # Initialize model using pretraining config. + pretrained_model = PretrainingModelFactory.from_config(_C) + + # Load weights according to the init method, do nothing for `random`, and + # `imagenet` is already taken care of. + if _A.weight_init == "virtex": + CheckpointManager(model=pretrained_model).load(_A.checkpoint_path) + elif _A.weight_init == "torchvision": + # Keep strict=False because this state dict may have weights for + # last fc layer. + pretrained_model.visual.cnn.load_state_dict( + torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], + strict=False, + ) + + # Pull out the CNN (torchvision-like) from our pretrained model and add + # back the FC layer - this is exists in torchvision models, and is set to + # `nn.Identity()` during pretraining. + model = pretrained_model.visual.cnn # type: ignore + model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device) + model = model.to(device) + + # Re-initialize the FC layer. + torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01) + torch.nn.init.constant_(model.fc.bias.data, 0.0) + + # Freeze all layers except FC as per config param. + if _DOWNC.MODEL.VISUAL.FROZEN: + # Set model to eval mode to prevent BatchNorm from updating running + # mean and std. With only a linear layer, being in eval mode when + # training will not matter anyway. + model.eval() + + for name, param in model.named_parameters(): + if "fc" not in name: + param.requires_grad = False + + # Cross entropy loss and accuracy meter. + criterion = nn.CrossEntropyLoss() + top1 = TopkAccuracy(top_k=1) + + optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters()) + scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer) + del pretrained_model + + # ------------------------------------------------------------------------- + # BEFORE TRAINING STARTS + # ------------------------------------------------------------------------- + + # Create a gradient scaler for automatic mixed precision. + scaler = amp.GradScaler(enabled=_DOWNC.AMP) + + # Create an iterator from dataloader to sample batches perpetually. + train_dataloader_iter = cycle(train_dataloader, device) + + if dist.get_world_size() > 1: + dist.synchronize() + model = nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) + + if dist.is_master_process(): + checkpoint_manager = CheckpointManager( + _A.serialization_dir, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) + + # Keep track of time per iteration and ETA. + timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS) + + # ------------------------------------------------------------------------- + # TRAINING LOOP + # ------------------------------------------------------------------------- + for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1): + timer.tic() + optimizer.zero_grad() + batch = next(train_dataloader_iter) + + with amp.autocast(enabled=_DOWNC.AMP): + logits = model(batch["image"]) + loss = criterion(logits, batch["label"]) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + scheduler.step() + timer.toc() + + if iteration % _A.log_every == 0 and dist.is_master_process(): + logger.info( + f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB" + ) + tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration) + tensorboard_writer.add_scalar( + f"{DATASET}/learning_rate", + optimizer.param_groups[0]["lr"], + iteration, + ) + + # --------------------------------------------------------------------- + # VALIDATION + # --------------------------------------------------------------------- + if iteration % _A.checkpoint_every == 0: + torch.set_grad_enabled(False) + model.eval() + + total_val_loss = torch.tensor(0.0).to(device) + + for val_iteration, batch in enumerate(val_dataloader, start=1): + for key in batch: + batch[key] = batch[key].to(device) + + logits = model(batch["image"]) + loss = criterion(logits, batch["label"]) + top1(logits, batch["label"]) + total_val_loss += loss + + # Divide each loss component by number of val batches per GPU. + total_val_loss = total_val_loss / val_iteration + dist.average_across_processes(total_val_loss) + + # Get accumulated Top-1 accuracy for logging across GPUs. + acc = top1.get_metric(reset=True) + dist.average_across_processes(acc) + + torch.set_grad_enabled(True) + + # Set model back to train mode only when fine-tuning end-to-end. + if not _DOWNC.MODEL.VISUAL.FROZEN: + model.train() + + # Save recent checkpoint and best checkpoint based on accuracy. + if dist.is_master_process(): + checkpoint_manager.step(iteration) + + if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): + logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})") + tensorboard_writer.add_scalar( + f"{DATASET}/val_loss", total_val_loss, iteration + ) + # This name scoping will result in Tensorboard displaying all metrics + # (VOC07, caption, etc.) together. + tensorboard_writer.add_scalars( + f"metrics/{DATASET}", {"top1": acc}, iteration + ) + + # All processes will wait till master process is done logging. + dist.synchronize() + + +if __name__ == "__main__": + _A = parser.parse_args() + + # Add an arg in config override if `--weight-init` is imagenet. + if _A.weight_init == "imagenet": + _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) + + if _A.num_gpus_per_machine == 0: + main(_A) + else: + # This will launch `main` and set appropriate CUDA device (GPU ID) as + # per process (accessed in the beginning of `main`). + dist.launch( + main, + num_machines=_A.num_machines, + num_gpus_per_machine=_A.num_gpus_per_machine, + machine_rank=_A.machine_rank, + dist_url=_A.dist_url, + args=(_A,), + ) diff --git a/virtex/scripts/clf_voc07.py b/virtex/scripts/clf_voc07.py new file mode 100644 index 0000000000000000000000000000000000000000..0e382c1ac49a3c9c254ab9c97f14652ed664fbf6 --- /dev/null +++ b/virtex/scripts/clf_voc07.py @@ -0,0 +1,272 @@ +import argparse +import multiprocessing as mp +import os +from typing import Any, List + +from loguru import logger +import numpy as np +from sklearn.svm import LinearSVC +from sklearn.metrics import average_precision_score +from sklearn.model_selection import cross_val_score +import torch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from virtex.config import Config +from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser, common_setup + + +parser = common_parser( + description="Train SVMs for VOC2007 classification on a pretrained model." +) +group = parser.add_argument_group("Downstream config arguments.") +group.add_argument( + "--down-config", metavar="FILE", help="Path to a downstream config file." +) +group.add_argument( + "--down-config-override", + nargs="*", + default=[], + help="A list of key-value pairs to modify downstream config params.", +) + +# fmt: off +parser.add_argument_group("Checkpointing") +parser.add_argument( + "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], + default="virtex", help="""How to initialize weights: + 1. 'random' initializes all weights randomly + 2. 'imagenet' initializes backbone weights from torchvision model zoo + 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path + - with 'torchvision', state dict would be from PyTorch's training + script. + - with 'virtex' it should be for our full pretrained model.""" +) +parser.add_argument( + "--checkpoint-path", + help="Path to load checkpoint and run downstream task evaluation." +) +# fmt: on + + +def train_test_single_svm(args): + + feats_train, tgts_train, feats_test, tgts_test, cls_name = args + SVM_COSTS = [0.01, 0.1, 1.0, 10.0] + + cls_labels = np.copy(tgts_train) + # Meaning of labels in VOC/COCO original loaded target files: + # label 0 = not present, set it to -1 as svm train target + # label 1 = present. Make the svm train target labels as -1, 1. + cls_labels[np.where(cls_labels == 0)] = -1 + + # See which cost maximizes the AP for this class. + best_crossval_ap: float = 0.0 + best_crossval_clf = None + best_cost: float = 0.0 + + # fmt: off + for cost in SVM_COSTS: + clf = LinearSVC( + C=cost, class_weight={1: 2, -1: 1}, penalty="l2", + loss="squared_hinge", max_iter=2000, + ) + ap_scores = cross_val_score( + clf, feats_train, cls_labels, cv=3, scoring="average_precision", + ) + clf.fit(feats_train, cls_labels) + + # Keep track of best SVM (based on cost) for each class. + if ap_scores.mean() > best_crossval_ap: + best_crossval_ap = ap_scores.mean() + best_crossval_clf = clf + best_cost = cost + + logger.info(f"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}") + # fmt: on + + # ------------------------------------------------------------------------- + # TEST THE TRAINED SVM (PER CLASS) + # ------------------------------------------------------------------------- + predictions = best_crossval_clf.decision_function(feats_test) + evaluate_data_inds = tgts_test != -1 + eval_preds = predictions[evaluate_data_inds] + + cls_labels = np.copy(tgts_test) + eval_cls_labels = cls_labels[evaluate_data_inds] + eval_cls_labels[np.where(eval_cls_labels == 0)] = -1 + + # Binarize class labels to make AP targets. + targets = eval_cls_labels > 0 + return average_precision_score(targets, eval_preds) + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device = torch.device("cpu") + else: + # Get the current device (this will be zero here by default). + device = torch.cuda.current_device() + + # Create a downstream config object (this will be immutable) and perform + # common setup such as logging and setting up serialization directory. + _DOWNC = Config(_A.down_config, _A.down_config_override) + common_setup(_DOWNC, _A, job_type="downstream") + + # Create a (pretraining) config object and backup in serialization directory. + _C = Config(_A.config, _A.config_override) + _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) + + # ------------------------------------------------------------------------- + # INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR + # ------------------------------------------------------------------------- + + train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="trainval") + train_dataloader = DataLoader( + train_dataset, + batch_size=_DOWNC.OPTIM.BATCH_SIZE, + num_workers=_A.cpu_workers, + pin_memory=True, + ) + test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="test") + test_dataloader = DataLoader( + test_dataset, + batch_size=_DOWNC.OPTIM.BATCH_SIZE, + num_workers=_A.cpu_workers, + pin_memory=True, + ) + NUM_CLASSES = len(train_dataset.class_names) + + # Initialize from a checkpoint, but only keep the visual module. + model = PretrainingModelFactory.from_config(_C) + + # Load weights according to the init method, do nothing for `random`, and + # `imagenet` is already taken care of. + if _A.weight_init == "virtex": + ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) + elif _A.weight_init == "torchvision": + # Keep strict=False because this state dict may have weights for + # last fc layer. + model.visual.cnn.load_state_dict( + torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], + strict=False, + ) + # Set ``ITERATION`` to a dummy value. + ITERATION = 0 + + # Transfer model to GPU and set to eval mode. This is a torchvision model + # and it returns features as ``(batch_size, 2048, 7, 7)``. + model = model.visual.cnn.to(device).eval() + + # ------------------------------------------------------------------------- + # EXTRACT FEATURES FOR TRAINING SVMs + # ------------------------------------------------------------------------- + + features_train: List[torch.Tensor] = [] + targets_train: List[torch.Tensor] = [] + + features_test: List[torch.Tensor] = [] + targets_test: List[torch.Tensor] = [] + + # VOC07 is small, extract all features and keep them in memory. + with torch.no_grad(): + for batch in tqdm(train_dataloader, desc="Extracting train features:"): + features = model(batch["image"].to(device)) + + # Global average pool features. Assume the tensor is in NCHW format. + if len(features.size()) > 2: + features = features.view(features.size(0), features.size(1), -1) + + # shape: (batch_size, visual_feature_size) + features = features.mean(dim=-1) + + # shape: (batch_size, visual_feature_size) + features = features.view(features.size(0), -1) + + # L2-normalize the global average pooled features. + features = features / torch.norm(features, dim=-1).unsqueeze(-1) + + features_train.append(features.cpu()) + targets_train.append(batch["label"]) + + # Similarly extract test features. + for batch in tqdm(test_dataloader, desc="Extracting test features:"): + features = model(batch["image"].to(device)) + + if len(features.size()) > 2: + features = features.view(features.size(0), features.size(1), -1) + features = features.mean(dim=-1) + + features = features.view(features.size(0), -1) + features = features / torch.norm(features, dim=-1).unsqueeze(-1) + + features_test.append(features.cpu()) + targets_test.append(batch["label"]) + + # Convert batches of features/targets to one large numpy array + features_train = torch.cat(features_train, dim=0).numpy() + targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32) + + features_test = torch.cat(features_test, dim=0).numpy() + targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32) + + # ------------------------------------------------------------------------- + # TRAIN AND TEST SVMs WITH EXTRACTED FEATURES + # ------------------------------------------------------------------------- + + input_args: List[Any] = [] + + # Iterate over all VOC07 classes and train one-vs-all linear SVMs. + for cls_idx in range(NUM_CLASSES): + # fmt: off + input_args.append(( + features_train, targets_train[:, cls_idx], + features_test, targets_test[:, cls_idx], + train_dataset.class_names[cls_idx], + )) + # fmt: on + + pool = mp.Pool(processes=_A.cpu_workers) + pool_output = pool.map(train_test_single_svm, input_args) + + # ------------------------------------------------------------------------- + # TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint) + # ------------------------------------------------------------------------- + + # Tensorboard writer for logging mAP scores. This is useful especially + # when weight_init=checkpoint (which maybe be coming from a training job). + tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) + + # Test set mAP for each class, for features from every layer. + test_map = torch.tensor(pool_output).mean() + logger.info(f"Iteration: {ITERATION}, mAP: {test_map * 100}") + tensorboard_writer.add_scalars( + "metrics/voc07_clf", {f"voc07_mAP": test_map * 100}, ITERATION + ) + + # NOTE: for copy-pasting to spreadsheet. + logger.info( + f"{_C.DATA.ROOT.split('/')[1]},{_C.DATA.TOKENIZER_MODEL.split('/')[-1][:-6]}," + f"{_C.DATA.VOCAB_SIZE},{_C.MODEL.NAME},{_C.MODEL.VISUAL.NAME},{_C.MODEL.TEXTUAL.NAME}," + f"{_C.MODEL.LABEL_SMOOTHING},{_C.OPTIM.OPTIMIZER_NAME},{_C.OPTIM.BATCH_SIZE}," + f"{_C.OPTIM.NUM_ITERATIONS},{_C.OPTIM.LR},{_C.OPTIM.WEIGHT_DECAY}," + f"{ITERATION},{test_map * 100:.3f}" + ) + +if __name__ == "__main__": + _A = parser.parse_args() + + if _A.num_gpus_per_machine > 1: + raise ValueError("Using multiple GPUs is not supported for this script.") + + # Add an arg in config override if `--weight-init` is imagenet. + if _A.weight_init == "imagenet": + _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) + + # No distributed training here, just a single process. + main(_A) diff --git a/virtex/scripts/eval_captioning.py b/virtex/scripts/eval_captioning.py new file mode 100644 index 0000000000000000000000000000000000000000..8da98284f1726027536e38b72b4a82ba04bea396 --- /dev/null +++ b/virtex/scripts/eval_captioning.py @@ -0,0 +1,114 @@ +import argparse +import json +import os +from typing import Any, Dict, List + +from loguru import logger +import torch +from torch.utils.data import DataLoader + +from virtex.config import Config +from virtex.data import ImageDirectoryDataset +from virtex.factories import TokenizerFactory, PretrainingModelFactory +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser +from virtex.utils.metrics import CocoCaptionsEvaluator + + +# fmt: off +parser = common_parser( + description="""Run image captioning inference on a pretrained model, and/or + evaluate pretrained model on COCO Captions val2017 split.""" +) +parser.add_argument( + "--data-root", default=None, + help="""Path to a directory containing image files to generate captions for. + Default: COCO val2017 image directory as expected relative to project root.""" +) +parser.add_argument( + "--checkpoint-path", required=True, + help="Path to load checkpoint and run captioning evaluation." +) +parser.add_argument( + "--output", default=None, + help="Path to save predictions as a JSON file." +) +parser.add_argument( + "--calc-metrics", action="store_true", + help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. + This flag should not be set when running inference on arbitrary images.""" +) +# fmt: on + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device = torch.device("cpu") + else: + # Get the current device (this will be zero here by default). + device = torch.cuda.current_device() + + _C = Config(_A.config, _A.config_override) + + tokenizer = TokenizerFactory.from_config(_C) + + if _A.data_root is None: + _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") + + val_dataloader = DataLoader( + ImageDirectoryDataset(_A.data_root), + batch_size=_C.OPTIM.BATCH_SIZE, + num_workers=_A.cpu_workers, + pin_memory=True, + ) + # Initialize model from a checkpoint. + model = PretrainingModelFactory.from_config(_C).to(device) + ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) + model.eval() + + # Make a list of predictions to evaluate. + predictions: List[Dict[str, Any]] = [] + + for val_iteration, val_batch in enumerate(val_dataloader, start=1): + + val_batch["image"] = val_batch["image"].to(device) + with torch.no_grad(): + output_dict = model(val_batch) + + # Make a dictionary of predictions in COCO format. + for image_id, caption in zip( + val_batch["image_id"], output_dict["predictions"] + ): + predictions.append( + { + # Convert image id to int if possible (mainly for COCO eval). + "image_id": int(image_id) if image_id.isdigit() else image_id, + "caption": tokenizer.decode(caption.tolist()), + } + ) + + # Save predictions as a JSON file if specified. + if _A.output is not None: + os.makedirs(os.path.dirname(_A.output), exist_ok=True) + json.dump(predictions, open(_A.output, "w")) + logger.info(f"Saved predictions to {_A.output}") + + # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This + # should be skipped when running inference on arbitrary images. + if _A.calc_metrics: + # Assume ground truth (COCO val2017 annotations) exist. + gt = os.path.join(_C.DATA.ROOT, "annotations", "captions_val2017.json") + + metrics = CocoCaptionsEvaluator(gt).evaluate(predictions) + logger.info(f"Iter: {ITERATION} | Metrics: {metrics}") + + +if __name__ == "__main__": + _A = parser.parse_args() + if _A.num_gpus_per_machine > 1: + raise ValueError("Using multiple GPUs is not supported for this script.") + + # No distributed training here, just a single process. + main(_A) diff --git a/virtex/scripts/eval_detectron2.py b/virtex/scripts/eval_detectron2.py new file mode 100644 index 0000000000000000000000000000000000000000..b79147080f8c56313e1a809b9f1a791ecd380e11 --- /dev/null +++ b/virtex/scripts/eval_detectron2.py @@ -0,0 +1,248 @@ +""" +Finetune a pre-trained model on a downstream task, one of those available in +Detectron2. +Supported downstream: + - LVIS Instance Segmentation + - COCO Instance Segmentation + - Pascal VOC 2007+12 Object Detection + +Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py +Thanks to the developers of Detectron2! +""" +import argparse +import os +import re +from typing import Any, Dict, Union + +import torch +from torch.utils.tensorboard import SummaryWriter + +import detectron2 as d2 +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.engine import DefaultTrainer, default_setup +from detectron2.evaluation import ( + LVISEvaluator, + PascalVOCDetectionEvaluator, + COCOEvaluator, +) +from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads + +from virtex.config import Config +from virtex.factories import PretrainingModelFactory +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser +import virtex.utils.distributed as dist + +# fmt: off +parser = common_parser( + description="Train object detectors from pretrained visual backbone." +) +parser.add_argument( + "--d2-config", required=True, + help="Path to a detectron2 config for downstream task finetuning." +) +parser.add_argument( + "--d2-config-override", nargs="*", default=[], + help="""Key-value pairs from Detectron2 config to override from file. + Some keys will be ignored because they are set from other args: + [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD, + TEST.EVAL_PERIOD, OUTPUT_DIR]""", +) + +parser.add_argument_group("Checkpointing and Logging") +parser.add_argument( + "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], + default="virtex", help="""How to initialize weights: + 1. 'random' initializes all weights randomly + 2. 'imagenet' initializes backbone weights from torchvision model zoo + 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path + - with 'torchvision', state dict would be from PyTorch's training + script. + - with 'virtex' it should be for our full pretrained model.""" +) +parser.add_argument( + "--checkpoint-path", + help="Path to load checkpoint and run downstream task evaluation." +) +parser.add_argument( + "--resume", action="store_true", help="""Specify this flag when resuming + training from a checkpoint saved by Detectron2.""" +) +parser.add_argument( + "--eval-only", action="store_true", + help="Skip training and evaluate checkpoint provided at --checkpoint-path.", +) +parser.add_argument( + "--checkpoint-every", type=int, default=5000, + help="Serialize model to a checkpoint after every these many iterations.", +) +# fmt: on + + +@ROI_HEADS_REGISTRY.register() +class Res5ROIHeadsExtraNorm(Res5ROIHeads): + r""" + ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN + C4/DC5 backbones for VOC detection. + """ + + def _build_res5_block(self, cfg): + seq, out_channels = super()._build_res5_block(cfg) + norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels) + seq.add_module("norm", norm) + return seq, out_channels + + +def build_detectron2_config(_C: Config, _A: argparse.Namespace): + r"""Build detectron2 config based on our pre-training config and args.""" + _D2C = d2.config.get_cfg() + + # Override some default values based on our config file. + _D2C.merge_from_file(_A.d2_config) + _D2C.merge_from_list(_A.d2_config_override) + + # Set some config parameters from args. + _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers + _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every + _D2C.OUTPUT_DIR = _A.serialization_dir + + # Set ResNet depth to override in Detectron2's config. + _D2C.MODEL.RESNETS.DEPTH = int( + re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1) + if "torchvision" in _C.MODEL.VISUAL.NAME + else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1) + if "detectron2" in _C.MODEL.VISUAL.NAME + else 0 + ) + return _D2C + + +class DownstreamTrainer(DefaultTrainer): + r""" + Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks. + + Parameters + ---------- + cfg: detectron2.config.CfgNode + Detectron2 config object containing all config params. + weights: Union[str, Dict[str, Any]] + Weights to load in the initialized model. If ``str``, then we assume path + to a checkpoint, or if a ``dict``, we assume a state dict. This will be + an ``str`` only if we resume training from a Detectron2 checkpoint. + """ + + def __init__(self, cfg, weights: Union[str, Dict[str, Any]]): + + super().__init__(cfg) + + # Load pre-trained weights before wrapping to DDP because `ApexDDP` has + # some weird issue with `DetectionCheckpointer`. + # fmt: off + if isinstance(weights, str): + # weights are ``str`` means ImageNet init or resume training. + self.start_iter = ( + DetectionCheckpointer( + self._trainer.model, + optimizer=self._trainer.optimizer, + scheduler=self.scheduler + ).resume_or_load(weights, resume=True).get("iteration", -1) + 1 + ) + elif isinstance(weights, dict): + # weights are a state dict means our pretrain init. + DetectionCheckpointer(self._trainer.model)._load_model(weights) + # fmt: on + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + elif evaluator_type == "coco": + return COCOEvaluator(dataset_name, cfg, True, output_folder) + elif evaluator_type == "lvis": + return LVISEvaluator(dataset_name, cfg, True, output_folder) + + def test(self, cfg=None, model=None, evaluators=None): + r"""Evaluate the model and log results to stdout and tensorboard.""" + cfg = cfg or self.cfg + model = model or self.model + + tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) + results = super().test(cfg, model) + flat_results = d2.evaluation.testing.flatten_results_dict(results) + for k, v in flat_results.items(): + tensorboard_writer.add_scalar(k, v, self.start_iter) + + +def main(_A: argparse.Namespace): + + # Get the current device as set for current distributed process. + # Check `launch` function in `virtex.utils.distributed` module. + device = torch.cuda.current_device() + + # Local process group is needed for detectron2. + pg = list(range(dist.get_world_size())) + d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg) + + # Create a config object (this will be immutable) and perform common setup + # such as logging and setting up serialization directory. + if _A.weight_init == "imagenet": + _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) + _C = Config(_A.config, _A.config_override) + + # We use `default_setup` from detectron2 to do some common setup, such as + # logging, setting up serialization etc. For more info, look into source. + _D2C = build_detectron2_config(_C, _A) + default_setup(_D2C, _A) + + # Prepare weights to pass in instantiation call of trainer. + if _A.weight_init in {"virtex", "torchvision"}: + if _A.resume: + # If resuming training, let detectron2 load weights by providing path. + model = None + weights = _A.checkpoint_path + else: + # Load backbone weights from VirTex pretrained checkpoint. + model = PretrainingModelFactory.from_config(_C) + if _A.weight_init == "virtex": + CheckpointManager(model=model).load(_A.checkpoint_path) + else: + model.visual.cnn.load_state_dict( + torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], + strict=False, + ) + weights = model.visual.detectron2_backbone_state_dict() + else: + # If random or imagenet init, just load weights after initializing model. + model = PretrainingModelFactory.from_config(_C) + weights = model.visual.detectron2_backbone_state_dict() + + # Back up pretrain config and model checkpoint (if provided). + _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) + if _A.weight_init == "virtex" and not _A.resume: + torch.save( + model.state_dict(), + os.path.join(_A.serialization_dir, "pretrain_model.pth"), + ) + + del model + trainer = DownstreamTrainer(_D2C, weights) + trainer.test() if _A.eval_only else trainer.train() + + +if __name__ == "__main__": + _A = parser.parse_args() + + # This will launch `main` and set appropriate CUDA device (GPU ID) as + # per process (accessed in the beginning of `main`). + dist.launch( + main, + num_machines=_A.num_machines, + num_gpus_per_machine=_A.num_gpus_per_machine, + machine_rank=_A.machine_rank, + dist_url=_A.dist_url, + args=(_A, ), + ) diff --git a/virtex/scripts/preprocess/build_redcaps_vocab.py b/virtex/scripts/preprocess/build_redcaps_vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..fd28f1b8d72fde036f631032a539c4fe16d169f2 --- /dev/null +++ b/virtex/scripts/preprocess/build_redcaps_vocab.py @@ -0,0 +1,107 @@ +import argparse +import glob +import json +import os +import re +import tempfile +from functools import lru_cache +from typing import List + +import ftfy +import sentencepiece as sp +import wordsegment as ws +from tqdm import tqdm + + +ws.load() + +# fmt: off +parser = argparse.ArgumentParser( + description="""Build a vocabulary out of captions corpus. This vocabulary + would be a file which our tokenizer can understand. + """ +) +parser.add_argument( + "-f", "--files", nargs="+", default="datasets/redcaps/annotations/*.json", + help="Path(s) to SBU, Conceptual, or RedCaps annotation files.", +) +parser.add_argument( + "-s", "--vocab-size", type=int, default=32000, + help="Total desired size of our vocabulary.", +) +parser.add_argument( + "-o", "--output-prefix", default="datasets/vocab/redcaps_32k", + help="Prefix of the files to be saved. Two files will be saved: " + "[prefix].model and [prefix].vocab", +) +# fmt: on + + +def read_captions_from_file(annotations_path: str) -> List[str]: + r""" + Given a path to annotation file, read it and return a list of captions. + + Parameters + ---------- + annotations_path: str + Path to an annotations file containing captions. + + Returns + ------- + List[str] + List of captions from this annotation file. + """ + + _annotations = json.load(open(annotations_path)) + + captions: List[str] = [] + for ann in tqdm(_annotations["annotations"], desc=annotations_path): + + # This field only exists in RedCaps. Perform word segmentation on the + # subreddit name to add appropriae whitespaces. + if "subreddit" in ann: + subreddit_seg = _segment_subreddit(ann["subreddit"].lower()) + caption = f"{subreddit_seg} {ann['caption']}" + else: + caption = ann["caption"] + + captions.append(caption.lower()) + return captions + + +@lru_cache(maxsize=10) +def _segment_subreddit(subreddit): + return " ".join(ws.segment(ws.clean(subreddit))) + + +if __name__ == "__main__": + _A = parser.parse_args() + + all_filepaths: List[str] = [] + for f in _A.files: + all_filepaths.extend(glob.glob(f)) + + captions: List[str] = [] + for path in tqdm(all_filepaths, desc="Reading captions"): + captions.extend(read_captions_from_file(path)) + + # Create a temporary directory and dump the captions corpus as a text file + # with one caption per line. That's how sentencepiece wants its input. + tmpdir_path = tempfile.mkdtemp() + + with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file: + for caption in captions: + captions_file.write(caption + "\n") + + # Padding/out-of-vocab token will be "" and ID 0 by default. + # Add [SOS],[EOS] and [SEP] tokens. [SEP] will not be used during + # captioning, but good to have to reuse vocabulary across pretext tasks. + sp.SentencePieceTrainer.train( + f" --input={os.path.join(tmpdir_path, 'captions.txt')}" + f" --vocab_size={_A.vocab_size}" + f" --model_prefix={_A.output_prefix}" + " --model_type=bpe --character_coverage=1.0" + " --bos_id=-1 --eos_id=-1" + " --control_symbols=[SOS],[EOS],[SEP]" + " --user_defined_symbols=" + ) diff --git a/virtex/scripts/preprocess/build_vocabulary.py b/virtex/scripts/preprocess/build_vocabulary.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7a592b40d8044919279dc8116ca03dce20b5d1 --- /dev/null +++ b/virtex/scripts/preprocess/build_vocabulary.py @@ -0,0 +1,100 @@ +import argparse +import json +import os +import tempfile +import unicodedata +from typing import List + +import sentencepiece as sp + + +# fmt: off +parser = argparse.ArgumentParser( + description="""Build a vocabulary out of captions corpus. This vocabulary + would be a file which our tokenizer can understand. + """ +) +parser.add_argument( + "-c", "--captions", default="datasets/coco/annotations/captions_train2017.json", + help="Path to caption annotations file in COCO format.", +) +parser.add_argument( + "-s", "--vocab-size", type=int, default=10000, + help="Total desired size of our vocabulary.", +) +parser.add_argument( + "-o", "--output-prefix", default="datasets/vocab/coco_10k", + help="Prefix of the files to be saved. Two files will be saved: " + "[prefix].model and [prefix].vocab", +) +parser.add_argument( + "-l", "--do-lower-case", action="store_true", + help="Whether to lower case the captions before forming vocabulary.", +) +parser.add_argument( + "-a", "--keep-accents", action="store_true", + help="Whether to keep accents before forming vocabulary (dropped by default).", +) +# fmt: on + + +def _read_captions(annotations_path: str) -> List[str]: + r""" + Given a path to annotation file, read it and return a list of captions. + These are not processed by any means, returned from the file as-is. + + Parameters + ---------- + annotations_path: str + Path to an annotations file containing captions. + + Returns + ------- + List[str] + List of captions from this annotation file. + """ + + _annotations = json.load(open(annotations_path)) + + captions: List[str] = [] + for ann in _annotations["annotations"]: + captions.append(ann["caption"]) + + return captions + + +if __name__ == "__main__": + _A = parser.parse_args() + captions: List[str] = _read_captions(_A.captions) + + # Lower case the captions and remove accents according to arguments. + for i, caption in enumerate(captions): + caption = caption.lower() if _A.do_lower_case else caption + + if not _A.keep_accents: + caption = unicodedata.normalize("NFKD", caption) + caption = "".join( + [chr for chr in caption if not unicodedata.combining(chr)] + ) + + captions[i] = caption + + # Create a temporary directory and dump the captions corpus as a text file + # with one caption per line. That's how sentencepiece wants its input. + tmpdir_path = tempfile.mkdtemp() + + with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file: + for caption in captions: + captions_file.write(caption + "\n") + + # Padding/out-of-vocab token will be "" and ID 0 by default. + # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during + # captioning, but good to have to reuse vocabulary across pretext tasks. + sp.SentencePieceTrainer.train( + f" --input={os.path.join(tmpdir_path, 'captions.txt')}" + f" --vocab_size={_A.vocab_size}" + f" --model_prefix={_A.output_prefix}" + " --model_type=bpe --character_coverage=1.0" + " --bos_id=-1 --eos_id=-1" + " --control_symbols=[SOS],[EOS],[MASK]" + ) diff --git a/virtex/scripts/preprocess/preprocess_coco.py b/virtex/scripts/preprocess/preprocess_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..abf768d94d494a2e8397b596ca7993a638d2d840 --- /dev/null +++ b/virtex/scripts/preprocess/preprocess_coco.py @@ -0,0 +1,100 @@ +import argparse +import os +import pickle +import platform +from typing import Any, List + +import albumentations as alb +import lmdb +from tqdm import tqdm +from torch.utils.data import DataLoader + +from virtex.data.readers import SimpleCocoCaptionsReader + + +# fmt: off +parser = argparse.ArgumentParser("Serialize a COCO Captions split to LMDB.") +parser.add_argument( + "-d", "--data-root", default="datasets/coco", + help="Path to the root directory of COCO dataset.", +) +parser.add_argument( + "-s", "--split", choices=["train", "val"], + help="Which split to process, either `train` or `val`.", +) +parser.add_argument( + "-b", "--batch-size", type=int, default=16, + help="Batch size to process and serialize data. Set as per CPU memory.", +) +parser.add_argument( + "-j", "--cpu-workers", type=int, default=4, + help="Number of CPU workers for data loading.", +) +parser.add_argument( + "-e", "--short-edge-size", type=int, default=None, + help="""Resize shorter edge to this size (keeping aspect ratio constant) + before serializing. Useful for saving disk memory, and faster read. + If None, no images are resized.""" +) +parser.add_argument( + "-o", "--output", default="datasets/serialized/coco_train2017.lmdb", + help="Path to store the file containing serialized dataset.", +) + + +def collate_fn(instances: List[Any]): + r"""Collate function for data loader to return list of instances as-is.""" + return instances + + +if __name__ == "__main__": + + _A = parser.parse_args() + os.makedirs(os.path.dirname(_A.output), exist_ok=True) + + dloader = DataLoader( + SimpleCocoCaptionsReader(_A.data_root, _A.split), + batch_size=_A.batch_size, + num_workers=_A.cpu_workers, + shuffle=False, + drop_last=False, + collate_fn=collate_fn + ) + # Open an LMDB database. + # Set a sufficiently large map size for LMDB (based on platform). + map_size = 1099511627776 * 2 if platform.system() == "Linux" else 1280000 + db = lmdb.open( + _A.output, map_size=map_size, subdir=False, meminit=False, map_async=True + ) + + # Transform to resize shortest edge and keep aspect ratio same. + if _A.short_edge_size is not None: + resize = alb.SmallestMaxSize(max_size=_A.short_edge_size, always_apply=True) + + # Serialize each instance (as a dictionary). Use `pickle.dumps`. Key will + # be an integer (cast as string) starting from `0`. + INSTANCE_COUNTER: int = 0 + + for idx, batch in enumerate(tqdm(dloader)): + + txn = db.begin(write=True) + + for instance in batch: + image = instance["image"] + width, height, channels = image.shape + + # Resize image from instance and convert instance to tuple. + if _A.short_edge_size is not None and min(width, height) > _A.short_edge_size: + image = resize(image=image)["image"] + + instance = (instance["image_id"], instance["image"], instance["captions"]) + txn.put( + f"{INSTANCE_COUNTER}".encode("ascii"), + pickle.dumps(instance, protocol=-1) + ) + INSTANCE_COUNTER += 1 + + txn.commit() + + db.sync() + db.close() diff --git a/virtex/scripts/preprocess/preprocess_redcaps.py b/virtex/scripts/preprocess/preprocess_redcaps.py new file mode 100644 index 0000000000000000000000000000000000000000..abb4e9e71e272797e2caf3eed304bc4cdc98f85e --- /dev/null +++ b/virtex/scripts/preprocess/preprocess_redcaps.py @@ -0,0 +1,102 @@ +import argparse +import json +import os +import tarfile +import tempfile +from typing import Dict, List + +from loguru import logger +from tqdm import tqdm + + +# fmt: off +parser = argparse.ArgumentParser( + description="""Pre-process RedCaps dataset for training VirTex models - make + small shards of TAR files containing images and captions.""" +) +parser.add_argument( + "-a", "--annotations", required=True, help="Path to a RedCaps annotation file." +) +parser.add_argument( + "-i", "--images", default="datasets/redcaps/images", + help="""Path to RedCaps image directory. This directory is expected to have + subreddit specific sub-directories containing images.""", +) +parser.add_argument( + "-z", "--shard-size", type=int, default=1000, + help="Maximum number of RedCaps instances in a single TAR file shard.", +) +parser.add_argument( + "-o", "--output-prefix", required=True, + help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` " + "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...", +) +# fmt: on + + +def main(_A: argparse.Namespace): + r""" + Make TAR files containing images and annotations from a single RedCaps + annotations file. These TAR files are arranged in a way that + `WebDataset `_ can understand. + """ + + ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"] + + # Keep track of the current index of TAR file shard and dataset index. + SHARD_INDEX: int = 0 + DATASET_INDEX: int = 0 + + # Create TAR file handle for the initial shard. + tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w") + + # Keep a count of submissions that were skipped because their image was + # not downloaded (not present in image dir). + SKIPPED: int = 0 + + for ann in tqdm(ANNOTATIONS): + + image_path = os.path.join( + _A.images, ann["subreddit"], f"{ann['image_id']}.jpg" + ) + # Add current image in shard if it exists. + if os.path.exists(image_path): + + tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg") + + # Save subreddit name and caption as a JSON file. + subreddit_and_caption = { + "subreddit": ann["subreddit"], "caption": ann["caption"] + } + tmpfile = tempfile.NamedTemporaryFile("w+") + tmpfile.write(json.dumps(subreddit_and_caption)) + tmpfile.seek(0) + tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json") + tmpfile.close() + + DATASET_INDEX += 1 + + # Create new shard if current shard is full. + if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0: + tar_handle.close() + logger.success( + f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar" + ) + SHARD_INDEX += 1 + + # Open new TAR file shard. + tar_handle = tarfile.open( + f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w" + ) + else: + SKIPPED += 1 + + # Close the file handle to properly save it. + tar_handle.close() + logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n") + logger.info(f"Skipped {SKIPPED} annotations due to missing images.") + + +if __name__ == "__main__": + _A = parser.parse_args() + main(_A) diff --git a/virtex/scripts/pretrain_virtex.py b/virtex/scripts/pretrain_virtex.py new file mode 100644 index 0000000000000000000000000000000000000000..73e36ed3428c6899876fca9961dbeb81dcb2bd0c --- /dev/null +++ b/virtex/scripts/pretrain_virtex.py @@ -0,0 +1,239 @@ +import argparse +from collections import Counter +from typing import Any + +from loguru import logger +import torch +from torch import nn +from torch.cuda import amp +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +# fmt: off +from virtex.config import Config +from virtex.factories import ( + PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, + LRSchedulerFactory, +) +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser, common_setup, cycle +import virtex.utils.distributed as dist +from virtex.utils.timer import Timer + + +parser = common_parser( + description="Train a VirTex model (CNN + Transformer) on COCO Captions." +) +group = parser.add_argument_group("Checkpointing and Logging") +group.add_argument( + "--resume-from", default=None, + help="Path to a checkpoint to resume training from (if provided)." +) +group.add_argument( + "--checkpoint-every", type=int, default=2000, + help="Serialize model to a checkpoint after every these many iterations.", +) +group.add_argument( + "--log-every", type=int, default=20, + help="""Log training curves to tensorboard after every these many iterations + only master process logs averaged loss values across processes.""", +) +# fmt: on + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device: Any = torch.device("cpu") + else: + # Get the current device as set for current distributed process. + # Check `launch` function in `virtex.utils.distributed` module. + device = torch.cuda.current_device() + + # Create a config object (this will be immutable) and perform common setup + # such as logging and setting up serialization directory. + _C = Config(_A.config, _A.config_override) + common_setup(_C, _A) + + # ------------------------------------------------------------------------- + # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER + # ------------------------------------------------------------------------- + train_dataset = PretrainingDatasetFactory.from_config(_C, split="train") + val_dataset = PretrainingDatasetFactory.from_config(_C, split="val") + + # Make `DistributedSampler`s to shard datasets across GPU processes. + # Skip this if training on CPUs. + train_sampler = ( + DistributedSampler(train_dataset, shuffle=True) # type: ignore + if _A.num_gpus_per_machine > 0 + else None + ) + val_sampler = ( + DistributedSampler(val_dataset, shuffle=False) # type: ignore + if _A.num_gpus_per_machine > 0 + else None + ) + train_dataloader = DataLoader( + train_dataset, + batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), + sampler=train_sampler, + shuffle=train_sampler is None, + num_workers=_A.cpu_workers, + pin_memory=True, + drop_last=True, + collate_fn=train_dataset.collate_fn, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), + sampler=val_sampler, + shuffle=False, + num_workers=_A.cpu_workers, + pin_memory=True, + drop_last=False, + collate_fn=val_dataset.collate_fn, + ) + + model = PretrainingModelFactory.from_config(_C).to(device) + optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) + scheduler = LRSchedulerFactory.from_config(_C, optimizer) + + # ------------------------------------------------------------------------- + # BEFORE TRAINING STARTS + # ------------------------------------------------------------------------- + + # Create a gradient scaler for automatic mixed precision. + scaler = amp.GradScaler(enabled=_C.AMP) + + # Load checkpoint to resume training if specified. + if _A.resume_from is not None: + start_iteration = CheckpointManager( + model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler, + ).load(_A.resume_from) + else: + start_iteration = 0 + + # Create an iterator from dataloader to sample batches perpetually. + train_dataloader_iter = cycle(train_dataloader, device, start_iteration) + + # Wrap model in DDP if using more than one processes. + if dist.get_world_size() > 1: + dist.synchronize() + model = nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) + + # Keep track of time per iteration and ETA. + timer = Timer( + start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS + ) + # Create tensorboard writer and checkpoint manager (only in master process). + if dist.is_master_process(): + tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) + tensorboard_writer.add_text("config", f"```\n{_C}\n```") + + checkpoint_manager = CheckpointManager( + _A.serialization_dir, + model=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + ) + + # ------------------------------------------------------------------------- + # TRAINING LOOP + # ------------------------------------------------------------------------- + for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): + timer.tic() + optimizer.zero_grad() + batch = next(train_dataloader_iter) + + with amp.autocast(enabled=_C.AMP): + output_dict = model(batch) + loss = output_dict["loss"] + + scaler.scale(loss).backward() + + # First clip norm of gradients, and then perform optimizer step. + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) + scaler.step(optimizer) + + scaler.update() + scheduler.step() + timer.toc() + + # --------------------------------------------------------------------- + # LOGGING + # --------------------------------------------------------------------- + if iteration % _A.log_every == 0: + logger.info( + f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" + ) + if dist.is_master_process(): + tensorboard_writer.add_scalars( + "learning_rate", + { + "visual": optimizer.param_groups[0]["lr"], + "common": optimizer.param_groups[-1]["lr"], + }, + iteration, + ) + tensorboard_writer.add_scalars( + "train", output_dict["loss_components"], iteration + ) + + # --------------------------------------------------------------------- + # VALIDATION + # --------------------------------------------------------------------- + if iteration % _A.checkpoint_every == 0: + if dist.is_master_process(): + checkpoint_manager.step(iteration) + + # All processes will wait till master process is done serializing. + dist.synchronize() + + torch.set_grad_enabled(False) + model.eval() + + # Accumulate different val loss components according to the type of + # pretraining model. + val_loss_counter: Counter = Counter() + + for val_iteration, val_batch in enumerate(val_dataloader, start=1): + for key in val_batch: + val_batch[key] = val_batch[key].to(device) + output_dict = model(val_batch) + + val_loss_counter.update(output_dict["loss_components"]) + + # Divide each loss component by number of val batches per GPU. + val_loss_dict = { + k: v / val_iteration for k, v in dict(val_loss_counter).items() + } + dist.average_across_processes(val_loss_dict) + torch.set_grad_enabled(True) + model.train() + + logger.info(f"Iteration: {iteration} [Val loss: {val_loss_dict}]") + if dist.is_master_process(): + tensorboard_writer.add_scalars("val", val_loss_dict, iteration) + + +if __name__ == "__main__": + _A = parser.parse_args() + + if _A.num_gpus_per_machine == 0: + main(_A) + else: + # This will launch `main` and set appropriate CUDA device (GPU ID) as + # per process (accessed in the beginning of `main`). + dist.launch( + main, + num_machines=_A.num_machines, + num_gpus_per_machine=_A.num_gpus_per_machine, + machine_rank=_A.machine_rank, + dist_url=_A.dist_url, + args=(_A, ), + ) diff --git a/virtex/scripts/redcaps_caption_decode.py b/virtex/scripts/redcaps_caption_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..d63b69ac6a13dc235a3ba4980dba582a9cd75be6 --- /dev/null +++ b/virtex/scripts/redcaps_caption_decode.py @@ -0,0 +1,140 @@ +import argparse +import json +import os +from typing import Any, Dict, List + +from loguru import logger +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +import wordsegment as ws + +from virtex.config import Config +from virtex.data import ImageDirectoryDataset +from virtex.factories import TokenizerFactory, PretrainingModelFactory +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser + +ws.load() + +# fmt: off +parser = common_parser( + description="Decode captions using a RedCaps-pretrained VirTex model." +) +parser.add_argument( + "--images", required=True, + help="Path to a directory containing image files to generate captions for." +) +parser.add_argument( + "--checkpoint-path", required=True, + help="Path to load checkpoint and run captioning evaluation." +) +parser.add_argument( + "--output", required=True, + help="Path to save predictions as a JSON file." +) +parser.add_argument( + "--subreddit-prompt", default=None, + help="Optional subreddit prompt for controllable subreddit-style captioning." +) +# fmt: on + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device = torch.device("cpu") + else: + # Get the current device (this will be zero here by default). + device = torch.cuda.current_device() + + _C = Config(_A.config, _A.config_override) + + tokenizer = TokenizerFactory.from_config(_C) + + val_dataloader = DataLoader( + ImageDirectoryDataset(_A.images), + batch_size=_C.OPTIM.BATCH_SIZE, + num_workers=_A.cpu_workers, + pin_memory=True, + ) + # Initialize model from a checkpoint. + model = PretrainingModelFactory.from_config(_C).to(device) + CheckpointManager(model=model).load(_A.checkpoint_path) + model.eval() + + # Prepare subreddit prompt for the model if provided. + if _A.subreddit_prompt is not None: + + # Remove "r/" if provided. + _A.subreddit_prompt = _A.subreddit_prompt.replace("r/", "") + + # Word segmenting (e.g. "itookapicture" -> "i took a picture"). + _segments = " ".join(ws.segment(ws.clean(_A.subreddit_prompt))) + subreddit_tokens = ( + [model.sos_index] + + tokenizer.encode(_segments) + + [tokenizer.token_to_id("[SEP]")] + ) + else: + # Just seed the model with [SOS] + subreddit_tokens = [model.sos_index] + + # Shift the subreddit prompt to appropriate device. + subreddit_tokens = torch.tensor(subreddit_tokens, device=device).long() + + # Make a list of predictions to evaluate. + predictions: List[Dict[str, Any]] = [] + + for val_batch in tqdm(val_dataloader): + val_batch["image"] = val_batch["image"].to(device) + + # Add the subreddit tokens as decoding prompt to batch. + val_batch["decode_prompt"] = subreddit_tokens + + with torch.no_grad(): + output_dict = model(val_batch) + + for idx, (image_id, caption) in enumerate( + zip(val_batch["image_id"], output_dict["predictions"]) + ): + caption = caption.tolist() + + # Replace [SOS] index with "::" temporarily so it gets decoded. + if tokenizer.token_to_id("[SEP]") in caption: + sos_index = caption.index(tokenizer.token_to_id("[SEP]")) + caption[sos_index] = tokenizer.token_to_id("::") + + caption = tokenizer.decode(caption) + + # Separate out subreddit from the rest of caption. + if "::" in caption: + subreddit, rest_of_caption = caption.split("::") + subreddit = "".join(subreddit.split()) + rest_of_caption = rest_of_caption.strip() + else: + subreddit, rest_of_caption = "", caption + + predictions.append( + {"image_id": image_id, "subreddit": subreddit, "caption": rest_of_caption} + ) + + logger.info("Displaying first 25 caption predictions:") + for pred in predictions[:25]: + logger.info(f"{pred['image_id']} - r/{pred['subreddit']}:: {pred['caption']}") + + # Save predictions as a JSON file. + os.makedirs(os.path.dirname(_A.output), exist_ok=True) + json.dump(predictions, open(_A.output, "w")) + logger.info(f"Saved predictions to {_A.output}") + + +if __name__ == "__main__": + _A = parser.parse_args() + if _A.num_gpus_per_machine > 1: + raise ValueError("Using multiple GPUs is not supported for this script.") + + # No distributed training here, just a single process. + main(_A) diff --git a/virtex/scripts/redcaps_train.py b/virtex/scripts/redcaps_train.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c63010361c80ffcdca873a8166448aa2f359ef --- /dev/null +++ b/virtex/scripts/redcaps_train.py @@ -0,0 +1,172 @@ +import argparse +import os +import tempfile +from typing import Any + +from loguru import logger +import torch +from torch import nn +from torch.cuda import amp +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +# fmt: off +from virtex.config import Config +from virtex.factories import ( + PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, + LRSchedulerFactory, +) +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser, common_setup, cycle +import virtex.utils.distributed as dist +from virtex.utils.timer import Timer + + +parser = common_parser( + description="Train a VirTex model (CNN + Transformer) on COCO Captions." +) +group = parser.add_argument_group("Checkpointing and Logging") +group.add_argument( + "--resume-from", default=None, + help="Path to a checkpoint to resume training from (if provided)." +) +group.add_argument( + "--checkpoint-every", type=int, default=2000, + help="Serialize model to a checkpoint after every these many iterations.", +) +group.add_argument( + "--log-every", type=int, default=50, + help="""Log training curves to tensorboard after every these many iterations + only master process logs averaged loss values across processes.""", +) +# fmt: on + + +def main(_A: argparse.Namespace): + + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device: Any = torch.device("cpu") + else: + # Get the current device as set for current distributed process. + # Check `launch` function in `virtex.utils.distributed` module. + device = torch.cuda.current_device() + + # Create a config object (this will be immutable) and perform common setup + # such as logging and setting up serialization directory. + _C = Config(_A.config, _A.config_override) + common_setup(_C, _A) + + # ------------------------------------------------------------------------- + # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER + # ------------------------------------------------------------------------- + + # fmt: off + train_dataset = PretrainingDatasetFactory.from_config(_C) + train_dataloader = DataLoader( + train_dataset, batch_size=None, shuffle=False, + num_workers=_A.cpu_workers, pin_memory=True, + ) + # fmt: on + + model = PretrainingModelFactory.from_config(_C).to(device) + optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) + scheduler = LRSchedulerFactory.from_config(_C, optimizer) + + # ------------------------------------------------------------------------- + # BEFORE TRAINING STARTS + # ------------------------------------------------------------------------- + + # Create a gradient scaler for automatic mixed precision. + scaler = amp.GradScaler(enabled=_C.AMP) + + # Load checkpoint to resume training if specified. + if _A.resume_from is not None: + start_iteration = CheckpointManager( + model=model, optimizer=optimizer, scheduler=scheduler, + ).load(_A.resume_from) + else: + start_iteration = 0 + + # Create an iterator from dataloader to sample batches perpetually. + train_dataloader_iter = cycle(train_dataloader, device, start_iteration) + + # Wrap model in DDP if using more than one processes. + if dist.get_world_size() > 1: + dist.synchronize() + model = nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) + + # Keep track of time per iteration and ETA. + timer = Timer( + start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS + ) + # Create tensorboard writer and checkpoint manager (only in master process). + if dist.is_master_process(): + tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) + tensorboard_writer.add_text("config", f"```\n{_C}\n```") + + checkpoint_manager = CheckpointManager( + _A.serialization_dir, + model=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + ) + + # ------------------------------------------------------------------------- + # TRAINING LOOP + # ------------------------------------------------------------------------- + for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): + timer.tic() + optimizer.zero_grad() + batch = next(train_dataloader_iter) + + with amp.autocast(enabled=_C.AMP): + output_dict = model(batch) + loss = output_dict["loss"] + + scaler.scale(loss).backward() + + # First clip norm of gradients, and then perform optimizer step. + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) + scaler.step(optimizer) + + scaler.update() + scheduler.step() + timer.toc() + + # --------------------------------------------------------------------- + # LOGGING + # --------------------------------------------------------------------- + if iteration % _A.log_every == 0: + logger.info( + f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" + ) + if dist.is_master_process(): + tensorboard_writer.add_scalars( + "train", output_dict["loss_components"], iteration + ) + + if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): + checkpoint_manager.step(iteration) + + +if __name__ == "__main__": + _A = parser.parse_args() + + if _A.num_gpus_per_machine == 0: + main(_A) + else: + # This will launch `main` and set appropriate CUDA device (GPU ID) as + # per process (accessed in the beginning of `main`). + dist.launch( + main, + num_machines=_A.num_machines, + num_gpus_per_machine=_A.num_gpus_per_machine, + machine_rank=_A.machine_rank, + dist_url=_A.dist_url, + args=(_A, ), + ) diff --git a/virtex/scripts/zero_shot_classification.py b/virtex/scripts/zero_shot_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8523f6012aab90a86f77c1ba235ad740848fe2 --- /dev/null +++ b/virtex/scripts/zero_shot_classification.py @@ -0,0 +1,171 @@ +import argparse +import json +import os +import random +from typing import Any, Dict, List + +from loguru import logger +import torch +from torch.utils.data import DataLoader, DistributedSampler +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm + +import wordsegment as ws + +from virtex.config import Config +from virtex.data import ZeroShotDataset + +from virtex.data.tokenizers import SentencePieceBPETokenizer + +from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory +from virtex.utils.checkpointing import CheckpointManager +from virtex.utils.common import common_parser +from virtex.utils.metrics import TopkAccuracy +import virtex.utils.distributed as dist + + +#importing classifier +from virtex.models.zero_shot_classification_eval import ZeroShotClassifier + +ws.load() + +# fmt: off +parser = common_parser( + description="""Run image captioning inference on a pretrained model, and/or + evaluate pretrained model on COCO Captions val2017 split.""" +) +parser.add_argument( + "--data-root", default=None, + help="""Path to a directory containing image files to generate captions for imagenet. + Default: COCO val2017 image directory as expected relative to project root.""" +) +parser.add_argument( + "--checkpoint-path", required=False, + help="Path to load checkpoint and run captioning evaluation." +) +parser.add_argument( + "--output", default=None, + help="Path to save predictions as a JSON file." +) +parser.add_argument( + "--calc-metrics", action="store_true", + help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. + This flag should not be set when running inference on arbitrary images.""" +) + +parser.add_argument( + "--idx_label_dict", default=None, required=False, + help="""a dictionary that maps from lable index to label string for classification""" +) +parser.add_argument( + "--is_redcaps", default=None, required=False, + help="""a dictionary that maps from lable index to label string for""" +) +parser.add_argument( + "--prompt_cls_sos", default=None, required=False, + help="""a dictionary that maps from lable index to label string for""" +) +parser.add_argument( + "--prompt_sos_eos", default=None, required=False, + help="""a dictionary that maps from lable index to label string for""" +) +# fmt: on + +print("###########") +print(os.getcwd() ) +print("###########") + +tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model") + +def main(_A: argparse.Namespace): + if _A.num_gpus_per_machine == 0: + # Set device as CPU if num_gpus_per_machine = 0. + device = torch.device("cpu") + else: + # Get the current device (this will be zero here by default). + device = torch.cuda.current_device() + + _C = Config(_A.config, _A.config_override) + + #tokenizer = TokenizerFactory.from_config(_C) + + if _A.data_root is None: + _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") + + if _A.is_redcaps == 1: + model_dataset = 'redcaps' + else: + model_dataset = 'gcc or sbu' + + print(_A.idx_label_dict) + + val_dataset = ZeroShotDataset(data_root=_A.data_root, + split="test/", + label_map=_A.idx_label_dict, + tokenizer=tokenizer, + prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "), + prompt_sos_eos=_A.prompt_sos_eos.replace("_", " ")) + + val_dataloader = DataLoader( + val_dataset, + batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(), + num_workers=_A.cpu_workers, + sampler=DistributedSampler( + val_dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + ), + pin_memory=True, + drop_last=False, + collate_fn=val_dataset.collate_fn, + ) + + # Initialize model from a checkpoint + visual = VisualBackboneFactory.from_config(_C) + textual = TextualHeadFactory.from_config(_C) + model = ZeroShotClassifier(visual,textual) + ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) + model.to(device).eval() + + ## setup distributed training + if dist.get_world_size() > 1: + dist.synchronize() + model = nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) + + top_1 = TopkAccuracy(top_k=1) + top_5 = TopkAccuracy(top_k=5) + batch_num = 0 + + + for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)): + val_batch["image"] = val_batch["image"].to(device) + val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device) + val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device) + val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device) + val_batch["label"] = val_batch["label"].to(device) + + with torch.no_grad(): + classification_losses = model(val_batch) + + batch_num+=1 + top_1(classification_losses, val_batch["label"]) + top_1_acc = top_1.get_metric(reset=False) + dist.average_across_processes(top_1_acc) + + top_5(classification_losses, val_batch["label"]) + top_5_acc = top_5.get_metric(reset=False) + dist.average_across_processes(top_5_acc) + + logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}") + + + +if __name__ == "__main__": + _A = parser.parse_args() + #if _A.num_gpus_per_machine > 1: + # raise ValueError("Using multiple GPUs is not supported for this script.") + + # No distributed training here, just a single process. + main(_A) diff --git a/virtex/setup.py b/virtex/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..fc715695a0b1e6eb83a52205c9fec3224131bb21 --- /dev/null +++ b/virtex/setup.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +import glob +import os +from setuptools import setup +import shutil +from typing import List + + +def get_model_zoo_configs() -> List[str]: + """ + Return a list of configs to include in package for model zoo. Copy over + these configs inside virtex/model_zoo. + """ + + # Use absolute paths while symlinking. + source_configs_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs" + ) + destination = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs" + ) + # Symlink the config directory inside package to have a cleaner pip install. + + # Remove stale symlink/directory from a previous build. + if os.path.exists(source_configs_dir): + if os.path.islink(destination): + os.unlink(destination) + elif os.path.isdir(destination): + shutil.rmtree(destination) + + if not os.path.exists(destination): + try: + os.symlink(source_configs_dir, destination) + except OSError: + # Fall back to copying if symlink fails: ex. on Windows. + shutil.copytree(source_configs_dir, destination) + + config_paths = glob.glob("configs/**/*.yaml", recursive=True) + return config_paths + + +setup( + name="virtex", + version="1.1.0", + author="Karan Desai and Justin Johnson", + description="VirTex: Learning Visual Representations with Textual Annotations", + package_data={"virtex.model_zoo": get_model_zoo_configs()}, + python_requires=">=3.6", + license="MIT", + zip_safe=True, +) diff --git a/virtex/virtex/__init__.py b/virtex/virtex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/virtex/virtex/config.py b/virtex/virtex/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2e42bafb758db6a09e507b43cd0faa82cd9c7259 --- /dev/null +++ b/virtex/virtex/config.py @@ -0,0 +1,260 @@ +from typing import Any, List, Optional + +from fvcore.common.config import CfgNode as CN + + +class Config(object): + r""" + This class provides package-wide configuration management. It is a + nested dict-like structure with nested keys accessible as attributes. It + contains sensible default values, which can be modified by (first) a YAML + file and (second) a list of attributes and values. + + An instantiated object is immutable: modifying any attribute is illegal. + You must override required parameter values either through ``config_file`` + or ``override_list`` arguments. For adding more parameters at runtime + (based on existing parameters), modify :meth:`add_derived_params`. + + Parameters + ---------- + config_file: str + Path to a YAML file containing configuration parameters to override. + config_override: List[Any], optional (default = []) + A list of sequential attributes and values of parameters to override. + This happens after overriding from YAML file. + + Examples + -------- + Let a YAML file named "config.yaml" specify these parameters to override:: + + OPTIM: + BATCH_SIZE: 512 + LR: 0.01 + + >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 1024]) + >>> _C.LR # default: 0.001 + 0.01 + >>> _C.OPTIM.BATCH_SIZE # default: 256, file: 512 + 1024 + """ + + def __init__( + self, config_file: Optional[str] = None, override_list: List[Any] = [] + ): + _C = CN() + + # Random seed for NumPy and PyTorch, important for reproducibility. + _C.RANDOM_SEED = 0 + # Train with Automatic Mixed Precision (native PyTorch). + _C.AMP = True + # Set CUDNN deterministic flag (torch.backends.cudnn.deterministic). + # Setting this will ensure exact results on every run at the cost of + # little slowdown. Good for debugging. + _C.CUDNN_DETERMINISTIC = False + # Set CUDNN benchmark flag (torch.backends.cudnn.benchmark). Enables + # CUDNN to select fastest implementation for operations based on GPU. + # May change results (in decimals) on different hardware, but faster + # to train. Turn off while debugging. + _C.CUDNN_BENCHMARK = True + + # --------------------------------------------------------------------- + # Data paths and parameters related to dataloading. + # --------------------------------------------------------------------- + _C.DATA = CN() + + # Path to the dataset root, which structure as per README. Path is + # assumed to be relative to project root. + _C.DATA.ROOT = "datasets/coco" + # Path to .model file generated by ``sentencepiece``. + _C.DATA.TOKENIZER_MODEL = "datasets/vocab/coco_10k.model" + + # Handy config params for vocab size and indices of special tokens. + # While these can be picked up from the tokenizer, having these in + # the config makes it easy to create a model without instantiating too + # many tokenizer instances (especially when not needed, e.g. model zoo). + # These must match according to what's present in ``TOKENIZER_VOCAB`` + # and ``TOKENIZER_MODEL`` above. + _C.DATA.VOCAB_SIZE = 10000 + # Index of out-of-vocabulary (and padding) token. + _C.DATA.UNK_INDEX = 0 + # Index of the start-of-sentence [SOS] token. + _C.DATA.SOS_INDEX = 1 + # Index of the end-of-sentence [EOS] token. + _C.DATA.EOS_INDEX = 2 + # Index of the word masking token. While not used for captioning, having + # this extra token makes it possible to train an MLM model without + # re-creating a new vocab mapping. + _C.DATA.MASK_INDEX = 3 + + # Size of the image (square) to crop from original input image. + _C.DATA.IMAGE_CROP_SIZE = 224 + # Maximum length of input caption (number of tokens). + # Longer captions will be truncated up to this length. + _C.DATA.MAX_CAPTION_LENGTH = 30 + + # COCO Captions has five captions per image. If ``True``, training will + # use one random caption per image (data efficiency ablations). + _C.DATA.USE_SINGLE_CAPTION = False + # Percentage of dataset to use for training (data efficiency ablations). + _C.DATA.USE_PERCENTAGE = 100.0 + + # List of image transforms (pre-processing and data augmentation) to be + # applied sequentially (always or randomly) during training and + # validation. Refer ``virtex/facetories.py`` for all possible transforms. + _C.DATA.IMAGE_TRANSFORM_TRAIN = [ + "random_resized_crop", + "horizontal_flip", + "color_jitter", + "normalize", + ] + _C.DATA.IMAGE_TRANSFORM_VAL = [ + "smallest_resize", + "center_crop", + "normalize", + ] + + # Hyper-parameters for masked LM pretraining task. These are only used + # when ``MODEL.NAME`` is "masked_lm". + _C.DATA.MASKED_LM = CN() + # Fraction of tokens to choose for masking, this must be less than 1. + _C.DATA.MASKED_LM.MASK_PROPORTION = 0.15 + # Probability to replace chosen tokens with [MASK] token. + _C.DATA.MASKED_LM.MASK_PROBABILITY = 0.85 + # Probability to replace chosen tokens with a random token. + _C.DATA.MASKED_LM.REPLACE_PROBABILITY = 0.10 + + # --------------------------------------------------------------------- + # Model architecture: visual backbone and textual head. + # --------------------------------------------------------------------- + _C.MODEL = CN() + + # Name of model, based on pretraining task. + # Possible choices: {"token_classification", "multilabel_classification", + # "captioning", "bicaptioning", "masked_lm", "virtex"} + _C.MODEL.NAME = "virtex" + + _C.MODEL.VISUAL = CN() + # Name of visual backbone. Possible choices: {"blind", "torchvision"} + # Models from torchvision can be specified as shown below. + _C.MODEL.VISUAL.NAME = "torchvision::resnet50" + # Number of channels in pooled spatial features of visual backbone. + _C.MODEL.VISUAL.FEATURE_SIZE = 2048 + # Whether to load ImageNet pretrained weights into visual backbone. + _C.MODEL.VISUAL.PRETRAINED = False + # Whether to keep visual backbone frozen and train only textual head. + _C.MODEL.VISUAL.FROZEN = False + + _C.MODEL.TEXTUAL = CN() + # Name of textual head. Set to "none" for MODEL.NAME = "*_classification". + # Possible choices: {"transdec_postnorm", "transdec_prenorm"}. + # Architectural hyper-parameters are specified as shown above. + _C.MODEL.TEXTUAL.NAME = "transdec_postnorm::L1_H2048_A32_F8192" + # L = Number of layers in the transformer. + # H = Hidden size of the transformer (embeddings, attention features). + # A = Number of attention heads in the transformer. + # F = Size of feedforward layers in the transformer. + # Typically, we have (A = H / 64) and (F = 4 * H). + + # Dropout probability for embedding, hidden features in textual head. + _C.MODEL.TEXTUAL.DROPOUT = 0.1 + + # Apply label smoothing to targets for (cross entropy) loss computation. + _C.MODEL.LABEL_SMOOTHING = 0.0 + + _C.MODEL.DECODER = CN() + # What algorithm to use for decoding. Supported values: {"beam_search", + # "nucleus_sampling"}. + _C.MODEL.DECODER.NAME = "beam_search" + # Number of beams to decode (1 = greedy decoding). Ignored when decoding + # through nucleus sampling. + _C.MODEL.DECODER.BEAM_SIZE = 5 + # Size of nucleus for sampling predictions. Ignored when decoding through + # beam search. + _C.MODEL.DECODER.NUCLEUS_SIZE = 0.9 + # Maximum length of decoded caption. Decoding may end earlier when [EOS] + # token is sampled. + _C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH + + # --------------------------------------------------------------------- + # Optimization hyper-parameters, default values are for pretraining + # our best model on bicaptioning task (COCO Captions). + # --------------------------------------------------------------------- + _C.OPTIM = CN() + + # Name of optimizer to use. Supported values: {"sgd", "adamw"}. + # AdamW uses default (beta1, beta2) values from PyTorch. + _C.OPTIM.OPTIMIZER_NAME = "sgd" + # Momentum co-efficient for SGD. Ignored for AdamW. + _C.OPTIM.SGD_MOMENTUM = 0.9 + # Weight decay co-efficient for the optimizer. + _C.OPTIM.WEIGHT_DECAY = 0.0001 + # Regex pattern of params for which there will be no weight decay. + _C.OPTIM.NO_DECAY = ".*textual.(embedding|transformer).*(norm.*|bias)" + # Max gradient norm for clipping to avoid exploding gradients. + _C.OPTIM.CLIP_GRAD_NORM = 10.0 + + # Wrap our optimizer with Lookahead (https://arxiv.org/abs/1907.08610). + _C.OPTIM.LOOKAHEAD = CN() + _C.OPTIM.LOOKAHEAD.USE = True + _C.OPTIM.LOOKAHEAD.ALPHA = 0.5 + _C.OPTIM.LOOKAHEAD.STEPS = 5 + + # We set different learning rates for CNN (visual backbone) and rest of + # the model. CNN LR is typically much higher for training from scratch. + # Both LRs undergo same warmup-decay schedules. + + # Total batch size (will be distributed evenly across GPUs). + _C.OPTIM.BATCH_SIZE = 256 + # Max learning rate for CNN (visual backbone). + _C.OPTIM.CNN_LR = 0.2 + # Max learning rate for rest of the model. + _C.OPTIM.LR = 0.001 + # Number of iterations to train for, batches are randomly sampled. + _C.OPTIM.NUM_ITERATIONS = 500000 + + # Number of steps at the start of training for linear LR warmup. + _C.OPTIM.WARMUP_STEPS = 10000 + # Learning rate annealing schedule for decay after warmup. + # Possible choices: {"none", "linear", "cosine", "multistep"}. + _C.OPTIM.LR_DECAY_NAME = "cosine" + # Steps to decay LR for "multistep" schedule. + _C.OPTIM.LR_STEPS = [] + # Factor to multiply with LR for "multistep" schedule. + _C.OPTIM.LR_GAMMA = 0.1 + + # Override parameter values from YAML file first, then from override + # list, then add derived params. + self._C = _C + if config_file is not None: + self._C.merge_from_file(config_file) + self._C.merge_from_list(override_list) + + self.add_derived_params() + + # Make an instantiated object of this class immutable. + self._C.freeze() + + def add_derived_params(self): + r"""Add parameters with values derived from existing parameters.""" + + # We don't have any such cases so far. + pass + + def dump(self, file_path: str): + r"""Save config at the specified file path. + + Parameters + ---------- + file_path: str + (YAML) path to save config at. + """ + self._C.dump(stream=open(file_path, "w")) + + def __getattr__(self, attr: str): + return self._C.__getattr__(attr) + + def __str__(self): + return self._C.__str__() + + def __repr__(self): + return self._C.__repr__() diff --git a/virtex/virtex/data/__init__.py b/virtex/virtex/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91942e1e2b2e44e40433ae73109b5869b721c743 --- /dev/null +++ b/virtex/virtex/data/__init__.py @@ -0,0 +1,26 @@ +from .datasets.captioning import CaptioningDataset +from .datasets.classification import ( + TokenClassificationDataset, + MultiLabelClassificationDataset, +) +from .datasets.masked_lm import MaskedLmDataset +from .datasets.downstream import ( + ImageNetDataset, + INaturalist2018Dataset, + VOC07ClassificationDataset, + ImageDirectoryDataset, +) +from .datasets.redcaps import TarfileDataset + + +__all__ = [ + "CaptioningDataset", + "TokenClassificationDataset", + "MultiLabelClassificationDataset", + "MaskedLmDataset", + "ImageDirectoryDataset", + "ImageNetDataset", + "INaturalist2018Dataset", + "VOC07ClassificationDataset", + "TarfileDataset", +] diff --git a/virtex/virtex/data/datasets/captioning.py b/virtex/virtex/data/datasets/captioning.py new file mode 100644 index 0000000000000000000000000000000000000000..def6d4b2124ff04722f93dadc4798895604aec07 --- /dev/null +++ b/virtex/virtex/data/datasets/captioning.py @@ -0,0 +1,123 @@ +import os +import random +from typing import Callable, Dict, List + +import albumentations as alb +import numpy as np +import torch +from torch.utils.data import Dataset + +from virtex.data.readers import LmdbReader +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.data import transforms as T + + +class CaptioningDataset(Dataset): + r""" + A dataset which provides image-caption (forward and backward) pairs from + a serialized LMDB file (COCO Captions in this codebase). This is used for + pretraining tasks which use captions - bicaptioning, forward captioning and + token classification. + + This dataset also supports training on a randomly selected subset of the + full dataset. + + Parameters + ---------- + data_root: str, optional (default = "datasets/coco") + Path to the dataset root directory. This must contain the serialized + LMDB files (for COCO ``train2017`` and ``val2017`` splits). + split: str, optional (default = "train") + Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. + tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer + A tokenizer which has the mapping between word tokens and their + integer IDs. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + max_caption_length: int, optional (default = 30) + Maximum number of tokens to keep in output caption tokens. Extra tokens + will be trimmed from the right end of the token list. + use_single_caption: bool, optional (default = False) + COCO Captions provides five captions per image. If this is True, only + one fixed caption per image is use fo training (used for an ablation). + percentage: float, optional (default = 100.0) + Randomly sample this much percentage of full dataset for training. + """ + + def __init__( + self, + data_root: str, + split: str, + tokenizer: SentencePieceBPETokenizer, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + max_caption_length: int = 30, + use_single_caption: bool = False, + percentage: float = 100.0, + ): + lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb") + self.reader = LmdbReader(lmdb_path, percentage=percentage) + + self.image_transform = image_transform + self.caption_transform = alb.Compose( + [ + T.NormalizeCaption(), + T.TokenizeCaption(tokenizer), + T.TruncateCaptionTokens(max_caption_length), + ] + ) + self.use_single_caption = use_single_caption + self.padding_idx = tokenizer.token_to_id("") + + def __len__(self): + return len(self.reader) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + + image_id, image, captions = self.reader[idx] + + # Pick a random caption or first caption and process (transform) it. + if self.use_single_caption: + caption = captions[0] + else: + caption = random.choice(captions) + + # Transform image-caption pair and convert image from HWC to CHW format. + # Pass in caption to image_transform due to paired horizontal flip. + # Caption won't be tokenized/processed here. + image_caption = self.image_transform(image=image, caption=caption) + image, caption = image_caption["image"], image_caption["caption"] + image = np.transpose(image, (2, 0, 1)) + + caption_tokens = self.caption_transform(caption=caption)["caption"] + return { + "image_id": torch.tensor(image_id, dtype=torch.long), + "image": torch.tensor(image, dtype=torch.float), + "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long), + "noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0), + "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long), + } + + def collate_fn( + self, data: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + + # Pad `caption_tokens` and `masked_labels` up to this length. + caption_tokens = torch.nn.utils.rnn.pad_sequence( + [d["caption_tokens"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + noitpac_tokens = torch.nn.utils.rnn.pad_sequence( + [d["noitpac_tokens"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + return { + "image_id": torch.stack([d["image_id"] for d in data], dim=0), + "image": torch.stack([d["image"] for d in data], dim=0), + "caption_tokens": caption_tokens, + "noitpac_tokens": noitpac_tokens, + "caption_lengths": torch.stack([d["caption_lengths"] for d in data]), + } diff --git a/virtex/virtex/data/datasets/classification.py b/virtex/virtex/data/datasets/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..886033e9e0f22a72135520dfe0475491072cfdb5 --- /dev/null +++ b/virtex/virtex/data/datasets/classification.py @@ -0,0 +1,202 @@ +from collections import defaultdict +import glob +import json +import os +import random +from typing import Any, Callable, Dict, List, Tuple + +import albumentations as alb +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset + +from virtex.data.readers import LmdbReader +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.data import transforms as T + + +class TokenClassificationDataset(Dataset): + r""" + A dataset which provides image-labelset pairs from a serialized LMDB file + (COCO Captions in this codebase). the set of caption tokens (unordered) + is treated as a labelset. Used for token classification pretraining task. + + Parameters + ---------- + data_root: str, optional (default = "datasets/coco") + Path to the dataset root directory. This must contain the serialized + LMDB files (for COCO ``train2017`` and ``val2017`` splits). + split: str, optional (default = "train") + Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. + tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer + A tokenizer which has the mapping between word tokens and their + integer IDs. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + max_caption_length: int, optional (default = 30) + Maximum number of tokens to keep in output caption tokens. Extra tokens + will be trimmed from the right end of the token list. + """ + + def __init__( + self, + data_root: str, + split: str, + tokenizer: SentencePieceBPETokenizer, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + max_caption_length: int = 30, + ): + lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb") + self.reader = LmdbReader(lmdb_path) + + self.image_transform = image_transform + self.caption_transform = alb.Compose( + [ + T.NormalizeCaption(), + T.TokenizeCaption(tokenizer), + T.TruncateCaptionTokens(max_caption_length), + ] + ) + self.padding_idx = tokenizer.token_to_id("") + + def __len__(self): + return len(self.reader) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + + image_id, image, captions = self.reader[idx] + + # Pick a random caption and then transform it. + caption = random.choice(captions) + + # Transform image-caption pair and convert image from HWC to CHW format. + # Pass in caption to image_transform due to paired horizontal flip. + # Caption won't be tokenized/processed here. + image_caption = self.image_transform(image=image, caption=caption) + image, caption = image_caption["image"], image_caption["caption"] + image = np.transpose(image, (2, 0, 1)) + + caption_tokens = self.caption_transform(caption=caption)["caption"] + return { + "image_id": torch.tensor(image_id, dtype=torch.long), + "image": torch.tensor(image, dtype=torch.float), + "labels": torch.tensor(caption_tokens, dtype=torch.long), + } + + def collate_fn( + self, data: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + + labels = torch.nn.utils.rnn.pad_sequence( + [d["labels"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + return { + "image_id": torch.stack([d["image_id"] for d in data], dim=0), + "image": torch.stack([d["image"] for d in data], dim=0), + "labels": labels, + } + + +class MultiLabelClassificationDataset(Dataset): + r""" + A dataset which provides image-labelset pairs from COCO instance annotation + files. This is used for multilabel classification pretraining task. + + Parameters + ---------- + data_root: str, optional (default = "datasets/coco") + Path to the dataset root directory. This must contain images and + annotations (``train2017``, ``val2017`` and ``annotations`` directories). + split: str, optional (default = "train") + Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + """ + + def __init__( + self, + data_root: str, + split: str, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + ): + self.image_transform = image_transform + + # Make a tuple of image id and its filename, get image_id from its + # filename (assuming directory has images with names in COCO 2017 format). + image_filenames = glob.glob(os.path.join(data_root, f"{split}2017", "*.jpg")) + self.id_filename: List[Tuple[int, str]] = [ + (int(os.path.basename(name)[:-4]), name) for name in image_filenames + ] + # Load the instance (bounding box and mask) annotations. + _annotations = json.load( + open(os.path.join(data_root, "annotations", f"instances_{split}2017.json")) + ) + # Make a mapping between COCO category id and its index, to make IDs + # consecutive, else COCO has 80 classes with IDs 1-90. Start index from 1 + # as 0 is reserved for background (padding idx). + _category_ids = { + ann["id"]: index + 1 for index, ann in enumerate(_annotations["categories"]) + } + # Mapping from image ID to list of unique category IDs (indices as above) + # in corresponding image. + self._labels: Dict[str, Any] = defaultdict(list) + + for ann in _annotations["annotations"]: + self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]]) + + # De-duplicate and drop empty labels, we only need to do classification. + self._labels = { + _id: list(set(lbl)) for _id, lbl in self._labels.items() if len(lbl) > 0 + } + # Filter out image IDs which didn't have any labels. + self.id_filename = [ + (t[0], t[1]) for t in self.id_filename if t[0] in self._labels + ] + # Padding while forming a batch, because images may have variable number + # of instances. We do not need padding index from tokenizer: COCO has + # category ID 0 as background, conventionally. + self.padding_idx = 0 + + def __len__(self): + return len(self.id_filename) + + def __getitem__(self, idx: int): + # Get image ID and filename. + image_id, filename = self.id_filename[idx] + + # Open image from path and apply transformation, convert to CHW format. + image = cv2.imread(filename) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + + # Get a list of instances present in the image. + labels = self._labels[image_id] + + return { + "image_id": torch.tensor(image_id, dtype=torch.long), + "image": torch.tensor(image, dtype=torch.float), + "labels": torch.tensor(labels, dtype=torch.long), + } + + def collate_fn( + self, data: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + + labels = torch.nn.utils.rnn.pad_sequence( + [d["labels"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + return { + "image_id": torch.stack([d["image_id"] for d in data], dim=0), + "image": torch.stack([d["image"] for d in data], dim=0), + "labels": labels, + } diff --git a/virtex/virtex/data/datasets/downstream.py b/virtex/virtex/data/datasets/downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..45625a5abc3be717d4bcf9f0fd63d94905b4c2ae --- /dev/null +++ b/virtex/virtex/data/datasets/downstream.py @@ -0,0 +1,286 @@ +from collections import defaultdict +import glob +import json +import os +from typing import Callable, Dict, List, Tuple + +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision.datasets import ImageNet + +from virtex.data import transforms as T + + +class ImageNetDataset(ImageNet): + r""" + Simple wrapper over torchvision's ImageNet dataset with a feature to support + restricting dataset size for semi-supervised learning setup (data-efficiency + ablations). + + We also handle image transform here instead of passing to super class. + + Parameters + ---------- + data_root: str, optional (default = "datasets/imagenet") + Path to the dataset root directory. This must contain directories + ``train``, ``val`` with per-category sub-directories. + split: str, optional (default = "train") + Which split to read from. One of ``{"train", "val"}``. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + percentage: int, optional (default = 100) + Percentage of dataset to keep. This dataset retains first K% of images + per class to retain same class label distribution. This is 100% by + default, and will be ignored if ``split`` is ``val``. + """ + + def __init__( + self, + data_root: str = "datasets/imagenet", + split: str = "train", + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + percentage: float = 100, + ): + super().__init__(data_root, split) + assert percentage > 0, "Cannot load dataset with 0 percent original size." + + self.image_transform = image_transform + + # Super class has `imgs` list and `targets` list. Make a dict of + # class ID to index of instances in these lists and pick first K%. + if split == "train" and percentage < 100: + label_to_indices: Dict[int, List[int]] = defaultdict(list) + for index, target in enumerate(self.targets): + label_to_indices[target].append(index) + + # Trim list of indices per label. + for label in label_to_indices: + retain = int(len(label_to_indices[label]) * (percentage / 100)) + label_to_indices[label] = label_to_indices[label][:retain] + + # Trim `self.imgs` and `self.targets` as per indices we have. + retained_indices: List[int] = [ + index + for indices_per_label in label_to_indices.values() + for index in indices_per_label + ] + # Shorter dataset with size K% of original dataset, but almost same + # class label distribution. super class will handle the rest. + self.imgs = [self.imgs[i] for i in retained_indices] + self.targets = [self.targets[i] for i in retained_indices] + self.samples = self.imgs + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + image, label = super().__getitem__(idx) + + # Apply transformation to image and convert to CHW format. + image = self.image_transform(image=np.array(image))["image"] + image = np.transpose(image, (2, 0, 1)) + return { + "image": torch.tensor(image, dtype=torch.float), + "label": torch.tensor(label, dtype=torch.long), + } + + @staticmethod + def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + return { + "image": torch.stack([d["image"] for d in data], dim=0), + "label": torch.stack([d["label"] for d in data], dim=0), + } + + +class INaturalist2018Dataset(Dataset): + r""" + A dataset which provides image-label pairs from the iNaturalist 2018 dataset. + + Parameters + ---------- + data_root: str, optional (default = "datasets/inaturalist") + Path to the dataset root directory. This must contain images and + annotations (``train2018``, ``val2018`` and ``annotations`` directories). + split: str, optional (default = "train") + Which split to read from. One of ``{"train", "val"}``. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + """ + + def __init__( + self, + data_root: str = "datasets/inaturalist", + split: str = "train", + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + ): + self.split = split + self.image_transform = image_transform + + annotations = json.load( + open(os.path.join(data_root, "annotations", f"{split}2018.json")) + ) + # Make a list of image IDs to file paths. + self.image_id_to_file_path = { + ann["id"]: os.path.join(data_root, ann["file_name"]) + for ann in annotations["images"] + } + # For a list of instances: (image_id, category_id) tuples. + self.instances = [ + (ann["image_id"], ann["category_id"]) + for ann in annotations["annotations"] + ] + + def __len__(self): + return len(self.instances) + + def __getitem__(self, idx: int): + image_id, label = self.instances[idx] + image_path = self.image_id_to_file_path[image_id] + + # Open image from path and apply transformation, convert to CHW format. + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + + return { + "image": torch.tensor(image, dtype=torch.float), + "label": torch.tensor(label, dtype=torch.long), + } + + @staticmethod + def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + return { + "image": torch.stack([d["image"] for d in data], dim=0), + "label": torch.stack([d["label"] for d in data], dim=0), + } + + +class VOC07ClassificationDataset(Dataset): + r""" + A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset. + + Parameters + ---------- + data_root: str, optional (default = "datasets/VOC2007") + Path to the dataset root directory. This must contain directories + ``Annotations``, ``ImageSets`` and ``JPEGImages``. + split: str, optional (default = "trainval") + Which split to read from. One of ``{"trainval", "test"}``. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + """ + + def __init__( + self, + data_root: str = "datasets/VOC2007", + split: str = "trainval", + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + ): + self.split = split + self.image_transform = image_transform + + ann_paths = sorted( + glob.glob(os.path.join(data_root, "ImageSets", "Main", f"*_{split}.txt")) + ) + # A list like; ["aeroplane", "bicycle", "bird", ...] + self.class_names = [ + os.path.basename(path).split("_")[0] for path in ann_paths + ] + + # We will construct a map for image name to a list of + # shape: (num_classes, ) and values as one of {-1, 0, 1}. + # 1: present, -1: not present, 0: ignore. + image_names_to_labels: Dict[str, torch.Tensor] = defaultdict( + lambda: -torch.ones(len(self.class_names), dtype=torch.int32) + ) + for cls_num, ann_path in enumerate(ann_paths): + with open(ann_path, "r") as fopen: + for line in fopen: + img_name, orig_label_str = line.strip().split() + orig_label = int(orig_label_str) + + # In VOC data, -1 (not present): set to 0 as train target + # In VOC data, 0 (ignore): set to -1 as train target. + orig_label = ( + 0 if orig_label == -1 else -1 if orig_label == 0 else 1 + ) + image_names_to_labels[img_name][cls_num] = orig_label + + # Convert the dict to a list of tuples for easy indexing. + # Replace image name with full image path. + self.instances: List[Tuple[str, torch.Tensor]] = [ + ( + os.path.join(data_root, "JPEGImages", f"{image_name}.jpg"), + label.tolist(), + ) + for image_name, label in image_names_to_labels.items() + ] + + def __len__(self): + return len(self.instances) + + def __getitem__(self, idx: int): + image_path, label = self.instances[idx] + + # Open image from path and apply transformation, convert to CHW format. + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + + return { + "image": torch.tensor(image, dtype=torch.float), + "label": torch.tensor(label, dtype=torch.long), + } + + @staticmethod + def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + return { + "image": torch.stack([d["image"] for d in data], dim=0), + "label": torch.stack([d["label"] for d in data], dim=0), + } + + +class ImageDirectoryDataset(Dataset): + r""" + A dataset which reads images from any directory. This class is useful to + run image captioning inference on our models with any arbitrary images. + + Parameters + ---------- + data_root: str + Path to a directory containing images. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + """ + + def __init__( + self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM + ): + self.image_paths = glob.glob(os.path.join(data_root, "*")) + self.image_transform = image_transform + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx: int): + image_path = self.image_paths[idx] + # Remove extension from image name to use as image_id. + image_id = os.path.splitext(os.path.basename(image_path))[0] + + # Open image from path and apply transformation, convert to CHW format. + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + + # Return image id as string so collate_fn does not cast to torch.tensor. + return {"image_id": str(image_id), "image": torch.tensor(image)} diff --git a/virtex/virtex/data/datasets/masked_lm.py b/virtex/virtex/data/datasets/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..7721321a8dce4e35ed962f894d6c4f6b4c22b8cf --- /dev/null +++ b/virtex/virtex/data/datasets/masked_lm.py @@ -0,0 +1,132 @@ +import math +import os +import random +from typing import Callable, Dict, List + +import albumentations as alb +import numpy as np +import torch +from torch.utils.data import Dataset + +from virtex.data.readers import LmdbReader +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.data import transforms as T + + +class MaskedLmDataset(Dataset): + def __init__( + self, + data_root: str, + split: str, + tokenizer: SentencePieceBPETokenizer, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + mask_proportion: float = 0.15, + mask_probability: float = 0.80, + replace_probability: float = 0.10, + max_caption_length: int = 30, + use_single_caption: bool = False, + percentage: float = 100.0, + ): + lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb") + self.reader = LmdbReader(lmdb_path, percentage=percentage) + + self.image_transform = image_transform + self.caption_transform = alb.Compose( + [ + T.NormalizeCaption(), + T.TokenizeCaption(tokenizer), + T.TruncateCaptionTokens(max_caption_length), + ] + ) + self.use_single_caption = use_single_caption + self.padding_idx = tokenizer.token_to_id("") + + # Handles to commonly used variables for word masking. + self._vocab_size = tokenizer.get_vocab_size() + self._mask_index = tokenizer.token_to_id("[MASK]") + self._mask_proportion = mask_proportion + self._mask_prob = mask_probability + self._repl_prob = replace_probability + + def __len__(self): + return len(self.reader) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + + image_id, image, captions = self.reader[idx] + + # Pick a random caption or first caption and process (transform) it. + if self.use_single_caption: + caption = captions[0] + else: + caption = random.choice(captions) + + # Transform image-caption pair and convert image from HWC to CHW format. + # Pass in caption to image_transform due to paired horizontal flip. + # Caption won't be tokenized/processed here. + image_caption = self.image_transform(image=image, caption=caption) + image, caption = image_caption["image"], image_caption["caption"] + image = np.transpose(image, (2, 0, 1)) + + caption_tokens = self.caption_transform(caption=caption)["caption"] + + # --------------------------------------------------------------------- + # Mask some tokens randomly. + # --------------------------------------------------------------------- + masked_labels = [self.padding_idx] * len(caption_tokens) + + # Indices in `caption_tokens` list to mask (minimum 1 index). + # Leave out first and last indices (boundary tokens). + tokens_to_mask: List[int] = random.sample( + list(range(1, len(caption_tokens) - 1)), + math.ceil((len(caption_tokens) - 2) * self._mask_proportion), + ) + for i in tokens_to_mask: + # Whether to replace with [MASK] or random word. + # If only one token, always [MASK]. + if len(tokens_to_mask) == 1: + masked_labels[i] = caption_tokens[i] + caption_tokens[i] = self._mask_index + else: + _flag: float = random.random() + if _flag <= self._mask_prob + self._repl_prob: + if _flag <= self._mask_prob: + masked_labels[i] = caption_tokens[i] + caption_tokens[i] = self._mask_index + else: + caption_tokens[i] = self._random_token_index() + # --------------------------------------------------------------------- + + return { + "image_id": torch.tensor(image_id, dtype=torch.long), + "image": torch.tensor(image, dtype=torch.float), + "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long), + "masked_labels": torch.tensor(masked_labels, dtype=torch.long), + "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long), + } + + def collate_fn( + self, data: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + + # Pad `caption_tokens` and `masked_labels` up to this length. + caption_tokens = torch.nn.utils.rnn.pad_sequence( + [d["caption_tokens"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + masked_labels = torch.nn.utils.rnn.pad_sequence( + [d["masked_labels"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + return { + "image_id": torch.stack([d["image_id"] for d in data], dim=0), + "image": torch.stack([d["image"] for d in data], dim=0), + "caption_tokens": caption_tokens, + "masked_labels": masked_labels, + "caption_lengths": torch.stack([d["caption_lengths"] for d in data]), + } + + def _random_token_index(self) -> int: + return random.randint(0, self._vocab_size - 1) diff --git a/virtex/virtex/data/datasets/redcaps.py b/virtex/virtex/data/datasets/redcaps.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e0711512aaca55d3848a9b71dd979e182688e6 --- /dev/null +++ b/virtex/virtex/data/datasets/redcaps.py @@ -0,0 +1,129 @@ +import glob +import os +import random +from typing import Callable + +import numpy as np +import torch +from torch.utils.data import IterableDataset +import webdataset as wds +import wordsegment as ws + +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.data import transforms as T +import virtex.utils.distributed as dist + +ws.load() + + +class TarfileDataset(IterableDataset): + def __init__( + self, + data_root: str, + batch_size: int, + tokenizer: SentencePieceBPETokenizer, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + shuffle_buffer_size: int = 3000, # Set -1 to turn off shuffle. + max_caption_length: int = 50, + ): + super().__init__() + + self.tokenizer = tokenizer + self.image_transform = image_transform + self.max_caption_length = max_caption_length + + self.padding_idx = tokenizer.token_to_id("") + self.sos_idx = tokenizer.token_to_id("[SOS]") + self.eos_idx = tokenizer.token_to_id("[EOS]") + self.sep_idx = tokenizer.token_to_id("[SEP]") + + # Glob expand all paths in data root. + all_data_paths = [] + for dr in data_root.split(" "): + all_data_paths.extend(glob.glob(dr)) + + # Deterministic shuffle across GPU process. + all_data_paths = sorted(all_data_paths) + random.Random(0).shuffle(all_data_paths) + + # Shard the data paths as per gpu process. + all_data_paths = all_data_paths[dist.get_rank()::dist.get_world_size()] + + self._dset = ( + wds.WebDataset(all_data_paths) + .shuffle(shuffle_buffer_size, initial=shuffle_buffer_size) + .decode("rgb8", handler=wds.warn_and_continue) + .map(self._preprocess) + .batched(batch_size) + ) + # Perform word-segmentation of all subreddit names (that's how the + # tokenizer was prepared). Subreddit names can be obtained from + # TAR file names: `{subreddit}_{year}_{index}.tar`. + if "redcaps" in data_root: + self.subreddit_segs = { + sub: " ".join(ws.segment(ws.clean(sub))) for sub in + set([os.path.basename(p).split("_")[0] for p in all_data_paths]) + } + + def _preprocess(self, annotation): + image, caption = annotation["jpg"], annotation["json"]["caption"] + + # Transform image-caption pair and convert image from HWC to CHW format. + # Pass in caption to image_transform due to paired horizontal flip. + # Caption won't be tokenized/processed here. + image_caption = self.image_transform(image=image, caption=caption) + image, caption = image_caption["image"], image_caption["caption"] + image = np.transpose(image, (2, 0, 1)) + + # Tokenize caption. + _caption_tokens = self.tokenizer.encode(caption) + + # Get subreddit name if it exists, and tokenize it. Only for RedCaps. + if "subreddit" in annotation["json"]: + subreddit = annotation["json"]["subreddit"].lower() + subreddit = self.subreddit_segs[subreddit] + + # Add special [SEP] token after subreddit. + _subreddit_tokens = self.tokenizer.encode(subreddit) + [self.sep_idx] + else: + _subreddit_tokens = [] + + # Create forward and backward caption with subreddit token at the start. + caption_tokens = ( + [self.sos_idx] + _subreddit_tokens + _caption_tokens + [self.eos_idx] + )[: self.max_caption_length] + + noitpac_tokens = ( + [self.eos_idx] + _subreddit_tokens + _caption_tokens[::-1] + [self.sos_idx] + )[: self.max_caption_length] + + return image, caption_tokens, noitpac_tokens, len(caption_tokens) + + def __len__(self): + raise NotImplementedError + + def __iter__(self): + + for batch in iter(self._dset): + # Collate the batch properly here. `image` and `caption_lengths` + # are already tensors. + image, caption_tokens, noitpac_tokens, caption_lengths = batch + + # Pad `caption_tokens` and `masked_labels` up to this length. + caption_tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(c, dtype=torch.long) for c in caption_tokens], + batch_first=True, + padding_value=self.padding_idx, + ) + noitpac_tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(c, dtype=torch.long) for c in noitpac_tokens], + batch_first=True, + padding_value=self.padding_idx, + ) + caption_lengths = torch.tensor(caption_lengths, dtype=torch.long) + yield { + "image": torch.tensor(image, dtype=torch.float), + "caption_tokens": caption_tokens, + "noitpac_tokens": noitpac_tokens, + "caption_lengths": caption_lengths, + } diff --git a/virtex/virtex/data/datasets/zero_shot.py b/virtex/virtex/data/datasets/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..8e20100c15255d7707289786aea03853f6db5a50 --- /dev/null +++ b/virtex/virtex/data/datasets/zero_shot.py @@ -0,0 +1,125 @@ +from collections import defaultdict +import glob +import json +import os +from typing import Callable, Dict, List, Tuple + +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence + +from virtex.data import transforms as T + +class ZeroShotDataset(Dataset): + def __init__( + self, + data_root: str = "datasets/inaturalist", + split: str = "train", + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + label_map: str = None, + tokenizer = None, + model_dataset = 'redcaps', + prompt_cls_sos = None, + prompt_sos_eos = None + ): + + self.data_root = data_root + self.split = split + self.label_map = json.load(open(label_map)) + self.tokenizer = tokenizer + self.image_transform = image_transform + self.model_dataset = model_dataset + self.prompt_cls_sos = prompt_cls_sos + self.prompt_sos_eos = prompt_sos_eos + + im_id = 0 + + self.image_id_to_file_path = {} + self.instances = [] + + for folder_name,labelname in self.label_map.items(): + image_folder = self.data_root + self.split + folder_name + "/" + for image_file in [x for x in os.listdir(image_folder) if x[-4:]=='.jpg']: + path = image_folder + image_file + self.image_id_to_file_path[im_id] = path + self.instances.append((im_id,labelname[1])) + im_id+=1 + + + im_net_list = [x[0].replace('_',' ').lower() for x in sorted(self.label_map.values(),key=lambda x: x[1])] + + print(im_net_list) + + cls_token = [tokenizer.token_to_id("[CLS]")] + sos_token = [tokenizer.token_to_id("[SOS]")] + eos_token =[tokenizer.token_to_id("[EOS]")] + + a_an_dets = [ " an " if cat[0].lower() in ["a","e","i","o","u"] else " a " for cat in im_net_list ] + imagenet_tensors = [cls_token + +tokenizer.encode("i took a picture") + +sos_token + +tokenizer.encode("itap of "+a_an_dets[i]+im_net_list[i]) + +eos_token + for i in range(len(im_net_list))] + + imagenet_tensors_backward = [cls_token + +tokenizer.encode("i took a picture") + +eos_token + +tokenizer.encode("itap of "+a_an_dets[i]+im_net_list[i])[::-1] + +sos_token + for i in range(len(im_net_list))] + + + tensor_lengths = torch.tensor([len(x) for x in imagenet_tensors]) + imagenet_tensors_forward = [torch.tensor(x) for x in imagenet_tensors] + imagenet_tensors_backward = [torch.tensor(x) for x in imagenet_tensors_backward] + imagenet_tensors_forward = pad_sequence(imagenet_tensors_forward,batch_first=True) + imagenet_tensors_backward = pad_sequence(imagenet_tensors_backward,batch_first=True) + + + print("imagenet_tensors_forward.shape: ", imagenet_tensors_forward.shape) + print("imagenet_tensors_backward.shape: ", imagenet_tensors_backward.shape) + print("tensor_lengths.shape: ", tensor_lengths.shape) + + self.imagenet_tensors_forward = imagenet_tensors_forward + self.imagenet_tensors_backward = imagenet_tensors_backward + self.tensor_lengths = tensor_lengths.long() + + def __len__(self): + return len(self.instances) + + def __getitem__(self, idx: int): + + image_id, label = self.instances[idx] + image_path = self.image_id_to_file_path[image_id] + try: + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + except: + print("$#%@#$%#image_path$@%:",image_path) + image = np.random.rand(234, 325, 3) + #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = self.image_transform(image=image)["image"] + image = np.transpose(image, (2, 0, 1)) + + return { + "image": torch.tensor(image, dtype=torch.float), + "label": torch.tensor(label, dtype=torch.long), + "caption_tokens": self.imagenet_tensors_forward, + "noitpac_tokens": self.imagenet_tensors_backward, + "caption_lengths": self.tensor_lengths + } + + @staticmethod + def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + return { + "image": torch.stack([d["image"] for d in data], dim=0), + "label": torch.stack([d["label"] for d in data], dim=0), + "caption_tokens": data[0]['caption_tokens'], + "noitpac_tokens": data[0]['noitpac_tokens'], + "caption_lengths": data[0]['caption_lengths'] + } \ No newline at end of file diff --git a/virtex/virtex/data/readers.py b/virtex/virtex/data/readers.py new file mode 100644 index 0000000000000000000000000000000000000000..6915329db05fdd8160f6a63806a2d939f7915c74 --- /dev/null +++ b/virtex/virtex/data/readers.py @@ -0,0 +1,180 @@ +r""" +A *Reader* is a PyTorch :class:`~torch.utils.data.Dataset` which simply reads +data from disk and returns it almost as is. Readers defined here are used by +datasets in :mod:`virtex.data.datasets`. +""" +from collections import defaultdict +import glob +import json +import os +import pickle +import random +from typing import Dict, List, Tuple + +import cv2 +import lmdb +from loguru import logger +from torch.utils.data import Dataset + + +# Some simplified type renaming for better readability +ImageID = int +Captions = List[str] + + +class SimpleCocoCaptionsReader(Dataset): + r""" + A reader interface to read COCO Captions dataset and directly from official + annotation files and return it unprocessed. We only use this for serializing + the dataset to LMDB files, and use :class:`~virtex.data.readers.LmdbReader` + in rest of the datasets. + + Parameters + ---------- + root: str, optional (default = "datasets/coco") + Path to the COCO dataset root directory. + split: str, optional (default = "train") + Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. + """ + def __init__(self, root: str = "datasets/coco", split: str = "train"): + + image_dir = os.path.join(root, f"{split}2017") + + # Make a tuple of image id and its filename, get image_id from its + # filename (assuming directory has images with names in COCO2017 format). + image_filenames = glob.glob(os.path.join(image_dir, "*.jpg")) + self.id_filename: List[Tuple[ImageID, str]] = [ + (int(os.path.basename(name)[:-4]), name) for name in image_filenames + ] + + # Make a mapping between image_id and its captions. + _captions = json.load( + open(os.path.join(root, "annotations", f"captions_{split}2017.json")) + ) + self._id_to_captions: Dict[ImageID, Captions] = defaultdict(list) + + for ann in _captions["annotations"]: + self._id_to_captions[ann["image_id"]].append(ann["caption"]) + + def __len__(self): + return len(self.id_filename) + + def __getitem__(self, idx: int): + image_id, filename = self.id_filename[idx] + + # shape: (height, width, channels), dtype: uint8 + image = cv2.imread(filename) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + captions = self._id_to_captions[image_id] + + return {"image_id": image_id, "image": image, "captions": captions} + + +class LmdbReader(Dataset): + r""" + A reader interface to read datapoints from a serialized LMDB file containing + ``(image_id, image, caption)`` tuples. Optionally, one may specify a + partial percentage of datapoints to use. + + .. note:: + + When training in distributed setting, make sure each worker has SAME + random seed because there is some randomness in selecting keys for + training with partial dataset. If you wish to use a different seed for + each worker, select keys manually outside of this class and use + :meth:`set_keys`. + + .. note:: + + Similar to :class:`~torch.utils.data.distributed.DistributedSampler`, + this reader can shuffle the dataset deterministically at the start of + epoch. Use :meth:`set_shuffle_seed` manually from outside to change the + seed at every epoch. + + Parameters + ---------- + lmdb_path: str + Path to LMDB file with datapoints. + shuffle: bool, optional (default = True) + Whether to shuffle or not. If this is on, there will be one deterministic + shuffle based on epoch before sharding the dataset (to workers). + percentage: float, optional (default = 100.0) + Percentage of datapoints to use. If less than 100.0, keys will be + shuffled and first K% will be retained and use throughout training. + Make sure to set this only for training, not validation. + """ + + def __init__(self, lmdb_path: str, shuffle: bool = True, percentage: float = 100): + self.lmdb_path = lmdb_path + self.shuffle = shuffle + + assert percentage > 0, "Cannot load dataset with 0 percent original size." + self.percentage = percentage + + # fmt: off + # Create an LMDB transaction right here. It will be aborted when this + # class goes out of scope. + env = lmdb.open( + self.lmdb_path, subdir=False, readonly=True, lock=False, + readahead=False, map_size=1099511627776 * 2, + ) + self.db_txn = env.begin() + + # Form a list of LMDB keys numbered from 0 (as binary strings). + self._keys = [ + f"{i}".encode("ascii") for i in range(env.stat()["entries"]) + ] + # fmt: on + + # If data percentage < 100%, randomly retain K% keys. This will be + # deterministic based on random seed. + if percentage < 100.0: + retain_k: int = int(len(self._keys) * percentage / 100.0) + random.shuffle(self._keys) + self._keys = self._keys[:retain_k] + logger.info(f"Retained {retain_k} datapoints for training!") + + # A seed to deterministically shuffle at the start of epoch. This is + # set externally through `set_shuffle_seed`. + self.shuffle_seed = 0 + + def set_shuffle_seed(self, seed: int): + r"""Set random seed for shuffling data.""" + self.shuffle_seed = seed + + def get_keys(self) -> List[bytes]: + r"""Return list of keys, useful while saving checkpoint.""" + return self._keys + + def set_keys(self, keys: List[bytes]): + r"""Set list of keys, useful while loading from checkpoint.""" + self._keys = keys + + def __getstate__(self): + r""" + This magic method allows an object of this class to be pickable, useful + for dataloading with multiple CPU workers. :attr:`db_txn` is not + pickable, so we remove it from state, and re-instantiate it in + :meth:`__setstate__`. + """ + state = self.__dict__ + state["db_txn"] = None + return state + + def __setstate__(self, state): + self.__dict__ = state + + env = lmdb.open( + self.lmdb_path, subdir=False, readonly=True, lock=False, + readahead=False, map_size=1099511627776 * 2, + ) + self.db_txn = env.begin() + + def __len__(self): + return len(self._keys) + + def __getitem__(self, idx: int): + datapoint_pickled = self.db_txn.get(self._keys[idx]) + image_id, image, captions = pickle.loads(datapoint_pickled) + + return image_id, image, captions diff --git a/virtex/virtex/data/tokenizers.py b/virtex/virtex/data/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..8a860b54261efd31b501bc2fa19990e8d579e3a7 --- /dev/null +++ b/virtex/virtex/data/tokenizers.py @@ -0,0 +1,61 @@ +import csv +from typing import Any, Dict, List + +import sentencepiece as sp + + +class SentencePieceBPETokenizer(object): + r""" + A tokenizer based on `SentencePiece `_ + with BPE sub-routine. It encodes caption strings into list of tokens. + + Parameters + ---------- + model_path: str + Path to the ``.model`` file trained by SentencePiece. + """ + SP_SPACE = u"▁" + + def __init__(self, model_path: str): + self.model_path = model_path + + # Load pretrained tokenizer model. + self.model = sp.SentencePieceProcessor() + self.model.Load(model_path) + + def __getstate__(self): + r""" + This magic method, along with ``__setstate__`` makes an object of this + class picklable (and usable while data loading with multiple workers). + """ + state_dict = self.__dict__.copy() + state_dict["model"] = None + return state_dict + + def __setstate__(self, state_dict: Dict[str, Any]): + self.__dict__ = state_dict + + self.model = sp.SentencePieceProcessor() + self.model.Load(self.model_path) + + def get_vocab_size(self) -> int: + r"""Return number of tokens in vocabulary (including special tokens).""" + return len(self.model) + + def token_to_id(self, token: str) -> int: + r"""Get integer ID of a string token (```` if does not exist).""" + # Since tokenizer uses subword regularization, one token may break down to multiple IDs. + # Keep trying till we get a single ID. + return self.model.piece_to_id(token) + + def id_to_token(self, token_id: int) -> str: + r"""Get string token of an integer ID (```` if does not exist).""" + return self.model.id_to_piece(token_id) + + def encode(self, text: str) -> List[int]: + r"""Convert a text string to a list of integer token ids.""" + return self.model.EncodeAsIds(text) + + def decode(self, token_ids: List[int]) -> str: + r"""Convert a sequence of token IDs to a text string.""" + return self.model.DecodeIds(token_ids) diff --git a/virtex/virtex/data/transforms.py b/virtex/virtex/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d4141c6292584e76cdb279439cca1807f6e866fd --- /dev/null +++ b/virtex/virtex/data/transforms.py @@ -0,0 +1,231 @@ +import random +from typing import List +import unicodedata + +import albumentations as alb +import cv2 + +from virtex.data.tokenizers import SentencePieceBPETokenizer + + +class CaptionOnlyTransform(alb.BasicTransform): + r""" + A base class for custom `albumentations `_ + transform, which can transform captions. Captions may be ``str``, or tokens + (``List[int]``) as per implementation of :meth:`apply_to_caption`. These + transforms will have consistent API as other transforms from albumentations. + """ + + @property + def targets(self): + return {"caption": self.apply_to_caption} + + def apply_to_caption(self, caption, **params): + raise NotImplementedError + + def update_params(self, params, **kwargs): + # Super class adds "width" and "height" but we don't have image here. + return params + + +class ImageCaptionTransform(alb.BasicTransform): + r""" + Similar to :class:`~virtex.data.transforms.CaptionOnlyTransform`, this + extends super class to work on ``(image, caption)`` pair together. + """ + + @property + def targets(self): + return {"image": self.apply, "caption": self.apply_to_caption} + + def apply_to_caption(self): + raise NotImplementedError + + +class NormalizeCaption(CaptionOnlyTransform): + r""" + Perform common normalization with caption: lowercase, trim leading and + trailing whitespaces, NFKD normalization and strip accents. + + Examples + -------- + >>> normalize = NormalizeCaption(always_apply=True) + >>> out = normalize(caption="Some caption input here.") # keys: {"caption"} + """ + + def __init__(self): + # `always_apply = True` because this is essential part of pipeline. + super().__init__(always_apply=True) + + def apply_to_caption(self, caption: str, **params) -> str: + caption = caption.lower() + caption = unicodedata.normalize("NFKD", caption) + caption = "".join([chr for chr in caption if not unicodedata.combining(chr)]) + return caption + + +class TokenizeCaption(CaptionOnlyTransform): + r""" + Tokenize a caption (``str``) to list of tokens (``List[int]``) by the + mapping defined in :attr:`tokenizer`. + + Parameters + ---------- + tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer + A :class:`~virtex.data.tokenizers.SentencePieceBPETokenizer` which encodes + a caption into tokens. + add_boundaries: bool, optional (defalult = True) + Whether to add ``[SOS]`` and ``[EOS]`` boundary tokens from tokenizer. + + Examples + -------- + >>> tokenizer = SentencePieceBPETokenizer("coco.vocab", "coco.model") + >>> tokenize = TokenizeCaption(tokenizer, always_apply=True) + >>> out = tokenize(caption="Some caption input here.") # keys: {"caption"} + """ + + def __init__(self, tokenizer: SentencePieceBPETokenizer): + # `always_apply = True` because this is essential part of pipeline. + super().__init__(always_apply=True) + self.tokenizer = tokenizer + + def apply_to_caption(self, caption: str, **params) -> List[int]: + token_indices: List[int] = self.tokenizer.encode(caption) + + # Add boundary tokens. + token_indices.insert(0, self.tokenizer.token_to_id("[SOS]")) + token_indices.append(self.tokenizer.token_to_id("[EOS]")) + return token_indices + + def get_transform_init_args_names(self): + return ("tokenizer",) + + +class TruncateCaptionTokens(CaptionOnlyTransform): + r""" + Truncate a list of caption tokens (``List[int]``) to maximum length. + + Parameters + ---------- + max_caption_length: int, optional (default = 30) + Maximum number of tokens to keep in output caption tokens. Extra tokens + will be trimmed from the right end of the token list. + + Examples + -------- + >>> truncate = TruncateCaptionTokens(max_caption_length=5, always_apply=True) + >>> out = truncate(caption=[2, 35, 41, 67, 98, 50, 3]) + >>> out["caption"] + [2, 35, 41, 67, 98] + """ + + def __init__(self, max_caption_length: int = 30): + # `always_apply = True` because this is essential part of pipeline. + super().__init__(always_apply=True) + self.max_caption_length = max_caption_length + + def apply_to_caption(self, caption: List[int], **params) -> List[int]: + return caption[: self.max_caption_length] + + def get_transform_init_args_names(self): + return ("max_caption_length",) + + +class HorizontalFlip(ImageCaptionTransform): + r""" + Flip the image horizontally randomly (equally likely) and replace the + word "left" with "right" in the caption. + + .. note:: + + This transform can also work on images only (without the captions). + Its behavior will be same as albumentations + :class:`~albumentations.augmentations.transforms.HorizontalFlip`. + + Examples + -------- + >>> flip = HorizontalFlip(p=0.5) + >>> out1 = flip(image=image, caption=caption) # keys: {"image", "caption"} + >>> # Also works with images (without caption). + >>> out2 = flip(image=image) # keys: {"image"} + + """ + + def apply(self, img, **params): + return cv2.flip(img, 1) + + def apply_to_caption(self, caption, **params): + caption = ( + caption.replace("left", "[TMP]") + .replace("right", "left") + .replace("[TMP]", "right") + ) + return caption + + +class RandomResizedSquareCrop(alb.RandomResizedCrop): + r""" + A variant of :class:`albumentations.augmentations.transforms.RandomResizedCrop` + which assumes a square crop (width = height). Everything else is same. + + Parameters + ---------- + size: int + Dimension of the width and height of the cropped image. + """ + + def __init__(self, size: int, *args, **kwargs): + super().__init__(height=size, width=size, *args, **kwargs) + + +class CenterSquareCrop(alb.CenterCrop): + r""" + A variant of :class:`albumentations.augmentations.transforms.CenterCrop` which + assumes a square crop (width = height). Everything else is same. + + Parameters + ---------- + size: int + Dimension of the width and height of the cropped image. + """ + + def __init__(self, size: int, *args, **kwargs): + super().__init__(height=size, width=size, *args, **kwargs) + + +class SquareResize(alb.Resize): + r""" + A variant of :class:`albumentations.augmentations.transforms.Resize` which + assumes a square resize (width = height). Everything else is same. + + Parameters + ---------- + size: int + Dimension of the width and height of the resized image. + """ + + def __init__(self, size: int, *args, **kwargs): + super().__init__(height=size, width=size, *args, **kwargs) + + +# ============================================================================= +# SOME COMMON CONSTANTS AND IMAGE TRANSFORMS: +# These serve as references here, and are used as default params in many +# dataset class constructors. +# ----------------------------------------------------------------------------- + +IMAGENET_COLOR_MEAN = (0.485, 0.456, 0.406) +r"""ImageNet color normalization mean in RGB format (values in 0-1).""" + +IMAGENET_COLOR_STD = (0.229, 0.224, 0.225) +r"""ImageNet color normalization std in RGB format (values in 0-1).""" + +DEFAULT_IMAGE_TRANSFORM = alb.Compose( + [ + alb.SmallestMaxSize(256, p=1.0), + CenterSquareCrop(224, p=1.0), + alb.Normalize(mean=IMAGENET_COLOR_MEAN, std=IMAGENET_COLOR_STD, p=1.0), + ] +) +r"""Default transform without any data augmentation (during pretraining).""" +# ============================================================================= diff --git a/virtex/virtex/factories.py b/virtex/virtex/factories.py new file mode 100644 index 0000000000000000000000000000000000000000..67636b202eef931829321d757673a5f65fbf95c6 --- /dev/null +++ b/virtex/virtex/factories.py @@ -0,0 +1,638 @@ +r""" +This module is a collection of *factories* for creating objects of datasets, +models, optimizers and other useful components. For example, a ResNet-50 +visual backbone can be created as: + + .. code-block:: python + + >>> # Explicitly by name, args and kwargs: + >>> backbone = VisualBackboneFactory.create( + ... "torchvision::resnet50", pretrained=False + ... ) + >>> # Directly from a config object: + >>> _C = Config(override_list=["MODEL.VISUAL.NAME", "torchvision::resnet50"]) + >>> backbone = VisualBackboneFactory.from_config(_C) + +Creating directly from :class:`~virtex.config.Config` is fast and simple, and +ensures minimal changes throughout the codebase upon any change in the call +signature of underlying class; or config hierarchy. Refer description of +specific factories for more details. +""" +import re +from functools import partial +from typing import Any, Callable, Dict, Iterable, List + +import albumentations as alb +from torch import nn, optim + +import virtex.data as vdata +import virtex.models as vmodels +import virtex.utils.distributed as dist +from virtex.config import Config +from virtex.data import transforms as T +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.modules import visual_backbones, textual_heads +from virtex.optim import Lookahead, lr_scheduler +from virtex.utils.beam_search import AutoRegressiveBeamSearch +from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling + +class Factory(object): + r""" + Base class for all factories. All factories must inherit this base class + and follow these guidelines for a consistent behavior: + + * Factory objects cannot be instantiated, doing ``factory = SomeFactory()`` + is illegal. Child classes should not implement ``__init__`` methods. + * All factories must have an attribute named ``PRODUCTS`` of type + ``Dict[str, Callable]``, which associates each class with a unique string + name which can be used to create it. + * All factories must implement one classmethod, :meth:`from_config` which + contains logic for creating an object directly by taking name and other + arguments directly from :class:`~virtex.config.Config`. They can use + :meth:`create` already implemented in this base class. + * :meth:`from_config` should not use too many extra arguments than the + config itself, unless necessary (such as model parameters for optimizer). + """ + + PRODUCTS: Dict[str, Callable] = {} + + def __init__(self): + raise ValueError( + f"""Cannot instantiate {self.__class__.__name__} object, use + `create` classmethod to create a product from this factory. + """ + ) + + @classmethod + def create(cls, name: str, *args, **kwargs) -> Any: + r"""Create an object by its name, args and kwargs.""" + if name not in cls.PRODUCTS: + raise KeyError(f"{cls.__class__.__name__} cannot create {name}.") + + return cls.PRODUCTS[name](*args, **kwargs) + + @classmethod + def from_config(cls, config: Config) -> Any: + r"""Create an object directly from config.""" + raise NotImplementedError + + +class TokenizerFactory(Factory): + r""" + Factory to create text tokenizers. This codebase ony supports one tokenizer + for now, but having a dedicated factory makes it easy to add more if needed. + + Possible choices: ``{"SentencePieceBPETokenizer"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "SentencePieceBPETokenizer": SentencePieceBPETokenizer + } + + @classmethod + def from_config(cls, config: Config) -> SentencePieceBPETokenizer: + r""" + Create a tokenizer directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + + tokenizer = cls.create( + "SentencePieceBPETokenizer", model_path=_C.DATA.TOKENIZER_MODEL + ) + return tokenizer + + +class ImageTransformsFactory(Factory): + r""" + Factory to create image transformations for common preprocessing and data + augmentations. These are a mix of default transformations from + `albumentations `_ and + some extended ones defined in :mod:`virtex.data.transforms`. + + It uses sensible default values, however they can be provided with the name + in dict syntax. Example: ``random_resized_crop::{'scale': (0.08, 1.0)}`` + + .. note:: + + This factory does not implement :meth:`from_config` method. It is only + used by :class:`PretrainingDatasetFactory` and + :class:`DownstreamDatasetFactory`. + + Possible choices: ``{"center_crop", "horizontal_flip", "random_resized_crop", + "normalize", "global_resize", "color_jitter", "smallest_resize"}``. + """ + + # fmt: off + PRODUCTS: Dict[str, Callable] = { + # Input resize transforms: whenever selected, these are always applied. + # These transforms require one position argument: image dimension. + "random_resized_crop": partial( + T.RandomResizedSquareCrop, scale=(0.2, 1.0), ratio=(0.75, 1.333), p=1.0 + ), + "center_crop": partial(T.CenterSquareCrop, p=1.0), + "smallest_resize": partial(alb.SmallestMaxSize, p=1.0), + "global_resize": partial(T.SquareResize, p=1.0), + + # Keep hue limits small in color jitter because it changes color drastically + # and captions often mention colors. Apply with higher probability. + "color_jitter": partial( + alb.ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8 + ), + "horizontal_flip": partial(T.HorizontalFlip, p=0.5), + + # Color normalization: whenever selected, always applied. This accepts images + # in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`. + "normalize": partial( + alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0 + ), + } + # fmt: on + + @classmethod + def create(cls, name: str, *args, **kwargs) -> Any: + r"""Create an object by its name, args and kwargs.""" + + if "::" in name: + name, __kwargs = name.split("::") + _kwargs = eval(__kwargs) + else: + _kwargs = {} + + _kwargs.update(kwargs) + return super().create(name, *args, **_kwargs) + + @classmethod + def from_config(cls, config: Config): + r"""Augmentations cannot be created from config, only :meth:`create`.""" + raise NotImplementedError + + +class PretrainingDatasetFactory(Factory): + r""" + Factory to create :class:`~torch.utils.data.Dataset` s for pretraining + VirTex models. Datasets are created depending on pretraining task used. + Typically these datasets either provide image-caption pairs, or only images + from COCO Captions dataset (serialized to an LMDB file). + + As an exception, the dataset for ``multilabel_classification`` provides + COCO images and labels of their bounding box annotations. + + Possible choices: ``{"bicaptioning", "captioning", "masked_lm", + "token_classification", "multilabel_classification"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "virtex": vdata.CaptioningDataset, + "bicaptioning": vdata.CaptioningDataset, + "captioning": vdata.CaptioningDataset, + "masked_lm": vdata.MaskedLmDataset, + "token_classification": vdata.TokenClassificationDataset, + "multilabel_classification": vdata.MultiLabelClassificationDataset, + "virtex_web": vdata.TarfileDataset, + "miniclip_web": vdata.TarfileDataset, + } + + @classmethod + def from_config(cls, config: Config, split: str = "train"): + r""" + Create a dataset directly from config. Names in this factory match with + names in :class:`PretrainingModelFactory` because both use same config + parameter ``MODEL.NAME`` to create objects. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + split: str, optional (default = "train") + Which split to load for the dataset. One of ``{"train", "val"}``. + """ + + _C = config + # Every dataset needs these two args. + kwargs = {"data_root": _C.DATA.ROOT, "split": split} + + # Create a list of image transformations based on transform names. + image_transform_list: List[Callable] = [] + + for name in getattr(_C.DATA, f"IMAGE_TRANSFORM_{split.upper()}"): + # Pass dimensions if cropping / resizing, else rely on the defaults + # as per `ImageTransformsFactory`. + if "resize" in name or "crop" in name: + image_transform_list.append( + ImageTransformsFactory.create(name, _C.DATA.IMAGE_CROP_SIZE) + ) + else: + image_transform_list.append(ImageTransformsFactory.create(name)) + + kwargs["image_transform"] = alb.Compose(image_transform_list) + + tokenizer = TokenizerFactory.from_config(_C) + + if _C.MODEL.NAME in {"virtex", "bicaptioning", "captioning"}: + kwargs.update( + tokenizer=tokenizer, + max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, + use_single_caption=_C.DATA.USE_SINGLE_CAPTION, + percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0, + ) + + elif _C.MODEL.NAME == "token_classification": + kwargs.update( + tokenizer=tokenizer, max_caption_length=_C.DATA.MAX_CAPTION_LENGTH + ) + + elif _C.MODEL.NAME == "masked_lm": + kwargs.update( + tokenizer=tokenizer, + max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, + use_single_caption=_C.DATA.USE_SINGLE_CAPTION, + percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0, + mask_proportion=_C.DATA.MASKED_LM.MASK_PROPORTION, + mask_probability=_C.DATA.MASKED_LM.MASK_PROBABILITY, + replace_probability=_C.DATA.MASKED_LM.REPLACE_PROBABILITY, + ) + + elif _C.MODEL.NAME in {"virtex_web", "miniclip_web"}: + # Remove "split" argument, not necessary. + _ = kwargs.pop("split") + kwargs.update( + batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), + tokenizer=tokenizer, + max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, + ) + + # Dataset names match with model names (and ofcourse pretext names). + return cls.create(_C.MODEL.NAME, **kwargs) + + +class DownstreamDatasetFactory(Factory): + r""" + Factory to create :class:`~torch.utils.data.Dataset` s for evaluating + VirTex models on downstream tasks. + + Possible choices: ``{"datasets/VOC2007", "datasets/imagenet"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "datasets/VOC2007": vdata.VOC07ClassificationDataset, + "datasets/imagenet": vdata.ImageNetDataset, + "datasets/inaturalist": vdata.INaturalist2018Dataset, + } + + @classmethod + def from_config(cls, config: Config, split: str = "train"): + r""" + Create a dataset directly from config. Names in this factory are paths + of dataset directories (relative to the project directory), because + config parameter ``DATA.ROOT`` is used to create objects. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + split: str, optional (default = "train") + Which split to load for the dataset. One of ``{"trainval", "test"}`` + for VOC2007, or one of ``{"train", "val"}`` for ImageNet. + """ + + _C = config + # Every dataset needs these two args. + kwargs = {"data_root": _C.DATA.ROOT, "split": split} + + # For VOC2007, `IMAGE_TRANSFORM_TRAIN` is used for "trainval" split and + # `IMAGE_TRANSFORM_VAL` is used fo "test" split. + image_transform_names: List[str] = list( + _C.DATA.IMAGE_TRANSFORM_TRAIN + if "train" in split + else _C.DATA.IMAGE_TRANSFORM_VAL + ) + # Create a list of image transformations based on names. + image_transform_list: List[Callable] = [] + + for name in image_transform_names: + # Pass dimensions for resize/crop, else rely on the defaults. + if name.split("::")[0] in { + "random_resized_crop", + "center_crop", + "global_resize", + }: + transform = ImageTransformsFactory.create(name, 224) + elif name.split("::")[0] in {"smallest_resize"}: + transform = ImageTransformsFactory.create(name, 256) + else: + transform = ImageTransformsFactory.create(name) + + image_transform_list.append(transform) + + kwargs["image_transform"] = alb.Compose(image_transform_list) + + return cls.create(_C.DATA.ROOT, **kwargs) + + +class VisualBackboneFactory(Factory): + r""" + Factory to create :mod:`~virtex.modules.visual_backbones`. This factory + supports any ResNet-like model from + `Torchvision `_. + Use the method name for model as in torchvision, for example, + ``torchvision::resnet50``, ``torchvision::wide_resnet50_2`` etc. + + Possible choices: ``{"torchvision", "timm"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "torchvision": visual_backbones.TorchvisionVisualBackbone, + "timm": visual_backbones.TimmVisualBackbone, + } + + @classmethod + def from_config(cls, config: Config) -> visual_backbones.VisualBackbone: + r""" + Create a visual backbone directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + kwargs = {"visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE} + + # Check the name for models from torchvision or timm. + package_name, cnn_name = _C.MODEL.VISUAL.NAME.split("::") + kwargs["pretrained"] = _C.MODEL.VISUAL.PRETRAINED + kwargs["frozen"] = _C.MODEL.VISUAL.FROZEN + + return cls.create(package_name, cnn_name, **kwargs) + + +class TextualHeadFactory(Factory): + r""" + Factory to create :mod:`~virtex.modules.textual_heads`. Architectural + hyperparameters for transformers can be specified as ``name::*``. + For example, ``transdec_postnorm::L1_H1024_A16_F4096`` would create a + transformer textual head with ``L = 1`` layers, ``H = 1024`` hidden size, + ``A = 16`` attention heads, and ``F = 4096`` size of feedforward layers. + + Textual head should be ``"none"`` for pretraining tasks which do not + involve language modeling, such as ``"token_classification"``. + + Possible choices: ``{"transdec_postnorm", "transdec_prenorm", "none"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "transdec_prenorm": partial( + textual_heads.TransformerDecoderTextualHead, norm_type="pre" + ), + "transdec_postnorm": partial( + textual_heads.TransformerDecoderTextualHead, norm_type="post" + ), + "transenc_postnorm": partial( + textual_heads.TransformerEncoderTextualHead, norm_type="post" + ), + "transenc_prenorm": partial( + textual_heads.TransformerEncoderTextualHead, norm_type="pre" + ), + "none": textual_heads.LinearTextualHead, + } + + @classmethod + def from_config(cls, config: Config) -> nn.Module: + r""" + Create a textual head directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + name = _C.MODEL.TEXTUAL.NAME + kwargs = { + "visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE, + "vocab_size": _C.DATA.VOCAB_SIZE, + } + + if "trans" in _C.MODEL.TEXTUAL.NAME: + # Get architectural hyper-params as per name by matching regex. + name, architecture = name.split("::") + architecture = re.match(r"L(\d+)_H(\d+)_A(\d+)_F(\d+)", architecture) + + num_layers = int(architecture.group(1)) + hidden_size = int(architecture.group(2)) + attention_heads = int(architecture.group(3)) + feedforward_size = int(architecture.group(4)) + + # Mask the future tokens for autoregressive captioning. + mask_future = _C.MODEL.NAME in {"virtex", "virtex_web", "captioning", "bicaptioning"} + + kwargs.update( + hidden_size=hidden_size, + num_layers=num_layers, + attention_heads=attention_heads, + feedforward_size=feedforward_size, + dropout=_C.MODEL.TEXTUAL.DROPOUT, + mask_future_positions=mask_future, + max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, + padding_idx=_C.DATA.UNK_INDEX, + ) + return cls.create(name, **kwargs) + + +class PretrainingModelFactory(Factory): + r""" + Factory to create :mod:`~virtex.models` for different pretraining tasks. + + Possible choices: ``{"bicaptioning", "captioning", "masked_lm", + "token_classification", "multilabel_classification"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + # First two are basically the same. Added for shorthand notation. + "virtex": vmodels.VirTexModel, + "bicaptioning": vmodels.BidirectionalCaptioningModel, + "captioning": vmodels.ForwardCaptioningModel, + "masked_lm": vmodels.MaskedLMModel, + "token_classification": vmodels.TokenClassificationModel, + "multilabel_classification": vmodels.MultiLabelClassificationModel, + "virtex_web": vmodels.VirTexModel, + "miniclip_web": vmodels.ImageTextContrastiveModel, + } + + @classmethod + def from_config(cls, config: Config) -> nn.Module: + r""" + Create a model directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + + # Build visual and textual streams based on config. + visual = VisualBackboneFactory.from_config(_C) + textual = TextualHeadFactory.from_config(_C) + + # Add model specific kwargs. Refer call signatures of specific models + # for matching kwargs here. + if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning", "virtex_web"}: + kwargs = { + "sos_index": _C.DATA.SOS_INDEX, + "eos_index": _C.DATA.EOS_INDEX, + "label_smoothing": _C.MODEL.LABEL_SMOOTHING, + "decoder": CaptionDecoderFactory.from_config(_C), + } + elif _C.MODEL.NAME in {"miniclip_web"}: + kwargs = {"label_smoothing": _C.MODEL.LABEL_SMOOTHING} + + elif _C.MODEL.NAME == "token_classification": + kwargs = { + "ignore_indices": [ + _C.DATA.UNK_INDEX, + _C.DATA.SOS_INDEX, + _C.DATA.EOS_INDEX, + _C.DATA.MASK_INDEX, + ] + } + elif _C.MODEL.NAME == "multilabel_classification": + kwargs = {"ignore_indices": [0]} # background index + else: + kwargs = {} + + return cls.create(_C.MODEL.NAME, visual, textual, **kwargs) + + +class CaptionDecoderFactory(Factory): + r""" + Factory to create decoders from predicting captions from VirTex model. + + Possible choices: ``{"beam_search", "nucleus_sampling"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "beam_search": AutoRegressiveBeamSearch, + "nucleus_sampling": AutoRegressiveNucleusSampling, + } + + @classmethod + def from_config(cls, config: Config) -> nn.Module: + r""" + Create a model directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + """ + + _C = config + kwargs = { + "eos_index": _C.DATA.EOS_INDEX, + "max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS, + } + if _C.MODEL.DECODER.NAME == "beam_search": + kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE + elif _C.MODEL.DECODER.NAME == "nucleus_sampling": + kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE + + return cls.create(_C.MODEL.DECODER.NAME, **kwargs) + + +class OptimizerFactory(Factory): + r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``.""" + + PRODUCTS: Dict[str, Callable] = {"sgd": optim.SGD, "adamw": optim.AdamW} + + @classmethod + def from_config( + cls, config: Config, named_parameters: Iterable[Any] + ) -> optim.Optimizer: + r""" + Create an optimizer directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + named_parameters: Iterable + Named parameters of model (retrieved by ``model.named_parameters()``) + for the optimizer. We use named parameters to set different LR and + turn off weight decay for certain parameters based on their names. + """ + + _C = config + + # Set different learning rate for CNN and rest of the model during + # pretraining. This doesn't matter for downstream evaluation because + # there are no modules with "cnn" in their name. + # Also turn off weight decay for layer norm and bias in textual stream. + param_groups = [] + for name, param in named_parameters: + wd = 0.0 if re.match(_C.OPTIM.NO_DECAY, name) else _C.OPTIM.WEIGHT_DECAY + lr = _C.OPTIM.CNN_LR if "cnn" in name else _C.OPTIM.LR + param_groups.append({"params": [param], "lr": lr, "weight_decay": wd}) + + if _C.OPTIM.OPTIMIZER_NAME == "sgd": + kwargs = {"momentum": _C.OPTIM.SGD_MOMENTUM} + else: + kwargs = {} + + optimizer = cls.create(_C.OPTIM.OPTIMIZER_NAME, param_groups, **kwargs) + if _C.OPTIM.LOOKAHEAD.USE: + optimizer = Lookahead( + optimizer, k=_C.OPTIM.LOOKAHEAD.STEPS, alpha=_C.OPTIM.LOOKAHEAD.ALPHA + ) + return optimizer + + +class LRSchedulerFactory(Factory): + r""" + Factory to create LR schedulers. All schedulers have a built-in LR warmup + schedule before actual LR scheduling (decay) starts. + + Possible choices: ``{"none", "multistep", "linear", "cosine"}``. + """ + + PRODUCTS: Dict[str, Callable] = { + "none": lr_scheduler.LinearWarmupNoDecayLR, + "multistep": lr_scheduler.LinearWarmupMultiStepLR, + "linear": lr_scheduler.LinearWarmupLinearDecayLR, + "cosine": lr_scheduler.LinearWarmupCosineAnnealingLR, + } + + @classmethod + def from_config( + cls, config: Config, optimizer: optim.Optimizer + ) -> optim.lr_scheduler.LambdaLR: + r""" + Create an LR scheduler directly from config. + + Parameters + ---------- + config: virtex.config.Config + Config object with all the parameters. + optimizer: torch.optim.Optimizer + Optimizer on which LR scheduling would be performed. + """ + + _C = config + kwargs = { + "total_steps": _C.OPTIM.NUM_ITERATIONS, + "warmup_steps": _C.OPTIM.WARMUP_STEPS, + } + # Multistep LR requires multiplicative factor and milestones. + if _C.OPTIM.LR_DECAY_NAME == "multistep": + kwargs.update(gamma=_C.OPTIM.LR_GAMMA, milestones=_C.OPTIM.LR_STEPS) + + return cls.create(_C.OPTIM.LR_DECAY_NAME, optimizer, **kwargs) diff --git a/virtex/virtex/model_zoo/__init__.py b/virtex/virtex/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ada912fc95eb52edc418d1dbad7c4392295f30 --- /dev/null +++ b/virtex/virtex/model_zoo/__init__.py @@ -0,0 +1,3 @@ +from .model_zoo import get + +__all__ = ["get"] diff --git a/virtex/virtex/model_zoo/model_zoo.py b/virtex/virtex/model_zoo/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..333b68a25c72090669468c2c6da6b51dfaf757f7 --- /dev/null +++ b/virtex/virtex/model_zoo/model_zoo.py @@ -0,0 +1,105 @@ +r""" +A utility module to easily load common VirTex models (optionally with pretrained +weights) using a single line of code. + +Get our full best performing VirTex model (with pretrained weights as): + +>>> import virtex.model_zoo as mz +>>> model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True) + +Any config available in ``configs/`` directory under project root can be +specified here, although this command need not be executed from project root. +For more details on available models, refer :doc:`usage/model_zoo`. + +Part of this code is adapted from Detectron2's model zoo; which was originally +implemented by the developers of this codebase, with reviews and further +changes by Detectron2 developers. +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import os +import pkg_resources + +from fvcore.common.download import download +import torch + +from virtex.config import Config +from virtex.factories import PretrainingModelFactory +from virtex.utils.checkpointing import CheckpointManager + + +class _ModelZooUrls(object): + r"""Mapping from config names to URL suffixes of pretrained weights.""" + + URL_PREFIX = "https://umich.box.com/shared/static" + + CONFIG_PATH_TO_URL_SUFFIX = { + + # Pretraining Task Ablations + "task_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth", + "task_ablations/captioning_R_50_L1_H2048.yaml": "1q9qh1cj2u4r5laj7mefd2mlzwthnga7.pth", + "task_ablations/token_classification_R_50.yaml": "idvoxjl60pzpcllkbvadqgvwazil2mis.pth", + "task_ablations/multilabel_classification_R_50.yaml": "yvlflmo0klqy3m71p6ug06c6aeg282hy.pth", + "task_ablations/masked_lm_R_50_L1_H2048.yaml": "x3eij00eslse9j35t9j9ijyj8zkbkizh.pth", + + # Width Ablations + "width_ablations/bicaptioning_R_50_L1_H512.yaml": "wtk18v0vffws48u5yrj2qjt94wje1pit.pth", + "width_ablations/bicaptioning_R_50_L1_H768.yaml": "e94n0iexdvksi252bn7sm2vqjnyt9okf.pth", + "width_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", + "width_ablations/bicaptioning_R_50_L1_H2048.yaml": "zu8zxtxrron29icd76owgjzojmfcgdk3.pth", + + # Depth Ablations + "depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", + "depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "9e88f6l13a9r8wq5bbe8qnoh9zenanq3.pth", + "depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "4cv8052xiq91h7lyx52cp2a6m7m9qkgo.pth", + "depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "bk5w4471mgvwa5mv6e4c7htgsafzmfm0.pth", + + # Backbone Ablations + "backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "1so9cu9y06gy27rqbzwvek4aakfd8opf.pth", + "backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "19vcaf1488945836kir9ebm5itgtugaw.pth", + "backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "nptbh4jsj0c0kjsnc2hw754fkikpgx9v.pth", + } + + +def get(config_path, pretrained: bool = False): + r""" + Get a model specified by relative path under Detectron2's official + ``configs/`` directory. + + Parameters + ---------- + config_path: str + Name of config file relative to ``configs/`` directory under project + root. (For example, ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) + pretrained: bool, optional (default = False) + If ``True``, will initialize the model with the pretrained weights. If + ``False``, the weights will be initialized randomly. + """ + + # Get the original path to config file (shipped with inside the package). + _pkg_config_path = pkg_resources.resource_filename( + "virtex.model_zoo", os.path.join("configs", config_path) + ) + if not os.path.exists(_pkg_config_path): + raise RuntimeError("{} not available in Model Zoo!".format(config_path)) + + _C = Config(_pkg_config_path) + model = PretrainingModelFactory.from_config(_C) + + if pretrained: + # Get URL for the checkpoint for this config path. + if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX: + url_suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path] + checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{url_suffix}" + else: + raise RuntimeError("{} not available in Model Zoo!".format(config_path)) + + # Download the pretrained model weights and save with a sensible name. + # This will be downloaded only if it does not exist. + checkpoint_path = download( + checkpoint_url, + dir=os.path.expanduser("~/.torch/virtex_cache"), + filename=os.path.basename(config_path).replace(".yaml", ".pth") + ) + CheckpointManager(model=model).load(checkpoint_path) + + return model diff --git a/virtex/virtex/models/__init__.py b/virtex/virtex/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58b32e9e44a5feb015822614ac4aaa5c830ea712 --- /dev/null +++ b/virtex/virtex/models/__init__.py @@ -0,0 +1,22 @@ +from .captioning import ( + ForwardCaptioningModel, + BidirectionalCaptioningModel, + VirTexModel +) +from .masked_lm import MaskedLMModel +from .classification import ( + MultiLabelClassificationModel, + TokenClassificationModel, +) +from .contrastive import ImageTextContrastiveModel + + +__all__ = [ + "VirTexModel", + "BidirectionalCaptioningModel", + "ForwardCaptioningModel", + "MaskedLMModel", + "MultiLabelClassificationModel", + "TokenClassificationModel", + "ImageTextContrastiveModel", +] diff --git a/virtex/virtex/models/captioning.py b/virtex/virtex/models/captioning.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8e94248c45c9222369e3d11b8f9de0c82e9bd7 --- /dev/null +++ b/virtex/virtex/models/captioning.py @@ -0,0 +1,316 @@ +import copy +import functools +from typing import Any, Dict + +import torch +from torch import nn + +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing +from virtex.modules.textual_heads import TextualHead +from virtex.modules.visual_backbones import VisualBackbone + + +class CaptioningModel(nn.Module): + r""" + A model to perform image captioning (in both forward and backward directions + independently, only in forward direction). It is composed of a + :class:`~virtex.modules.visual_backbones.VisualBackbone` and a + :class:`~virtex.modules.textual_heads.TextualHead` on top of it. + + During training, it maximizes the likelihood of ground truth caption + conditioned on image features. During inference, it predicts a caption for + an input image through beam search decoding. + + Parameters + ---------- + visual: virtex.modules.visual_backbones.VisualBackbone + A :class:`~virtex.modules.visual_backbones.VisualBackbone` which + computes visual features from an input image. + textual: virtex.modules.textual_heads.TextualHead + A :class:`~virtex.modules.textual_heads.TextualHead` which + makes final predictions conditioned on visual features. + sos_index: int, optional (default = 1) + The index of the end token (``[SOS]``) in vocabulary. + eos_index: int, optional (default = 2) + The index of the end token (``[EOS]``) in vocabulary. + caption_backward: bool, optional (default = False) + Whether to *also* perform captioning in backward direction. Default is + ``False`` -- only forward captioning is performed. When ``True``, a + clone of textual head is created, which does not share weights with + "forward" model except input and output embeddings. + decoder: Any, optional (default = None) + An instance of :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` + or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling` + for decoding captions during inference (unused during training). + """ + + def __init__( + self, + visual: VisualBackbone, + textual: TextualHead, + caption_backward: bool = False, + sos_index: int = 1, + eos_index: int = 2, + label_smoothing: float = 0.0, + decoder: Any = None, + ): + super().__init__() + self.visual = visual + self.textual = textual + self.padding_idx = self.textual.padding_idx + self.caption_backward = caption_backward + + # Clone the textual module for backward direction if doing captioning + # in both directions (separately). + if self.caption_backward: + self.backward_textual = copy.deepcopy(self.textual) + + # Share weights for visual projection, and input/output embeddings. + self.backward_textual.visual_projection = self.textual.visual_projection + self.backward_textual.embedding = self.textual.embedding + self.backward_textual.output = self.textual.output + + # These boundary indices are needed for beam search. + self.sos_index = sos_index + self.eos_index = eos_index + self.decoder = decoder + self.loss = CrossEntropyLossWithLabelSmoothing( + label_smoothing, ignore_index=self.padding_idx + ) + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: + r""" + Given a batch of images and captions, compute log likelihood loss per + caption token during training. During inference (with images), predict + a caption through either beam search decoding or nucleus sampling. + + Parameters + ---------- + batch: Dict[str, torch.Tensor] + Training or inference batch. During training, a batch would at least + contain keys ``{"image", "caption_tokens", "caption_lengths"}`` and + also ``"noitpac_tokens"`` for bicaptioning. + During inference, a batch would contain key ``{"image"}`` and + optionally ``"decode_prompt"`` as a partial sequence for decoding. + + Returns + ------- + Dict[str, Any] + + A dict with the following structure, containing loss for optimization, + loss components to log directly to tensorboard, and optionally + predictions. + + .. code-block:: + + { + "loss": torch.Tensor, + "loss_components": { + "captioning_forward": torch.Tensor, + "captioning_backward": torch.Tensor, (optional) + }, + "predictions": torch.Tensor + } + """ + + # shape: (batch_size, channels, height, width) + visual_features = self.visual(batch["image"]) + batch_size = visual_features.size(0) + + if "caption_tokens" in batch: + caption_tokens = batch["caption_tokens"] + caption_lengths = batch["caption_lengths"] + + # shape: (batch_size, max_caption_length, vocab_size) + output_logits = self.textual( + visual_features, caption_tokens, caption_lengths + ) + loss = self.loss( + output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), + caption_tokens[:, 1:].contiguous().view(-1), + ) + output_dict: Dict[str, Any] = { + "loss": loss, + # Single scalar per batch for logging in training script. + "loss_components": {"captioning_forward": loss.clone().detach()}, + } + # Do captioning in backward direction if specified. + if self.caption_backward: + backward_caption_tokens = batch["noitpac_tokens"] + + backward_output_logits = self.backward_textual( + visual_features, backward_caption_tokens, caption_lengths + ) + backward_loss = self.loss( + backward_output_logits[:, :-1] + .contiguous() + .view(-1, self.textual.vocab_size), + backward_caption_tokens[:, 1:].contiguous().view(-1), + ) + output_dict["loss"] += backward_loss + + # Single scalar per batch for logging in training script. + output_dict["loss_components"].update( + captioning_backward=backward_loss.clone().detach() + ) + + if not self.training: + # During validation (while pretraining), get best prediction + # at every timestep. + output_dict["predictions"] = torch.argmax(output_logits, dim=-1) + else: + if self.decoder is None: + raise ValueError("Decoder for predicting captions is missing!") + + # During inference, decode captions from forward transformer model. + # Check if the batch contains decoding prompt. + if "decode_prompt" in batch: + + # shape: (batch_size, prompt_length) + start_predictions = torch.unsqueeze(batch["decode_prompt"], 0) + start_predictions = start_predictions.repeat(batch_size, 1) + else: + # shape: (batch_size, ) + start_predictions = torch.full( + (batch_size,), self.sos_index, device=visual_features.device + ).long() + + # Add image features as a default argument to match callable + # signature accepted by beam search class (partial captions only). + decoding_step = functools.partial(self.decoding_step, visual_features) + + predicted_caption, _ = self.decoder.search( + start_predictions, decoding_step + ) + output_dict = {"predictions": predicted_caption} + + return output_dict + + def decoding_step( + self, visual_features: torch.Tensor, partial_captions: torch.Tensor + ) -> torch.Tensor: + r""" + Given visual features and a batch of (assumed) partial captions, predict + the logits over output vocabulary tokens for next timestep. This method + is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` + and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`. + + .. note:: + + For nucleus sampling, ``beam_size`` will always be 1 (not relevant). + + Parameters + ---------- + projected_visual_features: torch.Tensor + A tensor of shape ``(batch_size, ..., textual_feature_size)`` + with visual features already projected to ``textual_feature_size``. + partial_captions: torch.Tensor + A tensor of shape ``(batch_size * beam_size, timesteps)`` + containing tokens predicted so far -- one for each beam. We need all + prior predictions because our model is auto-regressive. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits + over output vocabulary tokens for next timestep. + """ + + # Expand and repeat image features while doing beam search. + batch_size, channels, height, width = visual_features.size() + beam_size = int(partial_captions.size(0) / batch_size) + if beam_size > 1: + # shape: (batch_size * beam_size, channels, height, width) + visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1) + visual_features = visual_features.view( + batch_size * beam_size, channels, height, width + ) + + # Provide caption lengths as current length (irrespective of predicted + # EOS/padding tokens). shape: (batch_size, ) + caption_lengths = torch.ones_like(partial_captions) + if len(caption_lengths.size()) == 2: + caption_lengths = caption_lengths.sum(1) + else: + # Add a timestep. shape: (batch_size, 1) + partial_captions = partial_captions.unsqueeze(1) + + # shape: (batch_size * beam_size, partial_caption_length, vocab_size) + logits = self.textual(visual_features, partial_captions, caption_lengths) + # Return logits from the last timestep. + return logits[:, -1, :] + + def log_predictions( + self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer + ) -> str: + + self.eval() + with torch.no_grad(): + predictions = self.forward(batch)["predictions"] + self.train() + + predictions_str = "" + for tokens, preds in zip(batch["caption_tokens"], predictions): + predictions_str += f""" + Caption tokens : {" ".join(tokens.tolist())} + Predictions (f): {" ".join(preds.tolist())} + + """ + return predictions_str + + +class ForwardCaptioningModel(CaptioningModel): + r""" + Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` + for better readability: this passes ``caption_backward=False`` to super class. + """ + + def __init__( + self, + visual: VisualBackbone, + textual: TextualHead, + sos_index: int = 1, + eos_index: int = 2, + label_smoothing: float = 0.0, + decoder: Any = None, + ): + super().__init__( + visual, + textual, + sos_index=sos_index, + eos_index=eos_index, + caption_backward=False, + label_smoothing=label_smoothing, + decoder=decoder, + ) + + +class BidirectionalCaptioningModel(CaptioningModel): + r""" + Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` + for better readability: this passes ``caption_backward=True`` to super class. + """ + + def __init__( + self, + visual: VisualBackbone, + textual: TextualHead, + sos_index: int = 1, + eos_index: int = 2, + label_smoothing: float = 0.0, + decoder: Any = None, + ): + super().__init__( + visual, + textual, + sos_index=sos_index, + eos_index=eos_index, + caption_backward=True, + label_smoothing=label_smoothing, + decoder=decoder, + ) + + +# Convenient handle for our main model. +VirTexModel = BidirectionalCaptioningModel diff --git a/virtex/virtex/models/classification.py b/virtex/virtex/models/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9dd60bc7108905d89ec1816beb51773682ef18 --- /dev/null +++ b/virtex/virtex/models/classification.py @@ -0,0 +1,183 @@ +from typing import Any, Dict, List + +import torch +from torch import nn +from torch.nn import functional as F + +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.modules.textual_heads import TextualHead +from virtex.modules.visual_backbones import VisualBackbone + + +class ClassificationModel(nn.Module): + r""" + A model to perform classification (generally, with multiple targets). It is + composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a + :class:`~virtex.modules.textual_heads.TextualHead` on top of it. + + .. note:: + + As with currently available textual heads, only one textual head is + supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`. + + During training, it minimizes the KL-divergence loss with a K-hot vector, + with values ``1/K``, where K are the number of unique labels to classify. + + Parameters + ---------- + visual: virtex.modules.visual_backbones.VisualBackbone + A :class:`~virtex.modules.visual_backbones.VisualBackbone` which + computes visual features from an input image. + textual: virtex.modules.textual_heads.TextualHead + A :class:`~virtex.modules.textual_heads.TextualHead` which + makes final predictions conditioned on visual features. + ignore_indices: List[int] + Ignore a set of token indices while computing KL-divergence loss. These + are usually the special tokens such as ``[SOS]``, ``[EOS]`` etc. + """ + + def __init__( + self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int] + ): + super().__init__() + self.visual = visual + self.textual = textual + self.ignore_indices = ignore_indices + + def forward(self, batch: Dict[str, torch.Tensor]): + r""" + Given a batch of images and set of labels, perform classification with + multiple targets by minimizing a KL-divergence loss. + + Parameters + ---------- + batch: Dict[str, torch.Tensor] + A batch of images and labels. Possible set of keys: + ``{"image_id", "image", "labels"}`` + + Returns + ------- + Dict[str, Any] + + A dict with the following structure, containing loss for optimization, + loss components to log directly to tensorboard, and optionally + predictions. + + .. code-block:: + + { + "loss": torch.Tensor, + "loss_components": { + "classification": torch.Tensor, + }, + "predictions": torch.Tensor + } + """ + + # shape: (batch_size, visual_feature_size, ...) + visual_features = self.visual(batch["image"]) + batch_size = visual_features.size(0) + + # Get logits and further log-probabilities. + # shape: (batch_size, vocab_size) + logits = self.textual(visual_features) + logprobs = F.log_softmax(logits, dim=1) + + # Average log-probs per unique token in associated caption to compute + # loss. This is simply cross-entropy with target-vector as a K-hot + # vector. Do in a for-loop, there isn't a straightforward vectorized way. + loss = torch.tensor(0.0, device=logprobs.device) + + for index in range(batch_size): + # Get unique labels for particular instance. + unique_labels = batch["labels"][index].unique() + + # Ignore indices of special tokens such as [SOS], [EOS] etc. and + # any other token specified. + unique_labels = [l for l in unique_labels if l not in self.ignore_indices] + # Get log-probabilities corresponding to these tokens. + instance_logprobs = logprobs[index, unique_labels].mean() + + # Accumulate negative log-probability for this instance in loss. + loss = loss - instance_logprobs + + # Average loss across instances. + output_dict: Dict[str, Any] = {"loss": loss / batch_size} + + # Single scalar per batch for logging to tensorboard in training script. + output_dict["loss_components"] = { + "classification": loss.clone().detach() / batch_size + } + # Return top-10 tokens according to log-probabilities during validation. + # Useful for logging. + if not self.training: + top_logprobs, top_tokens = logprobs.topk(k=10, dim=1) + output_dict["predictions"] = top_tokens + + return output_dict + + +class TokenClassificationModel(ClassificationModel): + r""" + Convenient extension of :class:`~virtex.models.classification.ClassificationModel` + for better readability (this only modifies the tensorboard logging logic). + + Ground truth targets here are a set of unique caption tokens (ignoring the + special tokens like ``[SOS]``, ``[EOS]`` etc.). + """ + + def log_predictions( + self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer + ) -> str: + + self.eval() + with torch.no_grad(): + predictions = self.forward(batch)["predictions"] + self.train() + + predictions_str = "" + for tokens, preds in zip(batch["caption_tokens"], predictions): + # Predictions here are individual tokens, and do not have any order + # like captions, so decode them separately so we don't strip off + # metaspace character and special tokens if any. + preds = [tokenizer.id_to_token(p) for p in preds.tolist()] + predictions_str += f""" + Caption tokens : {tokenizer.decode(tokens.tolist())} + Predictions (f): {" ".join(preds)} + + """ + return predictions_str + + +class MultiLabelClassificationModel(ClassificationModel): + r""" + Convenient extension of :class:`~virtex.models.classification.ClassificationModel` + for better readability (this only modifies the tensorboard logging logic). + + Ground truth targets here are a set of unique instances in images (ignoring + the special background token, category id = 0 in COCO). + """ + + def log_predictions( + self, + batch: Dict[str, torch.Tensor], + tokenizer: SentencePieceBPETokenizer = None, + ) -> str: + # We accept `tokenizer` for having consistent API but don't use it here. + self.eval() + with torch.no_grad(): + predictions = self.forward(batch)["predictions"] + self.train() + + predictions_str = "" + for tokens, preds in zip(batch["caption_tokens"], predictions): + # Predictions here are COCO category IDs, let them be as is. + # Sorted ground truth, remove background tokens. + tokens = sorted([t for t in tokens.tolist() if t != 0]) + preds = sorted(preds.tolist()[: len(tokens)]) + predictions_str += f""" + COCO Instance IDs (GT) : {tokens} + COCO Instance IDs (Pred) : {preds} + + """ + return predictions_str diff --git a/virtex/virtex/models/contrastive.py b/virtex/virtex/models/contrastive.py new file mode 100644 index 0000000000000000000000000000000000000000..a3db11a8155aa9d68579cbc0f2663d5c7e3f587b --- /dev/null +++ b/virtex/virtex/models/contrastive.py @@ -0,0 +1,119 @@ +from typing import Any, Dict + +import torch +from torch import nn +import torch.distributed as dist + +from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing +from virtex.modules.textual_heads import TextualHead +from virtex.modules.visual_backbones import VisualBackbone + + +class ImageTextContrastiveModel(nn.Module): + def __init__( + self, + visual: VisualBackbone, + textual: TextualHead, + label_smoothing: float = 0.0 + ): + super().__init__() + self.visual = visual + self.textual = textual + self.padding_idx = self.textual.padding_idx + + self.visual_projection = nn.Linear( + self.visual.visual_feature_size, + self.textual.textual_feature_size, + bias=False, + ) + self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07))) + self.loss = CrossEntropyLossWithLabelSmoothing( + label_smoothing, ignore_index=self.padding_idx + ) + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: + + # Check if logit_scale needs to be clipped from last iteration. + self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 3.912) + # 50 times + + # shape: (batch_size, channels, height, width) + visual_features = self.visual(batch["image"]) + batch_size = visual_features.size(0) + + # shape: (batch_size, channels) + visual_features = visual_features.mean(dim=[2, 3]).view(batch_size, -1) + + # shape: (batch_size, textual_feature_size) + visual_features = self.visual_projection(visual_features) + + caption_tokens = batch["caption_tokens"] + caption_lengths = batch["caption_lengths"] + + # shape: (batch_size, max_caption_length, hidden_size) + textual_features = self.textual(caption_tokens, caption_lengths) + + # Take features from the first time-step (as BERT-* models do). + # shape: (batch_size, hidden_size) + textual_features = textual_features[:, 0, :] + + # Normalize visual and textual features. + # shape: (batch_size, textual_feature_size) + visual_features = visual_features / visual_features.norm(dim=-1, keepdim=True) + textual_features = textual_features / textual_features.norm( + dim=-1, keepdim=True + ) + # Gather textual features from all processes into one large tensor to + # increase negative samples for contrastive learning. + gathered_textual_features = [ + torch.zeros_like(textual_features) for _ in range(dist.get_world_size()) + ] + dist.all_gather(gathered_textual_features, textual_features) + + # Shift features of current rank to zeroth index for easy implementation. + gathered_textual_features[0], gathered_textual_features[dist.get_rank()] = ( + gathered_textual_features[dist.get_rank()], + gathered_textual_features[0], + ) + # shape: (batch_size * world_size, textual_feature_size) + gathered_textual_features = torch.cat(gathered_textual_features, dim=0) + + # Calculate pairwise cosine similarity as logits. + logit_scale = self.logit_scale.exp() + visual_logits = logit_scale * visual_features @ gathered_textual_features.t() + + # Targets are an identity matrix (image [i] should match with caption [i]) + visual_loss = self.loss( + visual_logits, torch.arange(visual_logits.size(0)).to(visual_logits.device) + ) + + # Do the same thing for visual features. + gathered_visual_features = [ + torch.zeros_like(visual_features) for _ in range(dist.get_world_size()) + ] + dist.all_gather(gathered_visual_features, visual_features) + + gathered_visual_features[0], gathered_visual_features[dist.get_rank()] = ( + gathered_visual_features[dist.get_rank()], + gathered_visual_features[0], + ) + # shape: (batch_size * world_size, textual_feature_size) + gathered_visual_features = torch.cat(gathered_visual_features, dim=0) + + # Calculate pairwise cosine similarity as logits. + logit_scale = self.logit_scale.exp() + textual_logits = logit_scale * textual_features @ gathered_visual_features.t() + + # Targets are an identity matrix (image [i] should match with caption [i]) + textual_loss = self.loss( + textual_logits, + torch.arange(textual_logits.size(0)).to(textual_logits.device), + ) + loss = 0.5 * (visual_loss + textual_loss) + output_dict: Dict[str, Any] = { + "loss": loss, + # Single scalar per batch for logging in training script. + "loss_components": {"contrastive": loss.clone().detach()}, + } + + return output_dict diff --git a/virtex/virtex/models/masked_lm.py b/virtex/virtex/models/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..398d4cfb9ab8c7db7c83129eb24057313bc28fee --- /dev/null +++ b/virtex/virtex/models/masked_lm.py @@ -0,0 +1,114 @@ +from typing import Any, Dict + +import torch +from torch import nn + +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.modules.textual_heads import TextualHead +from virtex.modules.visual_backbones import VisualBackbone + + +class MaskedLMModel(nn.Module): + r""" + A model to perform BERT-like masked language modeling. It is composed of a + :class:`~virtex.modules.visual_backbones.VisualBackbone` and a + :class:`~virtex.modules.textual_heads.TextualHead` on top of it. + + During training, the model received caption tokens with certain tokens + replaced by ``[MASK]`` token, and it predicts these masked tokens based on + surrounding context. + + Parameters + ---------- + visual: virtex.modules.visual_backbones.VisualBackbone + A :class:`~virtex.modules.visual_backbones.VisualBackbone` which + computes visual features from an input image. + textual: virtex.modules.textual_heads.TextualHead + A :class:`~virtex.modules.textual_heads.TextualHead` which + makes final predictions conditioned on visual features. + """ + + def __init__(self, visual: VisualBackbone, textual: TextualHead): + super().__init__() + self.visual = visual + self.textual = textual + self.padding_idx = self.textual.padding_idx + self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: + r""" + Given a batch of images and captions with certain masked tokens, + predict the tokens at masked positions. + + Parameters + ---------- + batch: Dict[str, torch.Tensor] + A batch of images, ground truth caption tokens and masked labels. + Possible set of keys: ``{"image_id", "image", "caption_tokens", + "masked_labels", "caption_lengths"}``. + + Returns + ------- + Dict[str, Any] + + A dict with the following structure, containing loss for optimization, + loss components to log directly to tensorboard, and optionally + predictions. + + .. code-block:: + + { + "loss": torch.Tensor, + "loss_components": {"masked_lm": torch.Tensor}, + "predictions": torch.Tensor + } + """ + + # shape: (batch_size, channels, height, width) + visual_features = self.visual(batch["image"]) + + caption_tokens = batch["caption_tokens"] + caption_lengths = batch["caption_lengths"] + masked_labels = batch["masked_labels"] + + # shape: (batch_size, num_caption_tokens, vocab_size) + output_logits = self.textual(visual_features, caption_tokens, caption_lengths) + output_dict: Dict[str, Any] = { + "loss": self.loss( + output_logits.view(-1, output_logits.size(-1)), masked_labels.view(-1) + ) + } + # Single scalar per batch for logging in training script. + output_dict["loss_components"] = { + "masked_lm": output_dict["loss"].clone().detach() + } + # During evaluation, get predictions from logits. Useful for logging. + # Only the predictions at [MASK]ed positions are relevant. + if not self.training: + predictions = torch.argmax(output_logits, dim=-1) + redundant_positions = masked_labels == self.padding_idx + predictions[redundant_positions] = self.padding_idx + + output_dict["predictions"] = predictions + + return output_dict + + def log_predictions( + self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer + ) -> str: + + self.eval() + with torch.no_grad(): + predictions = self.forward(batch)["predictions"] + self.train() + + predictions_str = "" + for tokens, labels, preds in zip( + batch["caption_tokens"], batch["masked_labels"], predictions + ): + predictions_str += f""" + Caption tokens : {tokenizer.decode(tokens.tolist())} + Masked Labels : {tokenizer.decode(labels.tolist())} + Predictions : {tokenizer.decode(preds.tolist())} + """ + return predictions_str diff --git a/virtex/virtex/models/zero_shot_classification_eval.py b/virtex/virtex/models/zero_shot_classification_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..85893593c53227811253fca40ff8abf6d443ba56 --- /dev/null +++ b/virtex/virtex/models/zero_shot_classification_eval.py @@ -0,0 +1,112 @@ +import copy +import functools +from typing import Any, Dict + +import json + +import torch +from torch import nn + +from virtex.data.tokenizers import SentencePieceBPETokenizer +from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing +from virtex.modules.textual_heads import TextualHead +from virtex.modules.visual_backbones import VisualBackbone + + +class ZeroShotClassifier(nn.Module): + def __init__( + self, + visual: VisualBackbone, + textual: TextualHead, + ): + super().__init__() + self.visual = visual + self.textual = textual + self.padding_idx = self.textual.padding_idx + + # Clone the textual module for backward direction if doing captioning + # in both directions (separately). + self.backward_textual = copy.deepcopy(self.textual) + + # Share weights for visual projection, and input/output embeddings. + self.backward_textual.visual_projection = self.textual.visual_projection + self.backward_textual.embedding = self.textual.embedding + self.backward_textual.output = self.textual.output + + self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx,reduction='none') + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: + + # shape: (batch_size, channels, height, width) + visual_features = self.visual(batch["image"]) + batch_size = visual_features.size(0) + + classification_losses = [] + + #catagories shape: (1000, 20) + + caption_tokens = batch["caption_tokens"] + backward_caption_tokens = batch["noitpac_tokens"] + caption_lengths = batch["caption_lengths"] + print + + for i in range(caption_tokens.shape[0]): + # shape : (batch size, 20) + catagory_caption_tokens = caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1) + # shape : (batch size, 20) + catagory_backward_caption_tokens = backward_caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1) + # shape : (batch size) + catagory_caption_lengths = caption_lengths[i].unsqueeze(0).repeat(batch_size) + + #print("caption_tokens.shape:",caption_tokens.shape) + #print("backward_caption_tokens.shape:",backward_caption_tokens.shape) + #print("caption_lengths.shape:",caption_lengths.shape) + + #print("catagory_caption_tokens.shape:",catagory_caption_tokens.shape) + #print("catagory_backward_caption_tokens.shape:",catagory_backward_caption_tokens.shape) + #print("catagory_caption_lengths.shape:",catagory_caption_lengths.shape) + + output_logits = self.textual( + visual_features, catagory_caption_tokens, catagory_caption_lengths + ) + + + loss = self.loss( + output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), + catagory_caption_tokens[:, 1:].contiguous().view(-1) + ) + + # Do captioning in backward direction if specified. + backward_output_logits = self.backward_textual( + visual_features, catagory_backward_caption_tokens, catagory_caption_lengths + ) + + + backward_loss = self.loss( + backward_output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), + catagory_backward_caption_tokens[:, 1:].contiguous().view(-1), + ) + loss = loss.view(batch_size,-1).sum(dim=1) + backward_loss = backward_loss.view(batch_size,-1).sum(dim=1) + + total_scores = (-loss - backward_loss)/catagory_caption_lengths + + + #print("loss.shape:",loss.shape) + #print("backward_loss.shape:",backward_loss.shape) + #print("loss.shape:",loss.shape) + + #scores_caption = [torch.sum(x) for x in torch.chunk(loss, batch_size)] + #scores_noipac = [torch.sum(x) for x in torch.chunk(backward_loss, batch_size)] + + #total_scores = [(scores_caption[j]+scores_noipac[j]).item() for j in range(batch_size)] + + classification_losses.append(total_scores) + + + #classification_losses = torch.tensor(classification_losses) + classification_losses = torch.stack(classification_losses).t() + + return classification_losses + + diff --git a/virtex/virtex/modules/embedding.py b/virtex/virtex/modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e578f4327693aa17e03f86482597fbd203c90f5b --- /dev/null +++ b/virtex/virtex/modules/embedding.py @@ -0,0 +1,96 @@ +import functools + +import torch +from torch import nn + + +class WordAndPositionalEmbedding(nn.Module): + r""" + A :class:`~torch.nn.Module` for learned word embeddings and position + embeddings for input tokens. Each token is mapped to a fixed dimensional + word embedding; and corresponding positional embedding based on its index. + These are summed together followed by layer normalization and an optional + dropout. + + Parameters + ---------- + vocab_size: int + Size of token vocabulary. + hidden_size: int + Size of token embedding vectors. + dropout: float, optional (default = 0.1) + Dropout probability for final dropout applied after layer normalization. + max_caption_length: int, optional (default = 30) + Maximum length of input captions; this is used to create a fixed + positional embedding lookup table. + padding_idx: int, optional (default = 0) + Token index of ``[PAD]`` token, word embedding for these tokens will + be a vector of zeroes (and not trainable). + """ + def __init__( + self, + vocab_size: int, + hidden_size: int, + dropout: float = 0.0, + max_caption_length: int = 30, + padding_idx: int = 0, + ): + super().__init__() + self.vocab_size = vocab_size + self.padding_idx = padding_idx + + self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx) + + # We provide no "padding index" for positional embeddings. We zero out + # the positional embeddings of padded positions as a post-processing. + self.positions = nn.Embedding(max_caption_length, hidden_size) + self.layer_norm = nn.LayerNorm( + hidden_size, eps=1e-8, elementwise_affine=True + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + r""" + Get combined word and positional embeddings for input tokens. + + Parameters + ---------- + tokens: torch.Tensor + A tensor of shape ``(batch_size, max_caption_length)`` containing + a batch of caption tokens, with values in ``[0, vocab_size)``. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size, max_caption_length, hidden_size)`` + containing corresponding token embeddings. + """ + position_indices = self._create_position_indices(tokens) + + # shape: (batch_size, max_caption_length, hidden_size) + word_embeddings = self.words(tokens) + position_embeddings = self.positions(position_indices) + + # shape: (batch_size, max_caption_length, hidden_size) + embeddings = self.layer_norm(word_embeddings + position_embeddings) + embeddings = self.dropout(embeddings) + + # Zero-out embeddings for positions which have padding tokens. + # shape: (batch_size, max_caption_length, 1) + token_mask = (tokens != self.padding_idx).unsqueeze(-1) + + # shape: (batch_size, max_caption_length, hidden_size) + embeddings = embeddings * token_mask.type(embeddings.dtype) + return embeddings + + @functools.lru_cache(maxsize=128) + def _create_position_indices(self, tokens: torch.Tensor): + + # Create position indices of the same size as token indices. + batch_size, max_caption_length = tokens.size() + positions = torch.arange( + max_caption_length, dtype=tokens.dtype, device=tokens.device + ) + # shape: (batch_size, max_caption_length) + positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) + return positions diff --git a/virtex/virtex/modules/label_smoothing.py b/virtex/virtex/modules/label_smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..ce62ce62b7315b4ec65a8eae5dc614ad0778661e --- /dev/null +++ b/virtex/virtex/modules/label_smoothing.py @@ -0,0 +1,58 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class CrossEntropyLossWithLabelSmoothing(nn.Module): + r""" + PyTorch :class:`~torch.nn.CrossEntropyLoss` with label smoothing. Quoting + documentation from original PyTorch module: + + It is useful when training a classification problem with ``C`` classes. + The ``inputs`` is expected to contain raw, unnormalized scores for each class. + + ``inputs`` has to be a Tensor of size either ``(N, C)``. This criterion + expects a class index in the range ``[0, C - 1]`` as the ``targets`` for each + value of a 1D tensor of size ``minibatch``; if ``ignore_index`` is specified, + this criterion also accepts this class index (this index may not necessarily be + in the class range). + + Parameters + ---------- + smoothing: float, optional (default = 0.1) + Label smoothing value. It sets target weights as ``(1 - smoothing)`` + and all other weights as ``smoothing / (C - 1)``. Setting this to + zero will default to vanilla cross entropy loss. + """ + + def __init__(self, smoothing: float = 0.0, ignore_index: int = -100): + super().__init__() + self.smoothing = smoothing + self.ignore_index = ignore_index + + def forward(self, inputs, targets): + + if self.smoothing == 0.0: + # Use PyTorch cross entropy when smoothing is 0. This is slightly + # faster than what we are doing manually below. + return F.cross_entropy( + inputs, targets, ignore_index=self.ignore_index, reduction="mean" + ) + + # Remove entries matching ``ignore_index``. + if self.ignore_index >= 0: + _targets = targets[targets != self.ignore_index] + _inputs = inputs[targets != self.ignore_index] + + # shape: (batch_size, num_classes) + logprobs = F.log_softmax(_inputs, dim=-1) + + # shape: (batch_size, num_classes) + weights = ( + torch.ones_like(_inputs) * self.smoothing / (_inputs.size(-1) - 1.0) + ) + weights.scatter_(-1, _targets.unsqueeze(-1), (1.0 - self.smoothing)) + + # shape: (batch_size, ) + loss = (- weights * logprobs).sum(dim=-1) + return loss.mean() diff --git a/virtex/virtex/modules/textual_heads.py b/virtex/virtex/modules/textual_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5efb266e361e72ac6ca0e037116eed624da56b --- /dev/null +++ b/virtex/virtex/modules/textual_heads.py @@ -0,0 +1,450 @@ +r""" +A textual head accepts visual features from the visual backbone, and performs +task specific modeling (captioning, classification etc.) to predict an output +distribution over vocabulary tokens for one or multiple time-steps in the batch. +""" +import torch +from torch import nn +from typing import Optional + +from virtex.modules.embedding import WordAndPositionalEmbedding +from virtex.modules.transformer import ( + PreNormTransformerEncoderLayer, + PreNormTransformerDecoderLayer, +) + + +class TextualHead(nn.Module): + r""" + Base class for all textual heads. All child classes can simply inherit + from :class:`~torch.nn.Module`, however this is kept here for uniform + type annotations. + + Parameters + ---------- + visual_feature_size: int + Size (number of channels) of the input features from the visual backbone. + vocab_size: int + Number of tokens in the output vocabulary. + hidden_size: int + Size of the token embedding vectors, or hidden state vector of the + language model. + """ + + def __init__(self, visual_feature_size: int, vocab_size: int, hidden_size: int): + super().__init__() + self.visual_feature_size = visual_feature_size + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + @property + def textual_feature_size(self): + r""" + Size of the last dimension of output right before the output linear + layer (which predicts a distribution over vocabulary tokens). This is + typically same as :attr:`hidden_size` for most modules. This property + is used to add more modules on top of this. + """ + return self.hidden_size + + +class LinearTextualHead(TextualHead): + r""" + A textual head containing a single linear layer projecting from the visual + feature size to the output vocabulary size. + + Parameters + ---------- + visual_feature_size: int + Size (number of channels) of the input features from the visual backbone. + vocab_size: int + Number of tokens in the output vocabulary. + """ + + def __init__(self, visual_feature_size: int, vocab_size: int, **kwargs): + # For API consistency. + hidden_size = visual_feature_size + super().__init__(visual_feature_size, vocab_size, hidden_size) + self.output = nn.Linear(visual_feature_size, vocab_size) + + def forward( + self, + visual_features: torch.Tensor, + caption_tokens: Optional[torch.Tensor] = None, + caption_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Project visual features directly to predict a distribution over + vocabulary tokens through a single linear layer. This textual head + ignores arguments ``caption_tokens`` and ``caption_lengths``, they + are here for API consistency. + + Parameters + ---------- + visual_features: torch.Tensor + A tensor of shape ``(batch_size, channels, height, width)`` containing + features from visual backbone. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size, vocab_size)`` containing output + vocabulary logits. + """ + + # Convert to NHWC and project visual features to textual feature size. + batch_size, channels, height, width = visual_features.size() + visual_features = visual_features.view(batch_size, channels, -1) + visual_features = visual_features.permute(0, 2, 1) + + # Perform global average pooling of visual features. + # shape: (batch_size, channels) + visual_features = visual_features.mean(dim=1) + + # shape: (batch_size, max_caption_length, vocab_size) + output_logits = self.output(visual_features) + return output_logits + + +class TransformerDecoderTextualHead(TextualHead): + r""" + A textual head composed of four main modules: (1) input projection (linear + layer) for visual features to match size with textual features, (2) word + and positional embedding for input captions, (3) a unidirectional transformer + decoder, and (4) and output projection (linear layer) to predict a + distribution over vocabulary tokens. The word embedding weights are tied + with output projection; the latter still has its own learnable bias. + + .. note:: + + For the "bicaptioning" pretraining task, our *textual head* (as defined + in the paper) must have two transformer decoders: one each to decode + caption in either direction. This class however will always have one + transformer per object. + + Refer :class:`~virtex.models.captioning.BidirectionalCaptioningModel` + source to understand how an object of this class is cloned, along with + tying embedding and output weights, for bicaptioning. + + Hence, while there are *two objects* of this class, it is pragmatically + a *single* textual head as a whole, according to the terminology used + in paper. + + Parameters + ---------- + visual_feature_size: int + Size (number of channels) of the input features from the visual backbone. + vocab_size: int + Number of tokens in the output vocabulary. + hidden_size: int + Size of the token embedding vectors, or hidden state vector of the + language model. + num_layers: int + Number of layers in the transformer. + attention_heads: int + Number of attention heads in the transformer. + feedforward_size: int + Size of feedforward layers in the transformer. + dropout: float, optional (default = 0.1) + Dropout probability for transformer (applied after layer normalization). + norm_type: str, optional (default = "post") + Type of transformer layer: pre-normalization (like GPT-2) or + post-normalization (like BERT). One of ``{"pre", "post"}``. + mask_future_positions: bool, optional (default = True) + Whether to mask future positions for self-attention over caption tokens. + This must be ``True`` for captioning (and bicaptioning) tasks to prevent + the language model from cheating, and ``False`` for masked language + modeling, as the self-attention should consider all tokens. + max_caption_length: int, optional (default = 30) + Maximum length of input captions; this is used to create a fixed + positional embedding lookup table. + padding_idx: int, optional (default = 0) + Token index of ``[PAD]`` token, word embedding for these tokens will + be a vector of zeroes (and not trainable). + """ + + def __init__( + self, + visual_feature_size: int, + vocab_size: int, + hidden_size: int, + num_layers: int, + attention_heads: int, + feedforward_size: int, + dropout: float = 0.1, + norm_type: str = "post", + mask_future_positions: bool = True, + max_caption_length: int = 30, + padding_idx: int = 0, + ): + super().__init__(visual_feature_size, vocab_size, hidden_size) + self.num_layers = num_layers + self.attention_heads = attention_heads + self.feedforward_size = feedforward_size + self.dropout = dropout + self.mask_future_positions = mask_future_positions + self.padding_idx = padding_idx + + self.visual_projection = nn.Linear( + visual_feature_size, self.textual_feature_size + ) + self.embedding = WordAndPositionalEmbedding( + self.vocab_size, + self.textual_feature_size, + dropout=dropout, + max_caption_length=max_caption_length, + padding_idx=padding_idx, + ) + # Make decoder layer depending on whether it's a Pre-Norm or Post-Norm. + LayerClass = ( + nn.TransformerDecoderLayer + if norm_type == "post" + else PreNormTransformerDecoderLayer + ) + _layer = LayerClass( + self.textual_feature_size, + self.attention_heads, + dim_feedforward=self.feedforward_size, + dropout=dropout, + activation="gelu", + ) + self.transformer = nn.TransformerDecoder(_layer, self.num_layers) + self.apply(self._init_weights) + + # Create an output linear layer and tie the input and output word + # embeddings to reduce parameters. + self.output = nn.Linear(self.textual_feature_size, vocab_size) + self.output.weight = self.embedding.words.weight + + @staticmethod + def _init_weights(module): + r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.MultiheadAttention): + module.in_proj_weight.data.normal_(mean=0.0, std=0.02) + module.out_proj.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + visual_features: torch.Tensor, + caption_tokens: torch.Tensor, + caption_lengths: torch.Tensor, + ) -> torch.Tensor: + r""" + Given (projected) visual features from visual backbone and caption + tokens, predict the output logits for next time-step. + + Parameters + ---------- + visual_features: torch.Tensor + A tensor of shape ``(batch_size, channels, height, width)`` containing + features from visual backbone. + caption_tokens: torch.Tensor + A tensor of shape ``(batch_size, max_caption_length)`` of caption + tokens padded to the right by ``padding_idx``. + caption_lengths: torch.Tensor + A tensor of shape ``(batch_size, )`` containing lengths of caption + tokens in the batch. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size, max_caption_length, vocab_size)`` + containing output vocabulary logits for each time-step. + """ + + # Convert to NHWC and project visual features to textual feature size. + batch_size, channels, height, width = visual_features.size() + visual_features = visual_features.view(batch_size, channels, -1) + visual_features = visual_features.permute(0, 2, 1) + + # shape: (batch_size, height * width, textual_feature_size) + projected_visual_features = self.visual_projection(visual_features) + # Now visual and textual features are of same size. + + # Note that `max_caption_length` here may be less than the + # `max_caption_length` passed in `__init__`, but it does not matter. + batch_size, max_caption_length = caption_tokens.size() + + # Create a mask based on caption lengths, shape: (batch_size, ) + # Form a binary mask: it is True for padding positions. + # These positions will be ignored for multi-headed attention. + ones = torch.ones_like(caption_tokens) + caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1) + + # shape: (batch_size, max_caption_length, textual_feature_size) + caption_embeddings = self.embedding(caption_tokens) + + if self.mask_future_positions: + # An additive mask for masking the future (one direction). + unidirectional_mask = self._generate_future_mask( + max_caption_length, caption_embeddings.dtype, caption_embeddings.device + ) + else: + unidirectional_mask = None + + # We transpose the first two dimensions of tokens embeddings and visual + # features, as required by decoder. + caption_embeddings = caption_embeddings.transpose(0, 1) + projected_visual_features = projected_visual_features.transpose(0, 1) + + # shape: (max_caption_length, batch_size, hidden_size) + textual_features = self.transformer( + caption_embeddings, + projected_visual_features, + tgt_mask=unidirectional_mask, + tgt_key_padding_mask=caption_mask, + ) + # Undo the transpose and bring batch to dim 0. + # shape: (batch_size, max_caption_length, hidden_size) + textual_features = textual_features.transpose(0, 1) + + # shape: (batch_size, max_caption_length, vocab_size) + output_logits = self.output(textual_features) + return output_logits + + def _generate_future_mask( + self, size: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + r""" + Generate a mask for "future" positions, useful when using this module + for language modeling. + + Parameters + ---------- + size: int + """ + # Default mask is for forward direction. Flip for backward direction. + mask = torch.triu( + torch.ones(size, size, device=device, dtype=dtype), diagonal=1 + ) + mask = mask.masked_fill(mask == 1, float("-inf")) + return mask + + +class TransformerEncoderTextualHead(TextualHead): + def __init__( + self, + visual_feature_size: int, + vocab_size: int, + hidden_size: int, + num_layers: int, + attention_heads: int, + feedforward_size: int, + dropout: float = 0.1, + norm_type: str = "pre", + mask_future_positions: bool = True, + max_caption_length: int = 30, + padding_idx: int = 0, + ): + super().__init__(visual_feature_size, vocab_size, hidden_size) + self.num_layers = num_layers + self.attention_heads = attention_heads + self.feedforward_size = feedforward_size + self.dropout = dropout + self.mask_future_positions = mask_future_positions + self.padding_idx = padding_idx + + self.embedding = WordAndPositionalEmbedding( + self.vocab_size, + self.textual_feature_size, + dropout=dropout, + max_caption_length=max_caption_length, + padding_idx=padding_idx, + ) + # Make decoder layer depending on whether it's a Pre-Norm or Post-Norm. + LayerClass = ( + nn.TransformerEncoderLayer + if norm_type == "post" + else PreNormTransformerEncoderLayer + ) + _layer = LayerClass( + self.textual_feature_size, + self.attention_heads, + dim_feedforward=self.feedforward_size, + dropout=dropout, + activation="gelu", + ) + self.transformer = nn.TransformerEncoder(_layer, self.num_layers) + + self.final_ln = nn.LayerNorm(self.textual_feature_size) + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.embedding.words.weight, std=0.02) + nn.init.normal_(self.embedding.positions.weight, std=0.01) + + proj_std = (self.hidden_size ** -0.5) * ((2 * self.num_layers) ** -0.5) + for layer in self.transformer.layers: + nn.init.normal_(layer.self_attn.in_proj_weight, std=self.hidden_size ** -0.5) + nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) + nn.init.normal_(layer.linear1.weight, std=(2 * self.hidden_size) ** -0.5) + nn.init.normal_(layer.linear2.weight, std=proj_std) + + def forward( + self, + caption_tokens: torch.Tensor, + caption_lengths: torch.Tensor, + ) -> torch.Tensor: + + # Note that `max_caption_length` here may be less than the + # `max_caption_length` passed in `__init__`, but it does not matter. + batch_size, max_caption_length = caption_tokens.size() + + # Create a mask based on caption lengths, shape: (batch_size, ) + # Form a binary mask: it is True for padding positions. + # These positions will be ignored for multi-headed attention. + ones = torch.ones_like(caption_tokens) + caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1) + + # shape: (batch_size, max_caption_length, textual_feature_size) + caption_embeddings = self.embedding(caption_tokens) + + if self.mask_future_positions: + # An additive mask for masking the future (one direction). + unidirectional_mask = self._generate_future_mask( + max_caption_length, caption_embeddings.dtype, caption_embeddings.device + ) + else: + unidirectional_mask = None + + # We transpose the first two dimensions of tokens embeddings and visual + # features, as required by decoder. + caption_embeddings = caption_embeddings.transpose(0, 1) + + # shape: (max_caption_length, batch_size, hidden_size) + textual_features = self.transformer( + caption_embeddings, + mask=unidirectional_mask, + src_key_padding_mask=caption_mask, + ) + # Undo the transpose and bring batch to dim 0. + # shape: (batch_size, max_caption_length, hidden_size) + textual_features = textual_features.transpose(0, 1) + textual_features = self.final_ln(textual_features) + return textual_features + + @staticmethod + def _generate_future_mask( + size: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + r""" + Generate a mask for "future" positions, useful when using this module + for language modeling. + + Parameters + ---------- + size: int + """ + # Default mask is for forward direction. Flip for backward direction. + mask = torch.triu( + torch.ones(size, size, device=device, dtype=dtype), diagonal=1 + ) + mask = mask.masked_fill(mask == 1, float("-inf")) + return mask diff --git a/virtex/virtex/modules/transformer.py b/virtex/virtex/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..65fea63cac683bda24f050c6844cbba30519c271 --- /dev/null +++ b/virtex/virtex/modules/transformer.py @@ -0,0 +1,72 @@ +from typing import Optional + +import torch +from torch import nn + + +class PreNormTransformerEncoderLayer(nn.TransformerEncoderLayer): + r""" + A variant of :class:`torch.nn.TransformerEncoderLayer` where layer + normalization is included inside the residual branch, and performed before + self-attention and feedforward layers. + + Refer documentation of :class:`torch.nn.TransformerEncoderLayer` for more + details on the API. + """ + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # fmt: off + # We use the members (modules) from super-class, just the order of + # operations is changed here. First layernorm, then attention. + src2 = self.norm1(src) + src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + + # Layernorm first, then transformation through feedforward network. + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + +class PreNormTransformerDecoderLayer(nn.TransformerDecoderLayer): + r""" + A variant of :class:`torch.nn.TransformerDecoderLayer` where layer + normalization is included inside the residual branch, and performed before + self-attention and feedforward layers. + + Refer documentation of :class:`torch.nn.TransformerDecoderLayer` for more + details on the API. + """ + + def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None): + # fmt: off + # We use the members (modules) from super-class, just the order of + # operations is changed here. First layernorm, then attention. + tgt2 = self.norm1(tgt) + tgt2, _ = self.self_attn( + tgt2, tgt2, tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask + ) + tgt = tgt + self.dropout1(tgt2) + + # Layernorm first, then decoder attention. + tgt2 = self.norm2(tgt) + tgt2, _ = self.multihead_attn( + tgt2, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask + ) + tgt = tgt + self.dropout2(tgt2) + + # Layernorm first, then transformation through feedforward network. + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt diff --git a/virtex/virtex/modules/visual_backbones.py b/virtex/virtex/modules/visual_backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..4f1f9e24f8ab9b0db7cc579e496a33b88fe3796f --- /dev/null +++ b/virtex/virtex/modules/visual_backbones.py @@ -0,0 +1,198 @@ +from typing import Any, Dict + +import torch +from torch import nn +import torchvision + + +class VisualBackbone(nn.Module): + r""" + Base class for all visual backbones. All child classes can simply inherit + from :class:`~torch.nn.Module`, however this is kept here for uniform + type annotations. + """ + + def __init__(self, visual_feature_size: int): + super().__init__() + self.visual_feature_size = visual_feature_size + + +class TorchvisionVisualBackbone(VisualBackbone): + r""" + A visual backbone from `Torchvision model zoo + `_. Any model can + be specified using corresponding method name from the model zoo. + + Parameters + ---------- + name: str, optional (default = "resnet50") + Name of the model from Torchvision model zoo. + visual_feature_size: int, optional (default = 2048) + Size of the channel dimension of output visual features from forward pass. + pretrained: bool, optional (default = False) + Whether to load ImageNet pretrained weights from Torchvision. + frozen: float, optional (default = False) + Whether to keep all weights frozen during training. + """ + + def __init__( + self, + name: str = "resnet50", + visual_feature_size: int = 2048, + pretrained: bool = False, + frozen: bool = False, + ): + super().__init__(visual_feature_size) + + self.cnn = getattr(torchvision.models, name)( + pretrained, zero_init_residual=True + ) + # Do nothing after the final residual stage. + self.cnn.fc = nn.Identity() + + # Freeze all weights if specified. + if frozen: + for param in self.cnn.parameters(): + param.requires_grad = False + self.cnn.eval() + + def forward(self, image: torch.Tensor) -> torch.Tensor: + r""" + Compute visual features for a batch of input images. + + Parameters + ---------- + image: torch.Tensor + Batch of input images. A tensor of shape + ``(batch_size, 3, height, width)``. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size, channels, height, width)``, for + example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. + """ + + for idx, (name, layer) in enumerate(self.cnn.named_children()): + out = layer(image) if idx == 0 else layer(out) + + # These are the spatial features we need. + if name == "layer4": + # shape: (batch_size, channels, height, width) + return out + + def detectron2_backbone_state_dict(self) -> Dict[str, Any]: + r""" + Return state dict of visual backbone which can be loaded with + `Detectron2 `_. + This is useful for downstream tasks based on Detectron2 (such as + object detection and instance segmentation). This method renames + certain parameters from Torchvision-style to Detectron2-style. + + Returns + ------- + Dict[str, Any] + A dict with three keys: ``{"model", "author", "matching_heuristics"}``. + These are necessary keys for loading this state dict properly with + Detectron2. + """ + # Detectron2 backbones have slightly different module names, this mapping + # lists substrings of module names required to be renamed for loading a + # torchvision model into Detectron2. + DETECTRON2_RENAME_MAPPING: Dict[str, str] = { + "layer1": "res2", + "layer2": "res3", + "layer3": "res4", + "layer4": "res5", + "bn1": "conv1.norm", + "bn2": "conv2.norm", + "bn3": "conv3.norm", + "downsample.0": "shortcut", + "downsample.1": "shortcut.norm", + } + # Populate this dict by renaming module names. + d2_backbone_dict: Dict[str, torch.Tensor] = {} + + for name, param in self.cnn.state_dict().items(): + for old, new in DETECTRON2_RENAME_MAPPING.items(): + name = name.replace(old, new) + + # First conv and bn module parameters are prefixed with "stem.". + if not name.startswith("res"): + name = f"stem.{name}" + + d2_backbone_dict[name] = param + + return { + "model": d2_backbone_dict, + "__author__": "Karan Desai", + "matching_heuristics": True, + } + + +class TimmVisualBackbone(VisualBackbone): + r""" + A visual backbone from `Timm model zoo + `_. + This class is a generic wrapper over the ``timm`` library, and supports + all models provided by the library. Check ``timm.list_models()`` for all + supported model names. + + Parameters + ---------- + name: str, optional (default = "resnet50") + Name of the model from Timm model zoo. + visual_feature_size: int, optional (default = 2048) + Size of the channel dimension of output visual features from forward pass. + pretrained: bool, optional (default = False) + Whether to load ImageNet pretrained weights from Torchvision. + frozen: float, optional (default = False) + Whether to keep all weights frozen during training. + """ + + def __init__( + self, + name: str = "resnet50", + visual_feature_size: int = 2048, + pretrained: bool = False, + frozen: bool = False, + ): + super().__init__(visual_feature_size) + + # Limit the scope of library import inside class definition. + import timm + + # Create the model without any global pooling and softmax classifier. + self.cnn = timm.create_model( + name, pretrained=pretrained, num_classes=0, global_pool="" + ) + # Freeze all weights if specified. + if frozen: + for param in self.cnn.parameters(): + param.requires_grad = False + self.cnn.eval() + + def forward(self, image: torch.Tensor) -> torch.Tensor: + r""" + Compute visual features for a batch of input images. + + Parameters + ---------- + image: torch.Tensor + Batch of input images. A tensor of shape + ``(batch_size, 3, height, width)``. + + Returns + ------- + torch.Tensor + A tensor of shape ``(batch_size, channels, height, width)``, for + example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. + """ + # shape: (batch_size, channels, height, width) + return self.cnn(image) + + def detectron2_backbone_state_dict(self) -> Dict[str, Any]: + + # Detectron2 may not support all timm models out of the box. These + # backbones won't be transferred to downstream detection tasks anyway. + raise NotImplementedError diff --git a/virtex/virtex/optim/__init__.py b/virtex/virtex/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc624411a49e98c6c24735db7f3c14d076c3a27 --- /dev/null +++ b/virtex/virtex/optim/__init__.py @@ -0,0 +1,3 @@ +from .lookahead import Lookahead + +__all__ = ["Lookahead"] diff --git a/virtex/virtex/optim/lookahead.py b/virtex/virtex/optim/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..408aa1f786109c762656e8ae3243bff12412d072 --- /dev/null +++ b/virtex/virtex/optim/lookahead.py @@ -0,0 +1,128 @@ +r""" +`Lookahead Optimizer: k steps forward, 1 step back `_. + +This implementation is adapted with minimal modifications from the +`authors' implementation `_. + +If you take it from here, please cite them: + +.. code-block:: text + + @inproceedings{zhang2019lookahead, + title={Lookahead Optimizer: k steps forward, 1 step back}, + author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy}, + journal={NeurIPS}, + year={2019} + } +""" +from collections import defaultdict +from typing import Any, Callable, Dict + +import torch +from torch.optim.optimizer import Optimizer + + +class Lookahead(Optimizer): + r""" + Implements Lookahead optimizer. + + Parameters + ---------- + optimizer: torch.optim.Optimizer + Wrapper inner optimizer. The weights it manages will be the "fast" + weights. + k: int, optional (default = 5) + Number of lookahead steps before updating "slow" weights. + alpha: float, optional (default = 0.8) + Linear interpolation factor, 1.0 recovers inner optimizer. + """ + + def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.8): + self.optimizer = optimizer + self.k = k + self.alpha = alpha + + # Counter for inner optimizer. + self._k_counter = 0 + + # Cache the current optimizer parameters + self.state: Dict[str, Any] = defaultdict(dict) + for group in optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + param_state["slow_params"] = torch.zeros_like(p.data) + param_state["slow_params"].copy_(p.data) + + def __getstate__(self): + return { + "state": self.state, + "optimizer": self.optimizer, + "alpha": self.alpha, + "k": self.k, + "_k_counter": self._k_counter, + } + + @property + def param_groups(self): + return self.optimizer.param_groups + + def zero_grad(self): + r"""Clear all grad buffers at the start of new forward pass.""" + self.optimizer.zero_grad() + + def state_dict(self): + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.optimizer.load_state_dict(state_dict) + + def step(self, closure: Callable = None): + r""" + Perform a single Lookahead optimization step. + + Parameters + ---------- + closure: Callable, optional (default = None) + A callable that re-evaluates the model and returns the loss. + """ + loss = self.optimizer.step(closure) + self._k_counter += 1 + + if self._k_counter >= self.k: + self._k_counter = 0 + # Lookahead and cache the current optimizer parameters + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + p.data.mul_(self.alpha).add_( + param_state["slow_params"], alpha=1.0 - self.alpha + ) + param_state["slow_params"].copy_(p.data) + return loss + + def load_slow_weights(self): + r""" + Load slow weights from Lookahead optimizer. Useful for performing + evaluation on the slow weights (which typically generalize better). + + This method backs up fast weights to load them after evaluation. No + need to call this method if evaluation happens just after a lookahead + step. + """ + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + param_state["backup_params"] = torch.zeros_like(p.data) + param_state["backup_params"].copy_(p.data) + p.data.copy_(param_state["slow_params"]) + + def restore_fast_weights(self): + r""" + Restore fast weights for optimization. Call this after evaluation if + :meth:`load_slow_weights` was called. + """ + for group in self.optimizer.param_groups: + for p in group["params"]: + param_state = self.state[p] + p.data.copy_(param_state["backup_params"]) + del param_state["backup_params"] diff --git a/virtex/virtex/optim/lr_scheduler.py b/virtex/virtex/optim/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..23ce61f0ba686cbb9e963a162204262afc28c49e --- /dev/null +++ b/virtex/virtex/optim/lr_scheduler.py @@ -0,0 +1,202 @@ +import bisect +import math +from typing import List + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +class LinearWarmupNoDecayLR(LambdaLR): + r""" + A learning rate scheduler which linearly increases learning rate from 0 + LR, and further keeps it constant throughout training. + + Parameters + ---------- + optimizer: torch.optim.Optimizer + Wrapper optimizer. + total_steps: int + Total epochs (or iterations) for training. + warmup_steps: int + Number of first few steps to do linear warmup. + last_epoch: int, optional (default = -1) + The index of last step (epoch or iteration). We named it ``last_epoch`` + instead of ``last_step`` to keep the naming consistent with other LR + schedulers in PyTorch. + """ + + def __init__( + self, + optimizer: Optimizer, + total_steps: int, + warmup_steps: int, + last_epoch: int = -1, + ): + assert ( + warmup_steps < total_steps + ), "Warmup steps should be less than total steps." + + self.tsteps = total_steps + self.wsteps = warmup_steps + super().__init__(optimizer, self._lr_multiplier, last_epoch) + + def _lr_multiplier(self, step: int) -> float: + multiplier = step / float(max(1, self.wsteps)) if step < self.wsteps else 1 + return max(0, multiplier) + + +class LinearWarmupMultiStepLR(LambdaLR): + r""" + A learning rate scheduler which linearly increases learning rate from 0 + LR, and further decreases it by gamma once the number of steps reaches one + of the milestones. + + Parameters + ---------- + optimizer: torch.optim.Optimizer + Wrapper optimizer. + total_steps: int + Total epochs (or iterations) for training. + warmup_steps: int + Number of first few steps to do linear warmup. + milestones: List[int] + List of step indices (epochs or iterations depending on context). Must + be increasing. + gamma: float, optional (default = 0.1) + Multiplicative factor of learning rate decay. + last_epoch: int, optional (default = -1) + The index of last step (epoch or iteration). We named it ``last_epoch`` + instead of ``last_step`` to keep the naming consistent with other LR + schedulers in PyTorch. + """ + + def __init__( + self, + optimizer: Optimizer, + total_steps: int, + warmup_steps: int, + milestones: List[int], + gamma: float = 0.1, + last_epoch: int = -1, + ): + self.wsteps = warmup_steps + self.milestones = milestones + self.gamma = gamma + + # Keep a track of number of milestones encountered. + self.milestones_so_far = 0 + + # Common sanity checks. + assert milestones == sorted(milestones), "milestones must be increasing" + assert milestones[0] > warmup_steps, "first milestone must be after warmup" + assert ( + milestones[-1] < total_steps + ), "last milestone must be less than total steps" + + super().__init__(optimizer, self._lr_multiplier, last_epoch) + + def _lr_multiplier(self, step: int) -> float: + if step < self.wsteps: + # Linear warmup. + multiplier = step / float(max(1, self.wsteps)) + else: + # Step decay based on milestones. + multiplier = self.gamma ** bisect.bisect_right(self.milestones, step) + + # Avoid negative learning rate. + return max(0, multiplier) + + +class LinearWarmupLinearDecayLR(LambdaLR): + r""" + A learning rate scheduler which linearly increases learning rate from 0 + LR, and further decreases it linearly to zero. + + Parameters + ---------- + optimizer: torch.optim.Optimizer + Wrapper optimizer. + total_steps: int + Total epochs (or iterations) for training. + warmup_steps: int + Number of first few steps to do linear warmup. + last_epoch: int, optional (default = -1) + The index of last step (epoch or iteration). We named it ``last_epoch`` + instead of ``last_step`` to keep the naming consistent with other LR + schedulers in PyTorch. + """ + + def __init__( + self, + optimizer: Optimizer, + total_steps: int, + warmup_steps: int, + last_epoch: int = -1, + ): + assert ( + warmup_steps < total_steps + ), "Warmup steps should be less than total steps." + + self.tsteps = total_steps + self.wsteps = warmup_steps + super().__init__(optimizer, self._lr_multiplier, last_epoch) + + def _lr_multiplier(self, step: int) -> float: + if step < self.wsteps: + # Linear warmup. + multiplier = step / float(max(1, self.wsteps)) + else: + # Linear decay. + multiplier = (self.tsteps - step) / (self.tsteps - self.wsteps) + # Avoid negative learning rate. + return max(0, multiplier) + + +class LinearWarmupCosineAnnealingLR(LambdaLR): + r""" + A learning rate scheduler which linearly increases learning rate from 0 + LR, and further decreases it to zero by cosine decay. After linear warmup, + the LR decays as: + + .. math:: + \eta_t = \eta_{max}\cos^2(\frac{T_{cur} - T_{warm}}{T_{max} - T_{warm}}\frac{\pi}{2}) + + Parameters + ---------- + optimizer: torch.optim.Optimizer + Wrapper optimizer. + total_steps: int + Total epochs (or iterations) for training. + warmup_steps: int + Number of first few steps to do linear warmup. + last_epoch: int, optional (default = -1) + The index of last step (epoch or iteration). We named it ``last_epoch`` + instead of ``last_step`` to keep the naming consistent with other LR + schedulers in PyTorch. + """ + + def __init__( + self, + optimizer: Optimizer, + total_steps: int, + warmup_steps: int, + last_epoch: int = -1, + ): + assert ( + warmup_steps < total_steps + ), "Warmup steps should be less than total steps." + + self.tsteps = total_steps + self.wsteps = warmup_steps + super().__init__(optimizer, self._lr_multiplier, last_epoch) + + def _lr_multiplier(self, step: int) -> float: + if step < self.wsteps: + # Linear warmup. + multiplier = step / float(max(1, self.wsteps)) + else: + # Cosine annealing decay. + cos_factor = (step - self.wsteps) / (self.tsteps - self.wsteps) + multiplier = math.cos(cos_factor * (math.pi / 2)) ** 2 + # Avoid negative learning rate. + return max(0, multiplier) diff --git a/virtex/virtex/utils/assets/download_spice.sh b/virtex/virtex/utils/assets/download_spice.sh new file mode 100644 index 0000000000000000000000000000000000000000..ebbc35deb6b7695e5c430e0d82e91b8327f2ede7 --- /dev/null +++ b/virtex/virtex/utils/assets/download_spice.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env sh +# This script downloads the Stanford CoreNLP models. + +CORENLP=stanford-corenlp-full-2015-12-09 +SPICELIB=SPICE-1.0/lib + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +echo "Downloading..." + +wget https://panderson.me/images/SPICE-1.0.zip +wget http://nlp.stanford.edu/software/$CORENLP.zip +wget http://nlp.stanford.edu/software/stanford-corenlp-full-2014-08-27.zip + +echo "Unzipping..." + +unzip SPICE-1.0.zip + +unzip $CORENLP.zip -d $SPICELIB/ +mv $SPICELIB/$CORENLP/stanford-corenlp-3.6.0.jar $SPICELIB/ +mv $SPICELIB/$CORENLP/stanford-corenlp-3.6.0-models.jar $SPICELIB/ +rm -f stanford-corenlp-full-2015-12-09.zip +rm -rf $SPICELIB/$CORENLP/ + +rm -rf SPICE-1.0.zip diff --git a/virtex/virtex/utils/beam_search.py b/virtex/virtex/utils/beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa4226d072819ab4ccd8c8fc35e2410b6493cdd --- /dev/null +++ b/virtex/virtex/utils/beam_search.py @@ -0,0 +1,266 @@ +r""" +This Beam Search implementation is adapted with minor modifications from +`AllenNLP `_. + +Thanks to the developers of AllenNLP! +""" +from typing import Callable, List, Tuple +import warnings + +import torch +from torch.nn import functional as F + +class AutoRegressiveBeamSearch(object): + r""" + Implements the beam search algorithm for decoding the most likely captions. + This only works for auto-regressive models (Transformer-like) and not + recurrent models (LSTM-like). + + Parameters + ---------- + eos_index: int + The index of the end token (``[EOS]``) in vocabulary. + max_steps: int, optional (default = 50) + The maximum number of decoding steps. + beam_size: int, optional (default = 5) + The width of the beam used. + per_node_beam_size: int, optional (default = 2) + The maximum number of candidates to consider per node, at each step in + the search. Setting this parameter to a number smaller than `beam_size` + may give better results, as it can introduce more diversity into the + search. See `Beam Search Strategies for Neural Machine Translation. + Freitag and Al-Onaizan, 2017 `_. + """ + + def __init__( + self, + eos_index: int, + max_steps: int = 50, + beam_size: int = 5, + per_node_beam_size: int = 2, + ): + self._eos_index = eos_index + self.max_steps = max_steps + self.beam_size = beam_size + self.per_node_beam_size = per_node_beam_size or beam_size + + def search( + self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Given a starting state and a step function, apply beam search to find + the most likely target captions. + + Parameters + ---------- + start_predictions : torch.Tensor + Tensor containing the initial predictions, shape ``(batch_size, )``. + Usually the initial predictions are just the index of the start + token (``[SOS]``) in the vocabulary. + step : Callable[..., torch.Tensor] + A function that is responsible for computing the next most likely + tokens, given the past predictions. Predictions from all previous + timesteps are required, not just the last timestep, because our + model is auto-regressive instead of recurrent. The function should + The function is expected to return a tensor of shape + ``(group_size, target_vocab_size)`` containing + the logits of the tokens for the next step. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Tuple of ``(predictions, logprobs)``, where ``predictions`` + has shape ``(batch_size, beam_size, max_steps)`` and ``logprobs`` + has shape ``(batch_size, beam_size)``. + """ + batch_size = start_predictions.size()[0] + + # List of `(batch_size, beam_size)` tensors. One for each time step. + # Does not include the start symbols, which are implicit. + predictions: List[torch.Tensor] = [] + + # List of (batch_size, beam_size) tensors. One for each time step. None + # for the first. Stores the index n for the parent prediction, i.e. + # predictions[t-1][i][n], that it came from. + backpointers: List[torch.Tensor] = [] + + # Calculate the first timestep. This is done outside the main loop + # because we are going from a single decoder input (the output from the + # encoder) to the top `beam_size` decoder outputs. On the other hand, + # within the main loop we are going from the `beam_size` elements of the + # beam to `beam_size`^2 candidates from which we will select the top + # `beam_size` elements for the next iteration. + # shape: (batch_size, num_classes) + start_class_logits = step(start_predictions) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + start_class_logprobs = F.log_softmax(start_class_logits, dim=1) + + num_classes = start_class_logprobs.size()[1] + + # Make sure `per_node_beam_size` is not larger than `num_classes`. + if self.per_node_beam_size > num_classes: + raise ValueError( + f"Target vocab size ({num_classes:d}) too small " + f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n" + f"Please decrease beam_size or per_node_beam_size." + ) + + # shape: (batch_size, beam_size), (batch_size, beam_size) + start_top_logprobs, start_predicted_classes = start_class_logprobs.topk( + self.beam_size + ) + if ( + self.beam_size == 1 + and (start_predicted_classes == self._eos_index).all() + ): + warnings.warn( + "Empty captions predicted. You may want to increase beam " + "size or ensure your step function is working properly.", + RuntimeWarning, + ) + return start_predicted_classes.unsqueeze(-1), start_top_logprobs + + # The log probs for the last time step. + # shape: (batch_size, beam_size) + last_logprobs = start_top_logprobs + + # shape: [(batch_size, beam_size)] + predictions.append(start_predicted_classes) + + # Log probability tensor that mandates that the end token is selected. + # shape: (batch_size * beam_size, num_classes) + logprobs_after_end = start_class_logprobs.new_full( + (batch_size * self.beam_size, num_classes), float("-inf") + ) + logprobs_after_end[:, self._eos_index] = 0.0 + + for timestep in range(self.max_steps - 1): + # shape: (batch_size * beam_size,) + last_predictions = predictions[-1].reshape(batch_size * self.beam_size) + + # If every predicted token from the last step is `self._eos_index`, + # then we can stop early. + if (last_predictions == self._eos_index).all(): + break + + # Take a step. This get the predicted log probs of the next classes. + predictions_so_far = torch.stack(predictions).permute(1, 2, 0).view( + batch_size * self.beam_size, -1 + ) + # shape: (batch_size * beam_size, num_classes) + class_logits = step(predictions_so_far) + + # Convert logits to logprobs. + # shape: (batch_size * beam_size, vocab_size) + class_logprobs = F.log_softmax(class_logits, dim=1) + + # Set logprobs of last predicted tokens as high negative value to avoid + # repetition in caption. + for index in range(batch_size * self.beam_size): + class_logprobs[index, predictions_so_far[index, -1]] = -10000 + + # shape: (batch_size * beam_size, num_classes) + last_predictions_expanded = last_predictions.unsqueeze(-1).expand( + batch_size * self.beam_size, num_classes + ) + # Here we are finding any beams where we predicted the end token in + # the previous timestep and replacing the distribution with a + # one-hot distribution, forcing the beam to predict the end token + # this timestep as well. + # shape: (batch_size * beam_size, num_classes) + cleaned_logprobs = torch.where( + last_predictions_expanded == self._eos_index, + logprobs_after_end, + class_logprobs, + ) + # shape (both): (batch_size * beam_size, per_node_beam_size) + top_logprobs, predicted_classes = cleaned_logprobs.topk( + self.per_node_beam_size + ) + # Here we expand the last log probs to `(batch_size * beam_size, + # per_node_beam_size)` so that we can add them to the current log + # probs for this timestep. This lets us maintain the log + # probability of each element on the beam. + # shape: (batch_size * beam_size, per_node_beam_size) + expanded_last_logprobs = ( + last_logprobs.unsqueeze(2) + .expand(batch_size, self.beam_size, self.per_node_beam_size) + .reshape(batch_size * self.beam_size, self.per_node_beam_size) + ) + # shape: (batch_size * beam_size, per_node_beam_size) + summed_top_logprobs = top_logprobs + expanded_last_logprobs + + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_summed = summed_top_logprobs.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + # shape: (batch_size, beam_size * per_node_beam_size) + reshaped_predicted_classes = predicted_classes.reshape( + batch_size, self.beam_size * self.per_node_beam_size + ) + # Keep only the top `beam_size` beam indices. + # shape: (batch_size, beam_size), (batch_size, beam_size) + restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk( + self.beam_size + ) + # Use the beam indices to extract the corresponding classes. + # shape: (batch_size, beam_size) + restricted_predicted_classes = reshaped_predicted_classes.gather( + 1, restricted_beam_indices + ) + predictions.append(restricted_predicted_classes) + + # shape: (batch_size, beam_size) + last_logprobs = restricted_beam_logprobs + + # The beam indices come from a `beam_size * per_node_beam_size` + # dimension where the indices with a common ancestor are grouped + # together. Hence dividing by `per_node_beam_size` gives the + # ancestor. (Note that this is integer division as the tensor is a + # LongTensor.) + # shape: (batch_size, beam_size) + backpointer = restricted_beam_indices // self.per_node_beam_size + + backpointers.append(backpointer) + + if not torch.isfinite(last_logprobs).all(): + warnings.warn( + "Infinite log probs encountered. Some final captions may not " + "make sense. This can happen when the beam size is larger than" + " the number of valid (non-zero probability) transitions that " + "the step function produces.", + RuntimeWarning, + ) + + # Reconstruct the captions. + # shape: [(batch_size, beam_size, 1)] + reconstructed_predictions = [predictions[-1].unsqueeze(2)] + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[-1] + + for timestep in range(len(predictions) - 2, 0, -1): + # shape: (batch_size, beam_size, 1) + cur_preds = ( + predictions[timestep].gather(1, cur_backpointers).unsqueeze(2) + ) + reconstructed_predictions.append(cur_preds) + + # shape: (batch_size, beam_size) + cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers) + + # shape: (batch_size, beam_size, 1) + final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2) + + reconstructed_predictions.append(final_preds) + + # shape: (batch_size, beam_size, max_steps) + all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) + + # Select the top-beam and its logprobs. + all_predictions = all_predictions[:, 0, :] + last_logprobs = last_logprobs[:, 0] + + return all_predictions, last_logprobs diff --git a/virtex/virtex/utils/checkpointing.py b/virtex/virtex/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..dd696d5a92f1a174adac35fae5fa7de059765941 --- /dev/null +++ b/virtex/virtex/utils/checkpointing.py @@ -0,0 +1,184 @@ +import copy +import pathlib +from typing import Any, Dict, List, Optional + +from loguru import logger +import torch +from torch import nn + +import virtex.utils.distributed as dist + + +class CheckpointManager(object): + r""" + A helper class to periodically serialize models and other checkpointable + objects (optimizers, LR schedulers etc., which implement ``state_dict`` + method) during training, and optionally record best performing checkpoint + based on an observed metric. + + .. note:: + + For :class:`~torch.nn.parallel.DistributedDataParallel` objects, + ``state_dict`` of internal model is serialized. + + .. note:: + + The observed metric for keeping best checkpoint is assumed "higher is + better", flip the sign if otherwise. + + Parameters + ---------- + serialization_dir: str + Path to a directory to save checkpoints. + keep_recent: int, optional (default = 100) + Number of recent ``k`` checkpoints to keep on disk. Older checkpoints + will be removed. Set to a very large value for keeping all checkpoints. + checkpointables: Any + Keyword arguments with any checkpointable objects, for example: model, + optimizer, learning rate scheduler. + + Examples + -------- + >>> model = torch.nn.Linear(10, 2) + >>> optimizer = torch.optim.Adam(model.parameters()) + >>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer) + >>> num_epochs = 20 + >>> for epoch in range(num_epochs): + ... train(model) + ... val_loss = validate(model) + ... ckpt_manager.step(- val_loss, epoch) + """ + + def __init__( + self, + serialization_dir: str = "/tmp", + keep_recent: int = 200, + **checkpointables: Any, + ): + self.serialization_dir = pathlib.Path(serialization_dir) + self.keep_recent = keep_recent + + # Shallow copy, keeps references to tensors as original objects. + self.checkpointables = copy.copy(checkpointables) + + # Initialize members to hold state dict of best checkpoint and its + # performance. + self._best_metric: float = -1e-12 + self._best_ckpt: Dict[str, Any] = {} + + # Keep epoch/iteration numbers of recently saved 'k' checkpoints. + self._recent_iterations: List[int] = [] + + def step(self, iteration: int, metric: Optional[float] = None): + r""" + Serialize checkpoint and update best checkpoint based on metric. Keys + in serialized checkpoint match those in :attr:`checkpointables`. + + Parameters + ---------- + iteration: int + Current training iteration. Will be saved with other checkpointables. + metric: float, optional (default = None) + Observed metric (higher is better) for keeping track of best + checkpoint. If this is ``None``, best chckpoint will not be + recorded/updated. + """ + + checkpointable_state_dict: Dict[str, Any] = self._state_dict() + + # We also checkpoint current iteration. + checkpointable_state_dict["iteration"] = iteration + + # Update the best checkpoint based on metric, if provided. + if metric is not None and metric > self._best_metric: + self._best_metric = metric + self._best_ckpt = copy.copy(checkpointable_state_dict) + + # Serialize checkpoint corresponding to current iteration. + torch.save( + checkpointable_state_dict, + self.serialization_dir / f"checkpoint_{iteration}.pth", + ) + if self._best_metric != -1e-12: + # Serialize best performing checkpoint observed so far. + torch.save( + self._best_ckpt, self.serialization_dir / "checkpoint_best.pth" + ) + + # Remove earliest checkpoint if there are more on disk. + self._recent_iterations.append(iteration) + if len(self._recent_iterations) > self.keep_recent: + self.remove_earliest_checkpoint() + + def _state_dict(self): + r"""Return a dict containing state dict of all checkpointables.""" + + __state_dict: Dict[str, Any] = {} + for key in self.checkpointables: + if isinstance( + self.checkpointables[key], nn.parallel.DistributedDataParallel + ): + __state_dict[key] = self.checkpointables[key].module.state_dict() + else: + __state_dict[key] = self.checkpointables[key].state_dict() + + return __state_dict + + def remove_earliest_checkpoint(self): + r"""Remove earliest serialized checkpoint from disk.""" + + earliest_iteration = self._recent_iterations.pop(0) + (self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink() + + def load(self, checkpoint_path: str): + r""" + Load a serialized checkpoint from a path. This method will try to find + each of :attr:`checkpointables` in the file and load its state dict. + Since our checkpointables are held as references, this method does not + return them. + + Parameters + ---------- + checkpoint_path: str + Path to a checkpoint serialized by :meth:`step`. + + Returns + ------- + int + Iteration corresponding to the loaded checkpoint. Useful for + resuming training. This will be -1 in case of best checkpoint, + or if info does not exist. + """ + + # Each process will log a message after loading checkpoint. + rank = dist.get_rank() + + logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint.pop("iteration", -1) + + # Keep flags of all checkpointables to lo which ones were not loaded. + is_loaded = {key: False for key in self.checkpointables} + + # Load each checkpointable from checkpoint. + for key in checkpoint: + if key in self.checkpointables: + logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}") + + if isinstance( + self.checkpointables[key], nn.parallel.DistributedDataParallel + ): + self.checkpointables[key].module.load_state_dict(checkpoint[key]) + else: + self.checkpointables[key].load_state_dict(checkpoint[key]) + + is_loaded[key] = True + else: + logger.info(f"Rank {rank}: {key} not found in `checkpointables`.") + + not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]] + if len(not_loaded) > 0: + logger.info( + f"Rank {rank}: Checkpointables not found in file: {not_loaded}" + ) + return iteration diff --git a/virtex/virtex/utils/common.py b/virtex/virtex/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..987e2158b985a7c8c35598a32309a12269996bdb --- /dev/null +++ b/virtex/virtex/utils/common.py @@ -0,0 +1,162 @@ +import argparse +import os +import random +import sys + +from loguru import logger +import numpy as np +import torch + +from virtex.config import Config +import virtex.utils.distributed as dist + + +def cycle(dataloader, device, start_iteration: int = 0): + r""" + A generator to yield batches of data from dataloader infinitely. + + Internally, it sets the ``epoch`` for dataloader sampler to shuffle the + examples. One may optionally provide the starting iteration to make sure + the shuffling seed is different and continues naturally. + """ + iteration = start_iteration + + while True: + if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): + # Set the `epoch` of DistributedSampler as current iteration. This + # is a way of determinisitic shuffling after every epoch, so it is + # just a seed and need not necessarily be the "epoch". + logger.info(f"Beginning new epoch, setting shuffle seed {iteration}") + dataloader.sampler.set_epoch(iteration) + + for batch in dataloader: + for key in batch: + batch[key] = batch[key].to(device) + yield batch + iteration += 1 + + +def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"): + r""" + Setup common stuff at the start of every pretraining or downstream + evaluation job, all listed here to avoid code duplication. Basic steps: + + 1. Fix random seeds and other PyTorch flags. + 2. Set up a serialization directory and loggers. + 3. Log important stuff such as config, process info (useful during + distributed training). + 4. Save a copy of config to serialization directory. + + .. note:: + + It is assumed that multiple processes for distributed training have + already been launched from outside. Functions from + :mod:`virtex.utils.distributed` module ae used to get process info. + + Parameters + ---------- + _C: virtex.config.Config + Config object with all the parameters. + _A: argparse.Namespace + Command line arguments. + job_type: str, optional (default = "pretrain") + Type of job for which setup is to be done; one of ``{"pretrain", + "downstream"}``. + """ + + # Get process rank and world size (assuming distributed is initialized). + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + # For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html + torch.manual_seed(_C.RANDOM_SEED) + torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC + torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK + random.seed(_C.RANDOM_SEED) + np.random.seed(_C.RANDOM_SEED) + + # Create serialization directory and save config in it. + os.makedirs(_A.serialization_dir, exist_ok=True) + _C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml")) + + # Remove default logger, create a logger for each process which writes to a + # separate log-file. This makes changes in global scope. + logger.remove(0) + if dist.get_world_size() > 1: + logger.add( + os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"), + format="{time} {level} {message}", + ) + + # Add a logger for stdout only for the master process. + if dist.is_master_process(): + logger.add( + sys.stdout, format="{time}: {message}", colorize=True + ) + + # Print process info, config and args. + logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}") + logger.info(str(_C)) + + logger.info("Command line args:") + for arg in vars(_A): + logger.info("{:<20}: {}".format(arg, getattr(_A, arg))) + + +def common_parser(description: str = "") -> argparse.ArgumentParser: + r""" + Create an argument parser some common arguments useful for any pretraining + or downstream evaluation scripts. + + Parameters + ---------- + description: str, optional (default = "") + Description to be used with the argument parser. + + Returns + ------- + argparse.ArgumentParser + A parser object with added arguments. + """ + parser = argparse.ArgumentParser(description=description) + + # fmt: off + parser.add_argument( + "--config", metavar="FILE", help="Path to a pretraining config file." + ) + parser.add_argument( + "--config-override", nargs="*", default=[], + help="A list of key-value pairs to modify pretraining config params.", + ) + parser.add_argument( + "--serialization-dir", default="/tmp/virtex", + help="Path to a directory to serialize checkpoints and save job logs." + ) + + group = parser.add_argument_group("Compute resource management arguments.") + group.add_argument( + "--cpu-workers", type=int, default=0, + help="Number of CPU workers per GPU to use for data loading.", + ) + group.add_argument( + "--num-machines", type=int, default=1, + help="Number of machines used in distributed training." + ) + group.add_argument( + "--num-gpus-per-machine", type=int, default=0, + help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as + zero for single-process CPU training.""", + ) + group.add_argument( + "--machine-rank", type=int, default=0, + help="""Rank of the machine, integer in [0, num_machines). Default 0 + for training with a single machine.""", + ) + group.add_argument( + "--dist-url", default=f"tcp://127.0.0.1:23456", + help="""URL of the master process in distributed training, it defaults + to localhost for single-machine training.""", + ) + # fmt: on + + return parser diff --git a/virtex/virtex/utils/distributed.py b/virtex/virtex/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..50f2fd36f9ced800bf6d10f08e367c0567337c76 --- /dev/null +++ b/virtex/virtex/utils/distributed.py @@ -0,0 +1,179 @@ +r""" +A collection of common utilities for distributed training. These are a bunch of +wrappers over utilities from :mod:`torch.distributed` module, but they do not +raise exceptions in absence of distributed training / CPU-only training, and +fall back to sensible default behavior. +""" +from typing import Callable, Dict, Tuple, Union + +from loguru import logger +import torch +from torch import distributed as dist +from torch import multiprocessing as mp + + +def launch( + job_fn: Callable, + num_machines: int = 1, + num_gpus_per_machine: int = 1, + machine_rank: int = 0, + dist_url: str = "tcp://127.0.0.1:23456", + args=(), +): + r""" + Launch a job in a distributed fashion: given ``num_machines`` machines, + each with ``num_gpus_per_machine`` GPUs, this utility will launch one + process per GPU. This wrapper uses :func:`torch.multiprocessing.spawn`. + + The user has to launch one job on each machine, manually specifying a + machine rank (incrementing integers from 0), this utility will adjust + process ranks per machine. One process on ``machine_rank = 0`` will be + refered as the *master process*, and the IP + a free port on this machine + will serve as the distributed process communication URL. + + Default arguments imply one machine with one GPU, and communication URL + as ``localhost``. + + .. note:: + + This utility assumes same number of GPUs per machine with IDs as + ``(0, 1, 2 ...)``. If you do not wish to use all GPUs on a machine, + set ``CUDA_VISIBLE_DEVICES`` environment variable (for example, + ``CUDA_VISIBLE_DEVICES=5,6``, which restricts to GPU 5 and 6 and + re-assigns their IDs to 0 and 1 in this job scope). + + Parameters + ---------- + job_fn: Callable + A callable object to launch. Pass your main function doing training, + validation etc. here. + num_machines: int, optional (default = 1) + Number of machines used, each with ``num_gpus_per_machine`` GPUs. + num_gpus_per_machine: int, optional (default = 1) + Number of GPUs per machine, with IDs as ``(0, 1, 2 ...)``. + machine_rank: int, optional (default = 0) + A manually specified rank of the machine, serves as a unique identifier + and useful for assigning global ranks to processes. + dist_url: str, optional (default = "tcp://127.0.0.1:23456") + Disributed process communication URL as ``tcp://x.x.x.x:port``. Set + this as the IP (and a free port) of machine with rank 0. + args: Tuple + Arguments to be passed to ``job_fn``. + """ + + assert ( + torch.cuda.is_available() + ), "CUDA not available, Cannot launch distributed processes." + + world_size = num_machines * num_gpus_per_machine + + # Spawn ``num_gpus_per_machine``` processes per machine, and provide + # "local process rank" (GPU ID) as the first arg to ``_dist_worker``. + # fmt: off + if world_size > 1: + mp.spawn( + _job_worker, + nprocs=num_gpus_per_machine, + args=( + job_fn, world_size, num_gpus_per_machine, machine_rank, dist_url, args + ), + daemon=False, + ) + else: + # Default to single machine, single GPU, with ID 0. + _job_worker(0, job_fn, 1, 1, 0, dist_url, args) + # fmt: on + + +def _job_worker( + local_rank: int, + job_fn: Callable, + world_size: int, + num_gpus_per_machine: int, + machine_rank: int, + dist_url: str, + args: Tuple, +): + r""" + Single distibuted process worker. This should never be used directly, + only used by :func:`launch`. + """ + + # Adjust global rank of process based on its machine rank. + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + ) + except Exception as e: + logger.error(f"Error launching processes, dist URL: {dist_url}") + raise e + + synchronize() + # Set GPU ID for each process according to its rank. + torch.cuda.set_device(local_rank) + job_fn(*args) + + +def synchronize() -> None: + r"""Synchronize (barrier) all processes in a process group.""" + if dist.is_initialized(): + dist.barrier() + + +def get_world_size() -> int: + r"""Return number of processes in the process group, each uses 1 GPU.""" + return dist.get_world_size() if dist.is_initialized() else 1 + + +def get_rank() -> int: + r"""Return rank of current process in the process group.""" + return dist.get_rank() if dist.is_initialized() else 0 + + +def is_master_process() -> bool: + r""" + Check whether current process is the master process. This check is useful + to restrict logging and checkpointing to master process. It will always + return ``True`` for single machine, single GPU execution. + """ + return get_rank() == 0 + + +def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]): + r""" + Averages a tensor, or a dict of tensors across all processes in a process + group. Objects in all processes will finally have same mean value. + + .. note:: + + Nested dicts of tensors are not supported. + + Parameters + ---------- + t: torch.Tensor or Dict[str, torch.Tensor] + A tensor or dict of tensors to average across processes. + """ + if dist.is_initialized(): + if isinstance(t, torch.Tensor): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + t /= get_world_size() + elif isinstance(t, dict): + for k in t: + dist.all_reduce(t[k], op=dist.ReduceOp.SUM) + t[k] /= dist.get_world_size() + + +def gpu_mem_usage() -> int: + r""" + Return gpu memory usage (in megabytes). If not using GPU, return 0 without + raising any exceptions. + """ + if torch.cuda.is_available(): + # This will be in bytes, so we divide by (1024 * 1024). + return torch.cuda.max_memory_allocated() // 1048576 + else: + return 0 diff --git a/virtex/virtex/utils/metrics.py b/virtex/virtex/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..352208ff4d7fc735fb4976066d49011fdcbe9ea1 --- /dev/null +++ b/virtex/virtex/utils/metrics.py @@ -0,0 +1,319 @@ +r""" +This module is a collection of metrics commonly used during pretraining and +downstream evaluation. Two main classes here are: + +- :class:`TopkAccuracy` used for ImageNet linear classification evaluation. +- :class:`CocoCaptionsEvaluator` used for caption evaluation (CIDEr and SPICE). + +Parts of this module (:meth:`tokenize`, :meth:`cider` and :meth:`spice`) are +adapted from `coco-captions evaluation code `_. +""" +from collections import defaultdict +import json +import os +from subprocess import Popen, PIPE, check_call +import tempfile +from typing import Any, Dict, List + +import numpy as np +import torch + + +class TopkAccuracy(object): + r""" + An accumulator for Top-K classification accuracy. This accumulates per-batch + accuracy during training/validation, which can retrieved at the end. Assumes + integer labels and predictions. + + .. note:: + + If used in :class:`~torch.nn.parallel.DistributedDataParallel`, results + need to be aggregated across GPU processes outside this class. + + Parameters + ---------- + top_k: int, optional (default = 1) + ``k`` for computing Top-K accuracy. + """ + + def __init__(self, top_k: int = 1): + self._top_k = top_k + self.reset() + + def reset(self): + r"""Reset counters; to be used at the start of new epoch/validation.""" + self.num_total = 0.0 + self.num_correct = 0.0 + + def __call__(self, predictions: torch.Tensor, ground_truth: torch.Tensor): + r""" + Update accumulated accuracy using the current batch. + + Parameters + ---------- + ground_truth: torch.Tensor + A tensor of shape ``(batch_size, )``, an integer label per example. + predictions : torch.Tensor + Predicted logits or log-probabilities of shape + ``(batch_size, num_classes)``. + """ + + if self._top_k == 1: + top_k = predictions.max(-1)[1].unsqueeze(-1) + else: + top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1] + + correct = top_k.eq(ground_truth.unsqueeze(-1)).float() + + self.num_total += ground_truth.numel() + self.num_correct += correct.sum() + + def get_metric(self, reset: bool = False): + r"""Get accumulated accuracy so far (and optionally reset counters).""" + if self.num_total > 1e-12: + accuracy = float(self.num_correct) / float(self.num_total) + else: + accuracy = 0.0 + if reset: + self.reset() + return accuracy + + +class CocoCaptionsEvaluator(object): + r"""A helper class to evaluate caption predictions in COCO format. This uses + :meth:`cider` and :meth:`spice` which exactly follow original COCO Captions + evaluation protocol. + + Parameters + ---------- + gt_annotations_path: str + Path to ground truth annotations in COCO format (typically this would + be COCO Captions ``val2017`` split). + """ + + def __init__(self, gt_annotations_path: str): + gt_annotations = json.load(open(gt_annotations_path))["annotations"] + + # Keep a mapping from image id to a list of captions. + self.ground_truth: Dict[int, List[str]] = defaultdict(list) + for ann in gt_annotations: + self.ground_truth[ann["image_id"]].append(ann["caption"]) + + self.ground_truth = tokenize(self.ground_truth) + + def evaluate(self, preds: List[Dict[str, Any]]) -> Dict[str, float]: + r"""Compute CIDEr and SPICE scores for predictions. + + Parameters + ---------- + preds: List[Dict[str, Any]] + List of per instance predictions in COCO Captions format: + ``[ {"image_id": int, "caption": str} ...]``. + + Returns + ------- + Dict[str, float] + Computed metrics; a dict with keys ``{"CIDEr", "SPICE"}``. + """ + if isinstance(preds, str): + preds = json.load(open(preds)) + + res = {ann["image_id"]: [ann["caption"]] for ann in preds} + res = tokenize(res) + + # Remove IDs from predictions which are not in GT. + common_image_ids = self.ground_truth.keys() & res.keys() + res = {k: v for k, v in res.items() if k in common_image_ids} + + # Add dummy entries for IDs absent in preds, but present in GT. + for k in self.ground_truth: + res[k] = res.get(k, [""]) + + cider_score = cider(res, self.ground_truth) + spice_score = spice(res, self.ground_truth) + + return {"CIDEr": 100 * cider_score, "SPICE": 100 * spice_score} + + +def tokenize(image_id_to_captions: Dict[int, List[str]]) -> Dict[int, List[str]]: + r""" + Given a mapping of image id to a list of corrsponding captions, tokenize + captions in place according to Penn Treebank Tokenizer. This method assumes + the presence of Stanford CoreNLP JAR file in directory of this module. + """ + # Path to the Stanford CoreNLP JAR file. + CORENLP_JAR = ( + "assets/stanford-corenlp-full-2014-08-27/stanford-corenlp-3.4.1.jar" + ) + + # Prepare data for Tokenizer: write captions to a text file, one per line. + image_ids = [k for k, v in image_id_to_captions.items() for _ in range(len(v))] + sentences = "\n".join( + [c.replace("\n", " ") for k, v in image_id_to_captions.items() for c in v] + ) + tmp_file = tempfile.NamedTemporaryFile(delete=False) + tmp_file.write(sentences.encode()) + tmp_file.close() + + # fmt: off + # Tokenize sentences. We use the JAR file for tokenization. + command = [ + "java", "-cp", CORENLP_JAR, "edu.stanford.nlp.process.PTBTokenizer", + "-preserveLines", "-lowerCase", tmp_file.name + ] + tokenized_captions = ( + Popen(command, cwd=os.path.dirname(os.path.abspath(__file__)), stdout=PIPE) + .communicate(input=sentences.rstrip())[0] + .decode() + .split("\n") + ) + # fmt: on + os.remove(tmp_file.name) + + # Map tokenized captions back to their image IDs. + # Punctuations to be removed from the sentences (PTB style)). + # fmt: off + PUNCTS = [ + "''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", ".", "?", + "!", ",", ":", "-", "--", "...", ";", + ] + # fmt: on + image_id_to_tokenized_captions: Dict[int, List[str]] = defaultdict(list) + for image_id, caption in zip(image_ids, tokenized_captions): + image_id_to_tokenized_captions[image_id].append( + " ".join([w for w in caption.rstrip().split(" ") if w not in PUNCTS]) + ) + + return image_id_to_tokenized_captions + + +def cider( + predictions: Dict[int, List[str]], + ground_truth: Dict[int, List[str]], + n: int = 4, + sigma: float = 6.0, +) -> float: + r"""Compute CIDEr score given ground truth captions and predictions.""" + + # ------------------------------------------------------------------------- + def to_ngrams(sentence: str, n: int = 4): + r"""Convert a sentence into n-grams and their counts.""" + words = sentence.split() + counts = defaultdict(int) # type: ignore + for k in range(1, n + 1): + for i in range(len(words) - k + 1): + ngram = tuple(words[i : i + k]) + counts[ngram] += 1 + return counts + + def counts2vec(cnts, document_frequency, log_reference_length): + r"""Function maps counts of ngram to vector of tfidf weights.""" + vec = [defaultdict(float) for _ in range(n)] + length = 0 + norm = [0.0 for _ in range(n)] + for (ngram, term_freq) in cnts.items(): + df = np.log(max(1.0, document_frequency[ngram])) + # tf (term_freq) * idf (precomputed idf) for n-grams + vec[len(ngram) - 1][ngram] = float(term_freq) * ( + log_reference_length - df + ) + # Compute norm for the vector: will be used for computing similarity + norm[len(ngram) - 1] += pow(vec[len(ngram) - 1][ngram], 2) + + if len(ngram) == 2: + length += term_freq + norm = [np.sqrt(nn) for nn in norm] + return vec, norm, length + + def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): + r"""Compute the cosine similarity of two vectors.""" + delta = float(length_hyp - length_ref) + val = np.array([0.0 for _ in range(n)]) + for nn in range(n): + for (ngram, count) in vec_hyp[nn].items(): + val[nn] += ( + min(vec_hyp[nn][ngram], vec_ref[nn][ngram]) * vec_ref[nn][ngram] + ) + + val[nn] /= (norm_hyp[nn] * norm_ref[nn]) or 1 + val[nn] *= np.e ** (-(delta ** 2) / (2 * sigma ** 2)) + return val + + # ------------------------------------------------------------------------- + + ctest = [to_ngrams(predictions[image_id][0]) for image_id in ground_truth] + crefs = [ + [to_ngrams(gt) for gt in ground_truth[image_id]] for image_id in ground_truth + ] + # Build document frequency and compute IDF. + document_frequency = defaultdict(float) + for refs in crefs: + # refs, k ref captions of one image + for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]): + document_frequency[ngram] += 1 + + # Compute log reference length. + log_reference_length = np.log(float(len(crefs))) + + scores = [] + for test, refs in zip(ctest, crefs): + # Compute vector for test captions. + vec, norm, length = counts2vec( + test, document_frequency, log_reference_length + ) + # Compute vector for ref captions. + score = np.array([0.0 for _ in range(n)]) + for ref in refs: + vec_ref, norm_ref, length_ref = counts2vec( + ref, document_frequency, log_reference_length + ) + score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) + + score_avg = np.mean(score) + score_avg /= len(refs) + score_avg *= 10.0 + scores.append(score_avg) + + return np.mean(scores) + + +def spice( + predictions: Dict[int, List[str]], ground_truth: Dict[int, List[str]] +) -> float: + r"""Compute SPICE score given ground truth captions and predictions.""" + + # Prepare temporary input file for the SPICE scorer. + input_data = [ + { + "image_id": image_id, + "test": predictions[image_id][0], + "refs": ground_truth[image_id], + } + for image_id in ground_truth + ] + # Create a temporary directory and dump input file to SPICE. + temp_dir = tempfile.mkdtemp() + INPUT_PATH = os.path.join(temp_dir, "input_file.json") + OUTPUT_PATH = os.path.join(temp_dir, "output_file.json") + json.dump(input_data, open(INPUT_PATH, "w")) + + # fmt: off + # Run the command to execute SPICE jar. + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + SPICE_JAR = f"{CURRENT_DIR}/assets/SPICE-1.0/spice-1.0.jar" + CACHE_DIR = f"{CURRENT_DIR}/assets/cache" + os.makedirs(CACHE_DIR, exist_ok=True) + spice_cmd = [ + "java", "-jar", "-Xmx8G", SPICE_JAR, INPUT_PATH, + "-cache", CACHE_DIR, "-out", OUTPUT_PATH, "-subset", "-silent", + ] + check_call(spice_cmd, cwd=CURRENT_DIR) + # fmt: on + + # Read and process results + results = json.load(open(OUTPUT_PATH)) + image_id_to_scores = {item["image_id"]: item["scores"] for item in results} + spice_scores = [ + np.array(item["scores"]["All"]["f"]).astype(float) for item in results + ] + return np.mean(spice_scores) diff --git a/virtex/virtex/utils/nucleus_sampling.py b/virtex/virtex/utils/nucleus_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a905c228eef653d014417e7dac87b4ef62c6929f --- /dev/null +++ b/virtex/virtex/utils/nucleus_sampling.py @@ -0,0 +1,131 @@ +r""" +Nucleus Sampling was introduced in the paper +`The Curious Case of Neural Text Degeneration `_. +If you take it from here, make sure to cite them: + +.. code-block:: text + + @inproceedings{, + title={The Curious Case of Neural Text Degeneration}, + author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi}, + journal={ICLR}, + year={2020} + } + +Some core parts of this code are adapted with minor modifications from Thomas Wolf's +gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +""" + +from typing import Callable, List, Tuple + +import torch +import torch.nn.functional as F + + +class AutoRegressiveNucleusSampling(object): + """ + Implements the nucleus sampling for decoding captions. This class only works + for auto-regressive models (Transformer-like), not recurrent models (LSTM-like). + + Parameters + ---------- + eos_index: int + The index of the end token (``[EOS]``) in vocabulary. + max_steps: int, optional (default = 50) + The maximum number of decoding steps. + nucleus_size: int, optional (default = 5) + Size of top-K nucleus for sampling. + """ + + def __init__( + self, + eos_index: int, + max_steps: int = 50, + nucleus_size: float = 0.9, + ): + super().__init__() + self._eos_index = eos_index + self.max_steps = max_steps + self.nucleus_size = nucleus_size + + def search( + self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] + ) -> Tuple[torch.Tensor, None]: + + batch_size = start_predictions.size()[0] + + # List of `(batch_size, )` tensors. One for each timestep. + # This includes the start-of-sentence tokens, unlike the implementation + # in `AutoregressiveBeamSearch`. We will remove them in the end. + + # Transpose `start_predictions` and make a list when prompt is provided. + predictions = [ + start_predictions[:, i] for i in range(start_predictions.size(1)) + ] + + for timestep in range(self.max_steps): + # Get the predictions from last timestep (most recent). + # shape: (batch_size, ) + last_predictions = predictions[-1] + + # If every predicted token from the last step is end-of-sentence token, + # then we can stop early. + if (last_predictions == self._eos_index).all(): + break + + # Combine step predictions made so far into one tensor. This is our + # "partial" caption input to the transformer. + # shape: (batch_size, timestep + 1) + predictions_so_far = torch.stack(predictions).permute(1, 0) + + # Take a step, get the distribution of logits from next timestep. + # shape: (batch_size, num_classes) + current_logits = step(predictions_so_far) + + # Sort logits in descending order to determine the nucleus. + sorted_logits, sorted_idx = torch.sort(current_logits, descending=True) + + # Get cumulative softmax probabilites. For every instance in batch, a + # variable amount of tokens (N) will consitute the nucleus. + # shape: (batch_size, num_classes) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Determine indices of tokens at the tail of distribution. These will be + # removed from the nucleus. + sorted_idx_to_remove = cumulative_probs > self.nucleus_size + + # Shift the indices to the right to keep the first token outside nucleus. + sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() + sorted_idx_to_remove[..., 0] = 0 + + # Set logits to large negative value to avoid sampling them. Iterate over + # the batch of examples. + for t in range(current_logits.size()[0]): + idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]] + current_logits[t][idx_to_remove] = -1e12 + + # Set logits for last predicted token to a large negative value to + # avoid repetition. + current_logits[t][last_predictions[t]] = -1e12 + + # Sample from the filtered distribution. + # shape: (batch_size, num_classes) + current_probs = F.softmax(current_logits, dim=-1) + + # shape: (batch_size, ) + current_predictions = torch.multinomial(current_probs, 1) + current_predictions = current_predictions.view(batch_size) + + # Set current predicted tokens to be end-of-sentence for instances where + # last prediction was also end-of-sentence token. + current_predictions[last_predictions == self._eos_index] = self._eos_index + + predictions.append(current_predictions) + + # Remove start-of-sentence token from predictions, and collect them together. + # shape: (batch_size, max_steps) .. or could be less than max_steps. + all_predictions = torch.stack(predictions[1:]).permute(1, 0) + + # We don't return any logprobs of generated sequence with nucleus sampling, + # unlike `AutoregressiveBeamSearch`. + return all_predictions, None diff --git a/virtex/virtex/utils/timer.py b/virtex/virtex/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..02f7992b6e382e57c5ce3f85f4d08df53e3ebe1a --- /dev/null +++ b/virtex/virtex/utils/timer.py @@ -0,0 +1,69 @@ +import time +from typing import Optional + + +class Timer(object): + r""" + A simple timer to record time per iteration and ETA of training. ETA is + estimated by moving window average with fixed window size. + + Parameters + ---------- + start_from: int, optional (default = 1) + Iteration from which counting should be started/resumed. + total_iterations: int, optional (default = None) + Total number of iterations. ETA will not be tracked (will remain "N/A") + if this is not provided. + window_size: int, optional (default = 20) + Window size for calculating ETA based on average of past few iterations. + """ + + def __init__( + self, + start_from: int = 1, + total_iterations: Optional[int] = None, + window_size: int = 20, + ): + # We decrement by 1 because `current_iter` changes increment during + # an iteration (for example, will change from 0 -> 1 on iteration 1). + self.current_iter = start_from - 1 + self.total_iters = total_iterations + + self._start_time = time.time() + self._times = [0.0] * window_size + + def tic(self) -> None: + r"""Start recording time: call at the beginning of iteration.""" + self._start_time = time.time() + + def toc(self) -> None: + r"""Stop recording time: call at the end of iteration.""" + self._times.append(time.time() - self._start_time) + self._times = self._times[1:] + self.current_iter += 1 + + @property + def stats(self) -> str: + r"""Return a single string with current iteration, time and ETA.""" + return ( + f"Iter {self.current_iter} | Time: {self._times[-1]:.3f} sec | " + f"ETA: {self.eta_hhmm}" + ) + + @property + def eta_hhmm(self) -> str: + r"""Return ETA in the form of ``hh mm`` string.""" + if self.total_iters: + eta_sec = int(self.eta_sec) + return f"{eta_sec // 3600}h {((eta_sec % 3600) // 60):02d}m" + else: + return "N/A" + + @property + def eta_sec(self) -> float: + r"""Return ETA in the form of seconds.""" + if self.total_iters: + avg_time = sum(self._times) / len(self._times) + return avg_time * (self.total_iters - self.current_iter) + else: + return 0.0