File size: 1,503 Bytes
e50fe35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from tqdm import tqdm
import sys


def remove_large_sentences(src_path, tgt_path):
    count = 0
    new_src_lines = []
    new_tgt_lines = []
    src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8"))
    tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8"))
    assert src_num_lines == tgt_num_lines
    with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2:
        for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines):
            src_tokens = src_line.strip().split(" ")
            tgt_tokens = tgt_line.strip().split(" ")
            if len(src_tokens) > 200 or len(tgt_tokens) > 200:
                count += 1
                continue
            new_src_lines.append(src_line)
            new_tgt_lines.append(tgt_line)
    return count, new_src_lines, new_tgt_lines


def create_txt(outFile, lines, add_newline=False):
    outfile = open("{0}".format(outFile), "w", encoding="utf-8")
    for line in lines:
        if add_newline:
            outfile.write(line + "\n")
        else:
            outfile.write(line)
    outfile.close()


if __name__ == "__main__":

    src_path = sys.argv[1]
    tgt_path = sys.argv[2]
    new_src_path = sys.argv[3]
    new_tgt_path = sys.argv[4]

    count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path)
    print(f'{count} lines removed due to seq_len > 200')
    create_txt(new_src_path, new_src_lines)
    create_txt(new_tgt_path, new_tgt_lines)