#!/usr/bin/env python3 from datasets import load_dataset from collections import Counter import json import os import tempfile from transformers import Wav2Vec2CTCTokenizer # which dataset dataset_name = "tedlium" # which split -> we should only use train to train our tokenizer split = "train" # in case the dataset requires access use_auth_token = True # name of tok to upload to the Hub tokenizer_name = f"wav2vec2-ctc-{dataset_name}-tokenizer" # FIX the cutoff freq for all datasets -> an entirely dataset-agnostic approach cutoff_freq = 0.01 dataset = load_dataset( "esb/datasets", dataset_name, split=split, use_auth_token=use_auth_token, ) # remove all data that is unnecessary to save RAM dataset = dataset.remove_columns(list(set(dataset.column_names) - {"text"})) # define function to see stats about letters and to create vocab def create_vocabulary_from_data(dataset, word_delimiter_token="|", cutoff_freq=0.0): def extract_all_chars(batch): all_text = " ".join(batch["text"]) count_chars_dict = Counter(list(all_text)) # sort by freq count_chars_dict = sorted(count_chars_dict.items(), key=lambda item: (-item[1], item[0])) # retrieve dict, freq vocab, freqs = zip(*count_chars_dict) return {"vocab": list(vocab), "freqs": list(freqs)} dataset = dataset.map( extract_all_chars, batched=True, batch_size=-1, remove_columns=dataset.column_names, ) vocab, freqs = dataset["vocab"], dataset["freqs"] total_num_chars = sum(freqs) chars_to_remove = [] print("Character Occurences") print(f"Total characters in dataset: {total_num_chars}") print(50 * "-") print(f"{'Char'.rjust(5)} | {'Total occ'.rjust(10)} | {'% of total occ'.rjust(20)} |") print(50 * "-") for char, freq in zip(vocab, freqs): freq_in_percent = freq / total_num_chars * 100 print(f"{char.rjust(5)} | {str(freq).rjust(10)} | {str(round(freq_in_percent, 3)).rjust(20)} |") if freq_in_percent < cutoff_freq: chars_to_remove.append(char) print(50 * "-") vocab = list(set(vocab) - set(chars_to_remove)) # Wav2Vec2CTC Tokenizers always have those as the first tokens (important for CTC) vocab = ["", "", "", ""] + vocab # create json dict vocab_dict = {v: k for k, v in enumerate(list(vocab))} # replace white space with delimiter token if word_delimiter_token is not None: vocab_dict[word_delimiter_token] = vocab_dict[" "] del vocab_dict[" "] return vocab_dict # Note that the functions accepts the following important args # --cutoff_freq # => This is very important! Lots of datasets will contain "wrong" characters in the training set, e.g. # characters that just occur a couple of times. # By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared to the "super frequent" letters. We can see such characters as "errors" or irrelevant in the # dataset, so that we should delete them from the vocab. During training, they would then just be classified # unknown tokens which the model can handle. # In this script, we deploy a mechanism to remove all chars whose freq in % is below a certain threshold. # We FIX this threshold for all datasets (i.e. dataset-agnostic) vocab_dict = create_vocabulary_from_data(dataset, cutoff_freq=cutoff_freq) # save vocab dict to be loaded into tokenizer with tempfile.TemporaryDirectory() as tmp: with open(os.path.join(tmp, "vocab.json"), "w") as file: json.dump(vocab_dict, file) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tmp) # push tokenizer to the Hub tokenizer.push_to_hub(tokenizer_name)