virtex-redcaps / virtex /scripts /preprocess /build_redcaps_vocab.py
zamborg's picture
added datasets and virtex
a5f8a35
raw history blame
No virus
3.22 kB
import argparse
import glob
import json
import os
import re
import tempfile
from functools import lru_cache
from typing import List
import ftfy
import sentencepiece as sp
import wordsegment as ws
from tqdm import tqdm
ws.load()
# fmt: off
parser = argparse.ArgumentParser(
description="""Build a vocabulary out of captions corpus. This vocabulary
would be a file which our tokenizer can understand.
"""
)
parser.add_argument(
"-f", "--files", nargs="+", default="datasets/redcaps/annotations/*.json",
help="Path(s) to SBU, Conceptual, or RedCaps annotation files.",
)
parser.add_argument(
"-s", "--vocab-size", type=int, default=32000,
help="Total desired size of our vocabulary.",
)
parser.add_argument(
"-o", "--output-prefix", default="datasets/vocab/redcaps_32k",
help="Prefix of the files to be saved. Two files will be saved: "
"[prefix].model and [prefix].vocab",
)
# fmt: on
def read_captions_from_file(annotations_path: str) -> List[str]:
r"""
Given a path to annotation file, read it and return a list of captions.
Parameters
----------
annotations_path: str
Path to an annotations file containing captions.
Returns
-------
List[str]
List of captions from this annotation file.
"""
_annotations = json.load(open(annotations_path))
captions: List[str] = []
for ann in tqdm(_annotations["annotations"], desc=annotations_path):
# This field only exists in RedCaps. Perform word segmentation on the
# subreddit name to add appropriae whitespaces.
if "subreddit" in ann:
subreddit_seg = _segment_subreddit(ann["subreddit"].lower())
caption = f"{subreddit_seg} {ann['caption']}"
else:
caption = ann["caption"]
captions.append(caption.lower())
return captions
@lru_cache(maxsize=10)
def _segment_subreddit(subreddit):
return " ".join(ws.segment(ws.clean(subreddit)))
if __name__ == "__main__":
_A = parser.parse_args()
all_filepaths: List[str] = []
for f in _A.files:
all_filepaths.extend(glob.glob(f))
captions: List[str] = []
for path in tqdm(all_filepaths, desc="Reading captions"):
captions.extend(read_captions_from_file(path))
# Create a temporary directory and dump the captions corpus as a text file
# with one caption per line. That's how sentencepiece wants its input.
tmpdir_path = tempfile.mkdtemp()
with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file:
for caption in captions:
captions_file.write(caption + "\n")
# Padding/out-of-vocab token will be "<unk>" and ID 0 by default.
# Add [SOS],[EOS] and [SEP] tokens. [SEP] will not be used during
# captioning, but good to have to reuse vocabulary across pretext tasks.
sp.SentencePieceTrainer.train(
f" --input={os.path.join(tmpdir_path, 'captions.txt')}"
f" --vocab_size={_A.vocab_size}"
f" --model_prefix={_A.output_prefix}"
" --model_type=bpe --character_coverage=1.0"
" --bos_id=-1 --eos_id=-1"
" --control_symbols=[SOS],[EOS],[SEP]"
" --user_defined_symbols=<usr>"
)