Spaces:
Build error
Build error
File size: 3,184 Bytes
8121fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import json
from tqdm import tqdm
import re
import fire
def tokenize_caption(input_json: str,
keep_punctuation: bool = False,
host_address: str = None,
character_level: bool = False,
zh: bool = True,
output_json: str = None):
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
Args:
input_json(string): Preprossessed json file. Structure like this:
{
'audios': [
{
'audio_id': 'xxx',
'captions': [
{
'caption': 'xxx',
'cap_id': 'xxx'
}
]
},
...
]
}
threshold (int): Threshold to drop all words with counts < threshold
keep_punctuation (bool): Includes or excludes punctuation.
Returns:
vocab (Vocab): Object with the processed vocabulary
"""
data = json.load(open(input_json, "r"))["audios"]
if zh:
from nltk.parse.corenlp import CoreNLPParser
from zhon.hanzi import punctuation
parser = CoreNLPParser(host_address)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
# Remove all punctuations
if not keep_punctuation:
caption = re.sub("[{}]".format(punctuation), "", caption)
if character_level:
tokens = list(caption)
else:
tokens = list(parser.tokenize(caption))
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
else:
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
captions = {}
for audio_idx in range(len(data)):
audio_id = data[audio_idx]["audio_id"]
captions[audio_id] = []
for cap_idx in range(len(data[audio_idx]["captions"])):
caption = data[audio_idx]["captions"][cap_idx]["caption"]
captions[audio_id].append({
"audio_id": audio_id,
"id": cap_idx,
"caption": caption
})
tokenizer = PTBTokenizer()
captions = tokenizer.tokenize(captions)
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
audio_id = data[audio_idx]["audio_id"]
for cap_idx in range(len(data[audio_idx]["captions"])):
tokens = captions[audio_id][cap_idx]
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
if output_json:
json.dump(
{ "audios": data }, open(output_json, "w"),
indent=4, ensure_ascii=not zh)
else:
json.dump(
{ "audios": data }, open(input_json, "w"),
indent=4, ensure_ascii=not zh)
if __name__ == "__main__":
fire.Fire(tokenize_caption)
|