Spaces:
Runtime error
Runtime error
File size: 10,106 Bytes
e50fe35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import os
import string
import shutil
from itertools import permutations, chain
from collections import defaultdict
from tqdm import tqdm
import sys
INDIC_LANGS = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
# we will be testing the overlaps of training data with all these benchmarks
# benchmarks = ['wat2021-devtest', 'wat2020-devtest', 'wat-2018', 'wmt-news', 'ufal-ta', 'pmi']
def read_lines(path):
# if path doesnt exist, return empty list
if not os.path.exists(path):
return []
with open(path, "r") as f:
lines = f.readlines()
return lines
def create_txt(outFile, lines):
add_newline = not "\n" in lines[0]
outfile = open("{0}".format(outFile), "w")
for line in lines:
if add_newline:
outfile.write(line + "\n")
else:
outfile.write(line)
outfile.close()
def pair_dedup_files(src_file, tgt_file):
src_lines = read_lines(src_file)
tgt_lines = read_lines(tgt_file)
len_before = len(src_lines)
src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines)
len_after = len(src_dedupped)
num_duplicates = len_before - len_after
print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}")
create_txt(src_file, src_dedupped)
create_txt(tgt_file, tgt_dedupped)
def pair_dedup_lists(src_list, tgt_list):
src_tgt = list(set(zip(src_list, tgt_list)))
src_deduped, tgt_deduped = zip(*src_tgt)
return src_deduped, tgt_deduped
def strip_and_normalize(line):
# lowercase line, remove spaces and strip punctuation
# one of the fastest way to add an exclusion list and remove that
# list of characters from a string
# https://towardsdatascience.com/how-to-efficiently-remove-punctuations-from-a-string-899ad4a059fb
exclist = string.punctuation + "\u0964"
table_ = str.maketrans("", "", exclist)
line = line.replace(" ", "").lower()
# dont use this method, it is painfully slow
# line = "".join([i for i in line if i not in string.punctuation])
line = line.translate(table_)
return line
def expand_tupled_list(list_of_tuples):
# convert list of tuples into two lists
# https://stackoverflow.com/questions/8081545/how-to-convert-list-of-tuples-to-multiple-lists
# [(en, as), (as, bn), (bn, gu)] - > [en, as, bn], [as, bn, gu]
list_a, list_b = map(list, zip(*list_of_tuples))
return list_a, list_b
def get_src_tgt_lang_lists(many2many=False):
if many2many is False:
SRC_LANGS = ["en"]
TGT_LANGS = INDIC_LANGS
else:
all_languages = INDIC_LANGS + ["en"]
# lang_pairs = list(permutations(all_languages, 2))
SRC_LANGS, TGT_LANGS = all_languages, all_languages
return SRC_LANGS, TGT_LANGS
def normalize_and_gather_all_benchmarks(devtest_dir, many2many=False):
# This is a dict of dict of lists
# the first keys are for lang-pair, the second keys are for src/tgt
# the values are the devtest lines.
# so devtest_pairs_normalized[en-as][src] will store src(en lines)
# so devtest_pairs_normalized[en-as][tgt] will store tgt(as lines)
devtest_pairs_normalized = defaultdict(lambda: defaultdict(list))
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
benchmarks = os.listdir(devtest_dir)
for dataset in benchmarks:
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
if dataset == "wat2021-devtest":
# wat2021 dev and test sets have differnet folder structure
src_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{src_lang}")
tgt_dev = read_lines(f"{devtest_dir}/{dataset}/dev.{tgt_lang}")
src_test = read_lines(f"{devtest_dir}/{dataset}/test.{src_lang}")
tgt_test = read_lines(f"{devtest_dir}/{dataset}/test.{tgt_lang}")
else:
src_dev = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{src_lang}"
)
tgt_dev = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/dev.{tgt_lang}"
)
src_test = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{src_lang}"
)
tgt_test = read_lines(
f"{devtest_dir}/{dataset}/{src_lang}-{tgt_lang}/test.{tgt_lang}"
)
# if the tgt_pair data doesnt exist for a particular test set,
# it will be an empty list
if tgt_test == [] or tgt_dev == []:
# print(f'{dataset} does not have {src_lang}-{tgt_lang} data')
continue
# combine both dev and test sets into one
src_devtest = src_dev + src_test
tgt_devtest = tgt_dev + tgt_test
src_devtest = [strip_and_normalize(line) for line in src_devtest]
tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest]
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"].extend(
src_devtest
)
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"].extend(
tgt_devtest
)
# dedup merged benchmark datasets
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
src_devtest, tgt_devtest = (
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
)
# if the devtest data doesnt exist for the src-tgt pair then continue
if src_devtest == [] or tgt_devtest == []:
continue
src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest)
(
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["src"],
devtest_pairs_normalized[f"{src_lang}-{tgt_lang}"]["tgt"],
) = (
src_devtest,
tgt_devtest,
)
return devtest_pairs_normalized
def remove_train_devtest_overlaps(train_dir, devtest_dir, many2many=False):
devtest_pairs_normalized = normalize_and_gather_all_benchmarks(
devtest_dir, many2many
)
SRC_LANGS, TGT_LANGS = get_src_tgt_lang_lists(many2many)
if not many2many:
all_src_sentences_normalized = []
for key in devtest_pairs_normalized:
all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"])
# remove all duplicates. Now this contains all the normalized
# english sentences in all test benchmarks across all lang pair
all_src_sentences_normalized = list(set(all_src_sentences_normalized))
else:
all_src_sentences_normalized = None
src_overlaps = []
tgt_overlaps = []
for src_lang in SRC_LANGS:
for tgt_lang in TGT_LANGS:
if src_lang == tgt_lang:
continue
new_src_train = []
new_tgt_train = []
pair = f"{src_lang}-{tgt_lang}"
src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}")
tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}")
len_before = len(src_train)
if len_before == 0:
continue
src_train_normalized = [strip_and_normalize(line) for line in src_train]
tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train]
if all_src_sentences_normalized:
src_devtest_normalized = all_src_sentences_normalized
else:
src_devtest_normalized = devtest_pairs_normalized[pair]["src"]
tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"]
# compute all src and tgt super strict overlaps for a lang pair
overlaps = set(src_train_normalized) & set(src_devtest_normalized)
src_overlaps.extend(list(overlaps))
overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized)
tgt_overlaps.extend(list(overlaps))
# dictionaries offer o(1) lookup
src_overlaps_dict = {}
tgt_overlaps_dict = {}
for line in src_overlaps:
src_overlaps_dict[line] = 1
for line in tgt_overlaps:
tgt_overlaps_dict[line] = 1
# loop to remove the ovelapped data
idx = -1
for src_line_norm, tgt_line_norm in tqdm(
zip(src_train_normalized, tgt_train_normalized), total=len_before
):
idx += 1
if src_overlaps_dict.get(src_line_norm, None):
continue
if tgt_overlaps_dict.get(tgt_line_norm, None):
continue
new_src_train.append(src_train[idx])
new_tgt_train.append(tgt_train[idx])
len_after = len(new_src_train)
print(
f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}"
)
print(f"saving new files at {train_dir}/{pair}/")
create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train)
create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train)
if __name__ == "__main__":
train_data_dir = sys.argv[1]
# benchmarks directory should contains all the test sets
devtest_data_dir = sys.argv[2]
if len(sys.argv) == 3:
many2many = False
elif len(sys.argv) == 4:
many2many = sys.argv[4]
if many2many.lower() == "true":
many2many = True
else:
many2many = False
remove_train_devtest_overlaps(train_data_dir, devtest_data_dir, many2many)
|