Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
import os | |
from typing import Any, Dict, List | |
from loguru import logger | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import wordsegment as ws | |
from virtex.config import Config | |
from virtex.data import ImageDirectoryDataset | |
from virtex.factories import TokenizerFactory, PretrainingModelFactory | |
from virtex.utils.checkpointing import CheckpointManager | |
from virtex.utils.common import common_parser | |
ws.load() | |
# fmt: off | |
parser = common_parser( | |
description="Decode captions using a RedCaps-pretrained VirTex model." | |
) | |
parser.add_argument( | |
"--images", required=True, | |
help="Path to a directory containing image files to generate captions for." | |
) | |
parser.add_argument( | |
"--checkpoint-path", required=True, | |
help="Path to load checkpoint and run captioning evaluation." | |
) | |
parser.add_argument( | |
"--output", required=True, | |
help="Path to save predictions as a JSON file." | |
) | |
parser.add_argument( | |
"--subreddit-prompt", default=None, | |
help="Optional subreddit prompt for controllable subreddit-style captioning." | |
) | |
# fmt: on | |
def main(_A: argparse.Namespace): | |
if _A.num_gpus_per_machine == 0: | |
# Set device as CPU if num_gpus_per_machine = 0. | |
device = torch.device("cpu") | |
else: | |
# Get the current device (this will be zero here by default). | |
device = torch.cuda.current_device() | |
_C = Config(_A.config, _A.config_override) | |
tokenizer = TokenizerFactory.from_config(_C) | |
val_dataloader = DataLoader( | |
ImageDirectoryDataset(_A.images), | |
batch_size=_C.OPTIM.BATCH_SIZE, | |
num_workers=_A.cpu_workers, | |
pin_memory=True, | |
) | |
# Initialize model from a checkpoint. | |
model = PretrainingModelFactory.from_config(_C).to(device) | |
CheckpointManager(model=model).load(_A.checkpoint_path) | |
model.eval() | |
# Prepare subreddit prompt for the model if provided. | |
if _A.subreddit_prompt is not None: | |
# Remove "r/" if provided. | |
_A.subreddit_prompt = _A.subreddit_prompt.replace("r/", "") | |
# Word segmenting (e.g. "itookapicture" -> "i took a picture"). | |
_segments = " ".join(ws.segment(ws.clean(_A.subreddit_prompt))) | |
subreddit_tokens = ( | |
[model.sos_index] | |
+ tokenizer.encode(_segments) | |
+ [tokenizer.token_to_id("[SEP]")] | |
) | |
else: | |
# Just seed the model with [SOS] | |
subreddit_tokens = [model.sos_index] | |
# Shift the subreddit prompt to appropriate device. | |
subreddit_tokens = torch.tensor(subreddit_tokens, device=device).long() | |
# Make a list of predictions to evaluate. | |
predictions: List[Dict[str, Any]] = [] | |
for val_batch in tqdm(val_dataloader): | |
val_batch["image"] = val_batch["image"].to(device) | |
# Add the subreddit tokens as decoding prompt to batch. | |
val_batch["decode_prompt"] = subreddit_tokens | |
with torch.no_grad(): | |
output_dict = model(val_batch) | |
for idx, (image_id, caption) in enumerate( | |
zip(val_batch["image_id"], output_dict["predictions"]) | |
): | |
caption = caption.tolist() | |
# Replace [SOS] index with "::" temporarily so it gets decoded. | |
if tokenizer.token_to_id("[SEP]") in caption: | |
sos_index = caption.index(tokenizer.token_to_id("[SEP]")) | |
caption[sos_index] = tokenizer.token_to_id("::") | |
caption = tokenizer.decode(caption) | |
# Separate out subreddit from the rest of caption. | |
if "::" in caption: | |
subreddit, rest_of_caption = caption.split("::") | |
subreddit = "".join(subreddit.split()) | |
rest_of_caption = rest_of_caption.strip() | |
else: | |
subreddit, rest_of_caption = "", caption | |
predictions.append( | |
{"image_id": image_id, "subreddit": subreddit, "caption": rest_of_caption} | |
) | |
logger.info("Displaying first 25 caption predictions:") | |
for pred in predictions[:25]: | |
logger.info(f"{pred['image_id']} - r/{pred['subreddit']}:: {pred['caption']}") | |
# Save predictions as a JSON file. | |
os.makedirs(os.path.dirname(_A.output), exist_ok=True) | |
json.dump(predictions, open(_A.output, "w")) | |
logger.info(f"Saved predictions to {_A.output}") | |
if __name__ == "__main__": | |
_A = parser.parse_args() | |
if _A.num_gpus_per_machine > 1: | |
raise ValueError("Using multiple GPUs is not supported for this script.") | |
# No distributed training here, just a single process. | |
main(_A) | |