IndicTrans-MultilingualTranslation / scripts /remove_train_devtest_overlaps.py
hussain-shk's picture
Duplicate from ai4bharat/IndicTrans-MultilingualTranslation
ef23634
raw
history blame
10.1 kB
import os
import string
import shutil
from itertools import permutations, chain
from collections import defaultdict
from tqdm import tqdm
import sys
INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
# we will be testing the overlaps of training data with all these benchmarks
# benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi']
def read_lines(path):
# if path doesnt exist, return empty list
if not os.path.exists(path):
return []
with open(path, "r") as f:
lines = f.readlines()
return lines
def create_txt(outFile, lines):
add_newline = not "\n" in lines[0]
outfile = open("{0}".format(outFile), "w")
for line in lines:
if add_newline:
outfile.write(line + "\n")
else:
outfile.write(line)
outfile.close()
def pair_dedup_files(src_file, tgt_file):
src_lines = read_lines(src_file)
tgt_lines = read_lines(tgt_file)
len_before = len(src_lines)
src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)
len_after = len(src_dedupped)
num_duplicates = len_before - len_after
print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
create_txt(src_file, src_dedupped)
create_txt(tgt_file, tgt_dedupped)
def pair_dedup_lists(src_list, tgt_list):
src_tgt = list(set(zip(src_list, tgt_list)))
src_deduped, tgt_deduped = zip(*src_tgt)
return src_deduped, tgt_deduped
def strip_and_normalize(line):
# lowercase line, remove spaces and strip punctuation
# one of the fastest way to add an exclusion list and remove that
# list of characters from a string
# https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
exclist = string.punctuation + "\u0964"
table_ = str.maketrans("", "", exclist)
line = line.replace(" ", "").lower()
# dont use this method, it is painfully slow
# line = "".join([i for i in line if i not in string.punctuation])
line = line.translate(table_)
return line
def expand_tupled_list(list_of_tuples):
# convert list of tuples into two lists
# https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
# [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
list_a, list_b = map(list, zip(*list_of_tuples))
return list_a, list_b
def get_src_tgt_lang_lists(many2many=False):
if many2many is False:
SRC_LANGS = ["en"]
TGT_LANGS = INDIC_LANGS
else:
all_languages = INDIC_LANGS + ["en"]
# lang_pairs = list(permutations(all_languages, 2))
SRC_LANGS, TGT_LANGS = all_languages, all_languages
return SRC_LANGS, TGT_LANGS
def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False):
# This is a dict of dict of lists
# the first keys are for lang-pair, the second keys are for src/tgt
# the values are the devtest lines.
# so devtest_pairs_normalized[en-as][src] will store src(en lines)
# so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines)
devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
benchmarks = os.listdir(devtest_dir)
for dataset in benchmarks:
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
if dataset == "wat2021-devtest":
# wat2021 dev and test sets have differnet folder structure
src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}")
tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}")
src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}")
tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}")
else:
src_dev = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}"
)
tgt_dev = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}"
)
src_test = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}"
)
tgt_test = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}"
)
# if the tgt_pair data doesnt exist for a particular test set,
# it will be an empty list
if tgt_test == [] or tgt_dev == []:
# print(f'{dataset} does not have {src_lang}-{tgt_lang} data')
continue
# combine both dev and test sets into one
src_devtest = src_dev + src_test
tgt_devtest = tgt_dev + tgt_test
src_devtest = [strip_and_normalize(line) for line in src_devtest]
tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend(
src_devtest
)
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend(
tgt_devtest
)
# dedup merged benchmark datasets
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
src_devtest, tgt_devtest = (
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
)
# if the devtest data doesnt exist for the src-tgt pair then continue
if src_devtest == [] or tgt_devtest == []:
continue
src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
(
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
) = (
src_devtest,
tgt_devtest,
)
return devtest_pairs_normalized
def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False):
devtest_pairs_normalized = normalize_and_gather_all_benchmarks(
devtest_dir, many2many
)
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
if not many2many:
all_src_sentences_normalized = []
for key in devtest_pairs_normalized:
all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
# remove all duplicates. Now this contains all the normalized
# english sentences in all test benchmarks across all lang pair
all_src_sentences_normalized = list(set(all_src_sentences_normalized))
else:
all_src_sentences_normalized = None
src_overlaps = []
tgt_overlaps = []
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
new_src_train = []
new_tgt_train = []
pair = f"{src_lang}-{tgt_lang}"
src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")
len_before = len(src_train)
if len_before == 0:
continue
src_train_normalized = [strip_and_normalize(line) for line in src_train]
tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]
if all_src_sentences_normalized:
src_devtest_normalized = all_src_sentences_normalized
else:
src_devtest_normalized = devtest_pairs_normalized[pair]["src"]
tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]
# compute all src and tgt super strict overlaps for a lang pair
overlaps = set(src_train_normalized) & set(src_devtest_normalized)
src_overlaps.extend(list(overlaps))
overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
tgt_overlaps.extend(list(overlaps))
# dictionaries offer o(1) lookup
src_overlaps_dict = {}
tgt_overlaps_dict = {}
for line in src_overlaps:
src_overlaps_dict[line] = 1
for line in tgt_overlaps:
tgt_overlaps_dict[line] = 1
# loop to remove the ovelapped data
idx = -1
for src_line_norm, tgt_line_norm in tqdm(
zip(src_train_normalized, tgt_train_normalized), total=len_before
):
idx += 1
if src_overlaps_dict.get(src_line_norm, None):
continue
if tgt_overlaps_dict.get(tgt_line_norm, None):
continue
new_src_train.append(src_train[idx])
new_tgt_train.append(tgt_train[idx])
len_after = len(new_src_train)
print(
f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
)
print(f"saving new files at {train_dir}/{pair}/")
create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)
if __name__ == "__main__":
train_data_dir = sys.argv[1]
# benchmarks directory should contains all the test sets
devtest_data_dir = sys.argv[2]
if len(sys.argv) == 3:
many2many = False
elif len(sys.argv) == 4:
many2many = sys.argv[4]
if many2many.lower() == "true":
many2many = True
else:
many2many = False
remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many)