diff --git a/app.py b/app.py index dc042299f4e404ef8a0d146e49de3245072fc22a..e330aebf4955741cf4823806b8e34ccd9184aacf 100644 --- a/app.py +++ b/app.py @@ -1,18 +1,18 @@ -import streamlit as st import io -import sys -import time -import json -sys.path.append("./virtex/") + +import streamlit as st from model import * # # TODO: # - Reformat the model introduction # - Make the iterative text generation -def gen_show_caption(sub_prompt=None, cap_prompt = ""): + +def gen_show_caption(sub_prompt=None, cap_prompt=""): with st.spinner("Generating Caption"): - subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt) + subreddit, caption = virtexModel.predict( + image_dict, sub_prompt=sub_prompt, prompt=cap_prompt + ) st.markdown( f""" ### r/{subreddit} {cap_prompt} {caption} - """, - unsafe_allow_html=True) - -_, center, _ = st.columns([1,8,1]) + """, + unsafe_allow_html=True, + ) + + +_, center, _ = st.columns([1, 8, 1]) with center: st.title("Image Captioning Demo from RedCaps") @@ -50,7 +52,7 @@ st.sidebar.markdown( with st.spinner("Loading Model"): virtexModel, imageLoader, sample_images, valid_subs = create_objects() - + select_idx = None @@ -66,9 +68,9 @@ uploaded_image = None # with st.sidebar.form("file-uploader-form", clear_on_submit=True): uploaded_file = st.sidebar.file_uploader("Choose a file") # submitted = st.form_submit_button("Submit") -if uploaded_file is not None:# and submitted: +if uploaded_file is not None: # and submitted: uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue())) - select_idx = None # set this to help rewrite the cache + select_idx = None # set this to help rewrite the cache # class OnChange(): # def __init__(self, idx): @@ -88,21 +90,26 @@ if uploaded_file is not None:# and submitted: st.sidebar.title("Select a Subreddit") sub = st.sidebar.selectbox( "Type below to condition on a subreddit. Select None for a predicted subreddit", - valid_subs + valid_subs, ) st.sidebar.title("Write a Custom Prompt") -cap_prompt = st.sidebar.text_input( - "Write the start of your caption below", - value="" -) +cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="") _ = st.sidebar.button("Regenerate Caption") st.sidebar.write("Advanced Options:") -num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1) -nuc_size = st.sidebar.slider("Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05) +num_captions = st.sidebar.select_slider( + "Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1 +) +nuc_size = st.sidebar.slider( + "Nucelus Size:\nLarger values lead to more diverse captions", + min_value=0.0, + max_value=1.0, + value=0.8, + step=0.05, +) virtexModel.model.decoder.nucleus_size = nuc_size image_file = sample_image @@ -110,14 +117,14 @@ image_file = sample_image # LOAD AND CACHE THE IMAGE if uploaded_image is not None: image = uploaded_image -elif select_idx is None and 'image' in st.session_state: - image = st.session_state['image'] +elif select_idx is None and "image" in st.session_state: + image = st.session_state["image"] else: image = Image.open(image_file) image = image.convert("RGB") -st.session_state['image'] = image +st.session_state["image"] = image image_dict = imageLoader.transform(image) @@ -141,4 +148,4 @@ This demo accompanies our paper RedCaps. Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson """ -) \ No newline at end of file +) diff --git a/model.py b/model.py index f071d7e96a31ac14fb8d84ad4d28da2eaab9c430..a615e1fb9a09443b2040bc3ab1a9e12c60e08f63 100644 --- a/model.py +++ b/model.py @@ -1,18 +1,17 @@ -import streamlit as st -from huggingface_hub import hf_hub_url, cached_download -from PIL import Image import os import json import glob import random -from typing import Any, Dict, List import torch import torchvision +import streamlit as st import wordsegment as ws +from PIL import Image +from huggingface_hub import hf_hub_url, cached_download from virtex.config import Config -from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory +from virtex.factories import TokenizerFactory, PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager CONFIG_PATH = "config.yaml" @@ -20,98 +19,108 @@ MODEL_PATH = "checkpoint_last5.pth" VALID_SUBREDDITS_PATH = "subreddit_list.json" SAMPLES_PATH = "./samples/*.jpg" -class ImageLoader(): + +class ImageLoader: def __init__(self): - self.image_transform = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(256), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))]) - self.show_size=500 - + self.image_transform = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(256), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), + ] + ) + self.show_size = 500 + def load(self, im_path): im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0) return {"image": im} - + def raw_load(self, im_path): im = torch.FloatTensor(Image.open(im_path)) return {"image": im} - + def transform(self, image): im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0) return {"image": im} - + def text_transform(self, text): # at present just lowercasing: return text.lower() - + def show_resize(self, image): # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol image = torchvision.transforms.functional.to_tensor(image) - x,y = image.shape[-2:] - ratio = float(self.show_size/max((x,y))) - image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)]) + x, y = image.shape[-2:] + ratio = float(self.show_size / max((x, y))) + image = torchvision.transforms.functional.resize( + image, [int(x * ratio), int(y * ratio)] + ) return torchvision.transforms.functional.to_pil_image(image) - -class VirTexModel(): + +class VirTexModel: + def __init__(self): self.config = Config(CONFIG_PATH) ws.load() - self.device = 'cpu' + self.device = "cpu" self.tokenizer = TokenizerFactory.from_config(self.config) self.model = PretrainingModelFactory.from_config(self.config).to(self.device) CheckpointManager(model=self.model).load(MODEL_PATH) self.model.eval() self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH)) - - def predict(self, image_dict, sub_prompt = None, prompt = ""): + + def predict(self, image_dict, sub_prompt=None, prompt=""): if sub_prompt is None: - subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long() + subreddit_tokens = torch.tensor( + [self.model.sos_index], device=self.device + ).long() else: subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt))) subreddit_tokens = ( - [self.model.sos_index] + - self.tokenizer.encode(subreddit_tokens) + - [self.tokenizer.token_to_id("[SEP]")] - ) + [self.model.sos_index] + + self.tokenizer.encode(subreddit_tokens) + + [self.tokenizer.token_to_id("[SEP]")] + ) subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long() - + if prompt is not "": # at present prompts without subreddits will break without this change # TODO FIX cap_tokens = self.tokenizer.encode(prompt) cap_tokens = torch.tensor(cap_tokens, device=self.device).long() - subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor( - ( - [self.model.sos_index] + - self.tokenizer.encode("pics") + - [self.tokenizer.token_to_id("[SEP]")] - ), device = self.device).long() - - subreddit_tokens = torch.cat( - [ - subreddit_tokens, - cap_tokens - ]) - - - predictions: List[Dict[str, Any]] = [] - + subreddit_tokens = ( + subreddit_tokens + if sub_prompt is not None + else torch.tensor( + ( + [self.model.sos_index] + + self.tokenizer.encode("pics") + + [self.tokenizer.token_to_id("[SEP]")] + ), + device=self.device, + ).long() + ) + + subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens]) + is_valid_subreddit = False subreddit, rest_of_caption = "", "" image_dict["decode_prompt"] = subreddit_tokens while not is_valid_subreddit: - + with torch.no_grad(): caption = self.model(image_dict)["predictions"][0].tolist() - + if self.tokenizer.token_to_id("[SEP]") in caption: sep_index = caption.index(self.tokenizer.token_to_id("[SEP]")) caption[sep_index] = self.tokenizer.token_to_id("://") - + caption = self.tokenizer.decode(caption) - + if "://" in caption: subreddit, rest_of_caption = caption.split("://") subreddit = "".join(subreddit.split()) @@ -122,25 +131,29 @@ class VirTexModel(): # split prompt for coloring: if prompt is not "": _, rest_of_caption = caption.split(prompt.strip()) - + is_valid_subreddit = subreddit in self.valid_subs - + return subreddit, rest_of_caption + def download_files(): - #download model files + # download model files download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH] for f in download_files: fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f)) os.system(f"cp {fp} ./{f}") + def get_samples(): return glob.glob(SAMPLES_PATH) + def get_rand_idx(samples): - return random.randint(0,len(samples)-1) + return random.randint(0, len(samples) - 1) + -@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size +@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size def create_objects(): sample_images = get_samples() virtexModel = VirTexModel() @@ -149,7 +162,8 @@ def create_objects(): valid_subs.insert(0, None) return virtexModel, imageLoader, sample_images, valid_subs -footer=""" -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
task_ablations/bicaptioning_R_50_L1_H2048.yaml88.753.8model
task_ablations/captioning_R_50_L1_H2048.yaml88.650.8model
task_ablations/token_classification_R_50.yaml88.848.6model
task_ablations/multilabel_classification_R_50.yaml86.246.2model
task_ablations/masked_lm_R_50_L1_H2048.yaml86.446.7model
- - -Width Ablations -^^^^^^^^^^^^^^^ - -.. raw:: html - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
width_ablations/bicaptioning_R_50_L1_H512.yaml88.451.8model
width_ablations/bicaptioning_R_50_L1_H768.yaml88.352.3model
width_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
width_ablations/bicaptioning_R_50_L1_H2048.yaml88.753.8model
- - -Depth Ablations -^^^^^^^^^^^^^^^ - -.. raw:: html - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
depth_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
depth_ablations/bicaptioning_R_50_L2_H1024.yaml88.853.8model
depth_ablations/bicaptioning_R_50_L3_H1024.yaml88.753.9model
depth_ablations/bicaptioning_R_50_L4_H1024.yaml88.753.9model
- - -Backbone Ablations -^^^^^^^^^^^^^^^^^^ - -.. raw:: html - -
- - - - - - - - - - - - - - - - - - - - - - - - - - -
Model Config NameVOC07
mAP
ImageNet
Top-1 Acc.
Model URL
backbone_ablations/bicaptioning_R_50_L1_H1024.yaml88.353.2model
backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml88.552.9model
backbone_ablations/bicaptioning_R_101_L1_H1024.yaml88.752.1model
diff --git a/virtex/docs/virtex/usage/pretrain.rst b/virtex/docs/virtex/usage/pretrain.rst deleted file mode 100644 index 2f14305f2152afdede708d45cbe5b2d165e9246a..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/usage/pretrain.rst +++ /dev/null @@ -1,100 +0,0 @@ -How to train your VirTex model? -=============================== - -We provide training scripts for all type of VirTex models from the paper; -including our best-performing model and other ablations. -Our training jobs are specified by config files (YAML). -Execute all commands from project root to use the provided config files. - - -Training the base VirTex model ------------------------------- - -Train the base VirTex model with ResNet-50 visual backbone; and a textual head -with ``L = 1, H = 1024`` using all default optimization hyperparameters. - -.. code-block:: - - python scripts/pretrain_virtex.py \ - --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \ - --num-gpus-per-machine 8 \ - --cpu-workers 4 \ - --serialization-dir /tmp/VIRTEX_R_50_L1_H1024 - # Default: --checkpoint-every 2000 --log-every 20 - -Training job will save checkpoints, tensorboard logs (loss curves and metrics), -and back up the config in ``--serialization-dir``. Use ``tensorboard --logdir -`` to view training curves, validation metrics etc. directly -on tensorboard. - -We recommend training with 8 GPUs on the same machine, although training with -multiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``), -single GPU (``--num-gpus-per-machine 1``) as well as CPU -(``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for -interactive debugging with PDB is not supported, as PDB and ``multiprocessing`` -module do not play nice. - -------------------------------------------------------------------------------- - -Reproducing all VirTex ablations --------------------------------- - -To reproduce all ablations from the `paper `_, -replace the ``--config`` argument in above command with the following (all -assumed to be relative to project root): - -Pretraining Task Ablations -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml -2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml -3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml -4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml -5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml - -Transformer Size Ablations -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml -2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml -3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml -4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml -5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml -6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml -7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml -8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml - -Backbone Ablations -^^^^^^^^^^^^^^^^^^ - -1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml -2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml -3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml - -.. note:: - - **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5) - and **Backbone Ablations** (1) are all the same exact model. - -Data Efficiency Experiments -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -These are VirTex models trained on a subset of COCO Captions dataset. For example, -train a base VirTex model on randomly selected ``50%`` of COCO Captions: - -.. code-block:: - - python scripts/pretrain_virtex.py \ - --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \ - --config-override DATA.USE_PERCENTAGE 50.0 \ - --num-gpus-per-machine 8 \ - --cpu-workers 4 \ - --serialization-dir /tmp/VIRTEX_R_50_L1_H1024_PERCENT_50 - # Default: --checkpoint-every 2000 --log-every 20 - -COCO Captions provides five captions per image. To train with one fixed caption -per image, add ``DATA.USE_SINGLE_CAPTION True`` in ``--config-override``. - -The randomly selected subset is deterministic across runs based on random seed -(``RANDOM_SEED`` in config). When training on less than ``50%`` dataset size, we -recommend using multiple random seeds (results will have a variance of ``±1%``). diff --git a/virtex/docs/virtex/usage/setup_dependencies.rst b/virtex/docs/virtex/usage/setup_dependencies.rst deleted file mode 100644 index b4ece3964148f977154c367de2dfb84c57a86053..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/usage/setup_dependencies.rst +++ /dev/null @@ -1,153 +0,0 @@ -How to setup this codebase? -=========================== - -.. raw:: html - -
- -This codebase requires Python 3.6+ or higher. We recommend using Anaconda or -Miniconda. We walk through installation and data preprocessing here. - - -Install Dependencies --------------------- - -For these steps to install through Anaconda (or Miniconda). - -1. Install Anaconda or Miniconda distribution based on Python 3+ from their - `downloads site `_. - - -2. Clone the repository first. - - .. code-block:: shell - - git clone https://www.github.com/kdexd/virtex - - -3. Create a conda environment and install all the dependencies. - - .. code-block:: shell - - cd virtex - conda create -n virtex python=3.6 - conda activate virtex - pip install -r requirements.txt - - -4. Install this codebase as a package in development version. - - .. code-block:: shell - - python setup.py develop - -Now you can ``import virtex`` from anywhere as long as you have this conda -environment activated. - -------------------------------------------------------------------------------- - - -Setup Datasets --------------- - -Datasets are assumed to exist in ``./datasets`` directory (relative to the -project root) following the structure specified below. COCO is used for -pretraining, and rest of the datasets (including COCO) are used for downstream -tasks. This structure is compatible when using -`Detectron2 `_ for downstream -tasks. - -COCO -^^^^ -.. code-block:: - - datasets/coco/ - annotations/ - captions_{train,val}2017.json - instances_{train,val}2017.json - train2017/ - # images in train2017 split - val2017/ - # images in val2017 split - -LVIS -^^^^ -.. code-block:: - - datasets/coco/ - train2017/ - val2017/ - datasets/lvis/ - lvis_v1.0_{train,val}.json - -PASCAL VOC -^^^^^^^^^^ -.. code-block:: - - datasets/VOC2007/ - Annotations/ - ImageSets/ - Main/ - trainval.txt - test.txt - JPEGImages/ - - datasets/VOC2012/ - # Same as VOC2007 above - -ImageNet -^^^^^^^^ -.. code-block:: - - datasets/imagenet/ - train/ - # One directory per category with images in it - val/ - # One directory per category with images in it - ILSVRC2012_devkit_t12.tar.gz - -iNaturalist 2018 -^^^^^^^^^^^^^^^^ -.. code-block:: - - datasets/inaturalist/ - train_val2018/ - annotations/ - train2018.json - val2018.json - -------------------------------------------------------------------------------- - - -Preprocess Data ---------------- - -1. Build a vocabulary out of COCO Captions ``train2017`` split. - - .. code-block:: shell - - python scripts/preprocess/build_vocabulary.py \ - --captions datasets/coco/annotations/captions_train2017.json \ - --vocab-size 10000 \ - --output-prefix datasets/vocab/coco_10k \ - --do-lower-case - - -2. Serialize COCO Captions (``train2017`` and ``val2017`` splits) into LMDB - files. These are faster for data reading during pretraining. - - .. code-block:: shell - - python scripts/preprocess/preprocess_coco.py \ - --data-root datasets/coco \ - --split train \ - --output datasets/coco/serialized_train.lmdb - - .. code-block:: shell - - python scripts/preprocess/preprocess_coco.py \ - --data-root datasets/coco \ - --split val \ - --output datasets/coco/serialized_val.lmdb - -That's it! You are all set to use this codebase. diff --git a/virtex/docs/virtex/utils.beam_search.rst b/virtex/docs/virtex/utils.beam_search.rst deleted file mode 100644 index a04811e9c89a0c093e1ffb373467eb6ba9b81b87..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.beam_search.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.beam_search -======================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.beam_search diff --git a/virtex/docs/virtex/utils.checkpointing.rst b/virtex/docs/virtex/utils.checkpointing.rst deleted file mode 100644 index 1b3719bf7e330c13835dc57457a3bef238c29b0e..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.checkpointing.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.checkpointing -========================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.checkpointing diff --git a/virtex/docs/virtex/utils.common.rst b/virtex/docs/virtex/utils.common.rst deleted file mode 100644 index cadd36d26a01f03b4457f1caed1c0c03dc58a9ef..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.common.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.common -=================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.common diff --git a/virtex/docs/virtex/utils.distributed.rst b/virtex/docs/virtex/utils.distributed.rst deleted file mode 100644 index e6a44d674ecb8a96d2568b1cd4072dd1e38f2a9d..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.distributed.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.distributed -======================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.distributed diff --git a/virtex/docs/virtex/utils.metrics.rst b/virtex/docs/virtex/utils.metrics.rst deleted file mode 100644 index 75234d5e4d230adf20192af77849b1a9c3f059d1..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.metrics.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.metrics -==================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.metrics diff --git a/virtex/docs/virtex/utils.rst b/virtex/docs/virtex/utils.rst deleted file mode 100644 index 9d021d9c4e1e255554130264d12abad06cc53911..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.rst +++ /dev/null @@ -1,15 +0,0 @@ -virtex.utils -============ - -.. raw:: html - -
- -.. toctree:: - - utils.common - utils.distributed - utils.timer - utils.checkpointing - utils.beam_search - utils.metrics diff --git a/virtex/docs/virtex/utils.timer.rst b/virtex/docs/virtex/utils.timer.rst deleted file mode 100644 index c2ddcdb3459f519d9a98766a6ddbef2adefa072d..0000000000000000000000000000000000000000 --- a/virtex/docs/virtex/utils.timer.rst +++ /dev/null @@ -1,8 +0,0 @@ -virtex.utils.timer -================== - -.. raw:: html - -
- -.. automodule:: virtex.utils.timer diff --git a/virtex/virtex/factories.py b/virtex/factories.py similarity index 100% rename from virtex/virtex/factories.py rename to virtex/factories.py diff --git a/virtex/hubconf.py b/virtex/hubconf.py deleted file mode 100644 index f85d01d371151f0716680397b1c955d3c4dd42d7..0000000000000000000000000000000000000000 --- a/virtex/hubconf.py +++ /dev/null @@ -1,35 +0,0 @@ -dependencies = ["torch"] - -import torch -import torchvision - - -def resnet50(pretrained: bool = False, **kwargs): - r""" - ResNet-50 visual backbone from the best performing VirTex model: pretrained - for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``. - - This is a torchvision-like model, with the last ``avgpool`` and `fc`` - modules replaced with ``nn.Identity()`` modules. Given a batch of image - tensors with size ``(B, 3, 224, 224)``, this model computes spatial image - features of size ``(B, 7, 7, 2048)``, where B = batch size. - - pretrained (bool): Whether to load model with pretrained weights. - """ - - # Create a torchvision resnet50 with randomly initialized weights. - model = torchvision.models.resnet50(pretrained=False, **kwargs) - - # Replace global average pooling and fully connected layers with identity - # modules. - model.avgpool = torch.nn.Identity() - model.fc = torch.nn.Identity() - - if pretrained: - model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://umich.box.com/shared/static/gsjqm4i4fm1wpzi947h27wweljd8gcpy.pth", - progress=False, - )["model"] - ) - return model diff --git a/virtex/virtex/model_zoo/__init__.py b/virtex/model_zoo/__init__.py similarity index 100% rename from virtex/virtex/model_zoo/__init__.py rename to virtex/model_zoo/__init__.py diff --git a/virtex/virtex/model_zoo/model_zoo.py b/virtex/model_zoo/model_zoo.py similarity index 100% rename from virtex/virtex/model_zoo/model_zoo.py rename to virtex/model_zoo/model_zoo.py diff --git a/virtex/virtex/models/__init__.py b/virtex/models/__init__.py similarity index 100% rename from virtex/virtex/models/__init__.py rename to virtex/models/__init__.py diff --git a/virtex/virtex/models/captioning.py b/virtex/models/captioning.py similarity index 100% rename from virtex/virtex/models/captioning.py rename to virtex/models/captioning.py diff --git a/virtex/virtex/models/classification.py b/virtex/models/classification.py similarity index 100% rename from virtex/virtex/models/classification.py rename to virtex/models/classification.py diff --git a/virtex/virtex/models/contrastive.py b/virtex/models/contrastive.py similarity index 100% rename from virtex/virtex/models/contrastive.py rename to virtex/models/contrastive.py diff --git a/virtex/virtex/models/masked_lm.py b/virtex/models/masked_lm.py similarity index 100% rename from virtex/virtex/models/masked_lm.py rename to virtex/models/masked_lm.py diff --git a/virtex/virtex/models/zero_shot_classification_eval.py b/virtex/models/zero_shot_classification_eval.py similarity index 100% rename from virtex/virtex/models/zero_shot_classification_eval.py rename to virtex/models/zero_shot_classification_eval.py diff --git a/virtex/virtex/modules/embedding.py b/virtex/modules/embedding.py similarity index 100% rename from virtex/virtex/modules/embedding.py rename to virtex/modules/embedding.py diff --git a/virtex/virtex/modules/label_smoothing.py b/virtex/modules/label_smoothing.py similarity index 100% rename from virtex/virtex/modules/label_smoothing.py rename to virtex/modules/label_smoothing.py diff --git a/virtex/virtex/modules/textual_heads.py b/virtex/modules/textual_heads.py similarity index 100% rename from virtex/virtex/modules/textual_heads.py rename to virtex/modules/textual_heads.py diff --git a/virtex/virtex/modules/transformer.py b/virtex/modules/transformer.py similarity index 100% rename from virtex/virtex/modules/transformer.py rename to virtex/modules/transformer.py diff --git a/virtex/virtex/modules/visual_backbones.py b/virtex/modules/visual_backbones.py similarity index 100% rename from virtex/virtex/modules/visual_backbones.py rename to virtex/modules/visual_backbones.py diff --git a/virtex/virtex/optim/__init__.py b/virtex/optim/__init__.py similarity index 100% rename from virtex/virtex/optim/__init__.py rename to virtex/optim/__init__.py diff --git a/virtex/virtex/optim/lookahead.py b/virtex/optim/lookahead.py similarity index 100% rename from virtex/virtex/optim/lookahead.py rename to virtex/optim/lookahead.py diff --git a/virtex/virtex/optim/lr_scheduler.py b/virtex/optim/lr_scheduler.py similarity index 100% rename from virtex/virtex/optim/lr_scheduler.py rename to virtex/optim/lr_scheduler.py diff --git a/virtex/scripts/clf_linear.py b/virtex/scripts/clf_linear.py deleted file mode 100644 index 52ab5f22d974cf4e523f174aab09143d7d19b005..0000000000000000000000000000000000000000 --- a/virtex/scripts/clf_linear.py +++ /dev/null @@ -1,302 +0,0 @@ -import argparse -import os - -from loguru import logger -import torch -from torch import nn -from torch.cuda import amp -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter - -from virtex.config import Config -from virtex.factories import ( - DownstreamDatasetFactory, - PretrainingModelFactory, - OptimizerFactory, - LRSchedulerFactory, -) -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser, common_setup, cycle -import virtex.utils.distributed as dist -from virtex.utils.metrics import TopkAccuracy -from virtex.utils.timer import Timer - - -# fmt: off -parser = common_parser( - description="""Do image classification with linear models and frozen - feature extractor, or fine-tune the feature extractor end-to-end.""" -) -group = parser.add_argument_group("Downstream config arguments.") -group.add_argument( - "--down-config", metavar="FILE", help="Path to a downstream config file." -) -group.add_argument( - "--down-config-override", nargs="*", default=[], - help="A list of key-value pairs to modify downstream config params.", -) - -parser.add_argument_group("Checkpointing and Logging") -parser.add_argument( - "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], - default="virtex", help="""How to initialize weights: - 1. 'random' initializes all weights randomly - 2. 'imagenet' initializes backbone weights from torchvision model zoo - 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - - with 'torchvision', state dict would be from PyTorch's training - script. - - with 'virtex' it should be for our full pretrained model.""" -) -parser.add_argument( - "--log-every", type=int, default=50, - help="""Log training curves to tensorboard after every these many iterations - only master process logs averaged loss values across processes.""", -) -parser.add_argument( - "--checkpoint-path", - help="""Path to load checkpoint and run downstream task evaluation. The - name of checkpoint file is required to be `model_*.pth`, where * is - iteration number from which the checkpoint was serialized.""" -) -parser.add_argument( - "--checkpoint-every", type=int, default=5000, - help="""Serialize model to a checkpoint after every these many iterations. - For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations - = 1 epoch).""", -) -# fmt: on - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device = torch.device("cpu") - else: - # Get the current device as set for current distributed process. - # Check `launch` function in `virtex.utils.distributed` module. - device = torch.cuda.current_device() - - # Create a downstream config object (this will be immutable) and perform - # common setup such as logging and setting up serialization directory. - _DOWNC = Config(_A.down_config, _A.down_config_override) - common_setup(_DOWNC, _A, job_type="downstream") - - # Create a (pretraining) config object and backup in serializaion directory. - _C = Config(_A.config, _A.config_override) - _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) - - # Get dataset name for tensorboard logging. - DATASET = _DOWNC.DATA.ROOT.split("/")[-1] - - # Set number of output classes according to dataset: - NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142} - NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET] - - # ------------------------------------------------------------------------- - # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER - # ------------------------------------------------------------------------- - train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train") - train_dataloader = DataLoader( - train_dataset, - batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), - num_workers=_A.cpu_workers, - sampler=DistributedSampler( - train_dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=True, - ), - drop_last=False, - pin_memory=True, - collate_fn=train_dataset.collate_fn, - ) - val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val") - val_dataloader = DataLoader( - val_dataset, - batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), - num_workers=_A.cpu_workers, - sampler=DistributedSampler( - val_dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False, - ), - pin_memory=True, - drop_last=False, - collate_fn=val_dataset.collate_fn, - ) - # Initialize model using pretraining config. - pretrained_model = PretrainingModelFactory.from_config(_C) - - # Load weights according to the init method, do nothing for `random`, and - # `imagenet` is already taken care of. - if _A.weight_init == "virtex": - CheckpointManager(model=pretrained_model).load(_A.checkpoint_path) - elif _A.weight_init == "torchvision": - # Keep strict=False because this state dict may have weights for - # last fc layer. - pretrained_model.visual.cnn.load_state_dict( - torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], - strict=False, - ) - - # Pull out the CNN (torchvision-like) from our pretrained model and add - # back the FC layer - this is exists in torchvision models, and is set to - # `nn.Identity()` during pretraining. - model = pretrained_model.visual.cnn # type: ignore - model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device) - model = model.to(device) - - # Re-initialize the FC layer. - torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01) - torch.nn.init.constant_(model.fc.bias.data, 0.0) - - # Freeze all layers except FC as per config param. - if _DOWNC.MODEL.VISUAL.FROZEN: - # Set model to eval mode to prevent BatchNorm from updating running - # mean and std. With only a linear layer, being in eval mode when - # training will not matter anyway. - model.eval() - - for name, param in model.named_parameters(): - if "fc" not in name: - param.requires_grad = False - - # Cross entropy loss and accuracy meter. - criterion = nn.CrossEntropyLoss() - top1 = TopkAccuracy(top_k=1) - - optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters()) - scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer) - del pretrained_model - - # ------------------------------------------------------------------------- - # BEFORE TRAINING STARTS - # ------------------------------------------------------------------------- - - # Create a gradient scaler for automatic mixed precision. - scaler = amp.GradScaler(enabled=_DOWNC.AMP) - - # Create an iterator from dataloader to sample batches perpetually. - train_dataloader_iter = cycle(train_dataloader, device) - - if dist.get_world_size() > 1: - dist.synchronize() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) - - if dist.is_master_process(): - checkpoint_manager = CheckpointManager( - _A.serialization_dir, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) - - # Keep track of time per iteration and ETA. - timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS) - - # ------------------------------------------------------------------------- - # TRAINING LOOP - # ------------------------------------------------------------------------- - for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1): - timer.tic() - optimizer.zero_grad() - batch = next(train_dataloader_iter) - - with amp.autocast(enabled=_DOWNC.AMP): - logits = model(batch["image"]) - loss = criterion(logits, batch["label"]) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - scheduler.step() - timer.toc() - - if iteration % _A.log_every == 0 and dist.is_master_process(): - logger.info( - f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB" - ) - tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration) - tensorboard_writer.add_scalar( - f"{DATASET}/learning_rate", - optimizer.param_groups[0]["lr"], - iteration, - ) - - # --------------------------------------------------------------------- - # VALIDATION - # --------------------------------------------------------------------- - if iteration % _A.checkpoint_every == 0: - torch.set_grad_enabled(False) - model.eval() - - total_val_loss = torch.tensor(0.0).to(device) - - for val_iteration, batch in enumerate(val_dataloader, start=1): - for key in batch: - batch[key] = batch[key].to(device) - - logits = model(batch["image"]) - loss = criterion(logits, batch["label"]) - top1(logits, batch["label"]) - total_val_loss += loss - - # Divide each loss component by number of val batches per GPU. - total_val_loss = total_val_loss / val_iteration - dist.average_across_processes(total_val_loss) - - # Get accumulated Top-1 accuracy for logging across GPUs. - acc = top1.get_metric(reset=True) - dist.average_across_processes(acc) - - torch.set_grad_enabled(True) - - # Set model back to train mode only when fine-tuning end-to-end. - if not _DOWNC.MODEL.VISUAL.FROZEN: - model.train() - - # Save recent checkpoint and best checkpoint based on accuracy. - if dist.is_master_process(): - checkpoint_manager.step(iteration) - - if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): - logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})") - tensorboard_writer.add_scalar( - f"{DATASET}/val_loss", total_val_loss, iteration - ) - # This name scoping will result in Tensorboard displaying all metrics - # (VOC07, caption, etc.) together. - tensorboard_writer.add_scalars( - f"metrics/{DATASET}", {"top1": acc}, iteration - ) - - # All processes will wait till master process is done logging. - dist.synchronize() - - -if __name__ == "__main__": - _A = parser.parse_args() - - # Add an arg in config override if `--weight-init` is imagenet. - if _A.weight_init == "imagenet": - _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) - - if _A.num_gpus_per_machine == 0: - main(_A) - else: - # This will launch `main` and set appropriate CUDA device (GPU ID) as - # per process (accessed in the beginning of `main`). - dist.launch( - main, - num_machines=_A.num_machines, - num_gpus_per_machine=_A.num_gpus_per_machine, - machine_rank=_A.machine_rank, - dist_url=_A.dist_url, - args=(_A,), - ) diff --git a/virtex/scripts/clf_voc07.py b/virtex/scripts/clf_voc07.py deleted file mode 100644 index 0e382c1ac49a3c9c254ab9c97f14652ed664fbf6..0000000000000000000000000000000000000000 --- a/virtex/scripts/clf_voc07.py +++ /dev/null @@ -1,272 +0,0 @@ -import argparse -import multiprocessing as mp -import os -from typing import Any, List - -from loguru import logger -import numpy as np -from sklearn.svm import LinearSVC -from sklearn.metrics import average_precision_score -from sklearn.model_selection import cross_val_score -import torch -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -from virtex.config import Config -from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser, common_setup - - -parser = common_parser( - description="Train SVMs for VOC2007 classification on a pretrained model." -) -group = parser.add_argument_group("Downstream config arguments.") -group.add_argument( - "--down-config", metavar="FILE", help="Path to a downstream config file." -) -group.add_argument( - "--down-config-override", - nargs="*", - default=[], - help="A list of key-value pairs to modify downstream config params.", -) - -# fmt: off -parser.add_argument_group("Checkpointing") -parser.add_argument( - "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], - default="virtex", help="""How to initialize weights: - 1. 'random' initializes all weights randomly - 2. 'imagenet' initializes backbone weights from torchvision model zoo - 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - - with 'torchvision', state dict would be from PyTorch's training - script. - - with 'virtex' it should be for our full pretrained model.""" -) -parser.add_argument( - "--checkpoint-path", - help="Path to load checkpoint and run downstream task evaluation." -) -# fmt: on - - -def train_test_single_svm(args): - - feats_train, tgts_train, feats_test, tgts_test, cls_name = args - SVM_COSTS = [0.01, 0.1, 1.0, 10.0] - - cls_labels = np.copy(tgts_train) - # Meaning of labels in VOC/COCO original loaded target files: - # label 0 = not present, set it to -1 as svm train target - # label 1 = present. Make the svm train target labels as -1, 1. - cls_labels[np.where(cls_labels == 0)] = -1 - - # See which cost maximizes the AP for this class. - best_crossval_ap: float = 0.0 - best_crossval_clf = None - best_cost: float = 0.0 - - # fmt: off - for cost in SVM_COSTS: - clf = LinearSVC( - C=cost, class_weight={1: 2, -1: 1}, penalty="l2", - loss="squared_hinge", max_iter=2000, - ) - ap_scores = cross_val_score( - clf, feats_train, cls_labels, cv=3, scoring="average_precision", - ) - clf.fit(feats_train, cls_labels) - - # Keep track of best SVM (based on cost) for each class. - if ap_scores.mean() > best_crossval_ap: - best_crossval_ap = ap_scores.mean() - best_crossval_clf = clf - best_cost = cost - - logger.info(f"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}") - # fmt: on - - # ------------------------------------------------------------------------- - # TEST THE TRAINED SVM (PER CLASS) - # ------------------------------------------------------------------------- - predictions = best_crossval_clf.decision_function(feats_test) - evaluate_data_inds = tgts_test != -1 - eval_preds = predictions[evaluate_data_inds] - - cls_labels = np.copy(tgts_test) - eval_cls_labels = cls_labels[evaluate_data_inds] - eval_cls_labels[np.where(eval_cls_labels == 0)] = -1 - - # Binarize class labels to make AP targets. - targets = eval_cls_labels > 0 - return average_precision_score(targets, eval_preds) - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device = torch.device("cpu") - else: - # Get the current device (this will be zero here by default). - device = torch.cuda.current_device() - - # Create a downstream config object (this will be immutable) and perform - # common setup such as logging and setting up serialization directory. - _DOWNC = Config(_A.down_config, _A.down_config_override) - common_setup(_DOWNC, _A, job_type="downstream") - - # Create a (pretraining) config object and backup in serialization directory. - _C = Config(_A.config, _A.config_override) - _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) - - # ------------------------------------------------------------------------- - # INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR - # ------------------------------------------------------------------------- - - train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="trainval") - train_dataloader = DataLoader( - train_dataset, - batch_size=_DOWNC.OPTIM.BATCH_SIZE, - num_workers=_A.cpu_workers, - pin_memory=True, - ) - test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="test") - test_dataloader = DataLoader( - test_dataset, - batch_size=_DOWNC.OPTIM.BATCH_SIZE, - num_workers=_A.cpu_workers, - pin_memory=True, - ) - NUM_CLASSES = len(train_dataset.class_names) - - # Initialize from a checkpoint, but only keep the visual module. - model = PretrainingModelFactory.from_config(_C) - - # Load weights according to the init method, do nothing for `random`, and - # `imagenet` is already taken care of. - if _A.weight_init == "virtex": - ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) - elif _A.weight_init == "torchvision": - # Keep strict=False because this state dict may have weights for - # last fc layer. - model.visual.cnn.load_state_dict( - torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], - strict=False, - ) - # Set ``ITERATION`` to a dummy value. - ITERATION = 0 - - # Transfer model to GPU and set to eval mode. This is a torchvision model - # and it returns features as ``(batch_size, 2048, 7, 7)``. - model = model.visual.cnn.to(device).eval() - - # ------------------------------------------------------------------------- - # EXTRACT FEATURES FOR TRAINING SVMs - # ------------------------------------------------------------------------- - - features_train: List[torch.Tensor] = [] - targets_train: List[torch.Tensor] = [] - - features_test: List[torch.Tensor] = [] - targets_test: List[torch.Tensor] = [] - - # VOC07 is small, extract all features and keep them in memory. - with torch.no_grad(): - for batch in tqdm(train_dataloader, desc="Extracting train features:"): - features = model(batch["image"].to(device)) - - # Global average pool features. Assume the tensor is in NCHW format. - if len(features.size()) > 2: - features = features.view(features.size(0), features.size(1), -1) - - # shape: (batch_size, visual_feature_size) - features = features.mean(dim=-1) - - # shape: (batch_size, visual_feature_size) - features = features.view(features.size(0), -1) - - # L2-normalize the global average pooled features. - features = features / torch.norm(features, dim=-1).unsqueeze(-1) - - features_train.append(features.cpu()) - targets_train.append(batch["label"]) - - # Similarly extract test features. - for batch in tqdm(test_dataloader, desc="Extracting test features:"): - features = model(batch["image"].to(device)) - - if len(features.size()) > 2: - features = features.view(features.size(0), features.size(1), -1) - features = features.mean(dim=-1) - - features = features.view(features.size(0), -1) - features = features / torch.norm(features, dim=-1).unsqueeze(-1) - - features_test.append(features.cpu()) - targets_test.append(batch["label"]) - - # Convert batches of features/targets to one large numpy array - features_train = torch.cat(features_train, dim=0).numpy() - targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32) - - features_test = torch.cat(features_test, dim=0).numpy() - targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32) - - # ------------------------------------------------------------------------- - # TRAIN AND TEST SVMs WITH EXTRACTED FEATURES - # ------------------------------------------------------------------------- - - input_args: List[Any] = [] - - # Iterate over all VOC07 classes and train one-vs-all linear SVMs. - for cls_idx in range(NUM_CLASSES): - # fmt: off - input_args.append(( - features_train, targets_train[:, cls_idx], - features_test, targets_test[:, cls_idx], - train_dataset.class_names[cls_idx], - )) - # fmt: on - - pool = mp.Pool(processes=_A.cpu_workers) - pool_output = pool.map(train_test_single_svm, input_args) - - # ------------------------------------------------------------------------- - # TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint) - # ------------------------------------------------------------------------- - - # Tensorboard writer for logging mAP scores. This is useful especially - # when weight_init=checkpoint (which maybe be coming from a training job). - tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) - - # Test set mAP for each class, for features from every layer. - test_map = torch.tensor(pool_output).mean() - logger.info(f"Iteration: {ITERATION}, mAP: {test_map * 100}") - tensorboard_writer.add_scalars( - "metrics/voc07_clf", {f"voc07_mAP": test_map * 100}, ITERATION - ) - - # NOTE: for copy-pasting to spreadsheet. - logger.info( - f"{_C.DATA.ROOT.split('/')[1]},{_C.DATA.TOKENIZER_MODEL.split('/')[-1][:-6]}," - f"{_C.DATA.VOCAB_SIZE},{_C.MODEL.NAME},{_C.MODEL.VISUAL.NAME},{_C.MODEL.TEXTUAL.NAME}," - f"{_C.MODEL.LABEL_SMOOTHING},{_C.OPTIM.OPTIMIZER_NAME},{_C.OPTIM.BATCH_SIZE}," - f"{_C.OPTIM.NUM_ITERATIONS},{_C.OPTIM.LR},{_C.OPTIM.WEIGHT_DECAY}," - f"{ITERATION},{test_map * 100:.3f}" - ) - -if __name__ == "__main__": - _A = parser.parse_args() - - if _A.num_gpus_per_machine > 1: - raise ValueError("Using multiple GPUs is not supported for this script.") - - # Add an arg in config override if `--weight-init` is imagenet. - if _A.weight_init == "imagenet": - _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) - - # No distributed training here, just a single process. - main(_A) diff --git a/virtex/scripts/eval_captioning.py b/virtex/scripts/eval_captioning.py deleted file mode 100644 index 8da98284f1726027536e38b72b4a82ba04bea396..0000000000000000000000000000000000000000 --- a/virtex/scripts/eval_captioning.py +++ /dev/null @@ -1,114 +0,0 @@ -import argparse -import json -import os -from typing import Any, Dict, List - -from loguru import logger -import torch -from torch.utils.data import DataLoader - -from virtex.config import Config -from virtex.data import ImageDirectoryDataset -from virtex.factories import TokenizerFactory, PretrainingModelFactory -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser -from virtex.utils.metrics import CocoCaptionsEvaluator - - -# fmt: off -parser = common_parser( - description="""Run image captioning inference on a pretrained model, and/or - evaluate pretrained model on COCO Captions val2017 split.""" -) -parser.add_argument( - "--data-root", default=None, - help="""Path to a directory containing image files to generate captions for. - Default: COCO val2017 image directory as expected relative to project root.""" -) -parser.add_argument( - "--checkpoint-path", required=True, - help="Path to load checkpoint and run captioning evaluation." -) -parser.add_argument( - "--output", default=None, - help="Path to save predictions as a JSON file." -) -parser.add_argument( - "--calc-metrics", action="store_true", - help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. - This flag should not be set when running inference on arbitrary images.""" -) -# fmt: on - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device = torch.device("cpu") - else: - # Get the current device (this will be zero here by default). - device = torch.cuda.current_device() - - _C = Config(_A.config, _A.config_override) - - tokenizer = TokenizerFactory.from_config(_C) - - if _A.data_root is None: - _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") - - val_dataloader = DataLoader( - ImageDirectoryDataset(_A.data_root), - batch_size=_C.OPTIM.BATCH_SIZE, - num_workers=_A.cpu_workers, - pin_memory=True, - ) - # Initialize model from a checkpoint. - model = PretrainingModelFactory.from_config(_C).to(device) - ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) - model.eval() - - # Make a list of predictions to evaluate. - predictions: List[Dict[str, Any]] = [] - - for val_iteration, val_batch in enumerate(val_dataloader, start=1): - - val_batch["image"] = val_batch["image"].to(device) - with torch.no_grad(): - output_dict = model(val_batch) - - # Make a dictionary of predictions in COCO format. - for image_id, caption in zip( - val_batch["image_id"], output_dict["predictions"] - ): - predictions.append( - { - # Convert image id to int if possible (mainly for COCO eval). - "image_id": int(image_id) if image_id.isdigit() else image_id, - "caption": tokenizer.decode(caption.tolist()), - } - ) - - # Save predictions as a JSON file if specified. - if _A.output is not None: - os.makedirs(os.path.dirname(_A.output), exist_ok=True) - json.dump(predictions, open(_A.output, "w")) - logger.info(f"Saved predictions to {_A.output}") - - # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This - # should be skipped when running inference on arbitrary images. - if _A.calc_metrics: - # Assume ground truth (COCO val2017 annotations) exist. - gt = os.path.join(_C.DATA.ROOT, "annotations", "captions_val2017.json") - - metrics = CocoCaptionsEvaluator(gt).evaluate(predictions) - logger.info(f"Iter: {ITERATION} | Metrics: {metrics}") - - -if __name__ == "__main__": - _A = parser.parse_args() - if _A.num_gpus_per_machine > 1: - raise ValueError("Using multiple GPUs is not supported for this script.") - - # No distributed training here, just a single process. - main(_A) diff --git a/virtex/scripts/eval_detectron2.py b/virtex/scripts/eval_detectron2.py deleted file mode 100644 index b79147080f8c56313e1a809b9f1a791ecd380e11..0000000000000000000000000000000000000000 --- a/virtex/scripts/eval_detectron2.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Finetune a pre-trained model on a downstream task, one of those available in -Detectron2. -Supported downstream: - - LVIS Instance Segmentation - - COCO Instance Segmentation - - Pascal VOC 2007+12 Object Detection - -Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py -Thanks to the developers of Detectron2! -""" -import argparse -import os -import re -from typing import Any, Dict, Union - -import torch -from torch.utils.tensorboard import SummaryWriter - -import detectron2 as d2 -from detectron2.checkpoint import DetectionCheckpointer -from detectron2.engine import DefaultTrainer, default_setup -from detectron2.evaluation import ( - LVISEvaluator, - PascalVOCDetectionEvaluator, - COCOEvaluator, -) -from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads - -from virtex.config import Config -from virtex.factories import PretrainingModelFactory -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser -import virtex.utils.distributed as dist - -# fmt: off -parser = common_parser( - description="Train object detectors from pretrained visual backbone." -) -parser.add_argument( - "--d2-config", required=True, - help="Path to a detectron2 config for downstream task finetuning." -) -parser.add_argument( - "--d2-config-override", nargs="*", default=[], - help="""Key-value pairs from Detectron2 config to override from file. - Some keys will be ignored because they are set from other args: - [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD, - TEST.EVAL_PERIOD, OUTPUT_DIR]""", -) - -parser.add_argument_group("Checkpointing and Logging") -parser.add_argument( - "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], - default="virtex", help="""How to initialize weights: - 1. 'random' initializes all weights randomly - 2. 'imagenet' initializes backbone weights from torchvision model zoo - 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - - with 'torchvision', state dict would be from PyTorch's training - script. - - with 'virtex' it should be for our full pretrained model.""" -) -parser.add_argument( - "--checkpoint-path", - help="Path to load checkpoint and run downstream task evaluation." -) -parser.add_argument( - "--resume", action="store_true", help="""Specify this flag when resuming - training from a checkpoint saved by Detectron2.""" -) -parser.add_argument( - "--eval-only", action="store_true", - help="Skip training and evaluate checkpoint provided at --checkpoint-path.", -) -parser.add_argument( - "--checkpoint-every", type=int, default=5000, - help="Serialize model to a checkpoint after every these many iterations.", -) -# fmt: on - - -@ROI_HEADS_REGISTRY.register() -class Res5ROIHeadsExtraNorm(Res5ROIHeads): - r""" - ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN - C4/DC5 backbones for VOC detection. - """ - - def _build_res5_block(self, cfg): - seq, out_channels = super()._build_res5_block(cfg) - norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels) - seq.add_module("norm", norm) - return seq, out_channels - - -def build_detectron2_config(_C: Config, _A: argparse.Namespace): - r"""Build detectron2 config based on our pre-training config and args.""" - _D2C = d2.config.get_cfg() - - # Override some default values based on our config file. - _D2C.merge_from_file(_A.d2_config) - _D2C.merge_from_list(_A.d2_config_override) - - # Set some config parameters from args. - _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers - _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every - _D2C.OUTPUT_DIR = _A.serialization_dir - - # Set ResNet depth to override in Detectron2's config. - _D2C.MODEL.RESNETS.DEPTH = int( - re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1) - if "torchvision" in _C.MODEL.VISUAL.NAME - else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1) - if "detectron2" in _C.MODEL.VISUAL.NAME - else 0 - ) - return _D2C - - -class DownstreamTrainer(DefaultTrainer): - r""" - Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks. - - Parameters - ---------- - cfg: detectron2.config.CfgNode - Detectron2 config object containing all config params. - weights: Union[str, Dict[str, Any]] - Weights to load in the initialized model. If ``str``, then we assume path - to a checkpoint, or if a ``dict``, we assume a state dict. This will be - an ``str`` only if we resume training from a Detectron2 checkpoint. - """ - - def __init__(self, cfg, weights: Union[str, Dict[str, Any]]): - - super().__init__(cfg) - - # Load pre-trained weights before wrapping to DDP because `ApexDDP` has - # some weird issue with `DetectionCheckpointer`. - # fmt: off - if isinstance(weights, str): - # weights are ``str`` means ImageNet init or resume training. - self.start_iter = ( - DetectionCheckpointer( - self._trainer.model, - optimizer=self._trainer.optimizer, - scheduler=self.scheduler - ).resume_or_load(weights, resume=True).get("iteration", -1) + 1 - ) - elif isinstance(weights, dict): - # weights are a state dict means our pretrain init. - DetectionCheckpointer(self._trainer.model)._load_model(weights) - # fmt: on - - @classmethod - def build_evaluator(cls, cfg, dataset_name, output_folder=None): - if output_folder is None: - output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") - evaluator_list = [] - evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type - if evaluator_type == "pascal_voc": - return PascalVOCDetectionEvaluator(dataset_name) - elif evaluator_type == "coco": - return COCOEvaluator(dataset_name, cfg, True, output_folder) - elif evaluator_type == "lvis": - return LVISEvaluator(dataset_name, cfg, True, output_folder) - - def test(self, cfg=None, model=None, evaluators=None): - r"""Evaluate the model and log results to stdout and tensorboard.""" - cfg = cfg or self.cfg - model = model or self.model - - tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) - results = super().test(cfg, model) - flat_results = d2.evaluation.testing.flatten_results_dict(results) - for k, v in flat_results.items(): - tensorboard_writer.add_scalar(k, v, self.start_iter) - - -def main(_A: argparse.Namespace): - - # Get the current device as set for current distributed process. - # Check `launch` function in `virtex.utils.distributed` module. - device = torch.cuda.current_device() - - # Local process group is needed for detectron2. - pg = list(range(dist.get_world_size())) - d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg) - - # Create a config object (this will be immutable) and perform common setup - # such as logging and setting up serialization directory. - if _A.weight_init == "imagenet": - _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) - _C = Config(_A.config, _A.config_override) - - # We use `default_setup` from detectron2 to do some common setup, such as - # logging, setting up serialization etc. For more info, look into source. - _D2C = build_detectron2_config(_C, _A) - default_setup(_D2C, _A) - - # Prepare weights to pass in instantiation call of trainer. - if _A.weight_init in {"virtex", "torchvision"}: - if _A.resume: - # If resuming training, let detectron2 load weights by providing path. - model = None - weights = _A.checkpoint_path - else: - # Load backbone weights from VirTex pretrained checkpoint. - model = PretrainingModelFactory.from_config(_C) - if _A.weight_init == "virtex": - CheckpointManager(model=model).load(_A.checkpoint_path) - else: - model.visual.cnn.load_state_dict( - torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], - strict=False, - ) - weights = model.visual.detectron2_backbone_state_dict() - else: - # If random or imagenet init, just load weights after initializing model. - model = PretrainingModelFactory.from_config(_C) - weights = model.visual.detectron2_backbone_state_dict() - - # Back up pretrain config and model checkpoint (if provided). - _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) - if _A.weight_init == "virtex" and not _A.resume: - torch.save( - model.state_dict(), - os.path.join(_A.serialization_dir, "pretrain_model.pth"), - ) - - del model - trainer = DownstreamTrainer(_D2C, weights) - trainer.test() if _A.eval_only else trainer.train() - - -if __name__ == "__main__": - _A = parser.parse_args() - - # This will launch `main` and set appropriate CUDA device (GPU ID) as - # per process (accessed in the beginning of `main`). - dist.launch( - main, - num_machines=_A.num_machines, - num_gpus_per_machine=_A.num_gpus_per_machine, - machine_rank=_A.machine_rank, - dist_url=_A.dist_url, - args=(_A, ), - ) diff --git a/virtex/scripts/preprocess/build_redcaps_vocab.py b/virtex/scripts/preprocess/build_redcaps_vocab.py deleted file mode 100644 index fd28f1b8d72fde036f631032a539c4fe16d169f2..0000000000000000000000000000000000000000 --- a/virtex/scripts/preprocess/build_redcaps_vocab.py +++ /dev/null @@ -1,107 +0,0 @@ -import argparse -import glob -import json -import os -import re -import tempfile -from functools import lru_cache -from typing import List - -import ftfy -import sentencepiece as sp -import wordsegment as ws -from tqdm import tqdm - - -ws.load() - -# fmt: off -parser = argparse.ArgumentParser( - description="""Build a vocabulary out of captions corpus. This vocabulary - would be a file which our tokenizer can understand. - """ -) -parser.add_argument( - "-f", "--files", nargs="+", default="datasets/redcaps/annotations/*.json", - help="Path(s) to SBU, Conceptual, or RedCaps annotation files.", -) -parser.add_argument( - "-s", "--vocab-size", type=int, default=32000, - help="Total desired size of our vocabulary.", -) -parser.add_argument( - "-o", "--output-prefix", default="datasets/vocab/redcaps_32k", - help="Prefix of the files to be saved. Two files will be saved: " - "[prefix].model and [prefix].vocab", -) -# fmt: on - - -def read_captions_from_file(annotations_path: str) -> List[str]: - r""" - Given a path to annotation file, read it and return a list of captions. - - Parameters - ---------- - annotations_path: str - Path to an annotations file containing captions. - - Returns - ------- - List[str] - List of captions from this annotation file. - """ - - _annotations = json.load(open(annotations_path)) - - captions: List[str] = [] - for ann in tqdm(_annotations["annotations"], desc=annotations_path): - - # This field only exists in RedCaps. Perform word segmentation on the - # subreddit name to add appropriae whitespaces. - if "subreddit" in ann: - subreddit_seg = _segment_subreddit(ann["subreddit"].lower()) - caption = f"{subreddit_seg} {ann['caption']}" - else: - caption = ann["caption"] - - captions.append(caption.lower()) - return captions - - -@lru_cache(maxsize=10) -def _segment_subreddit(subreddit): - return " ".join(ws.segment(ws.clean(subreddit))) - - -if __name__ == "__main__": - _A = parser.parse_args() - - all_filepaths: List[str] = [] - for f in _A.files: - all_filepaths.extend(glob.glob(f)) - - captions: List[str] = [] - for path in tqdm(all_filepaths, desc="Reading captions"): - captions.extend(read_captions_from_file(path)) - - # Create a temporary directory and dump the captions corpus as a text file - # with one caption per line. That's how sentencepiece wants its input. - tmpdir_path = tempfile.mkdtemp() - - with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file: - for caption in captions: - captions_file.write(caption + "\n") - - # Padding/out-of-vocab token will be "" and ID 0 by default. - # Add [SOS],[EOS] and [SEP] tokens. [SEP] will not be used during - # captioning, but good to have to reuse vocabulary across pretext tasks. - sp.SentencePieceTrainer.train( - f" --input={os.path.join(tmpdir_path, 'captions.txt')}" - f" --vocab_size={_A.vocab_size}" - f" --model_prefix={_A.output_prefix}" - " --model_type=bpe --character_coverage=1.0" - " --bos_id=-1 --eos_id=-1" - " --control_symbols=[SOS],[EOS],[SEP]" - " --user_defined_symbols=" - ) diff --git a/virtex/scripts/preprocess/build_vocabulary.py b/virtex/scripts/preprocess/build_vocabulary.py deleted file mode 100644 index bc7a592b40d8044919279dc8116ca03dce20b5d1..0000000000000000000000000000000000000000 --- a/virtex/scripts/preprocess/build_vocabulary.py +++ /dev/null @@ -1,100 +0,0 @@ -import argparse -import json -import os -import tempfile -import unicodedata -from typing import List - -import sentencepiece as sp - - -# fmt: off -parser = argparse.ArgumentParser( - description="""Build a vocabulary out of captions corpus. This vocabulary - would be a file which our tokenizer can understand. - """ -) -parser.add_argument( - "-c", "--captions", default="datasets/coco/annotations/captions_train2017.json", - help="Path to caption annotations file in COCO format.", -) -parser.add_argument( - "-s", "--vocab-size", type=int, default=10000, - help="Total desired size of our vocabulary.", -) -parser.add_argument( - "-o", "--output-prefix", default="datasets/vocab/coco_10k", - help="Prefix of the files to be saved. Two files will be saved: " - "[prefix].model and [prefix].vocab", -) -parser.add_argument( - "-l", "--do-lower-case", action="store_true", - help="Whether to lower case the captions before forming vocabulary.", -) -parser.add_argument( - "-a", "--keep-accents", action="store_true", - help="Whether to keep accents before forming vocabulary (dropped by default).", -) -# fmt: on - - -def _read_captions(annotations_path: str) -> List[str]: - r""" - Given a path to annotation file, read it and return a list of captions. - These are not processed by any means, returned from the file as-is. - - Parameters - ---------- - annotations_path: str - Path to an annotations file containing captions. - - Returns - ------- - List[str] - List of captions from this annotation file. - """ - - _annotations = json.load(open(annotations_path)) - - captions: List[str] = [] - for ann in _annotations["annotations"]: - captions.append(ann["caption"]) - - return captions - - -if __name__ == "__main__": - _A = parser.parse_args() - captions: List[str] = _read_captions(_A.captions) - - # Lower case the captions and remove accents according to arguments. - for i, caption in enumerate(captions): - caption = caption.lower() if _A.do_lower_case else caption - - if not _A.keep_accents: - caption = unicodedata.normalize("NFKD", caption) - caption = "".join( - [chr for chr in caption if not unicodedata.combining(chr)] - ) - - captions[i] = caption - - # Create a temporary directory and dump the captions corpus as a text file - # with one caption per line. That's how sentencepiece wants its input. - tmpdir_path = tempfile.mkdtemp() - - with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file: - for caption in captions: - captions_file.write(caption + "\n") - - # Padding/out-of-vocab token will be "" and ID 0 by default. - # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during - # captioning, but good to have to reuse vocabulary across pretext tasks. - sp.SentencePieceTrainer.train( - f" --input={os.path.join(tmpdir_path, 'captions.txt')}" - f" --vocab_size={_A.vocab_size}" - f" --model_prefix={_A.output_prefix}" - " --model_type=bpe --character_coverage=1.0" - " --bos_id=-1 --eos_id=-1" - " --control_symbols=[SOS],[EOS],[MASK]" - ) diff --git a/virtex/scripts/preprocess/preprocess_coco.py b/virtex/scripts/preprocess/preprocess_coco.py deleted file mode 100644 index abf768d94d494a2e8397b596ca7993a638d2d840..0000000000000000000000000000000000000000 --- a/virtex/scripts/preprocess/preprocess_coco.py +++ /dev/null @@ -1,100 +0,0 @@ -import argparse -import os -import pickle -import platform -from typing import Any, List - -import albumentations as alb -import lmdb -from tqdm import tqdm -from torch.utils.data import DataLoader - -from virtex.data.readers import SimpleCocoCaptionsReader - - -# fmt: off -parser = argparse.ArgumentParser("Serialize a COCO Captions split to LMDB.") -parser.add_argument( - "-d", "--data-root", default="datasets/coco", - help="Path to the root directory of COCO dataset.", -) -parser.add_argument( - "-s", "--split", choices=["train", "val"], - help="Which split to process, either `train` or `val`.", -) -parser.add_argument( - "-b", "--batch-size", type=int, default=16, - help="Batch size to process and serialize data. Set as per CPU memory.", -) -parser.add_argument( - "-j", "--cpu-workers", type=int, default=4, - help="Number of CPU workers for data loading.", -) -parser.add_argument( - "-e", "--short-edge-size", type=int, default=None, - help="""Resize shorter edge to this size (keeping aspect ratio constant) - before serializing. Useful for saving disk memory, and faster read. - If None, no images are resized.""" -) -parser.add_argument( - "-o", "--output", default="datasets/serialized/coco_train2017.lmdb", - help="Path to store the file containing serialized dataset.", -) - - -def collate_fn(instances: List[Any]): - r"""Collate function for data loader to return list of instances as-is.""" - return instances - - -if __name__ == "__main__": - - _A = parser.parse_args() - os.makedirs(os.path.dirname(_A.output), exist_ok=True) - - dloader = DataLoader( - SimpleCocoCaptionsReader(_A.data_root, _A.split), - batch_size=_A.batch_size, - num_workers=_A.cpu_workers, - shuffle=False, - drop_last=False, - collate_fn=collate_fn - ) - # Open an LMDB database. - # Set a sufficiently large map size for LMDB (based on platform). - map_size = 1099511627776 * 2 if platform.system() == "Linux" else 1280000 - db = lmdb.open( - _A.output, map_size=map_size, subdir=False, meminit=False, map_async=True - ) - - # Transform to resize shortest edge and keep aspect ratio same. - if _A.short_edge_size is not None: - resize = alb.SmallestMaxSize(max_size=_A.short_edge_size, always_apply=True) - - # Serialize each instance (as a dictionary). Use `pickle.dumps`. Key will - # be an integer (cast as string) starting from `0`. - INSTANCE_COUNTER: int = 0 - - for idx, batch in enumerate(tqdm(dloader)): - - txn = db.begin(write=True) - - for instance in batch: - image = instance["image"] - width, height, channels = image.shape - - # Resize image from instance and convert instance to tuple. - if _A.short_edge_size is not None and min(width, height) > _A.short_edge_size: - image = resize(image=image)["image"] - - instance = (instance["image_id"], instance["image"], instance["captions"]) - txn.put( - f"{INSTANCE_COUNTER}".encode("ascii"), - pickle.dumps(instance, protocol=-1) - ) - INSTANCE_COUNTER += 1 - - txn.commit() - - db.sync() - db.close() diff --git a/virtex/scripts/preprocess/preprocess_redcaps.py b/virtex/scripts/preprocess/preprocess_redcaps.py deleted file mode 100644 index abb4e9e71e272797e2caf3eed304bc4cdc98f85e..0000000000000000000000000000000000000000 --- a/virtex/scripts/preprocess/preprocess_redcaps.py +++ /dev/null @@ -1,102 +0,0 @@ -import argparse -import json -import os -import tarfile -import tempfile -from typing import Dict, List - -from loguru import logger -from tqdm import tqdm - - -# fmt: off -parser = argparse.ArgumentParser( - description="""Pre-process RedCaps dataset for training VirTex models - make - small shards of TAR files containing images and captions.""" -) -parser.add_argument( - "-a", "--annotations", required=True, help="Path to a RedCaps annotation file." -) -parser.add_argument( - "-i", "--images", default="datasets/redcaps/images", - help="""Path to RedCaps image directory. This directory is expected to have - subreddit specific sub-directories containing images.""", -) -parser.add_argument( - "-z", "--shard-size", type=int, default=1000, - help="Maximum number of RedCaps instances in a single TAR file shard.", -) -parser.add_argument( - "-o", "--output-prefix", required=True, - help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` " - "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...", -) -# fmt: on - - -def main(_A: argparse.Namespace): - r""" - Make TAR files containing images and annotations from a single RedCaps - annotations file. These TAR files are arranged in a way that - `WebDataset `_ can understand. - """ - - ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"] - - # Keep track of the current index of TAR file shard and dataset index. - SHARD_INDEX: int = 0 - DATASET_INDEX: int = 0 - - # Create TAR file handle for the initial shard. - tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w") - - # Keep a count of submissions that were skipped because their image was - # not downloaded (not present in image dir). - SKIPPED: int = 0 - - for ann in tqdm(ANNOTATIONS): - - image_path = os.path.join( - _A.images, ann["subreddit"], f"{ann['image_id']}.jpg" - ) - # Add current image in shard if it exists. - if os.path.exists(image_path): - - tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg") - - # Save subreddit name and caption as a JSON file. - subreddit_and_caption = { - "subreddit": ann["subreddit"], "caption": ann["caption"] - } - tmpfile = tempfile.NamedTemporaryFile("w+") - tmpfile.write(json.dumps(subreddit_and_caption)) - tmpfile.seek(0) - tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json") - tmpfile.close() - - DATASET_INDEX += 1 - - # Create new shard if current shard is full. - if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0: - tar_handle.close() - logger.success( - f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar" - ) - SHARD_INDEX += 1 - - # Open new TAR file shard. - tar_handle = tarfile.open( - f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w" - ) - else: - SKIPPED += 1 - - # Close the file handle to properly save it. - tar_handle.close() - logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n") - logger.info(f"Skipped {SKIPPED} annotations due to missing images.") - - -if __name__ == "__main__": - _A = parser.parse_args() - main(_A) diff --git a/virtex/scripts/pretrain_virtex.py b/virtex/scripts/pretrain_virtex.py deleted file mode 100644 index 73e36ed3428c6899876fca9961dbeb81dcb2bd0c..0000000000000000000000000000000000000000 --- a/virtex/scripts/pretrain_virtex.py +++ /dev/null @@ -1,239 +0,0 @@ -import argparse -from collections import Counter -from typing import Any - -from loguru import logger -import torch -from torch import nn -from torch.cuda import amp -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter - -# fmt: off -from virtex.config import Config -from virtex.factories import ( - PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, - LRSchedulerFactory, -) -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser, common_setup, cycle -import virtex.utils.distributed as dist -from virtex.utils.timer import Timer - - -parser = common_parser( - description="Train a VirTex model (CNN + Transformer) on COCO Captions." -) -group = parser.add_argument_group("Checkpointing and Logging") -group.add_argument( - "--resume-from", default=None, - help="Path to a checkpoint to resume training from (if provided)." -) -group.add_argument( - "--checkpoint-every", type=int, default=2000, - help="Serialize model to a checkpoint after every these many iterations.", -) -group.add_argument( - "--log-every", type=int, default=20, - help="""Log training curves to tensorboard after every these many iterations - only master process logs averaged loss values across processes.""", -) -# fmt: on - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device: Any = torch.device("cpu") - else: - # Get the current device as set for current distributed process. - # Check `launch` function in `virtex.utils.distributed` module. - device = torch.cuda.current_device() - - # Create a config object (this will be immutable) and perform common setup - # such as logging and setting up serialization directory. - _C = Config(_A.config, _A.config_override) - common_setup(_C, _A) - - # ------------------------------------------------------------------------- - # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER - # ------------------------------------------------------------------------- - train_dataset = PretrainingDatasetFactory.from_config(_C, split="train") - val_dataset = PretrainingDatasetFactory.from_config(_C, split="val") - - # Make `DistributedSampler`s to shard datasets across GPU processes. - # Skip this if training on CPUs. - train_sampler = ( - DistributedSampler(train_dataset, shuffle=True) # type: ignore - if _A.num_gpus_per_machine > 0 - else None - ) - val_sampler = ( - DistributedSampler(val_dataset, shuffle=False) # type: ignore - if _A.num_gpus_per_machine > 0 - else None - ) - train_dataloader = DataLoader( - train_dataset, - batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), - sampler=train_sampler, - shuffle=train_sampler is None, - num_workers=_A.cpu_workers, - pin_memory=True, - drop_last=True, - collate_fn=train_dataset.collate_fn, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), - sampler=val_sampler, - shuffle=False, - num_workers=_A.cpu_workers, - pin_memory=True, - drop_last=False, - collate_fn=val_dataset.collate_fn, - ) - - model = PretrainingModelFactory.from_config(_C).to(device) - optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) - scheduler = LRSchedulerFactory.from_config(_C, optimizer) - - # ------------------------------------------------------------------------- - # BEFORE TRAINING STARTS - # ------------------------------------------------------------------------- - - # Create a gradient scaler for automatic mixed precision. - scaler = amp.GradScaler(enabled=_C.AMP) - - # Load checkpoint to resume training if specified. - if _A.resume_from is not None: - start_iteration = CheckpointManager( - model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler, - ).load(_A.resume_from) - else: - start_iteration = 0 - - # Create an iterator from dataloader to sample batches perpetually. - train_dataloader_iter = cycle(train_dataloader, device, start_iteration) - - # Wrap model in DDP if using more than one processes. - if dist.get_world_size() > 1: - dist.synchronize() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) - - # Keep track of time per iteration and ETA. - timer = Timer( - start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS - ) - # Create tensorboard writer and checkpoint manager (only in master process). - if dist.is_master_process(): - tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) - tensorboard_writer.add_text("config", f"```\n{_C}\n```") - - checkpoint_manager = CheckpointManager( - _A.serialization_dir, - model=model, - optimizer=optimizer, - scheduler=scheduler, - scaler=scaler, - ) - - # ------------------------------------------------------------------------- - # TRAINING LOOP - # ------------------------------------------------------------------------- - for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): - timer.tic() - optimizer.zero_grad() - batch = next(train_dataloader_iter) - - with amp.autocast(enabled=_C.AMP): - output_dict = model(batch) - loss = output_dict["loss"] - - scaler.scale(loss).backward() - - # First clip norm of gradients, and then perform optimizer step. - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) - scaler.step(optimizer) - - scaler.update() - scheduler.step() - timer.toc() - - # --------------------------------------------------------------------- - # LOGGING - # --------------------------------------------------------------------- - if iteration % _A.log_every == 0: - logger.info( - f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" - ) - if dist.is_master_process(): - tensorboard_writer.add_scalars( - "learning_rate", - { - "visual": optimizer.param_groups[0]["lr"], - "common": optimizer.param_groups[-1]["lr"], - }, - iteration, - ) - tensorboard_writer.add_scalars( - "train", output_dict["loss_components"], iteration - ) - - # --------------------------------------------------------------------- - # VALIDATION - # --------------------------------------------------------------------- - if iteration % _A.checkpoint_every == 0: - if dist.is_master_process(): - checkpoint_manager.step(iteration) - - # All processes will wait till master process is done serializing. - dist.synchronize() - - torch.set_grad_enabled(False) - model.eval() - - # Accumulate different val loss components according to the type of - # pretraining model. - val_loss_counter: Counter = Counter() - - for val_iteration, val_batch in enumerate(val_dataloader, start=1): - for key in val_batch: - val_batch[key] = val_batch[key].to(device) - output_dict = model(val_batch) - - val_loss_counter.update(output_dict["loss_components"]) - - # Divide each loss component by number of val batches per GPU. - val_loss_dict = { - k: v / val_iteration for k, v in dict(val_loss_counter).items() - } - dist.average_across_processes(val_loss_dict) - torch.set_grad_enabled(True) - model.train() - - logger.info(f"Iteration: {iteration} [Val loss: {val_loss_dict}]") - if dist.is_master_process(): - tensorboard_writer.add_scalars("val", val_loss_dict, iteration) - - -if __name__ == "__main__": - _A = parser.parse_args() - - if _A.num_gpus_per_machine == 0: - main(_A) - else: - # This will launch `main` and set appropriate CUDA device (GPU ID) as - # per process (accessed in the beginning of `main`). - dist.launch( - main, - num_machines=_A.num_machines, - num_gpus_per_machine=_A.num_gpus_per_machine, - machine_rank=_A.machine_rank, - dist_url=_A.dist_url, - args=(_A, ), - ) diff --git a/virtex/scripts/redcaps_caption_decode.py b/virtex/scripts/redcaps_caption_decode.py deleted file mode 100644 index d63b69ac6a13dc235a3ba4980dba582a9cd75be6..0000000000000000000000000000000000000000 --- a/virtex/scripts/redcaps_caption_decode.py +++ /dev/null @@ -1,140 +0,0 @@ -import argparse -import json -import os -from typing import Any, Dict, List - -from loguru import logger -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm - -import wordsegment as ws - -from virtex.config import Config -from virtex.data import ImageDirectoryDataset -from virtex.factories import TokenizerFactory, PretrainingModelFactory -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser - -ws.load() - -# fmt: off -parser = common_parser( - description="Decode captions using a RedCaps-pretrained VirTex model." -) -parser.add_argument( - "--images", required=True, - help="Path to a directory containing image files to generate captions for." -) -parser.add_argument( - "--checkpoint-path", required=True, - help="Path to load checkpoint and run captioning evaluation." -) -parser.add_argument( - "--output", required=True, - help="Path to save predictions as a JSON file." -) -parser.add_argument( - "--subreddit-prompt", default=None, - help="Optional subreddit prompt for controllable subreddit-style captioning." -) -# fmt: on - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device = torch.device("cpu") - else: - # Get the current device (this will be zero here by default). - device = torch.cuda.current_device() - - _C = Config(_A.config, _A.config_override) - - tokenizer = TokenizerFactory.from_config(_C) - - val_dataloader = DataLoader( - ImageDirectoryDataset(_A.images), - batch_size=_C.OPTIM.BATCH_SIZE, - num_workers=_A.cpu_workers, - pin_memory=True, - ) - # Initialize model from a checkpoint. - model = PretrainingModelFactory.from_config(_C).to(device) - CheckpointManager(model=model).load(_A.checkpoint_path) - model.eval() - - # Prepare subreddit prompt for the model if provided. - if _A.subreddit_prompt is not None: - - # Remove "r/" if provided. - _A.subreddit_prompt = _A.subreddit_prompt.replace("r/", "") - - # Word segmenting (e.g. "itookapicture" -> "i took a picture"). - _segments = " ".join(ws.segment(ws.clean(_A.subreddit_prompt))) - subreddit_tokens = ( - [model.sos_index] - + tokenizer.encode(_segments) - + [tokenizer.token_to_id("[SEP]")] - ) - else: - # Just seed the model with [SOS] - subreddit_tokens = [model.sos_index] - - # Shift the subreddit prompt to appropriate device. - subreddit_tokens = torch.tensor(subreddit_tokens, device=device).long() - - # Make a list of predictions to evaluate. - predictions: List[Dict[str, Any]] = [] - - for val_batch in tqdm(val_dataloader): - val_batch["image"] = val_batch["image"].to(device) - - # Add the subreddit tokens as decoding prompt to batch. - val_batch["decode_prompt"] = subreddit_tokens - - with torch.no_grad(): - output_dict = model(val_batch) - - for idx, (image_id, caption) in enumerate( - zip(val_batch["image_id"], output_dict["predictions"]) - ): - caption = caption.tolist() - - # Replace [SOS] index with "::" temporarily so it gets decoded. - if tokenizer.token_to_id("[SEP]") in caption: - sos_index = caption.index(tokenizer.token_to_id("[SEP]")) - caption[sos_index] = tokenizer.token_to_id("::") - - caption = tokenizer.decode(caption) - - # Separate out subreddit from the rest of caption. - if "::" in caption: - subreddit, rest_of_caption = caption.split("::") - subreddit = "".join(subreddit.split()) - rest_of_caption = rest_of_caption.strip() - else: - subreddit, rest_of_caption = "", caption - - predictions.append( - {"image_id": image_id, "subreddit": subreddit, "caption": rest_of_caption} - ) - - logger.info("Displaying first 25 caption predictions:") - for pred in predictions[:25]: - logger.info(f"{pred['image_id']} - r/{pred['subreddit']}:: {pred['caption']}") - - # Save predictions as a JSON file. - os.makedirs(os.path.dirname(_A.output), exist_ok=True) - json.dump(predictions, open(_A.output, "w")) - logger.info(f"Saved predictions to {_A.output}") - - -if __name__ == "__main__": - _A = parser.parse_args() - if _A.num_gpus_per_machine > 1: - raise ValueError("Using multiple GPUs is not supported for this script.") - - # No distributed training here, just a single process. - main(_A) diff --git a/virtex/scripts/redcaps_train.py b/virtex/scripts/redcaps_train.py deleted file mode 100644 index b8c63010361c80ffcdca873a8166448aa2f359ef..0000000000000000000000000000000000000000 --- a/virtex/scripts/redcaps_train.py +++ /dev/null @@ -1,172 +0,0 @@ -import argparse -import os -import tempfile -from typing import Any - -from loguru import logger -import torch -from torch import nn -from torch.cuda import amp -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -# fmt: off -from virtex.config import Config -from virtex.factories import ( - PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, - LRSchedulerFactory, -) -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser, common_setup, cycle -import virtex.utils.distributed as dist -from virtex.utils.timer import Timer - - -parser = common_parser( - description="Train a VirTex model (CNN + Transformer) on COCO Captions." -) -group = parser.add_argument_group("Checkpointing and Logging") -group.add_argument( - "--resume-from", default=None, - help="Path to a checkpoint to resume training from (if provided)." -) -group.add_argument( - "--checkpoint-every", type=int, default=2000, - help="Serialize model to a checkpoint after every these many iterations.", -) -group.add_argument( - "--log-every", type=int, default=50, - help="""Log training curves to tensorboard after every these many iterations - only master process logs averaged loss values across processes.""", -) -# fmt: on - - -def main(_A: argparse.Namespace): - - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device: Any = torch.device("cpu") - else: - # Get the current device as set for current distributed process. - # Check `launch` function in `virtex.utils.distributed` module. - device = torch.cuda.current_device() - - # Create a config object (this will be immutable) and perform common setup - # such as logging and setting up serialization directory. - _C = Config(_A.config, _A.config_override) - common_setup(_C, _A) - - # ------------------------------------------------------------------------- - # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER - # ------------------------------------------------------------------------- - - # fmt: off - train_dataset = PretrainingDatasetFactory.from_config(_C) - train_dataloader = DataLoader( - train_dataset, batch_size=None, shuffle=False, - num_workers=_A.cpu_workers, pin_memory=True, - ) - # fmt: on - - model = PretrainingModelFactory.from_config(_C).to(device) - optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) - scheduler = LRSchedulerFactory.from_config(_C, optimizer) - - # ------------------------------------------------------------------------- - # BEFORE TRAINING STARTS - # ------------------------------------------------------------------------- - - # Create a gradient scaler for automatic mixed precision. - scaler = amp.GradScaler(enabled=_C.AMP) - - # Load checkpoint to resume training if specified. - if _A.resume_from is not None: - start_iteration = CheckpointManager( - model=model, optimizer=optimizer, scheduler=scheduler, - ).load(_A.resume_from) - else: - start_iteration = 0 - - # Create an iterator from dataloader to sample batches perpetually. - train_dataloader_iter = cycle(train_dataloader, device, start_iteration) - - # Wrap model in DDP if using more than one processes. - if dist.get_world_size() > 1: - dist.synchronize() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) - - # Keep track of time per iteration and ETA. - timer = Timer( - start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS - ) - # Create tensorboard writer and checkpoint manager (only in master process). - if dist.is_master_process(): - tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) - tensorboard_writer.add_text("config", f"```\n{_C}\n```") - - checkpoint_manager = CheckpointManager( - _A.serialization_dir, - model=model, - optimizer=optimizer, - scheduler=scheduler, - scaler=scaler, - ) - - # ------------------------------------------------------------------------- - # TRAINING LOOP - # ------------------------------------------------------------------------- - for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): - timer.tic() - optimizer.zero_grad() - batch = next(train_dataloader_iter) - - with amp.autocast(enabled=_C.AMP): - output_dict = model(batch) - loss = output_dict["loss"] - - scaler.scale(loss).backward() - - # First clip norm of gradients, and then perform optimizer step. - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) - scaler.step(optimizer) - - scaler.update() - scheduler.step() - timer.toc() - - # --------------------------------------------------------------------- - # LOGGING - # --------------------------------------------------------------------- - if iteration % _A.log_every == 0: - logger.info( - f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" - ) - if dist.is_master_process(): - tensorboard_writer.add_scalars( - "train", output_dict["loss_components"], iteration - ) - - if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): - checkpoint_manager.step(iteration) - - -if __name__ == "__main__": - _A = parser.parse_args() - - if _A.num_gpus_per_machine == 0: - main(_A) - else: - # This will launch `main` and set appropriate CUDA device (GPU ID) as - # per process (accessed in the beginning of `main`). - dist.launch( - main, - num_machines=_A.num_machines, - num_gpus_per_machine=_A.num_gpus_per_machine, - machine_rank=_A.machine_rank, - dist_url=_A.dist_url, - args=(_A, ), - ) diff --git a/virtex/scripts/zero_shot_classification.py b/virtex/scripts/zero_shot_classification.py deleted file mode 100644 index fc8523f6012aab90a86f77c1ba235ad740848fe2..0000000000000000000000000000000000000000 --- a/virtex/scripts/zero_shot_classification.py +++ /dev/null @@ -1,171 +0,0 @@ -import argparse -import json -import os -import random -from typing import Any, Dict, List - -from loguru import logger -import torch -from torch.utils.data import DataLoader, DistributedSampler -from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm - -import wordsegment as ws - -from virtex.config import Config -from virtex.data import ZeroShotDataset - -from virtex.data.tokenizers import SentencePieceBPETokenizer - -from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory -from virtex.utils.checkpointing import CheckpointManager -from virtex.utils.common import common_parser -from virtex.utils.metrics import TopkAccuracy -import virtex.utils.distributed as dist - - -#importing classifier -from virtex.models.zero_shot_classification_eval import ZeroShotClassifier - -ws.load() - -# fmt: off -parser = common_parser( - description="""Run image captioning inference on a pretrained model, and/or - evaluate pretrained model on COCO Captions val2017 split.""" -) -parser.add_argument( - "--data-root", default=None, - help="""Path to a directory containing image files to generate captions for imagenet. - Default: COCO val2017 image directory as expected relative to project root.""" -) -parser.add_argument( - "--checkpoint-path", required=False, - help="Path to load checkpoint and run captioning evaluation." -) -parser.add_argument( - "--output", default=None, - help="Path to save predictions as a JSON file." -) -parser.add_argument( - "--calc-metrics", action="store_true", - help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. - This flag should not be set when running inference on arbitrary images.""" -) - -parser.add_argument( - "--idx_label_dict", default=None, required=False, - help="""a dictionary that maps from lable index to label string for classification""" -) -parser.add_argument( - "--is_redcaps", default=None, required=False, - help="""a dictionary that maps from lable index to label string for""" -) -parser.add_argument( - "--prompt_cls_sos", default=None, required=False, - help="""a dictionary that maps from lable index to label string for""" -) -parser.add_argument( - "--prompt_sos_eos", default=None, required=False, - help="""a dictionary that maps from lable index to label string for""" -) -# fmt: on - -print("###########") -print(os.getcwd() ) -print("###########") - -tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model") - -def main(_A: argparse.Namespace): - if _A.num_gpus_per_machine == 0: - # Set device as CPU if num_gpus_per_machine = 0. - device = torch.device("cpu") - else: - # Get the current device (this will be zero here by default). - device = torch.cuda.current_device() - - _C = Config(_A.config, _A.config_override) - - #tokenizer = TokenizerFactory.from_config(_C) - - if _A.data_root is None: - _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") - - if _A.is_redcaps == 1: - model_dataset = 'redcaps' - else: - model_dataset = 'gcc or sbu' - - print(_A.idx_label_dict) - - val_dataset = ZeroShotDataset(data_root=_A.data_root, - split="test/", - label_map=_A.idx_label_dict, - tokenizer=tokenizer, - prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "), - prompt_sos_eos=_A.prompt_sos_eos.replace("_", " ")) - - val_dataloader = DataLoader( - val_dataset, - batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(), - num_workers=_A.cpu_workers, - sampler=DistributedSampler( - val_dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - ), - pin_memory=True, - drop_last=False, - collate_fn=val_dataset.collate_fn, - ) - - # Initialize model from a checkpoint - visual = VisualBackboneFactory.from_config(_C) - textual = TextualHeadFactory.from_config(_C) - model = ZeroShotClassifier(visual,textual) - ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) - model.to(device).eval() - - ## setup distributed training - if dist.get_world_size() > 1: - dist.synchronize() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) - - top_1 = TopkAccuracy(top_k=1) - top_5 = TopkAccuracy(top_k=5) - batch_num = 0 - - - for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)): - val_batch["image"] = val_batch["image"].to(device) - val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device) - val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device) - val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device) - val_batch["label"] = val_batch["label"].to(device) - - with torch.no_grad(): - classification_losses = model(val_batch) - - batch_num+=1 - top_1(classification_losses, val_batch["label"]) - top_1_acc = top_1.get_metric(reset=False) - dist.average_across_processes(top_1_acc) - - top_5(classification_losses, val_batch["label"]) - top_5_acc = top_5.get_metric(reset=False) - dist.average_across_processes(top_5_acc) - - logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}") - - - -if __name__ == "__main__": - _A = parser.parse_args() - #if _A.num_gpus_per_machine > 1: - # raise ValueError("Using multiple GPUs is not supported for this script.") - - # No distributed training here, just a single process. - main(_A) diff --git a/virtex/setup.py b/virtex/setup.py deleted file mode 100644 index fc715695a0b1e6eb83a52205c9fec3224131bb21..0000000000000000000000000000000000000000 --- a/virtex/setup.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -import glob -import os -from setuptools import setup -import shutil -from typing import List - - -def get_model_zoo_configs() -> List[str]: - """ - Return a list of configs to include in package for model zoo. Copy over - these configs inside virtex/model_zoo. - """ - - # Use absolute paths while symlinking. - source_configs_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs" - ) - destination = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs" - ) - # Symlink the config directory inside package to have a cleaner pip install. - - # Remove stale symlink/directory from a previous build. - if os.path.exists(source_configs_dir): - if os.path.islink(destination): - os.unlink(destination) - elif os.path.isdir(destination): - shutil.rmtree(destination) - - if not os.path.exists(destination): - try: - os.symlink(source_configs_dir, destination) - except OSError: - # Fall back to copying if symlink fails: ex. on Windows. - shutil.copytree(source_configs_dir, destination) - - config_paths = glob.glob("configs/**/*.yaml", recursive=True) - return config_paths - - -setup( - name="virtex", - version="1.1.0", - author="Karan Desai and Justin Johnson", - description="VirTex: Learning Visual Representations with Textual Annotations", - package_data={"virtex.model_zoo": get_model_zoo_configs()}, - python_requires=">=3.6", - license="MIT", - zip_safe=True, -) diff --git a/virtex/virtex/utils/assets/download_spice.sh b/virtex/utils/assets/download_spice.sh similarity index 100% rename from virtex/virtex/utils/assets/download_spice.sh rename to virtex/utils/assets/download_spice.sh diff --git a/virtex/virtex/utils/beam_search.py b/virtex/utils/beam_search.py similarity index 100% rename from virtex/virtex/utils/beam_search.py rename to virtex/utils/beam_search.py diff --git a/virtex/virtex/utils/checkpointing.py b/virtex/utils/checkpointing.py similarity index 100% rename from virtex/virtex/utils/checkpointing.py rename to virtex/utils/checkpointing.py diff --git a/virtex/virtex/utils/common.py b/virtex/utils/common.py similarity index 100% rename from virtex/virtex/utils/common.py rename to virtex/utils/common.py diff --git a/virtex/virtex/utils/distributed.py b/virtex/utils/distributed.py similarity index 100% rename from virtex/virtex/utils/distributed.py rename to virtex/utils/distributed.py diff --git a/virtex/virtex/utils/metrics.py b/virtex/utils/metrics.py similarity index 100% rename from virtex/virtex/utils/metrics.py rename to virtex/utils/metrics.py diff --git a/virtex/virtex/utils/nucleus_sampling.py b/virtex/utils/nucleus_sampling.py similarity index 100% rename from virtex/virtex/utils/nucleus_sampling.py rename to virtex/utils/nucleus_sampling.py diff --git a/virtex/virtex/utils/timer.py b/virtex/utils/timer.py similarity index 100% rename from virtex/virtex/utils/timer.py rename to virtex/utils/timer.py