|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import os |
|
import random |
|
import re |
|
import subprocess |
|
|
|
from nemo.collections.nlp.data.token_classification.token_classification_utils import create_text_and_labels |
|
from nemo.utils import logging |
|
|
|
URL = {'tatoeba': 'https://downloads.tatoeba.org/exports/sentences.csv'} |
|
|
|
|
|
def __maybe_download_file(destination: str, source: str): |
|
""" |
|
Downloads source to destination if not exists. |
|
If exists, skips download |
|
Args: |
|
destination: local filepath |
|
source: url of resource |
|
""" |
|
source = URL[source] |
|
if not os.path.exists(destination): |
|
logging.info(f'Downloading {source} to {destination}') |
|
subprocess.run(['wget', '-O', destination, source]) |
|
else: |
|
logging.info(f'{destination} found. Skipping download') |
|
|
|
|
|
def __process_english_sentences( |
|
in_file: str, out_file: str, percent_to_cut: float = 0, num_to_combine: int = 1, num_samples: int = -1 |
|
): |
|
""" |
|
Extract English sentences from the Tatoeba dataset. |
|
|
|
Expected in_file format |
|
that |
|
contain letters and punctuation marks (,.?). |
|
Chop and combine sentences. |
|
Args: |
|
in_file: local filepath to the tatoeba dataset. |
|
Format: id [TAB] region_name [TAB] sentence, |
|
for example: "1276\teng\tLet's try something.\n" |
|
out_file: local filepath to the clean dataset |
|
percent_to_cut: Percent of sentences to cut in the middle |
|
to get examples of incomplete sentences. |
|
This could be useful since ASR output not always |
|
represents a complete sentence |
|
num_to_combine: Number of sentences to combine into |
|
a single example |
|
num_samples: Number of samples in the final dataset |
|
""" |
|
if not os.path.exists(in_file): |
|
raise FileNotFoundError(f'{in_file} not found.') |
|
|
|
in_file = open(in_file, 'r') |
|
out_file = open(out_file, 'w') |
|
lines_to_combine = [] |
|
samples_count = 0 |
|
|
|
for line in in_file: |
|
line = line.split('\t') |
|
|
|
if line[1] == 'eng': |
|
line = line[2].strip() |
|
if re.match("^[A-Z][A-Za-z.,'?\s]+$", line): |
|
|
|
if percent_to_cut > 0: |
|
line = line.split() |
|
if random.random() < percent_to_cut: |
|
line = line[: len(line) // 2] |
|
line = ' '.join(line) |
|
|
|
|
|
|
|
if len(lines_to_combine) >= num_to_combine: |
|
if samples_count == num_samples: |
|
return |
|
out_file.write(' '.join(lines_to_combine) + '\n') |
|
lines_to_combine = [] |
|
samples_count += 1 |
|
lines_to_combine.append(line) |
|
|
|
if len(lines_to_combine) > 0 and (samples_count < num_samples or num_samples < 0): |
|
out_file.write(' '.join(lines_to_combine) + '\n') |
|
|
|
|
|
def __split_into_train_dev(in_file: str, train_file: str, dev_file: str, percent_dev: float): |
|
""" |
|
Create train and dev split of the dataset. |
|
Args: |
|
in_file: local filepath to the dataset |
|
train_file: local filepath to the train dataset |
|
dev_file: local filepath to the dev dataset |
|
percent_dev: Percent of the sentences in the dev set |
|
""" |
|
if not os.path.exists(in_file): |
|
raise FileNotFoundError(f'{in_file} not found.') |
|
|
|
lines = open(in_file, 'r').readlines() |
|
train_file = open(train_file, 'w') |
|
dev_file = open(dev_file, 'w') |
|
|
|
dev_size = int(len(lines) * percent_dev) |
|
train_file.write(' '.join(lines[:-dev_size])) |
|
dev_file.write(' '.join(lines[-dev_size:])) |
|
|
|
|
|
def __delete_file(file_to_del: str): |
|
""" |
|
Deletes the file |
|
Args: |
|
file_to_del: local filepath to the file to delete |
|
""" |
|
if os.path.exists(file_to_del): |
|
os.remove(file_to_del) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Prepare tatoeba dataset') |
|
parser.add_argument("--data_dir", required=True, type=str) |
|
parser.add_argument("--dataset", default='tatoeba', type=str) |
|
parser.add_argument("--num_samples", default=-1, type=int, help='-1 to use the whole dataset') |
|
parser.add_argument("--percent_to_cut", default=0, type=float, help='Percent of sentences to cut in the middle') |
|
parser.add_argument( |
|
"--num_lines_to_combine", default=1, type=int, help='Number of lines to combine into single example' |
|
) |
|
parser.add_argument("--percent_dev", default=0.2, type=float, help='Size of the dev set, float') |
|
parser.add_argument("--clean_dir", action='store_true') |
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.data_dir): |
|
os.makedirs(args.data_dir) |
|
|
|
if args.dataset != 'tatoeba': |
|
raise ValueError("Unsupported dataset.") |
|
|
|
logging.info(f'Downloading tatoeba dataset') |
|
tatoeba_dataset = os.path.join(args.data_dir, 'sentences.csv') |
|
__maybe_download_file(tatoeba_dataset, args.dataset) |
|
|
|
logging.info(f'Processing English sentences...') |
|
clean_eng_sentences = os.path.join(args.data_dir, 'clean_eng_sentences.txt') |
|
__process_english_sentences( |
|
tatoeba_dataset, clean_eng_sentences, args.percent_to_cut, args.num_lines_to_combine, args.num_samples |
|
) |
|
|
|
train_file = os.path.join(args.data_dir, 'train.txt') |
|
dev_file = os.path.join(args.data_dir, 'dev.txt') |
|
|
|
logging.info( |
|
f'Splitting the {args.dataset} dataset into train and dev sets' + ' and creating labels and text files' |
|
) |
|
__split_into_train_dev(clean_eng_sentences, train_file, dev_file, args.percent_dev) |
|
|
|
logging.info(f'Creating text and label files for training') |
|
create_text_and_labels(args.data_dir, os.path.join(args.data_dir, 'train.txt')) |
|
create_text_and_labels(args.data_dir, os.path.join(args.data_dir, 'dev.txt')) |
|
|
|
if args.clean_dir: |
|
logging.info(f'Cleaning up {args.data_dir}') |
|
__delete_file(clean_eng_sentences) |
|
__delete_file(tatoeba_dataset) |
|
__delete_file(train_file) |
|
__delete_file(dev_file) |
|
logging.info(f'Processing of the {args.dataset} is complete') |
|
|