Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
import os | |
import random | |
from typing import Any, Dict, List | |
from loguru import logger | |
import torch | |
from torch.utils.data import DataLoader, DistributedSampler | |
from torch.nn.utils.rnn import pad_sequence | |
from tqdm import tqdm | |
import wordsegment as ws | |
from virtex.config import Config | |
from virtex.data import ZeroShotDataset | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory | |
from virtex.utils.checkpointing import CheckpointManager | |
from virtex.utils.common import common_parser | |
from virtex.utils.metrics import TopkAccuracy | |
import virtex.utils.distributed as dist | |
#importing classifier | |
from virtex.models.zero_shot_classification_eval import ZeroShotClassifier | |
ws.load() | |
# fmt: off | |
parser = common_parser( | |
description="""Run image captioning inference on a pretrained model, and/or | |
evaluate pretrained model on COCO Captions val2017 split.""" | |
) | |
parser.add_argument( | |
"--data-root", default=None, | |
help="""Path to a directory containing image files to generate captions for imagenet. | |
Default: COCO val2017 image directory as expected relative to project root.""" | |
) | |
parser.add_argument( | |
"--checkpoint-path", required=False, | |
help="Path to load checkpoint and run captioning evaluation." | |
) | |
parser.add_argument( | |
"--output", default=None, | |
help="Path to save predictions as a JSON file." | |
) | |
parser.add_argument( | |
"--calc-metrics", action="store_true", | |
help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. | |
This flag should not be set when running inference on arbitrary images.""" | |
) | |
parser.add_argument( | |
"--idx_label_dict", default=None, required=False, | |
help="""a dictionary that maps from lable index to label string for classification""" | |
) | |
parser.add_argument( | |
"--is_redcaps", default=None, required=False, | |
help="""a dictionary that maps from lable index to label string for""" | |
) | |
parser.add_argument( | |
"--prompt_cls_sos", default=None, required=False, | |
help="""a dictionary that maps from lable index to label string for""" | |
) | |
parser.add_argument( | |
"--prompt_sos_eos", default=None, required=False, | |
help="""a dictionary that maps from lable index to label string for""" | |
) | |
# fmt: on | |
print("###########") | |
print(os.getcwd() ) | |
print("###########") | |
tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model") | |
def main(_A: argparse.Namespace): | |
if _A.num_gpus_per_machine == 0: | |
# Set device as CPU if num_gpus_per_machine = 0. | |
device = torch.device("cpu") | |
else: | |
# Get the current device (this will be zero here by default). | |
device = torch.cuda.current_device() | |
_C = Config(_A.config, _A.config_override) | |
#tokenizer = TokenizerFactory.from_config(_C) | |
if _A.data_root is None: | |
_A.data_root = os.path.join(_C.DATA.ROOT, "val2017") | |
if _A.is_redcaps == 1: | |
model_dataset = 'redcaps' | |
else: | |
model_dataset = 'gcc or sbu' | |
print(_A.idx_label_dict) | |
val_dataset = ZeroShotDataset(data_root=_A.data_root, | |
split="test/", | |
label_map=_A.idx_label_dict, | |
tokenizer=tokenizer, | |
prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "), | |
prompt_sos_eos=_A.prompt_sos_eos.replace("_", " ")) | |
val_dataloader = DataLoader( | |
val_dataset, | |
batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(), | |
num_workers=_A.cpu_workers, | |
sampler=DistributedSampler( | |
val_dataset, | |
num_replicas=dist.get_world_size(), | |
rank=dist.get_rank(), | |
), | |
pin_memory=True, | |
drop_last=False, | |
collate_fn=val_dataset.collate_fn, | |
) | |
# Initialize model from a checkpoint | |
visual = VisualBackboneFactory.from_config(_C) | |
textual = TextualHeadFactory.from_config(_C) | |
model = ZeroShotClassifier(visual,textual) | |
ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) | |
model.to(device).eval() | |
## setup distributed training | |
if dist.get_world_size() > 1: | |
dist.synchronize() | |
model = nn.parallel.DistributedDataParallel( | |
model, device_ids=[device], find_unused_parameters=True | |
) | |
top_1 = TopkAccuracy(top_k=1) | |
top_5 = TopkAccuracy(top_k=5) | |
batch_num = 0 | |
for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)): | |
val_batch["image"] = val_batch["image"].to(device) | |
val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device) | |
val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device) | |
val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device) | |
val_batch["label"] = val_batch["label"].to(device) | |
with torch.no_grad(): | |
classification_losses = model(val_batch) | |
batch_num+=1 | |
top_1(classification_losses, val_batch["label"]) | |
top_1_acc = top_1.get_metric(reset=False) | |
dist.average_across_processes(top_1_acc) | |
top_5(classification_losses, val_batch["label"]) | |
top_5_acc = top_5.get_metric(reset=False) | |
dist.average_across_processes(top_5_acc) | |
logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}") | |
if __name__ == "__main__": | |
_A = parser.parse_args() | |
#if _A.num_gpus_per_machine > 1: | |
# raise ValueError("Using multiple GPUs is not supported for this script.") | |
# No distributed training here, just a single process. | |
main(_A) | |