Spaces:
Runtime error
Runtime error
from argparse import ArgumentParser | |
import os | |
import re | |
import time | |
from processors import FilesProcessor, get_text_distorter | |
from processes import ( | |
CharsRemover, | |
LengthFilter, | |
LinesSplitter, | |
LoadFile, | |
NumbersFilter, | |
OOVFilter, | |
RepeatedCharsCollapsor, | |
# SoloCharFilter, | |
SpacesRemover, | |
ValidCharsKeeper, | |
WordsFilter, | |
WordsNumberFilter, | |
CharsNormalizer, | |
TokenizerLengthFilter, | |
) | |
from helpers import load_json, save_text_file | |
from typing import Union, List | |
from pathlib import Path | |
import constants | |
import pandas as pd | |
def get_paths( | |
main_dir: Union[Path, str] | |
) -> List[Union[Path, str]]: | |
paths = [ | |
os.path.join(main_dir, file) | |
for file in os.listdir(main_dir) | |
] | |
return paths | |
def get_path( | |
file_path: Union[Path, str] | |
) -> List[Union[Path, str]]: | |
if os.path.isfile(file_path): | |
return [file_path] | |
else: | |
raise FileNotFoundError | |
def get_file_processor(args): | |
words = load_json(args.execlude_words_files) | |
processes = [ | |
LoadFile(), | |
*[LinesSplitter(sep=sep) for sep in args.sep], | |
RepeatedCharsCollapsor(args.max_rep_chars), | |
NumbersFilter(), | |
# SoloCharFilter(), | |
WordsFilter(words), | |
ValidCharsKeeper(constants.VALID_CHARS), | |
SpacesRemover(), | |
WordsNumberFilter(args.min_words, args.max_words), | |
# TokenizerLengthFilter(), | |
LengthFilter(args.min_len, args.max_len) | |
] | |
return FilesProcessor(processes) | |
def post_process(data: List[str]) -> List[str]: | |
lines = [] | |
for item in data: | |
lines.extend(item) | |
lines = list(set(lines)) | |
# lines = OOVFilter(args.max_oov).execute(lines) | |
return lines | |
clean_punctuation = re.compile(r"(?<!\d)[!.:;?،؛؟«» ،؛۔٫٪؟](?!\d)") | |
def remove_punctuation(text): | |
"""Remove all punctuation from string, except if it's between digits""" | |
return clean_punctuation.sub("", text) | |
def get_argparser(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
'--sep', default=[ | |
'\n', | |
# '\t', '.', '،', ',', '=', ':', '-', '\\', '/' | |
], nargs='+', type=str, | |
help='The seperator to be used to split the lines on' | |
) | |
parser.add_argument( | |
'--min_len', default=5, type=int, | |
help='The minimum line length to keep' | |
) | |
parser.add_argument( | |
'--max_len', default=1020, type=int, | |
help='The maximum line length to keep' | |
) | |
parser.add_argument( | |
'--dist_run', default=False, action='store_true' | |
) | |
parser.add_argument( | |
'--data_path', default='data/data.txt' | |
) | |
parser.add_argument( | |
'--save_path', default='data/clean_data.txt' | |
) | |
parser.add_argument( | |
'--max_rep_chars', default=2 | |
) | |
parser.add_argument( | |
'--execlude_words_files', default='data/words.json' | |
) | |
parser.add_argument( | |
'--max_oov', default=100, type=int | |
) | |
parser.add_argument( | |
'--min_words', default=3, type=int | |
) | |
parser.add_argument( | |
'--max_words', default=100, type=int | |
) | |
parser.add_argument( | |
'--dist_ratios', default=[0.05, 0.1, 0.15] | |
) | |
parser.add_argument( | |
'--remove_punc', default=False, action='store_true', help='Remove punctuation of the distorted lines' | |
) | |
return parser | |
def main(args) -> None: | |
fp = get_file_processor(args) | |
files = get_path(args.data_path) | |
print('Started!') | |
start = time.time() | |
if args.dist_run is True: | |
print('dist run') | |
data = fp.dist_run(files) | |
else: | |
data = fp.run(files) | |
end = time.time() | |
print(f'Files Processing completed in {end - start}') | |
data = post_process(data) | |
sentences = data[: len(data) // 2] | |
print("Length of data after post processing", len(data)) | |
df = None | |
for i, ratio in enumerate(args.dist_ratios): | |
distorter = get_text_distorter(ratio, sentences) | |
# TODO: Don't touch 2 percent of sentences to keep the model from having a high bias towards the noise | |
dist = list(map(distorter.run, data)) | |
if df is None: | |
df = pd.DataFrame({ | |
'clean': data, | |
f'distorted_{ratio}': dist | |
}) | |
else: | |
df[f'distorted_{ratio}'] = dist | |
if args.remove_punc is True: | |
print("Removing punctuations for the distorted lines") | |
for ratio in args.dist_ratios: | |
df[f'distorted_{ratio}'] = df[f'distorted_{ratio}'].apply( | |
remove_punctuation | |
) | |
df.to_csv(f'data/data.csv', encoding='utf-8') | |
# save_text_file(args.save_path, '\n'.join(data)) | |
if __name__ == '__main__': | |
parser = get_argparser() | |
args = parser.parse_args() | |
main(args) | |
num_lines = sum(1 for line in open(f"data/data.csv",'r')) | |
os.system(f"echo \"text,summary\" > train.csv") | |
# # Only change the first $ variable for different distortion ratios | |
# os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv >> train.csv") | |
# os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv >> train.csv") | |
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv") | |
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv") | |
os.system(f"awk -F',' 'NR>1 && NR<={num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> train.csv") | |
os.system(f"echo \"text,summary\" > test.csv") | |
# os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $4 \",\" $2}}' >> test.csv") | |
# os.system(f"tail -n 50000 data/data.csv | awk -F',' '{{print $3 \",\" $2}}' >> test.csv") | |
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $5 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv") | |
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $4 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv") | |
os.system(f"awk -F',' 'NR>{num_lines-50000} {{print $3 \",\" $2}}' data/data.csv | sed 's/\"//g' >> test.csv") | |