File size: 3,217 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import glob
import json
import os
import re
import tempfile
from functools import lru_cache
from typing import List

import ftfy
import sentencepiece as sp
import wordsegment as ws
from tqdm import tqdm


ws.load()

# fmt: off
parser = argparse.ArgumentParser(
    description="""Build a vocabulary out of captions corpus. This vocabulary
    would be a file which our tokenizer can understand.
    """
)
parser.add_argument(
    "-f", "--files", nargs="+", default="datasets/redcaps/annotations/*.json",
    help="Path(s) to SBU, Conceptual, or RedCaps annotation files.",
)
parser.add_argument(
    "-s", "--vocab-size", type=int, default=32000,
    help="Total desired size of our vocabulary.",
)
parser.add_argument(
    "-o", "--output-prefix", default="datasets/vocab/redcaps_32k",
    help="Prefix of the files to be saved. Two files will be saved: "
    "[prefix].model and [prefix].vocab",
)
# fmt: on


def read_captions_from_file(annotations_path: str) -> List[str]:
    r"""
    Given a path to annotation file, read it and return a list of captions.

    Parameters
    ----------
    annotations_path: str
        Path to an annotations file containing captions.

    Returns
    -------
    List[str]
        List of captions from this annotation file.
    """

    _annotations = json.load(open(annotations_path))

    captions: List[str] = []
    for ann in tqdm(_annotations["annotations"], desc=annotations_path):

        # This field only exists in RedCaps. Perform word segmentation on the
        # subreddit name to add appropriae whitespaces.
        if "subreddit" in ann:
            subreddit_seg = _segment_subreddit(ann["subreddit"].lower())
            caption = f"{subreddit_seg} {ann['caption']}"
        else:
            caption = ann["caption"]

        captions.append(caption.lower())
    return captions


@lru_cache(maxsize=10)
def _segment_subreddit(subreddit):
    return " ".join(ws.segment(ws.clean(subreddit)))


if __name__ == "__main__":
    _A = parser.parse_args()

    all_filepaths: List[str] = []
    for f in _A.files:
        all_filepaths.extend(glob.glob(f))

    captions: List[str] = []
    for path in tqdm(all_filepaths, desc="Reading captions"):
        captions.extend(read_captions_from_file(path))

    # Create a temporary directory and dump the captions corpus as a text file
    # with one caption per line. That's how sentencepiece wants its input.
    tmpdir_path = tempfile.mkdtemp()

    with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file:
        for caption in captions:
            captions_file.write(caption + "\n")

    # Padding/out-of-vocab token will be "<unk>" and ID 0 by default.
    # Add [SOS],[EOS] and [SEP] tokens. [SEP] will not be used during
    # captioning, but good to have to reuse vocabulary across pretext tasks.
    sp.SentencePieceTrainer.train(
        f" --input={os.path.join(tmpdir_path, 'captions.txt')}"
        f" --vocab_size={_A.vocab_size}"
        f" --model_prefix={_A.output_prefix}"
        " --model_type=bpe --character_coverage=1.0"
        " --bos_id=-1 --eos_id=-1"
        " --control_symbols=[SOS],[EOS],[SEP]"
        " --user_defined_symbols=<usr>"
    )