import argparse import json import os import random from typing import Any, Dict, List from loguru import logger import torch from torch.utils.data import DataLoader, DistributedSampler from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm import wordsegment as ws from virtex.config import Config from virtex.data import ZeroShotDataset from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser from virtex.utils.metrics import TopkAccuracy import virtex.utils.distributed as dist #importing classifier from virtex.models.zero_shot_classification_eval import ZeroShotClassifier ws.load() # fmt: off parser = common_parser( description="""Run image captioning inference on a pretrained model, and/or evaluate pretrained model on COCO Captions val2017 split.""" ) parser.add_argument( "--data-root", default=None, help="""Path to a directory containing image files to generate captions for imagenet. Default: COCO val2017 image directory as expected relative to project root.""" ) parser.add_argument( "--checkpoint-path", required=False, help="Path to load checkpoint and run captioning evaluation." ) parser.add_argument( "--output", default=None, help="Path to save predictions as a JSON file." ) parser.add_argument( "--calc-metrics", action="store_true", help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This flag should not be set when running inference on arbitrary images.""" ) parser.add_argument( "--idx_label_dict", default=None, required=False, help="""a dictionary that maps from lable index to label string for classification""" ) parser.add_argument( "--is_redcaps", default=None, required=False, help="""a dictionary that maps from lable index to label string for""" ) parser.add_argument( "--prompt_cls_sos", default=None, required=False, help="""a dictionary that maps from lable index to label string for""" ) parser.add_argument( "--prompt_sos_eos", default=None, required=False, help="""a dictionary that maps from lable index to label string for""" ) # fmt: on print("###########") print(os.getcwd() ) print("###########") tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model") def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device = torch.device("cpu") else: # Get the current device (this will be zero here by default). device = torch.cuda.current_device() _C = Config(_A.config, _A.config_override) #tokenizer = TokenizerFactory.from_config(_C) if _A.data_root is None: _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") if _A.is_redcaps == 1: model_dataset = 'redcaps' else: model_dataset = 'gcc or sbu' print(_A.idx_label_dict) val_dataset = ZeroShotDataset(data_root=_A.data_root, split="test/", label_map=_A.idx_label_dict, tokenizer=tokenizer, prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "), prompt_sos_eos=_A.prompt_sos_eos.replace("_", " ")) val_dataloader = DataLoader( val_dataset, batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(), num_workers=_A.cpu_workers, sampler=DistributedSampler( val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), ), pin_memory=True, drop_last=False, collate_fn=val_dataset.collate_fn, ) # Initialize model from a checkpoint visual = VisualBackboneFactory.from_config(_C) textual = TextualHeadFactory.from_config(_C) model = ZeroShotClassifier(visual,textual) ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) model.to(device).eval() ## setup distributed training if dist.get_world_size() > 1: dist.synchronize() model = nn.parallel.DistributedDataParallel( model, device_ids=[device], find_unused_parameters=True ) top_1 = TopkAccuracy(top_k=1) top_5 = TopkAccuracy(top_k=5) batch_num = 0 for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)): val_batch["image"] = val_batch["image"].to(device) val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device) val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device) val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device) val_batch["label"] = val_batch["label"].to(device) with torch.no_grad(): classification_losses = model(val_batch) batch_num+=1 top_1(classification_losses, val_batch["label"]) top_1_acc = top_1.get_metric(reset=False) dist.average_across_processes(top_1_acc) top_5(classification_losses, val_batch["label"]) top_5_acc = top_5.get_metric(reset=False) dist.average_across_processes(top_5_acc) logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}") if __name__ == "__main__": _A = parser.parse_args() #if _A.num_gpus_per_machine > 1: # raise ValueError("Using multiple GPUs is not supported for this script.") # No distributed training here, just a single process. main(_A)