File size: 3,773 Bytes
c0e34e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6528d
c0e34e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
from datasets import load_dataset
from collections import Counter
import json
import os
import tempfile
from transformers import Wav2Vec2CTCTokenizer

# which dataset
dataset_name = "switchboard"
# which split -> we should only use train to train our tokenizer
split = "train"
# in case the dataset requires access
use_auth_token = True
# name of tok to upload to the Hub
tokenizer_name = f"wav2vec2-ctc-{dataset_name}-tokenizer"

# FIX the cutoff freq for all datasets -> an entirely dataset-agnostic approach
cutoff_freq = 0.01

dataset = load_dataset(
    "esb/datasets",
    dataset_name,
    split=split,
    use_auth_token=use_auth_token,
)

# remove all data that is unnecessary to save RAM
dataset = dataset.remove_columns(list(set(dataset.column_names) - {"text"}))

# define function to see stats about letters and to create vocab
def create_vocabulary_from_data(dataset, word_delimiter_token="|", cutoff_freq=0.0):
    def extract_all_chars(batch):
        all_text = " ".join(batch["text"])

        count_chars_dict = Counter(list(all_text))
        # sort by freq
        count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0]))
        # retrieve dict, freq
        vocab, freqs = zip(*count_chars_dict)

        return {"vocab": list(vocab), "freqs": list(freqs)}

    dataset = dataset.map(
        extract_all_chars,
        batched=True,
        batch_size=-1,
        remove_columns=dataset.column_names,
    )

    vocab, freqs = dataset["vocab"], dataset["freqs"]
    total_num_chars = sum(freqs)
    chars_to_remove = []

    print("Character Occurences")
    print(f"Total characters in dataset: {total_num_chars}")
    print(50 * "-")
    print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |")
    print(50 * "-")
    for char, freq in zip(vocab, freqs):
        freq_in_percent = freq / total_num_chars * 100
        print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |")
        if freq_in_percent < cutoff_freq:
            chars_to_remove.append(char)
    print(50 * "-")

    vocab = list(set(vocab) - set(chars_to_remove))

    # Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC)
    vocab = ["<pad>", "<s>", "</s>", "<unk>"] + vocab

    # create json dict
    vocab_dict = {v: k for k, v in enumerate(list(vocab))}

    # replace white space with delimiter token
    if word_delimiter_token is not None:
        vocab_dict[word_delimiter_token] = vocab_dict[" "]
        del vocab_dict[" "]

    return vocab_dict

# Note that the functions accepts the following important args
# --cutoff_freq
# => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g.
# characters that just occur a couple of times.
# By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the
# dataset, so that we should delete them from the vocab. During training, they would then just be classified
# unknown <unk> tokens which the model can handle.
# In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold.
# We FIX this threshold for all datasets (i.e. dataset-agnostic)

vocab_dict = create_vocabulary_from_data(dataset, cutoff_freq=cutoff_freq)

# save vocab dict to be loaded into tokenizer
with tempfile.TemporaryDirectory() as tmp:
    with open(os.path.join(tmp, "vocab.json"), "w") as file:
        json.dump(vocab_dict, file)

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp)

# push tokenizer to the Hub
tokenizer.push_to_hub(tokenizer_name)