diff --git a/app.py b/app.py
index dc042299f4e404ef8a0d146e49de3245072fc22a..e330aebf4955741cf4823806b8e34ccd9184aacf 100644
--- a/app.py
+++ b/app.py
@@ -1,18 +1,18 @@
-import streamlit as st
import io
-import sys
-import time
-import json
-sys.path.append("./virtex/")
+
+import streamlit as st
from model import *
# # TODO:
# - Reformat the model introduction
# - Make the iterative text generation
-def gen_show_caption(sub_prompt=None, cap_prompt = ""):
+
+def gen_show_caption(sub_prompt=None, cap_prompt=""):
with st.spinner("Generating Caption"):
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt)
+ subreddit, caption = virtexModel.predict(
+ image_dict, sub_prompt=sub_prompt, prompt=cap_prompt
+ )
st.markdown(
f"""
### r/{subreddit} {cap_prompt} {caption}
- """,
- unsafe_allow_html=True)
-
-_, center, _ = st.columns([1,8,1])
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+_, center, _ = st.columns([1, 8, 1])
with center:
st.title("Image Captioning Demo from RedCaps")
@@ -50,7 +52,7 @@ st.sidebar.markdown(
with st.spinner("Loading Model"):
virtexModel, imageLoader, sample_images, valid_subs = create_objects()
-
+
select_idx = None
@@ -66,9 +68,9 @@ uploaded_image = None
# with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.sidebar.file_uploader("Choose a file")
# submitted = st.form_submit_button("Submit")
-if uploaded_file is not None:# and submitted:
+if uploaded_file is not None: # and submitted:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
- select_idx = None # set this to help rewrite the cache
+ select_idx = None # set this to help rewrite the cache
# class OnChange():
# def __init__(self, idx):
@@ -88,21 +90,26 @@ if uploaded_file is not None:# and submitted:
st.sidebar.title("Select a Subreddit")
sub = st.sidebar.selectbox(
"Type below to condition on a subreddit. Select None for a predicted subreddit",
- valid_subs
+ valid_subs,
)
st.sidebar.title("Write a Custom Prompt")
-cap_prompt = st.sidebar.text_input(
- "Write the start of your caption below",
- value=""
-)
+cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="")
_ = st.sidebar.button("Regenerate Caption")
st.sidebar.write("Advanced Options:")
-num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
-nuc_size = st.sidebar.slider("Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
+num_captions = st.sidebar.select_slider(
+ "Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1
+)
+nuc_size = st.sidebar.slider(
+ "Nucelus Size:\nLarger values lead to more diverse captions",
+ min_value=0.0,
+ max_value=1.0,
+ value=0.8,
+ step=0.05,
+)
virtexModel.model.decoder.nucleus_size = nuc_size
image_file = sample_image
@@ -110,14 +117,14 @@ image_file = sample_image
# LOAD AND CACHE THE IMAGE
if uploaded_image is not None:
image = uploaded_image
-elif select_idx is None and 'image' in st.session_state:
- image = st.session_state['image']
+elif select_idx is None and "image" in st.session_state:
+ image = st.session_state["image"]
else:
image = Image.open(image_file)
image = image.convert("RGB")
-st.session_state['image'] = image
+st.session_state["image"] = image
image_dict = imageLoader.transform(image)
@@ -141,4 +148,4 @@ This demo accompanies our paper RedCaps.
Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
"""
-)
\ No newline at end of file
+)
diff --git a/model.py b/model.py
index f071d7e96a31ac14fb8d84ad4d28da2eaab9c430..a615e1fb9a09443b2040bc3ab1a9e12c60e08f63 100644
--- a/model.py
+++ b/model.py
@@ -1,18 +1,17 @@
-import streamlit as st
-from huggingface_hub import hf_hub_url, cached_download
-from PIL import Image
import os
import json
import glob
import random
-from typing import Any, Dict, List
import torch
import torchvision
+import streamlit as st
import wordsegment as ws
+from PIL import Image
+from huggingface_hub import hf_hub_url, cached_download
from virtex.config import Config
-from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
+from virtex.factories import TokenizerFactory, PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
@@ -20,98 +19,108 @@ MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
-class ImageLoader():
+
+class ImageLoader:
def __init__(self):
- self.image_transform = torchvision.transforms.Compose([
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Resize(256),
- torchvision.transforms.CenterCrop(224),
- torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
- self.show_size=500
-
+ self.image_transform = torchvision.transforms.Compose(
+ [
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(256),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
+ ),
+ ]
+ )
+ self.show_size = 500
+
def load(self, im_path):
im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
return {"image": im}
-
+
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
-
+
def transform(self, image):
im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
return {"image": im}
-
+
def text_transform(self, text):
# at present just lowercasing:
return text.lower()
-
+
def show_resize(self, image):
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
image = torchvision.transforms.functional.to_tensor(image)
- x,y = image.shape[-2:]
- ratio = float(self.show_size/max((x,y)))
- image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
+ x, y = image.shape[-2:]
+ ratio = float(self.show_size / max((x, y)))
+ image = torchvision.transforms.functional.resize(
+ image, [int(x * ratio), int(y * ratio)]
+ )
return torchvision.transforms.functional.to_pil_image(image)
-
-class VirTexModel():
+
+class VirTexModel:
+
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
- self.device = 'cpu'
+ self.device = "cpu"
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load(MODEL_PATH)
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
-
- def predict(self, image_dict, sub_prompt = None, prompt = ""):
+
+ def predict(self, image_dict, sub_prompt=None, prompt=""):
if sub_prompt is None:
- subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
+ subreddit_tokens = torch.tensor(
+ [self.model.sos_index], device=self.device
+ ).long()
else:
subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
subreddit_tokens = (
- [self.model.sos_index] +
- self.tokenizer.encode(subreddit_tokens) +
- [self.tokenizer.token_to_id("[SEP]")]
- )
+ [self.model.sos_index]
+ + self.tokenizer.encode(subreddit_tokens)
+ + [self.tokenizer.token_to_id("[SEP]")]
+ )
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
-
+
if prompt is not "":
# at present prompts without subreddits will break without this change
# TODO FIX
cap_tokens = self.tokenizer.encode(prompt)
cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
- subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor(
- (
- [self.model.sos_index] +
- self.tokenizer.encode("pics") +
- [self.tokenizer.token_to_id("[SEP]")]
- ), device = self.device).long()
-
- subreddit_tokens = torch.cat(
- [
- subreddit_tokens,
- cap_tokens
- ])
-
-
- predictions: List[Dict[str, Any]] = []
-
+ subreddit_tokens = (
+ subreddit_tokens
+ if sub_prompt is not None
+ else torch.tensor(
+ (
+ [self.model.sos_index]
+ + self.tokenizer.encode("pics")
+ + [self.tokenizer.token_to_id("[SEP]")]
+ ),
+ device=self.device,
+ ).long()
+ )
+
+ subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
+
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
-
+
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
-
+
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
-
+
caption = self.tokenizer.decode(caption)
-
+
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
@@ -122,25 +131,29 @@ class VirTexModel():
# split prompt for coloring:
if prompt is not "":
_, rest_of_caption = caption.split(prompt.strip())
-
+
is_valid_subreddit = subreddit in self.valid_subs
-
+
return subreddit, rest_of_caption
+
def download_files():
- #download model files
+ # download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
+
def get_samples():
return glob.glob(SAMPLES_PATH)
+
def get_rand_idx(samples):
- return random.randint(0,len(samples)-1)
+ return random.randint(0, len(samples) - 1)
+
-@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
+@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
def create_objects():
sample_images = get_samples()
virtexModel = VirTexModel()
@@ -149,7 +162,8 @@ def create_objects():
valid_subs.insert(0, None)
return virtexModel, imageLoader, sample_images, valid_subs
-footer="""
-
-
-
- Model Config Name |
- VOC07 mAP |
- ImageNet Top-1 Acc. |
- Model URL |
-
-
- task_ablations/bicaptioning_R_50_L1_H2048.yaml |
- 88.7 |
- 53.8 |
- model |
-
-
- task_ablations/captioning_R_50_L1_H2048.yaml |
- 88.6 |
- 50.8 |
- model |
-
-
- task_ablations/token_classification_R_50.yaml |
- 88.8 |
- 48.6 |
- model |
-
-
- task_ablations/multilabel_classification_R_50.yaml |
- 86.2 |
- 46.2 |
- model |
-
-
- task_ablations/masked_lm_R_50_L1_H2048.yaml |
- 86.4 |
- 46.7 |
- model |
-
-
-
-
-
-Width Ablations
-^^^^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-
- Model Config Name |
- VOC07 mAP |
- ImageNet Top-1 Acc. |
- Model URL |
-
-
- width_ablations/bicaptioning_R_50_L1_H512.yaml |
- 88.4 |
- 51.8 |
- model |
-
-
- width_ablations/bicaptioning_R_50_L1_H768.yaml |
- 88.3 |
- 52.3 |
- model |
-
-
- width_ablations/bicaptioning_R_50_L1_H1024.yaml |
- 88.3 |
- 53.2 |
- model |
-
-
- width_ablations/bicaptioning_R_50_L1_H2048.yaml |
- 88.7 |
- 53.8 |
- model |
-
-
-
-
-
-Depth Ablations
-^^^^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-
- Model Config Name |
- VOC07 mAP |
- ImageNet Top-1 Acc. |
- Model URL |
-
-
- depth_ablations/bicaptioning_R_50_L1_H1024.yaml |
- 88.3 |
- 53.2 |
- model |
-
-
- depth_ablations/bicaptioning_R_50_L2_H1024.yaml |
- 88.8 |
- 53.8 |
- model |
-
-
- depth_ablations/bicaptioning_R_50_L3_H1024.yaml |
- 88.7 |
- 53.9 |
- model |
-
-
- depth_ablations/bicaptioning_R_50_L4_H1024.yaml |
- 88.7 |
- 53.9 |
- model |
-
-
-
-
-
-Backbone Ablations
-^^^^^^^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-
- Model Config Name |
- VOC07 mAP |
- ImageNet Top-1 Acc. |
- Model URL |
-
-
- backbone_ablations/bicaptioning_R_50_L1_H1024.yaml |
- 88.3 |
- 53.2 |
- model |
-
-
- backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml |
- 88.5 |
- 52.9 |
- model |
-
-
- backbone_ablations/bicaptioning_R_101_L1_H1024.yaml |
- 88.7 |
- 52.1 |
- model |
-
-
-
diff --git a/virtex/docs/virtex/usage/pretrain.rst b/virtex/docs/virtex/usage/pretrain.rst
deleted file mode 100644
index 2f14305f2152afdede708d45cbe5b2d165e9246a..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/usage/pretrain.rst
+++ /dev/null
@@ -1,100 +0,0 @@
-How to train your VirTex model?
-===============================
-
-We provide training scripts for all type of VirTex models from the paper;
-including our best-performing model and other ablations.
-Our training jobs are specified by config files (YAML).
-Execute all commands from project root to use the provided config files.
-
-
-Training the base VirTex model
-------------------------------
-
-Train the base VirTex model with ResNet-50 visual backbone; and a textual head
-with ``L = 1, H = 1024`` using all default optimization hyperparameters.
-
-.. code-block::
-
- python scripts/pretrain_virtex.py \
- --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \
- --num-gpus-per-machine 8 \
- --cpu-workers 4 \
- --serialization-dir /tmp/VIRTEX_R_50_L1_H1024
- # Default: --checkpoint-every 2000 --log-every 20
-
-Training job will save checkpoints, tensorboard logs (loss curves and metrics),
-and back up the config in ``--serialization-dir``. Use ``tensorboard --logdir
-`` to view training curves, validation metrics etc. directly
-on tensorboard.
-
-We recommend training with 8 GPUs on the same machine, although training with
-multiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``),
-single GPU (``--num-gpus-per-machine 1``) as well as CPU
-(``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for
-interactive debugging with PDB is not supported, as PDB and ``multiprocessing``
-module do not play nice.
-
--------------------------------------------------------------------------------
-
-Reproducing all VirTex ablations
---------------------------------
-
-To reproduce all ablations from the `paper `_,
-replace the ``--config`` argument in above command with the following (all
-assumed to be relative to project root):
-
-Pretraining Task Ablations
-^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml
-2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml
-3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml
-4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml
-5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml
-
-Transformer Size Ablations
-^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml
-2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml
-3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml
-4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml
-5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml
-6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml
-7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml
-8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml
-
-Backbone Ablations
-^^^^^^^^^^^^^^^^^^
-
-1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml
-2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml
-3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml
-
-.. note::
-
- **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5)
- and **Backbone Ablations** (1) are all the same exact model.
-
-Data Efficiency Experiments
-^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-These are VirTex models trained on a subset of COCO Captions dataset. For example,
-train a base VirTex model on randomly selected ``50%`` of COCO Captions:
-
-.. code-block::
-
- python scripts/pretrain_virtex.py \
- --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \
- --config-override DATA.USE_PERCENTAGE 50.0 \
- --num-gpus-per-machine 8 \
- --cpu-workers 4 \
- --serialization-dir /tmp/VIRTEX_R_50_L1_H1024_PERCENT_50
- # Default: --checkpoint-every 2000 --log-every 20
-
-COCO Captions provides five captions per image. To train with one fixed caption
-per image, add ``DATA.USE_SINGLE_CAPTION True`` in ``--config-override``.
-
-The randomly selected subset is deterministic across runs based on random seed
-(``RANDOM_SEED`` in config). When training on less than ``50%`` dataset size, we
-recommend using multiple random seeds (results will have a variance of ``±1%``).
diff --git a/virtex/docs/virtex/usage/setup_dependencies.rst b/virtex/docs/virtex/usage/setup_dependencies.rst
deleted file mode 100644
index b4ece3964148f977154c367de2dfb84c57a86053..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/usage/setup_dependencies.rst
+++ /dev/null
@@ -1,153 +0,0 @@
-How to setup this codebase?
-===========================
-
-.. raw:: html
-
-
-
-This codebase requires Python 3.6+ or higher. We recommend using Anaconda or
-Miniconda. We walk through installation and data preprocessing here.
-
-
-Install Dependencies
---------------------
-
-For these steps to install through Anaconda (or Miniconda).
-
-1. Install Anaconda or Miniconda distribution based on Python 3+ from their
- `downloads site `_.
-
-
-2. Clone the repository first.
-
- .. code-block:: shell
-
- git clone https://www.github.com/kdexd/virtex
-
-
-3. Create a conda environment and install all the dependencies.
-
- .. code-block:: shell
-
- cd virtex
- conda create -n virtex python=3.6
- conda activate virtex
- pip install -r requirements.txt
-
-
-4. Install this codebase as a package in development version.
-
- .. code-block:: shell
-
- python setup.py develop
-
-Now you can ``import virtex`` from anywhere as long as you have this conda
-environment activated.
-
--------------------------------------------------------------------------------
-
-
-Setup Datasets
---------------
-
-Datasets are assumed to exist in ``./datasets`` directory (relative to the
-project root) following the structure specified below. COCO is used for
-pretraining, and rest of the datasets (including COCO) are used for downstream
-tasks. This structure is compatible when using
-`Detectron2 `_ for downstream
-tasks.
-
-COCO
-^^^^
-.. code-block::
-
- datasets/coco/
- annotations/
- captions_{train,val}2017.json
- instances_{train,val}2017.json
- train2017/
- # images in train2017 split
- val2017/
- # images in val2017 split
-
-LVIS
-^^^^
-.. code-block::
-
- datasets/coco/
- train2017/
- val2017/
- datasets/lvis/
- lvis_v1.0_{train,val}.json
-
-PASCAL VOC
-^^^^^^^^^^
-.. code-block::
-
- datasets/VOC2007/
- Annotations/
- ImageSets/
- Main/
- trainval.txt
- test.txt
- JPEGImages/
-
- datasets/VOC2012/
- # Same as VOC2007 above
-
-ImageNet
-^^^^^^^^
-.. code-block::
-
- datasets/imagenet/
- train/
- # One directory per category with images in it
- val/
- # One directory per category with images in it
- ILSVRC2012_devkit_t12.tar.gz
-
-iNaturalist 2018
-^^^^^^^^^^^^^^^^
-.. code-block::
-
- datasets/inaturalist/
- train_val2018/
- annotations/
- train2018.json
- val2018.json
-
--------------------------------------------------------------------------------
-
-
-Preprocess Data
----------------
-
-1. Build a vocabulary out of COCO Captions ``train2017`` split.
-
- .. code-block:: shell
-
- python scripts/preprocess/build_vocabulary.py \
- --captions datasets/coco/annotations/captions_train2017.json \
- --vocab-size 10000 \
- --output-prefix datasets/vocab/coco_10k \
- --do-lower-case
-
-
-2. Serialize COCO Captions (``train2017`` and ``val2017`` splits) into LMDB
- files. These are faster for data reading during pretraining.
-
- .. code-block:: shell
-
- python scripts/preprocess/preprocess_coco.py \
- --data-root datasets/coco \
- --split train \
- --output datasets/coco/serialized_train.lmdb
-
- .. code-block:: shell
-
- python scripts/preprocess/preprocess_coco.py \
- --data-root datasets/coco \
- --split val \
- --output datasets/coco/serialized_val.lmdb
-
-That's it! You are all set to use this codebase.
diff --git a/virtex/docs/virtex/utils.beam_search.rst b/virtex/docs/virtex/utils.beam_search.rst
deleted file mode 100644
index a04811e9c89a0c093e1ffb373467eb6ba9b81b87..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.beam_search.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.beam_search
-========================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.beam_search
diff --git a/virtex/docs/virtex/utils.checkpointing.rst b/virtex/docs/virtex/utils.checkpointing.rst
deleted file mode 100644
index 1b3719bf7e330c13835dc57457a3bef238c29b0e..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.checkpointing.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.checkpointing
-==========================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.checkpointing
diff --git a/virtex/docs/virtex/utils.common.rst b/virtex/docs/virtex/utils.common.rst
deleted file mode 100644
index cadd36d26a01f03b4457f1caed1c0c03dc58a9ef..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.common.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.common
-===================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.common
diff --git a/virtex/docs/virtex/utils.distributed.rst b/virtex/docs/virtex/utils.distributed.rst
deleted file mode 100644
index e6a44d674ecb8a96d2568b1cd4072dd1e38f2a9d..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.distributed.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.distributed
-========================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.distributed
diff --git a/virtex/docs/virtex/utils.metrics.rst b/virtex/docs/virtex/utils.metrics.rst
deleted file mode 100644
index 75234d5e4d230adf20192af77849b1a9c3f059d1..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.metrics.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.metrics
-====================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.metrics
diff --git a/virtex/docs/virtex/utils.rst b/virtex/docs/virtex/utils.rst
deleted file mode 100644
index 9d021d9c4e1e255554130264d12abad06cc53911..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-virtex.utils
-============
-
-.. raw:: html
-
-
-
-.. toctree::
-
- utils.common
- utils.distributed
- utils.timer
- utils.checkpointing
- utils.beam_search
- utils.metrics
diff --git a/virtex/docs/virtex/utils.timer.rst b/virtex/docs/virtex/utils.timer.rst
deleted file mode 100644
index c2ddcdb3459f519d9a98766a6ddbef2adefa072d..0000000000000000000000000000000000000000
--- a/virtex/docs/virtex/utils.timer.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-virtex.utils.timer
-==================
-
-.. raw:: html
-
-
-
-.. automodule:: virtex.utils.timer
diff --git a/virtex/virtex/factories.py b/virtex/factories.py
similarity index 100%
rename from virtex/virtex/factories.py
rename to virtex/factories.py
diff --git a/virtex/hubconf.py b/virtex/hubconf.py
deleted file mode 100644
index f85d01d371151f0716680397b1c955d3c4dd42d7..0000000000000000000000000000000000000000
--- a/virtex/hubconf.py
+++ /dev/null
@@ -1,35 +0,0 @@
-dependencies = ["torch"]
-
-import torch
-import torchvision
-
-
-def resnet50(pretrained: bool = False, **kwargs):
- r"""
- ResNet-50 visual backbone from the best performing VirTex model: pretrained
- for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``.
-
- This is a torchvision-like model, with the last ``avgpool`` and `fc``
- modules replaced with ``nn.Identity()`` modules. Given a batch of image
- tensors with size ``(B, 3, 224, 224)``, this model computes spatial image
- features of size ``(B, 7, 7, 2048)``, where B = batch size.
-
- pretrained (bool): Whether to load model with pretrained weights.
- """
-
- # Create a torchvision resnet50 with randomly initialized weights.
- model = torchvision.models.resnet50(pretrained=False, **kwargs)
-
- # Replace global average pooling and fully connected layers with identity
- # modules.
- model.avgpool = torch.nn.Identity()
- model.fc = torch.nn.Identity()
-
- if pretrained:
- model.load_state_dict(
- torch.hub.load_state_dict_from_url(
- "https://umich.box.com/shared/static/gsjqm4i4fm1wpzi947h27wweljd8gcpy.pth",
- progress=False,
- )["model"]
- )
- return model
diff --git a/virtex/virtex/model_zoo/__init__.py b/virtex/model_zoo/__init__.py
similarity index 100%
rename from virtex/virtex/model_zoo/__init__.py
rename to virtex/model_zoo/__init__.py
diff --git a/virtex/virtex/model_zoo/model_zoo.py b/virtex/model_zoo/model_zoo.py
similarity index 100%
rename from virtex/virtex/model_zoo/model_zoo.py
rename to virtex/model_zoo/model_zoo.py
diff --git a/virtex/virtex/models/__init__.py b/virtex/models/__init__.py
similarity index 100%
rename from virtex/virtex/models/__init__.py
rename to virtex/models/__init__.py
diff --git a/virtex/virtex/models/captioning.py b/virtex/models/captioning.py
similarity index 100%
rename from virtex/virtex/models/captioning.py
rename to virtex/models/captioning.py
diff --git a/virtex/virtex/models/classification.py b/virtex/models/classification.py
similarity index 100%
rename from virtex/virtex/models/classification.py
rename to virtex/models/classification.py
diff --git a/virtex/virtex/models/contrastive.py b/virtex/models/contrastive.py
similarity index 100%
rename from virtex/virtex/models/contrastive.py
rename to virtex/models/contrastive.py
diff --git a/virtex/virtex/models/masked_lm.py b/virtex/models/masked_lm.py
similarity index 100%
rename from virtex/virtex/models/masked_lm.py
rename to virtex/models/masked_lm.py
diff --git a/virtex/virtex/models/zero_shot_classification_eval.py b/virtex/models/zero_shot_classification_eval.py
similarity index 100%
rename from virtex/virtex/models/zero_shot_classification_eval.py
rename to virtex/models/zero_shot_classification_eval.py
diff --git a/virtex/virtex/modules/embedding.py b/virtex/modules/embedding.py
similarity index 100%
rename from virtex/virtex/modules/embedding.py
rename to virtex/modules/embedding.py
diff --git a/virtex/virtex/modules/label_smoothing.py b/virtex/modules/label_smoothing.py
similarity index 100%
rename from virtex/virtex/modules/label_smoothing.py
rename to virtex/modules/label_smoothing.py
diff --git a/virtex/virtex/modules/textual_heads.py b/virtex/modules/textual_heads.py
similarity index 100%
rename from virtex/virtex/modules/textual_heads.py
rename to virtex/modules/textual_heads.py
diff --git a/virtex/virtex/modules/transformer.py b/virtex/modules/transformer.py
similarity index 100%
rename from virtex/virtex/modules/transformer.py
rename to virtex/modules/transformer.py
diff --git a/virtex/virtex/modules/visual_backbones.py b/virtex/modules/visual_backbones.py
similarity index 100%
rename from virtex/virtex/modules/visual_backbones.py
rename to virtex/modules/visual_backbones.py
diff --git a/virtex/virtex/optim/__init__.py b/virtex/optim/__init__.py
similarity index 100%
rename from virtex/virtex/optim/__init__.py
rename to virtex/optim/__init__.py
diff --git a/virtex/virtex/optim/lookahead.py b/virtex/optim/lookahead.py
similarity index 100%
rename from virtex/virtex/optim/lookahead.py
rename to virtex/optim/lookahead.py
diff --git a/virtex/virtex/optim/lr_scheduler.py b/virtex/optim/lr_scheduler.py
similarity index 100%
rename from virtex/virtex/optim/lr_scheduler.py
rename to virtex/optim/lr_scheduler.py
diff --git a/virtex/scripts/clf_linear.py b/virtex/scripts/clf_linear.py
deleted file mode 100644
index 52ab5f22d974cf4e523f174aab09143d7d19b005..0000000000000000000000000000000000000000
--- a/virtex/scripts/clf_linear.py
+++ /dev/null
@@ -1,302 +0,0 @@
-import argparse
-import os
-
-from loguru import logger
-import torch
-from torch import nn
-from torch.cuda import amp
-from torch.utils.data import DataLoader, DistributedSampler
-from torch.utils.tensorboard import SummaryWriter
-
-from virtex.config import Config
-from virtex.factories import (
- DownstreamDatasetFactory,
- PretrainingModelFactory,
- OptimizerFactory,
- LRSchedulerFactory,
-)
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser, common_setup, cycle
-import virtex.utils.distributed as dist
-from virtex.utils.metrics import TopkAccuracy
-from virtex.utils.timer import Timer
-
-
-# fmt: off
-parser = common_parser(
- description="""Do image classification with linear models and frozen
- feature extractor, or fine-tune the feature extractor end-to-end."""
-)
-group = parser.add_argument_group("Downstream config arguments.")
-group.add_argument(
- "--down-config", metavar="FILE", help="Path to a downstream config file."
-)
-group.add_argument(
- "--down-config-override", nargs="*", default=[],
- help="A list of key-value pairs to modify downstream config params.",
-)
-
-parser.add_argument_group("Checkpointing and Logging")
-parser.add_argument(
- "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
- default="virtex", help="""How to initialize weights:
- 1. 'random' initializes all weights randomly
- 2. 'imagenet' initializes backbone weights from torchvision model zoo
- 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
- - with 'torchvision', state dict would be from PyTorch's training
- script.
- - with 'virtex' it should be for our full pretrained model."""
-)
-parser.add_argument(
- "--log-every", type=int, default=50,
- help="""Log training curves to tensorboard after every these many iterations
- only master process logs averaged loss values across processes.""",
-)
-parser.add_argument(
- "--checkpoint-path",
- help="""Path to load checkpoint and run downstream task evaluation. The
- name of checkpoint file is required to be `model_*.pth`, where * is
- iteration number from which the checkpoint was serialized."""
-)
-parser.add_argument(
- "--checkpoint-every", type=int, default=5000,
- help="""Serialize model to a checkpoint after every these many iterations.
- For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations
- = 1 epoch).""",
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device = torch.device("cpu")
- else:
- # Get the current device as set for current distributed process.
- # Check `launch` function in `virtex.utils.distributed` module.
- device = torch.cuda.current_device()
-
- # Create a downstream config object (this will be immutable) and perform
- # common setup such as logging and setting up serialization directory.
- _DOWNC = Config(_A.down_config, _A.down_config_override)
- common_setup(_DOWNC, _A, job_type="downstream")
-
- # Create a (pretraining) config object and backup in serializaion directory.
- _C = Config(_A.config, _A.config_override)
- _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
-
- # Get dataset name for tensorboard logging.
- DATASET = _DOWNC.DATA.ROOT.split("/")[-1]
-
- # Set number of output classes according to dataset:
- NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142}
- NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET]
-
- # -------------------------------------------------------------------------
- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER
- # -------------------------------------------------------------------------
- train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train")
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),
- num_workers=_A.cpu_workers,
- sampler=DistributedSampler(
- train_dataset,
- num_replicas=dist.get_world_size(),
- rank=dist.get_rank(),
- shuffle=True,
- ),
- drop_last=False,
- pin_memory=True,
- collate_fn=train_dataset.collate_fn,
- )
- val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val")
- val_dataloader = DataLoader(
- val_dataset,
- batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),
- num_workers=_A.cpu_workers,
- sampler=DistributedSampler(
- val_dataset,
- num_replicas=dist.get_world_size(),
- rank=dist.get_rank(),
- shuffle=False,
- ),
- pin_memory=True,
- drop_last=False,
- collate_fn=val_dataset.collate_fn,
- )
- # Initialize model using pretraining config.
- pretrained_model = PretrainingModelFactory.from_config(_C)
-
- # Load weights according to the init method, do nothing for `random`, and
- # `imagenet` is already taken care of.
- if _A.weight_init == "virtex":
- CheckpointManager(model=pretrained_model).load(_A.checkpoint_path)
- elif _A.weight_init == "torchvision":
- # Keep strict=False because this state dict may have weights for
- # last fc layer.
- pretrained_model.visual.cnn.load_state_dict(
- torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
- strict=False,
- )
-
- # Pull out the CNN (torchvision-like) from our pretrained model and add
- # back the FC layer - this is exists in torchvision models, and is set to
- # `nn.Identity()` during pretraining.
- model = pretrained_model.visual.cnn # type: ignore
- model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device)
- model = model.to(device)
-
- # Re-initialize the FC layer.
- torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01)
- torch.nn.init.constant_(model.fc.bias.data, 0.0)
-
- # Freeze all layers except FC as per config param.
- if _DOWNC.MODEL.VISUAL.FROZEN:
- # Set model to eval mode to prevent BatchNorm from updating running
- # mean and std. With only a linear layer, being in eval mode when
- # training will not matter anyway.
- model.eval()
-
- for name, param in model.named_parameters():
- if "fc" not in name:
- param.requires_grad = False
-
- # Cross entropy loss and accuracy meter.
- criterion = nn.CrossEntropyLoss()
- top1 = TopkAccuracy(top_k=1)
-
- optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters())
- scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer)
- del pretrained_model
-
- # -------------------------------------------------------------------------
- # BEFORE TRAINING STARTS
- # -------------------------------------------------------------------------
-
- # Create a gradient scaler for automatic mixed precision.
- scaler = amp.GradScaler(enabled=_DOWNC.AMP)
-
- # Create an iterator from dataloader to sample batches perpetually.
- train_dataloader_iter = cycle(train_dataloader, device)
-
- if dist.get_world_size() > 1:
- dist.synchronize()
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True
- )
-
- if dist.is_master_process():
- checkpoint_manager = CheckpointManager(
- _A.serialization_dir,
- model=model,
- optimizer=optimizer,
- scheduler=scheduler,
- )
- tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)
-
- # Keep track of time per iteration and ETA.
- timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS)
-
- # -------------------------------------------------------------------------
- # TRAINING LOOP
- # -------------------------------------------------------------------------
- for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1):
- timer.tic()
- optimizer.zero_grad()
- batch = next(train_dataloader_iter)
-
- with amp.autocast(enabled=_DOWNC.AMP):
- logits = model(batch["image"])
- loss = criterion(logits, batch["label"])
-
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- scheduler.step()
- timer.toc()
-
- if iteration % _A.log_every == 0 and dist.is_master_process():
- logger.info(
- f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB"
- )
- tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration)
- tensorboard_writer.add_scalar(
- f"{DATASET}/learning_rate",
- optimizer.param_groups[0]["lr"],
- iteration,
- )
-
- # ---------------------------------------------------------------------
- # VALIDATION
- # ---------------------------------------------------------------------
- if iteration % _A.checkpoint_every == 0:
- torch.set_grad_enabled(False)
- model.eval()
-
- total_val_loss = torch.tensor(0.0).to(device)
-
- for val_iteration, batch in enumerate(val_dataloader, start=1):
- for key in batch:
- batch[key] = batch[key].to(device)
-
- logits = model(batch["image"])
- loss = criterion(logits, batch["label"])
- top1(logits, batch["label"])
- total_val_loss += loss
-
- # Divide each loss component by number of val batches per GPU.
- total_val_loss = total_val_loss / val_iteration
- dist.average_across_processes(total_val_loss)
-
- # Get accumulated Top-1 accuracy for logging across GPUs.
- acc = top1.get_metric(reset=True)
- dist.average_across_processes(acc)
-
- torch.set_grad_enabled(True)
-
- # Set model back to train mode only when fine-tuning end-to-end.
- if not _DOWNC.MODEL.VISUAL.FROZEN:
- model.train()
-
- # Save recent checkpoint and best checkpoint based on accuracy.
- if dist.is_master_process():
- checkpoint_manager.step(iteration)
-
- if iteration % _A.checkpoint_every == 0 and dist.is_master_process():
- logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})")
- tensorboard_writer.add_scalar(
- f"{DATASET}/val_loss", total_val_loss, iteration
- )
- # This name scoping will result in Tensorboard displaying all metrics
- # (VOC07, caption, etc.) together.
- tensorboard_writer.add_scalars(
- f"metrics/{DATASET}", {"top1": acc}, iteration
- )
-
- # All processes will wait till master process is done logging.
- dist.synchronize()
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- # Add an arg in config override if `--weight-init` is imagenet.
- if _A.weight_init == "imagenet":
- _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
-
- if _A.num_gpus_per_machine == 0:
- main(_A)
- else:
- # This will launch `main` and set appropriate CUDA device (GPU ID) as
- # per process (accessed in the beginning of `main`).
- dist.launch(
- main,
- num_machines=_A.num_machines,
- num_gpus_per_machine=_A.num_gpus_per_machine,
- machine_rank=_A.machine_rank,
- dist_url=_A.dist_url,
- args=(_A,),
- )
diff --git a/virtex/scripts/clf_voc07.py b/virtex/scripts/clf_voc07.py
deleted file mode 100644
index 0e382c1ac49a3c9c254ab9c97f14652ed664fbf6..0000000000000000000000000000000000000000
--- a/virtex/scripts/clf_voc07.py
+++ /dev/null
@@ -1,272 +0,0 @@
-import argparse
-import multiprocessing as mp
-import os
-from typing import Any, List
-
-from loguru import logger
-import numpy as np
-from sklearn.svm import LinearSVC
-from sklearn.metrics import average_precision_score
-from sklearn.model_selection import cross_val_score
-import torch
-from torch.utils.data import DataLoader
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-
-from virtex.config import Config
-from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser, common_setup
-
-
-parser = common_parser(
- description="Train SVMs for VOC2007 classification on a pretrained model."
-)
-group = parser.add_argument_group("Downstream config arguments.")
-group.add_argument(
- "--down-config", metavar="FILE", help="Path to a downstream config file."
-)
-group.add_argument(
- "--down-config-override",
- nargs="*",
- default=[],
- help="A list of key-value pairs to modify downstream config params.",
-)
-
-# fmt: off
-parser.add_argument_group("Checkpointing")
-parser.add_argument(
- "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
- default="virtex", help="""How to initialize weights:
- 1. 'random' initializes all weights randomly
- 2. 'imagenet' initializes backbone weights from torchvision model zoo
- 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
- - with 'torchvision', state dict would be from PyTorch's training
- script.
- - with 'virtex' it should be for our full pretrained model."""
-)
-parser.add_argument(
- "--checkpoint-path",
- help="Path to load checkpoint and run downstream task evaluation."
-)
-# fmt: on
-
-
-def train_test_single_svm(args):
-
- feats_train, tgts_train, feats_test, tgts_test, cls_name = args
- SVM_COSTS = [0.01, 0.1, 1.0, 10.0]
-
- cls_labels = np.copy(tgts_train)
- # Meaning of labels in VOC/COCO original loaded target files:
- # label 0 = not present, set it to -1 as svm train target
- # label 1 = present. Make the svm train target labels as -1, 1.
- cls_labels[np.where(cls_labels == 0)] = -1
-
- # See which cost maximizes the AP for this class.
- best_crossval_ap: float = 0.0
- best_crossval_clf = None
- best_cost: float = 0.0
-
- # fmt: off
- for cost in SVM_COSTS:
- clf = LinearSVC(
- C=cost, class_weight={1: 2, -1: 1}, penalty="l2",
- loss="squared_hinge", max_iter=2000,
- )
- ap_scores = cross_val_score(
- clf, feats_train, cls_labels, cv=3, scoring="average_precision",
- )
- clf.fit(feats_train, cls_labels)
-
- # Keep track of best SVM (based on cost) for each class.
- if ap_scores.mean() > best_crossval_ap:
- best_crossval_ap = ap_scores.mean()
- best_crossval_clf = clf
- best_cost = cost
-
- logger.info(f"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}")
- # fmt: on
-
- # -------------------------------------------------------------------------
- # TEST THE TRAINED SVM (PER CLASS)
- # -------------------------------------------------------------------------
- predictions = best_crossval_clf.decision_function(feats_test)
- evaluate_data_inds = tgts_test != -1
- eval_preds = predictions[evaluate_data_inds]
-
- cls_labels = np.copy(tgts_test)
- eval_cls_labels = cls_labels[evaluate_data_inds]
- eval_cls_labels[np.where(eval_cls_labels == 0)] = -1
-
- # Binarize class labels to make AP targets.
- targets = eval_cls_labels > 0
- return average_precision_score(targets, eval_preds)
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device = torch.device("cpu")
- else:
- # Get the current device (this will be zero here by default).
- device = torch.cuda.current_device()
-
- # Create a downstream config object (this will be immutable) and perform
- # common setup such as logging and setting up serialization directory.
- _DOWNC = Config(_A.down_config, _A.down_config_override)
- common_setup(_DOWNC, _A, job_type="downstream")
-
- # Create a (pretraining) config object and backup in serialization directory.
- _C = Config(_A.config, _A.config_override)
- _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
-
- # -------------------------------------------------------------------------
- # INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR
- # -------------------------------------------------------------------------
-
- train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="trainval")
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=_DOWNC.OPTIM.BATCH_SIZE,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- )
- test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="test")
- test_dataloader = DataLoader(
- test_dataset,
- batch_size=_DOWNC.OPTIM.BATCH_SIZE,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- )
- NUM_CLASSES = len(train_dataset.class_names)
-
- # Initialize from a checkpoint, but only keep the visual module.
- model = PretrainingModelFactory.from_config(_C)
-
- # Load weights according to the init method, do nothing for `random`, and
- # `imagenet` is already taken care of.
- if _A.weight_init == "virtex":
- ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)
- elif _A.weight_init == "torchvision":
- # Keep strict=False because this state dict may have weights for
- # last fc layer.
- model.visual.cnn.load_state_dict(
- torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
- strict=False,
- )
- # Set ``ITERATION`` to a dummy value.
- ITERATION = 0
-
- # Transfer model to GPU and set to eval mode. This is a torchvision model
- # and it returns features as ``(batch_size, 2048, 7, 7)``.
- model = model.visual.cnn.to(device).eval()
-
- # -------------------------------------------------------------------------
- # EXTRACT FEATURES FOR TRAINING SVMs
- # -------------------------------------------------------------------------
-
- features_train: List[torch.Tensor] = []
- targets_train: List[torch.Tensor] = []
-
- features_test: List[torch.Tensor] = []
- targets_test: List[torch.Tensor] = []
-
- # VOC07 is small, extract all features and keep them in memory.
- with torch.no_grad():
- for batch in tqdm(train_dataloader, desc="Extracting train features:"):
- features = model(batch["image"].to(device))
-
- # Global average pool features. Assume the tensor is in NCHW format.
- if len(features.size()) > 2:
- features = features.view(features.size(0), features.size(1), -1)
-
- # shape: (batch_size, visual_feature_size)
- features = features.mean(dim=-1)
-
- # shape: (batch_size, visual_feature_size)
- features = features.view(features.size(0), -1)
-
- # L2-normalize the global average pooled features.
- features = features / torch.norm(features, dim=-1).unsqueeze(-1)
-
- features_train.append(features.cpu())
- targets_train.append(batch["label"])
-
- # Similarly extract test features.
- for batch in tqdm(test_dataloader, desc="Extracting test features:"):
- features = model(batch["image"].to(device))
-
- if len(features.size()) > 2:
- features = features.view(features.size(0), features.size(1), -1)
- features = features.mean(dim=-1)
-
- features = features.view(features.size(0), -1)
- features = features / torch.norm(features, dim=-1).unsqueeze(-1)
-
- features_test.append(features.cpu())
- targets_test.append(batch["label"])
-
- # Convert batches of features/targets to one large numpy array
- features_train = torch.cat(features_train, dim=0).numpy()
- targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32)
-
- features_test = torch.cat(features_test, dim=0).numpy()
- targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32)
-
- # -------------------------------------------------------------------------
- # TRAIN AND TEST SVMs WITH EXTRACTED FEATURES
- # -------------------------------------------------------------------------
-
- input_args: List[Any] = []
-
- # Iterate over all VOC07 classes and train one-vs-all linear SVMs.
- for cls_idx in range(NUM_CLASSES):
- # fmt: off
- input_args.append((
- features_train, targets_train[:, cls_idx],
- features_test, targets_test[:, cls_idx],
- train_dataset.class_names[cls_idx],
- ))
- # fmt: on
-
- pool = mp.Pool(processes=_A.cpu_workers)
- pool_output = pool.map(train_test_single_svm, input_args)
-
- # -------------------------------------------------------------------------
- # TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint)
- # -------------------------------------------------------------------------
-
- # Tensorboard writer for logging mAP scores. This is useful especially
- # when weight_init=checkpoint (which maybe be coming from a training job).
- tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)
-
- # Test set mAP for each class, for features from every layer.
- test_map = torch.tensor(pool_output).mean()
- logger.info(f"Iteration: {ITERATION}, mAP: {test_map * 100}")
- tensorboard_writer.add_scalars(
- "metrics/voc07_clf", {f"voc07_mAP": test_map * 100}, ITERATION
- )
-
- # NOTE: for copy-pasting to spreadsheet.
- logger.info(
- f"{_C.DATA.ROOT.split('/')[1]},{_C.DATA.TOKENIZER_MODEL.split('/')[-1][:-6]},"
- f"{_C.DATA.VOCAB_SIZE},{_C.MODEL.NAME},{_C.MODEL.VISUAL.NAME},{_C.MODEL.TEXTUAL.NAME},"
- f"{_C.MODEL.LABEL_SMOOTHING},{_C.OPTIM.OPTIMIZER_NAME},{_C.OPTIM.BATCH_SIZE},"
- f"{_C.OPTIM.NUM_ITERATIONS},{_C.OPTIM.LR},{_C.OPTIM.WEIGHT_DECAY},"
- f"{ITERATION},{test_map * 100:.3f}"
- )
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- if _A.num_gpus_per_machine > 1:
- raise ValueError("Using multiple GPUs is not supported for this script.")
-
- # Add an arg in config override if `--weight-init` is imagenet.
- if _A.weight_init == "imagenet":
- _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
-
- # No distributed training here, just a single process.
- main(_A)
diff --git a/virtex/scripts/eval_captioning.py b/virtex/scripts/eval_captioning.py
deleted file mode 100644
index 8da98284f1726027536e38b72b4a82ba04bea396..0000000000000000000000000000000000000000
--- a/virtex/scripts/eval_captioning.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import argparse
-import json
-import os
-from typing import Any, Dict, List
-
-from loguru import logger
-import torch
-from torch.utils.data import DataLoader
-
-from virtex.config import Config
-from virtex.data import ImageDirectoryDataset
-from virtex.factories import TokenizerFactory, PretrainingModelFactory
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser
-from virtex.utils.metrics import CocoCaptionsEvaluator
-
-
-# fmt: off
-parser = common_parser(
- description="""Run image captioning inference on a pretrained model, and/or
- evaluate pretrained model on COCO Captions val2017 split."""
-)
-parser.add_argument(
- "--data-root", default=None,
- help="""Path to a directory containing image files to generate captions for.
- Default: COCO val2017 image directory as expected relative to project root."""
-)
-parser.add_argument(
- "--checkpoint-path", required=True,
- help="Path to load checkpoint and run captioning evaluation."
-)
-parser.add_argument(
- "--output", default=None,
- help="Path to save predictions as a JSON file."
-)
-parser.add_argument(
- "--calc-metrics", action="store_true",
- help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions.
- This flag should not be set when running inference on arbitrary images."""
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device = torch.device("cpu")
- else:
- # Get the current device (this will be zero here by default).
- device = torch.cuda.current_device()
-
- _C = Config(_A.config, _A.config_override)
-
- tokenizer = TokenizerFactory.from_config(_C)
-
- if _A.data_root is None:
- _A.data_root = os.path.join(_C.DATA.ROOT, "val2017")
-
- val_dataloader = DataLoader(
- ImageDirectoryDataset(_A.data_root),
- batch_size=_C.OPTIM.BATCH_SIZE,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- )
- # Initialize model from a checkpoint.
- model = PretrainingModelFactory.from_config(_C).to(device)
- ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)
- model.eval()
-
- # Make a list of predictions to evaluate.
- predictions: List[Dict[str, Any]] = []
-
- for val_iteration, val_batch in enumerate(val_dataloader, start=1):
-
- val_batch["image"] = val_batch["image"].to(device)
- with torch.no_grad():
- output_dict = model(val_batch)
-
- # Make a dictionary of predictions in COCO format.
- for image_id, caption in zip(
- val_batch["image_id"], output_dict["predictions"]
- ):
- predictions.append(
- {
- # Convert image id to int if possible (mainly for COCO eval).
- "image_id": int(image_id) if image_id.isdigit() else image_id,
- "caption": tokenizer.decode(caption.tolist()),
- }
- )
-
- # Save predictions as a JSON file if specified.
- if _A.output is not None:
- os.makedirs(os.path.dirname(_A.output), exist_ok=True)
- json.dump(predictions, open(_A.output, "w"))
- logger.info(f"Saved predictions to {_A.output}")
-
- # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This
- # should be skipped when running inference on arbitrary images.
- if _A.calc_metrics:
- # Assume ground truth (COCO val2017 annotations) exist.
- gt = os.path.join(_C.DATA.ROOT, "annotations", "captions_val2017.json")
-
- metrics = CocoCaptionsEvaluator(gt).evaluate(predictions)
- logger.info(f"Iter: {ITERATION} | Metrics: {metrics}")
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
- if _A.num_gpus_per_machine > 1:
- raise ValueError("Using multiple GPUs is not supported for this script.")
-
- # No distributed training here, just a single process.
- main(_A)
diff --git a/virtex/scripts/eval_detectron2.py b/virtex/scripts/eval_detectron2.py
deleted file mode 100644
index b79147080f8c56313e1a809b9f1a791ecd380e11..0000000000000000000000000000000000000000
--- a/virtex/scripts/eval_detectron2.py
+++ /dev/null
@@ -1,248 +0,0 @@
-"""
-Finetune a pre-trained model on a downstream task, one of those available in
-Detectron2.
-Supported downstream:
- - LVIS Instance Segmentation
- - COCO Instance Segmentation
- - Pascal VOC 2007+12 Object Detection
-
-Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
-Thanks to the developers of Detectron2!
-"""
-import argparse
-import os
-import re
-from typing import Any, Dict, Union
-
-import torch
-from torch.utils.tensorboard import SummaryWriter
-
-import detectron2 as d2
-from detectron2.checkpoint import DetectionCheckpointer
-from detectron2.engine import DefaultTrainer, default_setup
-from detectron2.evaluation import (
- LVISEvaluator,
- PascalVOCDetectionEvaluator,
- COCOEvaluator,
-)
-from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
-
-from virtex.config import Config
-from virtex.factories import PretrainingModelFactory
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser
-import virtex.utils.distributed as dist
-
-# fmt: off
-parser = common_parser(
- description="Train object detectors from pretrained visual backbone."
-)
-parser.add_argument(
- "--d2-config", required=True,
- help="Path to a detectron2 config for downstream task finetuning."
-)
-parser.add_argument(
- "--d2-config-override", nargs="*", default=[],
- help="""Key-value pairs from Detectron2 config to override from file.
- Some keys will be ignored because they are set from other args:
- [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
- TEST.EVAL_PERIOD, OUTPUT_DIR]""",
-)
-
-parser.add_argument_group("Checkpointing and Logging")
-parser.add_argument(
- "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
- default="virtex", help="""How to initialize weights:
- 1. 'random' initializes all weights randomly
- 2. 'imagenet' initializes backbone weights from torchvision model zoo
- 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
- - with 'torchvision', state dict would be from PyTorch's training
- script.
- - with 'virtex' it should be for our full pretrained model."""
-)
-parser.add_argument(
- "--checkpoint-path",
- help="Path to load checkpoint and run downstream task evaluation."
-)
-parser.add_argument(
- "--resume", action="store_true", help="""Specify this flag when resuming
- training from a checkpoint saved by Detectron2."""
-)
-parser.add_argument(
- "--eval-only", action="store_true",
- help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
-)
-parser.add_argument(
- "--checkpoint-every", type=int, default=5000,
- help="Serialize model to a checkpoint after every these many iterations.",
-)
-# fmt: on
-
-
-@ROI_HEADS_REGISTRY.register()
-class Res5ROIHeadsExtraNorm(Res5ROIHeads):
- r"""
- ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
- C4/DC5 backbones for VOC detection.
- """
-
- def _build_res5_block(self, cfg):
- seq, out_channels = super()._build_res5_block(cfg)
- norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
- seq.add_module("norm", norm)
- return seq, out_channels
-
-
-def build_detectron2_config(_C: Config, _A: argparse.Namespace):
- r"""Build detectron2 config based on our pre-training config and args."""
- _D2C = d2.config.get_cfg()
-
- # Override some default values based on our config file.
- _D2C.merge_from_file(_A.d2_config)
- _D2C.merge_from_list(_A.d2_config_override)
-
- # Set some config parameters from args.
- _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
- _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
- _D2C.OUTPUT_DIR = _A.serialization_dir
-
- # Set ResNet depth to override in Detectron2's config.
- _D2C.MODEL.RESNETS.DEPTH = int(
- re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1)
- if "torchvision" in _C.MODEL.VISUAL.NAME
- else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1)
- if "detectron2" in _C.MODEL.VISUAL.NAME
- else 0
- )
- return _D2C
-
-
-class DownstreamTrainer(DefaultTrainer):
- r"""
- Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.
-
- Parameters
- ----------
- cfg: detectron2.config.CfgNode
- Detectron2 config object containing all config params.
- weights: Union[str, Dict[str, Any]]
- Weights to load in the initialized model. If ``str``, then we assume path
- to a checkpoint, or if a ``dict``, we assume a state dict. This will be
- an ``str`` only if we resume training from a Detectron2 checkpoint.
- """
-
- def __init__(self, cfg, weights: Union[str, Dict[str, Any]]):
-
- super().__init__(cfg)
-
- # Load pre-trained weights before wrapping to DDP because `ApexDDP` has
- # some weird issue with `DetectionCheckpointer`.
- # fmt: off
- if isinstance(weights, str):
- # weights are ``str`` means ImageNet init or resume training.
- self.start_iter = (
- DetectionCheckpointer(
- self._trainer.model,
- optimizer=self._trainer.optimizer,
- scheduler=self.scheduler
- ).resume_or_load(weights, resume=True).get("iteration", -1) + 1
- )
- elif isinstance(weights, dict):
- # weights are a state dict means our pretrain init.
- DetectionCheckpointer(self._trainer.model)._load_model(weights)
- # fmt: on
-
- @classmethod
- def build_evaluator(cls, cfg, dataset_name, output_folder=None):
- if output_folder is None:
- output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
- evaluator_list = []
- evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
- if evaluator_type == "pascal_voc":
- return PascalVOCDetectionEvaluator(dataset_name)
- elif evaluator_type == "coco":
- return COCOEvaluator(dataset_name, cfg, True, output_folder)
- elif evaluator_type == "lvis":
- return LVISEvaluator(dataset_name, cfg, True, output_folder)
-
- def test(self, cfg=None, model=None, evaluators=None):
- r"""Evaluate the model and log results to stdout and tensorboard."""
- cfg = cfg or self.cfg
- model = model or self.model
-
- tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
- results = super().test(cfg, model)
- flat_results = d2.evaluation.testing.flatten_results_dict(results)
- for k, v in flat_results.items():
- tensorboard_writer.add_scalar(k, v, self.start_iter)
-
-
-def main(_A: argparse.Namespace):
-
- # Get the current device as set for current distributed process.
- # Check `launch` function in `virtex.utils.distributed` module.
- device = torch.cuda.current_device()
-
- # Local process group is needed for detectron2.
- pg = list(range(dist.get_world_size()))
- d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)
-
- # Create a config object (this will be immutable) and perform common setup
- # such as logging and setting up serialization directory.
- if _A.weight_init == "imagenet":
- _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
- _C = Config(_A.config, _A.config_override)
-
- # We use `default_setup` from detectron2 to do some common setup, such as
- # logging, setting up serialization etc. For more info, look into source.
- _D2C = build_detectron2_config(_C, _A)
- default_setup(_D2C, _A)
-
- # Prepare weights to pass in instantiation call of trainer.
- if _A.weight_init in {"virtex", "torchvision"}:
- if _A.resume:
- # If resuming training, let detectron2 load weights by providing path.
- model = None
- weights = _A.checkpoint_path
- else:
- # Load backbone weights from VirTex pretrained checkpoint.
- model = PretrainingModelFactory.from_config(_C)
- if _A.weight_init == "virtex":
- CheckpointManager(model=model).load(_A.checkpoint_path)
- else:
- model.visual.cnn.load_state_dict(
- torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
- strict=False,
- )
- weights = model.visual.detectron2_backbone_state_dict()
- else:
- # If random or imagenet init, just load weights after initializing model.
- model = PretrainingModelFactory.from_config(_C)
- weights = model.visual.detectron2_backbone_state_dict()
-
- # Back up pretrain config and model checkpoint (if provided).
- _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
- if _A.weight_init == "virtex" and not _A.resume:
- torch.save(
- model.state_dict(),
- os.path.join(_A.serialization_dir, "pretrain_model.pth"),
- )
-
- del model
- trainer = DownstreamTrainer(_D2C, weights)
- trainer.test() if _A.eval_only else trainer.train()
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- # This will launch `main` and set appropriate CUDA device (GPU ID) as
- # per process (accessed in the beginning of `main`).
- dist.launch(
- main,
- num_machines=_A.num_machines,
- num_gpus_per_machine=_A.num_gpus_per_machine,
- machine_rank=_A.machine_rank,
- dist_url=_A.dist_url,
- args=(_A, ),
- )
diff --git a/virtex/scripts/preprocess/build_redcaps_vocab.py b/virtex/scripts/preprocess/build_redcaps_vocab.py
deleted file mode 100644
index fd28f1b8d72fde036f631032a539c4fe16d169f2..0000000000000000000000000000000000000000
--- a/virtex/scripts/preprocess/build_redcaps_vocab.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import argparse
-import glob
-import json
-import os
-import re
-import tempfile
-from functools import lru_cache
-from typing import List
-
-import ftfy
-import sentencepiece as sp
-import wordsegment as ws
-from tqdm import tqdm
-
-
-ws.load()
-
-# fmt: off
-parser = argparse.ArgumentParser(
- description="""Build a vocabulary out of captions corpus. This vocabulary
- would be a file which our tokenizer can understand.
- """
-)
-parser.add_argument(
- "-f", "--files", nargs="+", default="datasets/redcaps/annotations/*.json",
- help="Path(s) to SBU, Conceptual, or RedCaps annotation files.",
-)
-parser.add_argument(
- "-s", "--vocab-size", type=int, default=32000,
- help="Total desired size of our vocabulary.",
-)
-parser.add_argument(
- "-o", "--output-prefix", default="datasets/vocab/redcaps_32k",
- help="Prefix of the files to be saved. Two files will be saved: "
- "[prefix].model and [prefix].vocab",
-)
-# fmt: on
-
-
-def read_captions_from_file(annotations_path: str) -> List[str]:
- r"""
- Given a path to annotation file, read it and return a list of captions.
-
- Parameters
- ----------
- annotations_path: str
- Path to an annotations file containing captions.
-
- Returns
- -------
- List[str]
- List of captions from this annotation file.
- """
-
- _annotations = json.load(open(annotations_path))
-
- captions: List[str] = []
- for ann in tqdm(_annotations["annotations"], desc=annotations_path):
-
- # This field only exists in RedCaps. Perform word segmentation on the
- # subreddit name to add appropriae whitespaces.
- if "subreddit" in ann:
- subreddit_seg = _segment_subreddit(ann["subreddit"].lower())
- caption = f"{subreddit_seg} {ann['caption']}"
- else:
- caption = ann["caption"]
-
- captions.append(caption.lower())
- return captions
-
-
-@lru_cache(maxsize=10)
-def _segment_subreddit(subreddit):
- return " ".join(ws.segment(ws.clean(subreddit)))
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- all_filepaths: List[str] = []
- for f in _A.files:
- all_filepaths.extend(glob.glob(f))
-
- captions: List[str] = []
- for path in tqdm(all_filepaths, desc="Reading captions"):
- captions.extend(read_captions_from_file(path))
-
- # Create a temporary directory and dump the captions corpus as a text file
- # with one caption per line. That's how sentencepiece wants its input.
- tmpdir_path = tempfile.mkdtemp()
-
- with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file:
- for caption in captions:
- captions_file.write(caption + "\n")
-
- # Padding/out-of-vocab token will be "" and ID 0 by default.
- # Add [SOS],[EOS] and [SEP] tokens. [SEP] will not be used during
- # captioning, but good to have to reuse vocabulary across pretext tasks.
- sp.SentencePieceTrainer.train(
- f" --input={os.path.join(tmpdir_path, 'captions.txt')}"
- f" --vocab_size={_A.vocab_size}"
- f" --model_prefix={_A.output_prefix}"
- " --model_type=bpe --character_coverage=1.0"
- " --bos_id=-1 --eos_id=-1"
- " --control_symbols=[SOS],[EOS],[SEP]"
- " --user_defined_symbols="
- )
diff --git a/virtex/scripts/preprocess/build_vocabulary.py b/virtex/scripts/preprocess/build_vocabulary.py
deleted file mode 100644
index bc7a592b40d8044919279dc8116ca03dce20b5d1..0000000000000000000000000000000000000000
--- a/virtex/scripts/preprocess/build_vocabulary.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import argparse
-import json
-import os
-import tempfile
-import unicodedata
-from typing import List
-
-import sentencepiece as sp
-
-
-# fmt: off
-parser = argparse.ArgumentParser(
- description="""Build a vocabulary out of captions corpus. This vocabulary
- would be a file which our tokenizer can understand.
- """
-)
-parser.add_argument(
- "-c", "--captions", default="datasets/coco/annotations/captions_train2017.json",
- help="Path to caption annotations file in COCO format.",
-)
-parser.add_argument(
- "-s", "--vocab-size", type=int, default=10000,
- help="Total desired size of our vocabulary.",
-)
-parser.add_argument(
- "-o", "--output-prefix", default="datasets/vocab/coco_10k",
- help="Prefix of the files to be saved. Two files will be saved: "
- "[prefix].model and [prefix].vocab",
-)
-parser.add_argument(
- "-l", "--do-lower-case", action="store_true",
- help="Whether to lower case the captions before forming vocabulary.",
-)
-parser.add_argument(
- "-a", "--keep-accents", action="store_true",
- help="Whether to keep accents before forming vocabulary (dropped by default).",
-)
-# fmt: on
-
-
-def _read_captions(annotations_path: str) -> List[str]:
- r"""
- Given a path to annotation file, read it and return a list of captions.
- These are not processed by any means, returned from the file as-is.
-
- Parameters
- ----------
- annotations_path: str
- Path to an annotations file containing captions.
-
- Returns
- -------
- List[str]
- List of captions from this annotation file.
- """
-
- _annotations = json.load(open(annotations_path))
-
- captions: List[str] = []
- for ann in _annotations["annotations"]:
- captions.append(ann["caption"])
-
- return captions
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
- captions: List[str] = _read_captions(_A.captions)
-
- # Lower case the captions and remove accents according to arguments.
- for i, caption in enumerate(captions):
- caption = caption.lower() if _A.do_lower_case else caption
-
- if not _A.keep_accents:
- caption = unicodedata.normalize("NFKD", caption)
- caption = "".join(
- [chr for chr in caption if not unicodedata.combining(chr)]
- )
-
- captions[i] = caption
-
- # Create a temporary directory and dump the captions corpus as a text file
- # with one caption per line. That's how sentencepiece wants its input.
- tmpdir_path = tempfile.mkdtemp()
-
- with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file:
- for caption in captions:
- captions_file.write(caption + "\n")
-
- # Padding/out-of-vocab token will be "" and ID 0 by default.
- # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during
- # captioning, but good to have to reuse vocabulary across pretext tasks.
- sp.SentencePieceTrainer.train(
- f" --input={os.path.join(tmpdir_path, 'captions.txt')}"
- f" --vocab_size={_A.vocab_size}"
- f" --model_prefix={_A.output_prefix}"
- " --model_type=bpe --character_coverage=1.0"
- " --bos_id=-1 --eos_id=-1"
- " --control_symbols=[SOS],[EOS],[MASK]"
- )
diff --git a/virtex/scripts/preprocess/preprocess_coco.py b/virtex/scripts/preprocess/preprocess_coco.py
deleted file mode 100644
index abf768d94d494a2e8397b596ca7993a638d2d840..0000000000000000000000000000000000000000
--- a/virtex/scripts/preprocess/preprocess_coco.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import argparse
-import os
-import pickle
-import platform
-from typing import Any, List
-
-import albumentations as alb
-import lmdb
-from tqdm import tqdm
-from torch.utils.data import DataLoader
-
-from virtex.data.readers import SimpleCocoCaptionsReader
-
-
-# fmt: off
-parser = argparse.ArgumentParser("Serialize a COCO Captions split to LMDB.")
-parser.add_argument(
- "-d", "--data-root", default="datasets/coco",
- help="Path to the root directory of COCO dataset.",
-)
-parser.add_argument(
- "-s", "--split", choices=["train", "val"],
- help="Which split to process, either `train` or `val`.",
-)
-parser.add_argument(
- "-b", "--batch-size", type=int, default=16,
- help="Batch size to process and serialize data. Set as per CPU memory.",
-)
-parser.add_argument(
- "-j", "--cpu-workers", type=int, default=4,
- help="Number of CPU workers for data loading.",
-)
-parser.add_argument(
- "-e", "--short-edge-size", type=int, default=None,
- help="""Resize shorter edge to this size (keeping aspect ratio constant)
- before serializing. Useful for saving disk memory, and faster read.
- If None, no images are resized."""
-)
-parser.add_argument(
- "-o", "--output", default="datasets/serialized/coco_train2017.lmdb",
- help="Path to store the file containing serialized dataset.",
-)
-
-
-def collate_fn(instances: List[Any]):
- r"""Collate function for data loader to return list of instances as-is."""
- return instances
-
-
-if __name__ == "__main__":
-
- _A = parser.parse_args()
- os.makedirs(os.path.dirname(_A.output), exist_ok=True)
-
- dloader = DataLoader(
- SimpleCocoCaptionsReader(_A.data_root, _A.split),
- batch_size=_A.batch_size,
- num_workers=_A.cpu_workers,
- shuffle=False,
- drop_last=False,
- collate_fn=collate_fn
- )
- # Open an LMDB database.
- # Set a sufficiently large map size for LMDB (based on platform).
- map_size = 1099511627776 * 2 if platform.system() == "Linux" else 1280000
- db = lmdb.open(
- _A.output, map_size=map_size, subdir=False, meminit=False, map_async=True
- )
-
- # Transform to resize shortest edge and keep aspect ratio same.
- if _A.short_edge_size is not None:
- resize = alb.SmallestMaxSize(max_size=_A.short_edge_size, always_apply=True)
-
- # Serialize each instance (as a dictionary). Use `pickle.dumps`. Key will
- # be an integer (cast as string) starting from `0`.
- INSTANCE_COUNTER: int = 0
-
- for idx, batch in enumerate(tqdm(dloader)):
-
- txn = db.begin(write=True)
-
- for instance in batch:
- image = instance["image"]
- width, height, channels = image.shape
-
- # Resize image from instance and convert instance to tuple.
- if _A.short_edge_size is not None and min(width, height) > _A.short_edge_size:
- image = resize(image=image)["image"]
-
- instance = (instance["image_id"], instance["image"], instance["captions"])
- txn.put(
- f"{INSTANCE_COUNTER}".encode("ascii"),
- pickle.dumps(instance, protocol=-1)
- )
- INSTANCE_COUNTER += 1
-
- txn.commit()
-
- db.sync()
- db.close()
diff --git a/virtex/scripts/preprocess/preprocess_redcaps.py b/virtex/scripts/preprocess/preprocess_redcaps.py
deleted file mode 100644
index abb4e9e71e272797e2caf3eed304bc4cdc98f85e..0000000000000000000000000000000000000000
--- a/virtex/scripts/preprocess/preprocess_redcaps.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import argparse
-import json
-import os
-import tarfile
-import tempfile
-from typing import Dict, List
-
-from loguru import logger
-from tqdm import tqdm
-
-
-# fmt: off
-parser = argparse.ArgumentParser(
- description="""Pre-process RedCaps dataset for training VirTex models - make
- small shards of TAR files containing images and captions."""
-)
-parser.add_argument(
- "-a", "--annotations", required=True, help="Path to a RedCaps annotation file."
-)
-parser.add_argument(
- "-i", "--images", default="datasets/redcaps/images",
- help="""Path to RedCaps image directory. This directory is expected to have
- subreddit specific sub-directories containing images.""",
-)
-parser.add_argument(
- "-z", "--shard-size", type=int, default=1000,
- help="Maximum number of RedCaps instances in a single TAR file shard.",
-)
-parser.add_argument(
- "-o", "--output-prefix", required=True,
- help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` "
- "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...",
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
- r"""
- Make TAR files containing images and annotations from a single RedCaps
- annotations file. These TAR files are arranged in a way that
- `WebDataset `_ can understand.
- """
-
- ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"]
-
- # Keep track of the current index of TAR file shard and dataset index.
- SHARD_INDEX: int = 0
- DATASET_INDEX: int = 0
-
- # Create TAR file handle for the initial shard.
- tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w")
-
- # Keep a count of submissions that were skipped because their image was
- # not downloaded (not present in image dir).
- SKIPPED: int = 0
-
- for ann in tqdm(ANNOTATIONS):
-
- image_path = os.path.join(
- _A.images, ann["subreddit"], f"{ann['image_id']}.jpg"
- )
- # Add current image in shard if it exists.
- if os.path.exists(image_path):
-
- tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg")
-
- # Save subreddit name and caption as a JSON file.
- subreddit_and_caption = {
- "subreddit": ann["subreddit"], "caption": ann["caption"]
- }
- tmpfile = tempfile.NamedTemporaryFile("w+")
- tmpfile.write(json.dumps(subreddit_and_caption))
- tmpfile.seek(0)
- tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json")
- tmpfile.close()
-
- DATASET_INDEX += 1
-
- # Create new shard if current shard is full.
- if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0:
- tar_handle.close()
- logger.success(
- f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar"
- )
- SHARD_INDEX += 1
-
- # Open new TAR file shard.
- tar_handle = tarfile.open(
- f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w"
- )
- else:
- SKIPPED += 1
-
- # Close the file handle to properly save it.
- tar_handle.close()
- logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n")
- logger.info(f"Skipped {SKIPPED} annotations due to missing images.")
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
- main(_A)
diff --git a/virtex/scripts/pretrain_virtex.py b/virtex/scripts/pretrain_virtex.py
deleted file mode 100644
index 73e36ed3428c6899876fca9961dbeb81dcb2bd0c..0000000000000000000000000000000000000000
--- a/virtex/scripts/pretrain_virtex.py
+++ /dev/null
@@ -1,239 +0,0 @@
-import argparse
-from collections import Counter
-from typing import Any
-
-from loguru import logger
-import torch
-from torch import nn
-from torch.cuda import amp
-from torch.utils.data import DataLoader, DistributedSampler
-from torch.utils.tensorboard import SummaryWriter
-
-# fmt: off
-from virtex.config import Config
-from virtex.factories import (
- PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory,
- LRSchedulerFactory,
-)
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser, common_setup, cycle
-import virtex.utils.distributed as dist
-from virtex.utils.timer import Timer
-
-
-parser = common_parser(
- description="Train a VirTex model (CNN + Transformer) on COCO Captions."
-)
-group = parser.add_argument_group("Checkpointing and Logging")
-group.add_argument(
- "--resume-from", default=None,
- help="Path to a checkpoint to resume training from (if provided)."
-)
-group.add_argument(
- "--checkpoint-every", type=int, default=2000,
- help="Serialize model to a checkpoint after every these many iterations.",
-)
-group.add_argument(
- "--log-every", type=int, default=20,
- help="""Log training curves to tensorboard after every these many iterations
- only master process logs averaged loss values across processes.""",
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device: Any = torch.device("cpu")
- else:
- # Get the current device as set for current distributed process.
- # Check `launch` function in `virtex.utils.distributed` module.
- device = torch.cuda.current_device()
-
- # Create a config object (this will be immutable) and perform common setup
- # such as logging and setting up serialization directory.
- _C = Config(_A.config, _A.config_override)
- common_setup(_C, _A)
-
- # -------------------------------------------------------------------------
- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER
- # -------------------------------------------------------------------------
- train_dataset = PretrainingDatasetFactory.from_config(_C, split="train")
- val_dataset = PretrainingDatasetFactory.from_config(_C, split="val")
-
- # Make `DistributedSampler`s to shard datasets across GPU processes.
- # Skip this if training on CPUs.
- train_sampler = (
- DistributedSampler(train_dataset, shuffle=True) # type: ignore
- if _A.num_gpus_per_machine > 0
- else None
- )
- val_sampler = (
- DistributedSampler(val_dataset, shuffle=False) # type: ignore
- if _A.num_gpus_per_machine > 0
- else None
- )
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),
- sampler=train_sampler,
- shuffle=train_sampler is None,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- drop_last=True,
- collate_fn=train_dataset.collate_fn,
- )
- val_dataloader = DataLoader(
- val_dataset,
- batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),
- sampler=val_sampler,
- shuffle=False,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- drop_last=False,
- collate_fn=val_dataset.collate_fn,
- )
-
- model = PretrainingModelFactory.from_config(_C).to(device)
- optimizer = OptimizerFactory.from_config(_C, model.named_parameters())
- scheduler = LRSchedulerFactory.from_config(_C, optimizer)
-
- # -------------------------------------------------------------------------
- # BEFORE TRAINING STARTS
- # -------------------------------------------------------------------------
-
- # Create a gradient scaler for automatic mixed precision.
- scaler = amp.GradScaler(enabled=_C.AMP)
-
- # Load checkpoint to resume training if specified.
- if _A.resume_from is not None:
- start_iteration = CheckpointManager(
- model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler,
- ).load(_A.resume_from)
- else:
- start_iteration = 0
-
- # Create an iterator from dataloader to sample batches perpetually.
- train_dataloader_iter = cycle(train_dataloader, device, start_iteration)
-
- # Wrap model in DDP if using more than one processes.
- if dist.get_world_size() > 1:
- dist.synchronize()
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True
- )
-
- # Keep track of time per iteration and ETA.
- timer = Timer(
- start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS
- )
- # Create tensorboard writer and checkpoint manager (only in master process).
- if dist.is_master_process():
- tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)
- tensorboard_writer.add_text("config", f"```\n{_C}\n```")
-
- checkpoint_manager = CheckpointManager(
- _A.serialization_dir,
- model=model,
- optimizer=optimizer,
- scheduler=scheduler,
- scaler=scaler,
- )
-
- # -------------------------------------------------------------------------
- # TRAINING LOOP
- # -------------------------------------------------------------------------
- for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1):
- timer.tic()
- optimizer.zero_grad()
- batch = next(train_dataloader_iter)
-
- with amp.autocast(enabled=_C.AMP):
- output_dict = model(batch)
- loss = output_dict["loss"]
-
- scaler.scale(loss).backward()
-
- # First clip norm of gradients, and then perform optimizer step.
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM)
- scaler.step(optimizer)
-
- scaler.update()
- scheduler.step()
- timer.toc()
-
- # ---------------------------------------------------------------------
- # LOGGING
- # ---------------------------------------------------------------------
- if iteration % _A.log_every == 0:
- logger.info(
- f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]"
- )
- if dist.is_master_process():
- tensorboard_writer.add_scalars(
- "learning_rate",
- {
- "visual": optimizer.param_groups[0]["lr"],
- "common": optimizer.param_groups[-1]["lr"],
- },
- iteration,
- )
- tensorboard_writer.add_scalars(
- "train", output_dict["loss_components"], iteration
- )
-
- # ---------------------------------------------------------------------
- # VALIDATION
- # ---------------------------------------------------------------------
- if iteration % _A.checkpoint_every == 0:
- if dist.is_master_process():
- checkpoint_manager.step(iteration)
-
- # All processes will wait till master process is done serializing.
- dist.synchronize()
-
- torch.set_grad_enabled(False)
- model.eval()
-
- # Accumulate different val loss components according to the type of
- # pretraining model.
- val_loss_counter: Counter = Counter()
-
- for val_iteration, val_batch in enumerate(val_dataloader, start=1):
- for key in val_batch:
- val_batch[key] = val_batch[key].to(device)
- output_dict = model(val_batch)
-
- val_loss_counter.update(output_dict["loss_components"])
-
- # Divide each loss component by number of val batches per GPU.
- val_loss_dict = {
- k: v / val_iteration for k, v in dict(val_loss_counter).items()
- }
- dist.average_across_processes(val_loss_dict)
- torch.set_grad_enabled(True)
- model.train()
-
- logger.info(f"Iteration: {iteration} [Val loss: {val_loss_dict}]")
- if dist.is_master_process():
- tensorboard_writer.add_scalars("val", val_loss_dict, iteration)
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- if _A.num_gpus_per_machine == 0:
- main(_A)
- else:
- # This will launch `main` and set appropriate CUDA device (GPU ID) as
- # per process (accessed in the beginning of `main`).
- dist.launch(
- main,
- num_machines=_A.num_machines,
- num_gpus_per_machine=_A.num_gpus_per_machine,
- machine_rank=_A.machine_rank,
- dist_url=_A.dist_url,
- args=(_A, ),
- )
diff --git a/virtex/scripts/redcaps_caption_decode.py b/virtex/scripts/redcaps_caption_decode.py
deleted file mode 100644
index d63b69ac6a13dc235a3ba4980dba582a9cd75be6..0000000000000000000000000000000000000000
--- a/virtex/scripts/redcaps_caption_decode.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import argparse
-import json
-import os
-from typing import Any, Dict, List
-
-from loguru import logger
-import torch
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-
-import wordsegment as ws
-
-from virtex.config import Config
-from virtex.data import ImageDirectoryDataset
-from virtex.factories import TokenizerFactory, PretrainingModelFactory
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser
-
-ws.load()
-
-# fmt: off
-parser = common_parser(
- description="Decode captions using a RedCaps-pretrained VirTex model."
-)
-parser.add_argument(
- "--images", required=True,
- help="Path to a directory containing image files to generate captions for."
-)
-parser.add_argument(
- "--checkpoint-path", required=True,
- help="Path to load checkpoint and run captioning evaluation."
-)
-parser.add_argument(
- "--output", required=True,
- help="Path to save predictions as a JSON file."
-)
-parser.add_argument(
- "--subreddit-prompt", default=None,
- help="Optional subreddit prompt for controllable subreddit-style captioning."
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device = torch.device("cpu")
- else:
- # Get the current device (this will be zero here by default).
- device = torch.cuda.current_device()
-
- _C = Config(_A.config, _A.config_override)
-
- tokenizer = TokenizerFactory.from_config(_C)
-
- val_dataloader = DataLoader(
- ImageDirectoryDataset(_A.images),
- batch_size=_C.OPTIM.BATCH_SIZE,
- num_workers=_A.cpu_workers,
- pin_memory=True,
- )
- # Initialize model from a checkpoint.
- model = PretrainingModelFactory.from_config(_C).to(device)
- CheckpointManager(model=model).load(_A.checkpoint_path)
- model.eval()
-
- # Prepare subreddit prompt for the model if provided.
- if _A.subreddit_prompt is not None:
-
- # Remove "r/" if provided.
- _A.subreddit_prompt = _A.subreddit_prompt.replace("r/", "")
-
- # Word segmenting (e.g. "itookapicture" -> "i took a picture").
- _segments = " ".join(ws.segment(ws.clean(_A.subreddit_prompt)))
- subreddit_tokens = (
- [model.sos_index]
- + tokenizer.encode(_segments)
- + [tokenizer.token_to_id("[SEP]")]
- )
- else:
- # Just seed the model with [SOS]
- subreddit_tokens = [model.sos_index]
-
- # Shift the subreddit prompt to appropriate device.
- subreddit_tokens = torch.tensor(subreddit_tokens, device=device).long()
-
- # Make a list of predictions to evaluate.
- predictions: List[Dict[str, Any]] = []
-
- for val_batch in tqdm(val_dataloader):
- val_batch["image"] = val_batch["image"].to(device)
-
- # Add the subreddit tokens as decoding prompt to batch.
- val_batch["decode_prompt"] = subreddit_tokens
-
- with torch.no_grad():
- output_dict = model(val_batch)
-
- for idx, (image_id, caption) in enumerate(
- zip(val_batch["image_id"], output_dict["predictions"])
- ):
- caption = caption.tolist()
-
- # Replace [SOS] index with "::" temporarily so it gets decoded.
- if tokenizer.token_to_id("[SEP]") in caption:
- sos_index = caption.index(tokenizer.token_to_id("[SEP]"))
- caption[sos_index] = tokenizer.token_to_id("::")
-
- caption = tokenizer.decode(caption)
-
- # Separate out subreddit from the rest of caption.
- if "::" in caption:
- subreddit, rest_of_caption = caption.split("::")
- subreddit = "".join(subreddit.split())
- rest_of_caption = rest_of_caption.strip()
- else:
- subreddit, rest_of_caption = "", caption
-
- predictions.append(
- {"image_id": image_id, "subreddit": subreddit, "caption": rest_of_caption}
- )
-
- logger.info("Displaying first 25 caption predictions:")
- for pred in predictions[:25]:
- logger.info(f"{pred['image_id']} - r/{pred['subreddit']}:: {pred['caption']}")
-
- # Save predictions as a JSON file.
- os.makedirs(os.path.dirname(_A.output), exist_ok=True)
- json.dump(predictions, open(_A.output, "w"))
- logger.info(f"Saved predictions to {_A.output}")
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
- if _A.num_gpus_per_machine > 1:
- raise ValueError("Using multiple GPUs is not supported for this script.")
-
- # No distributed training here, just a single process.
- main(_A)
diff --git a/virtex/scripts/redcaps_train.py b/virtex/scripts/redcaps_train.py
deleted file mode 100644
index b8c63010361c80ffcdca873a8166448aa2f359ef..0000000000000000000000000000000000000000
--- a/virtex/scripts/redcaps_train.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import argparse
-import os
-import tempfile
-from typing import Any
-
-from loguru import logger
-import torch
-from torch import nn
-from torch.cuda import amp
-from torch.utils.data import DataLoader
-from torch.utils.tensorboard import SummaryWriter
-
-# fmt: off
-from virtex.config import Config
-from virtex.factories import (
- PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory,
- LRSchedulerFactory,
-)
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser, common_setup, cycle
-import virtex.utils.distributed as dist
-from virtex.utils.timer import Timer
-
-
-parser = common_parser(
- description="Train a VirTex model (CNN + Transformer) on COCO Captions."
-)
-group = parser.add_argument_group("Checkpointing and Logging")
-group.add_argument(
- "--resume-from", default=None,
- help="Path to a checkpoint to resume training from (if provided)."
-)
-group.add_argument(
- "--checkpoint-every", type=int, default=2000,
- help="Serialize model to a checkpoint after every these many iterations.",
-)
-group.add_argument(
- "--log-every", type=int, default=50,
- help="""Log training curves to tensorboard after every these many iterations
- only master process logs averaged loss values across processes.""",
-)
-# fmt: on
-
-
-def main(_A: argparse.Namespace):
-
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device: Any = torch.device("cpu")
- else:
- # Get the current device as set for current distributed process.
- # Check `launch` function in `virtex.utils.distributed` module.
- device = torch.cuda.current_device()
-
- # Create a config object (this will be immutable) and perform common setup
- # such as logging and setting up serialization directory.
- _C = Config(_A.config, _A.config_override)
- common_setup(_C, _A)
-
- # -------------------------------------------------------------------------
- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER
- # -------------------------------------------------------------------------
-
- # fmt: off
- train_dataset = PretrainingDatasetFactory.from_config(_C)
- train_dataloader = DataLoader(
- train_dataset, batch_size=None, shuffle=False,
- num_workers=_A.cpu_workers, pin_memory=True,
- )
- # fmt: on
-
- model = PretrainingModelFactory.from_config(_C).to(device)
- optimizer = OptimizerFactory.from_config(_C, model.named_parameters())
- scheduler = LRSchedulerFactory.from_config(_C, optimizer)
-
- # -------------------------------------------------------------------------
- # BEFORE TRAINING STARTS
- # -------------------------------------------------------------------------
-
- # Create a gradient scaler for automatic mixed precision.
- scaler = amp.GradScaler(enabled=_C.AMP)
-
- # Load checkpoint to resume training if specified.
- if _A.resume_from is not None:
- start_iteration = CheckpointManager(
- model=model, optimizer=optimizer, scheduler=scheduler,
- ).load(_A.resume_from)
- else:
- start_iteration = 0
-
- # Create an iterator from dataloader to sample batches perpetually.
- train_dataloader_iter = cycle(train_dataloader, device, start_iteration)
-
- # Wrap model in DDP if using more than one processes.
- if dist.get_world_size() > 1:
- dist.synchronize()
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True
- )
-
- # Keep track of time per iteration and ETA.
- timer = Timer(
- start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS
- )
- # Create tensorboard writer and checkpoint manager (only in master process).
- if dist.is_master_process():
- tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)
- tensorboard_writer.add_text("config", f"```\n{_C}\n```")
-
- checkpoint_manager = CheckpointManager(
- _A.serialization_dir,
- model=model,
- optimizer=optimizer,
- scheduler=scheduler,
- scaler=scaler,
- )
-
- # -------------------------------------------------------------------------
- # TRAINING LOOP
- # -------------------------------------------------------------------------
- for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1):
- timer.tic()
- optimizer.zero_grad()
- batch = next(train_dataloader_iter)
-
- with amp.autocast(enabled=_C.AMP):
- output_dict = model(batch)
- loss = output_dict["loss"]
-
- scaler.scale(loss).backward()
-
- # First clip norm of gradients, and then perform optimizer step.
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM)
- scaler.step(optimizer)
-
- scaler.update()
- scheduler.step()
- timer.toc()
-
- # ---------------------------------------------------------------------
- # LOGGING
- # ---------------------------------------------------------------------
- if iteration % _A.log_every == 0:
- logger.info(
- f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]"
- )
- if dist.is_master_process():
- tensorboard_writer.add_scalars(
- "train", output_dict["loss_components"], iteration
- )
-
- if iteration % _A.checkpoint_every == 0 and dist.is_master_process():
- checkpoint_manager.step(iteration)
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
-
- if _A.num_gpus_per_machine == 0:
- main(_A)
- else:
- # This will launch `main` and set appropriate CUDA device (GPU ID) as
- # per process (accessed in the beginning of `main`).
- dist.launch(
- main,
- num_machines=_A.num_machines,
- num_gpus_per_machine=_A.num_gpus_per_machine,
- machine_rank=_A.machine_rank,
- dist_url=_A.dist_url,
- args=(_A, ),
- )
diff --git a/virtex/scripts/zero_shot_classification.py b/virtex/scripts/zero_shot_classification.py
deleted file mode 100644
index fc8523f6012aab90a86f77c1ba235ad740848fe2..0000000000000000000000000000000000000000
--- a/virtex/scripts/zero_shot_classification.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import argparse
-import json
-import os
-import random
-from typing import Any, Dict, List
-
-from loguru import logger
-import torch
-from torch.utils.data import DataLoader, DistributedSampler
-from torch.nn.utils.rnn import pad_sequence
-from tqdm import tqdm
-
-import wordsegment as ws
-
-from virtex.config import Config
-from virtex.data import ZeroShotDataset
-
-from virtex.data.tokenizers import SentencePieceBPETokenizer
-
-from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory
-from virtex.utils.checkpointing import CheckpointManager
-from virtex.utils.common import common_parser
-from virtex.utils.metrics import TopkAccuracy
-import virtex.utils.distributed as dist
-
-
-#importing classifier
-from virtex.models.zero_shot_classification_eval import ZeroShotClassifier
-
-ws.load()
-
-# fmt: off
-parser = common_parser(
- description="""Run image captioning inference on a pretrained model, and/or
- evaluate pretrained model on COCO Captions val2017 split."""
-)
-parser.add_argument(
- "--data-root", default=None,
- help="""Path to a directory containing image files to generate captions for imagenet.
- Default: COCO val2017 image directory as expected relative to project root."""
-)
-parser.add_argument(
- "--checkpoint-path", required=False,
- help="Path to load checkpoint and run captioning evaluation."
-)
-parser.add_argument(
- "--output", default=None,
- help="Path to save predictions as a JSON file."
-)
-parser.add_argument(
- "--calc-metrics", action="store_true",
- help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions.
- This flag should not be set when running inference on arbitrary images."""
-)
-
-parser.add_argument(
- "--idx_label_dict", default=None, required=False,
- help="""a dictionary that maps from lable index to label string for classification"""
-)
-parser.add_argument(
- "--is_redcaps", default=None, required=False,
- help="""a dictionary that maps from lable index to label string for"""
-)
-parser.add_argument(
- "--prompt_cls_sos", default=None, required=False,
- help="""a dictionary that maps from lable index to label string for"""
-)
-parser.add_argument(
- "--prompt_sos_eos", default=None, required=False,
- help="""a dictionary that maps from lable index to label string for"""
-)
-# fmt: on
-
-print("###########")
-print(os.getcwd() )
-print("###########")
-
-tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model")
-
-def main(_A: argparse.Namespace):
- if _A.num_gpus_per_machine == 0:
- # Set device as CPU if num_gpus_per_machine = 0.
- device = torch.device("cpu")
- else:
- # Get the current device (this will be zero here by default).
- device = torch.cuda.current_device()
-
- _C = Config(_A.config, _A.config_override)
-
- #tokenizer = TokenizerFactory.from_config(_C)
-
- if _A.data_root is None:
- _A.data_root = os.path.join(_C.DATA.ROOT, "val2017")
-
- if _A.is_redcaps == 1:
- model_dataset = 'redcaps'
- else:
- model_dataset = 'gcc or sbu'
-
- print(_A.idx_label_dict)
-
- val_dataset = ZeroShotDataset(data_root=_A.data_root,
- split="test/",
- label_map=_A.idx_label_dict,
- tokenizer=tokenizer,
- prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "),
- prompt_sos_eos=_A.prompt_sos_eos.replace("_", " "))
-
- val_dataloader = DataLoader(
- val_dataset,
- batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(),
- num_workers=_A.cpu_workers,
- sampler=DistributedSampler(
- val_dataset,
- num_replicas=dist.get_world_size(),
- rank=dist.get_rank(),
- ),
- pin_memory=True,
- drop_last=False,
- collate_fn=val_dataset.collate_fn,
- )
-
- # Initialize model from a checkpoint
- visual = VisualBackboneFactory.from_config(_C)
- textual = TextualHeadFactory.from_config(_C)
- model = ZeroShotClassifier(visual,textual)
- ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)
- model.to(device).eval()
-
- ## setup distributed training
- if dist.get_world_size() > 1:
- dist.synchronize()
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[device], find_unused_parameters=True
- )
-
- top_1 = TopkAccuracy(top_k=1)
- top_5 = TopkAccuracy(top_k=5)
- batch_num = 0
-
-
- for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)):
- val_batch["image"] = val_batch["image"].to(device)
- val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device)
- val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device)
- val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device)
- val_batch["label"] = val_batch["label"].to(device)
-
- with torch.no_grad():
- classification_losses = model(val_batch)
-
- batch_num+=1
- top_1(classification_losses, val_batch["label"])
- top_1_acc = top_1.get_metric(reset=False)
- dist.average_across_processes(top_1_acc)
-
- top_5(classification_losses, val_batch["label"])
- top_5_acc = top_5.get_metric(reset=False)
- dist.average_across_processes(top_5_acc)
-
- logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}")
-
-
-
-if __name__ == "__main__":
- _A = parser.parse_args()
- #if _A.num_gpus_per_machine > 1:
- # raise ValueError("Using multiple GPUs is not supported for this script.")
-
- # No distributed training here, just a single process.
- main(_A)
diff --git a/virtex/setup.py b/virtex/setup.py
deleted file mode 100644
index fc715695a0b1e6eb83a52205c9fec3224131bb21..0000000000000000000000000000000000000000
--- a/virtex/setup.py
+++ /dev/null
@@ -1,51 +0,0 @@
-#!/usr/bin/env python
-import glob
-import os
-from setuptools import setup
-import shutil
-from typing import List
-
-
-def get_model_zoo_configs() -> List[str]:
- """
- Return a list of configs to include in package for model zoo. Copy over
- these configs inside virtex/model_zoo.
- """
-
- # Use absolute paths while symlinking.
- source_configs_dir = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "configs"
- )
- destination = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs"
- )
- # Symlink the config directory inside package to have a cleaner pip install.
-
- # Remove stale symlink/directory from a previous build.
- if os.path.exists(source_configs_dir):
- if os.path.islink(destination):
- os.unlink(destination)
- elif os.path.isdir(destination):
- shutil.rmtree(destination)
-
- if not os.path.exists(destination):
- try:
- os.symlink(source_configs_dir, destination)
- except OSError:
- # Fall back to copying if symlink fails: ex. on Windows.
- shutil.copytree(source_configs_dir, destination)
-
- config_paths = glob.glob("configs/**/*.yaml", recursive=True)
- return config_paths
-
-
-setup(
- name="virtex",
- version="1.1.0",
- author="Karan Desai and Justin Johnson",
- description="VirTex: Learning Visual Representations with Textual Annotations",
- package_data={"virtex.model_zoo": get_model_zoo_configs()},
- python_requires=">=3.6",
- license="MIT",
- zip_safe=True,
-)
diff --git a/virtex/virtex/utils/assets/download_spice.sh b/virtex/utils/assets/download_spice.sh
similarity index 100%
rename from virtex/virtex/utils/assets/download_spice.sh
rename to virtex/utils/assets/download_spice.sh
diff --git a/virtex/virtex/utils/beam_search.py b/virtex/utils/beam_search.py
similarity index 100%
rename from virtex/virtex/utils/beam_search.py
rename to virtex/utils/beam_search.py
diff --git a/virtex/virtex/utils/checkpointing.py b/virtex/utils/checkpointing.py
similarity index 100%
rename from virtex/virtex/utils/checkpointing.py
rename to virtex/utils/checkpointing.py
diff --git a/virtex/virtex/utils/common.py b/virtex/utils/common.py
similarity index 100%
rename from virtex/virtex/utils/common.py
rename to virtex/utils/common.py
diff --git a/virtex/virtex/utils/distributed.py b/virtex/utils/distributed.py
similarity index 100%
rename from virtex/virtex/utils/distributed.py
rename to virtex/utils/distributed.py
diff --git a/virtex/virtex/utils/metrics.py b/virtex/utils/metrics.py
similarity index 100%
rename from virtex/virtex/utils/metrics.py
rename to virtex/utils/metrics.py
diff --git a/virtex/virtex/utils/nucleus_sampling.py b/virtex/utils/nucleus_sampling.py
similarity index 100%
rename from virtex/virtex/utils/nucleus_sampling.py
rename to virtex/utils/nucleus_sampling.py
diff --git a/virtex/virtex/utils/timer.py b/virtex/utils/timer.py
similarity index 100%
rename from virtex/virtex/utils/timer.py
rename to virtex/utils/timer.py