Spaces:
Runtime error
Runtime error
import os | |
import string | |
import shutil | |
from itertools import permutations, chain | |
from collections import defaultdict | |
from tqdm import tqdm | |
import sys | |
INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] | |
# we will be testing the overlaps of training data with all these benchmarks | |
# benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi'] | |
def read_lines(path): | |
# if path doesnt exist, return empty list | |
if not os.path.exists(path): | |
return [] | |
with open(path, "r") as f: | |
lines = f.readlines() | |
return lines | |
def create_txt(outFile, lines): | |
add_newline = not "\n" in lines[0] | |
outfile = open("{0}".format(outFile), "w") | |
for line in lines: | |
if add_newline: | |
outfile.write(line + "\n") | |
else: | |
outfile.write(line) | |
outfile.close() | |
def pair_dedup_files(src_file, tgt_file): | |
src_lines = read_lines(src_file) | |
tgt_lines = read_lines(tgt_file) | |
len_before = len(src_lines) | |
src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines) | |
len_after = len(src_dedupped) | |
num_duplicates = len_before - len_after | |
print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}") | |
create_txt(src_file, src_dedupped) | |
create_txt(tgt_file, tgt_dedupped) | |
def pair_dedup_lists(src_list, tgt_list): | |
src_tgt = list(set(zip(src_list, tgt_list))) | |
src_deduped, tgt_deduped = zip(*src_tgt) | |
return src_deduped, tgt_deduped | |
def strip_and_normalize(line): | |
# lowercase line, remove spaces and strip punctuation | |
# one of the fastest way to add an exclusion list and remove that | |
# list of characters from a string | |
# https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb | |
exclist = string.punctuation + "\u0964" | |
table_ = str.maketrans("", "", exclist) | |
line = line.replace(" ", "").lower() | |
# dont use this method, it is painfully slow | |
# line = "".join([i for i in line if i not in string.punctuation]) | |
line = line.translate(table_) | |
return line | |
def expand_tupled_list(list_of_tuples): | |
# convert list of tuples into two lists | |
# https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists | |
# [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu] | |
list_a, list_b = map(list, zip(*list_of_tuples)) | |
return list_a, list_b | |
def get_src_tgt_lang_lists(many2many=False): | |
if many2many is False: | |
SRC_LANGS = ["en"] | |
TGT_LANGS = INDIC_LANGS | |
else: | |
all_languages = INDIC_LANGS + ["en"] | |
# lang_pairs = list(permutations(all_languages, 2)) | |
SRC_LANGS, TGT_LANGS = all_languages, all_languages | |
return SRC_LANGS, TGT_LANGS | |
def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False): | |
# This is a dict of dict of lists | |
# the first keys are for lang-pair, the second keys are for src/tgt | |
# the values are the devtest lines. | |
# so devtest_pairs_normalized[en-as][src] will store src(en lines) | |
# so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines) | |
devtest_pairs_normalized = defaultdict(lambda: defaultdict(list)) | |
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many) | |
benchmarks = os.listdir(devtest_dir) | |
for dataset in benchmarks: | |
for src_lang in SRC_LANGS: | |
for tgt_lang in TGT_LANGS: | |
if src_lang == tgt_lang: | |
continue | |
if dataset == "wat2021-devtest": | |
# wat2021 dev and test sets have differnet folder structure | |
src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}") | |
tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}") | |
src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}") | |
tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}") | |
else: | |
src_dev = read_lines( | |
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}" | |
) | |
tgt_dev = read_lines( | |
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}" | |
) | |
src_test = read_lines( | |
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}" | |
) | |
tgt_test = read_lines( | |
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}" | |
) | |
# if the tgt_pair data doesnt exist for a particular test set, | |
# it will be an empty list | |
if tgt_test == [] or tgt_dev == []: | |
# print(f'{dataset} does not have {src_lang}-{tgt_lang} data') | |
continue | |
# combine both dev and test sets into one | |
src_devtest = src_dev + src_test | |
tgt_devtest = tgt_dev + tgt_test | |
src_devtest = [strip_and_normalize(line) for line in src_devtest] | |
tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest] | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend( | |
src_devtest | |
) | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend( | |
tgt_devtest | |
) | |
# dedup merged benchmark datasets | |
for src_lang in SRC_LANGS: | |
for tgt_lang in TGT_LANGS: | |
if src_lang == tgt_lang: | |
continue | |
src_devtest, tgt_devtest = ( | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"], | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"], | |
) | |
# if the devtest data doesnt exist for the src-tgt pair then continue | |
if src_devtest == [] or tgt_devtest == []: | |
continue | |
src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest) | |
( | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"], | |
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"], | |
) = ( | |
src_devtest, | |
tgt_devtest, | |
) | |
return devtest_pairs_normalized | |
def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False): | |
devtest_pairs_normalized = normalize_and_gather_all_benchmarks( | |
devtest_dir, many2many | |
) | |
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many) | |
if not many2many: | |
all_src_sentences_normalized = [] | |
for key in devtest_pairs_normalized: | |
all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"]) | |
# remove all duplicates. Now this contains all the normalized | |
# english sentences in all test benchmarks across all lang pair | |
all_src_sentences_normalized = list(set(all_src_sentences_normalized)) | |
else: | |
all_src_sentences_normalized = None | |
src_overlaps = [] | |
tgt_overlaps = [] | |
for src_lang in SRC_LANGS: | |
for tgt_lang in TGT_LANGS: | |
if src_lang == tgt_lang: | |
continue | |
new_src_train = [] | |
new_tgt_train = [] | |
pair = f"{src_lang}-{tgt_lang}" | |
src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}") | |
tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}") | |
len_before = len(src_train) | |
if len_before == 0: | |
continue | |
src_train_normalized = [strip_and_normalize(line) for line in src_train] | |
tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train] | |
if all_src_sentences_normalized: | |
src_devtest_normalized = all_src_sentences_normalized | |
else: | |
src_devtest_normalized = devtest_pairs_normalized[pair]["src"] | |
tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"] | |
# compute all src and tgt super strict overlaps for a lang pair | |
overlaps = set(src_train_normalized) & set(src_devtest_normalized) | |
src_overlaps.extend(list(overlaps)) | |
overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized) | |
tgt_overlaps.extend(list(overlaps)) | |
# dictionaries offer o(1) lookup | |
src_overlaps_dict = {} | |
tgt_overlaps_dict = {} | |
for line in src_overlaps: | |
src_overlaps_dict[line] = 1 | |
for line in tgt_overlaps: | |
tgt_overlaps_dict[line] = 1 | |
# loop to remove the ovelapped data | |
idx = -1 | |
for src_line_norm, tgt_line_norm in tqdm( | |
zip(src_train_normalized, tgt_train_normalized), total=len_before | |
): | |
idx += 1 | |
if src_overlaps_dict.get(src_line_norm, None): | |
continue | |
if tgt_overlaps_dict.get(tgt_line_norm, None): | |
continue | |
new_src_train.append(src_train[idx]) | |
new_tgt_train.append(tgt_train[idx]) | |
len_after = len(new_src_train) | |
print( | |
f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}" | |
) | |
print(f"saving new files at {train_dir}/{pair}/") | |
create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train) | |
create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train) | |
if __name__ == "__main__": | |
train_data_dir = sys.argv[1] | |
# benchmarks directory should contains all the test sets | |
devtest_data_dir = sys.argv[2] | |
if len(sys.argv) == 3: | |
many2many = False | |
elif len(sys.argv) == 4: | |
many2many = sys.argv[4] | |
if many2many.lower() == "true": | |
many2many = True | |
else: | |
many2many = False | |
remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many) | |