mimir-perplexity / normalization.py
versae's picture
Mdels and code
dcc5cd1
import argparse
import unicodedata
import re
from tqdm import tqdm
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import re
import unicodedata
PUNCTS = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~«»'
UNICODE_PUNCT = {
",": ",",
"。": ".",
"、": ",",
"„": '"',
"”": '"',
"“": '"',
"«": '"',
"»": '"',
"1": '"',
"」": '"',
"「": '"',
"《": '"',
"》": '"',
"´": "'",
"∶": ":",
":": ":",
"?": "?",
"!": "!",
"(": "(",
")": ")",
";": ";",
"–": "-",
"—": " - ",
".": ". ",
"~": "~",
"’": "'",
"…": "...",
"━": "-",
"〈": "<",
"〉": ">",
"【": "[",
"】": "]",
"%": "%",
"►": "-",
"■": " ", # added for Mimir
}
UNICODE_PUNCT_RE = re.compile(f"[{''.join(UNICODE_PUNCT.keys())}]")
def replace_unicode_punct(text: str) -> str:
return "".join(UNICODE_PUNCT.get(c, c) for c in text)
def remove_unicode_punct(text: str) -> str:
"""More aggressive version of replace_unicode_punct but also faster."""
return UNICODE_PUNCT_RE.sub("", text)
def strip_accents(line: str) -> str:
"""Strips accents from a piece of text."""
nfd = unicodedata.normalize("NFD", line)
output = [c for c in nfd if unicodedata.category(c) != "Mn"]
if len(output) == line:
return line
return "".join(output)
# Build a regex matching all control characters.
NON_PRINTING_CHARS_RE = re.compile(
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
)
DIGIT_RE = re.compile(r"\d")
PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile(
(UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "")
)
def remove_non_printing_char(text: str) -> str:
return NON_PRINTING_CHARS_RE.sub("", text)
def normalize(line: str, accent=True, case=True, numbers=True, punct=1) -> str:
line = line.strip()
if not line:
return line
if case:
line = line.lower()
if accent:
line = strip_accents(line)
if numbers:
line = DIGIT_RE.sub("0", line)
if punct == 1:
line = replace_unicode_punct(line)
elif punct == 2:
line = remove_unicode_punct(line)
line = remove_non_printing_char(line)
return line
def slow_normalize_for_dedup(line: str) -> str:
return normalize(line, accent=False, case=True, numbers=True, punct=2)
def normalize_for_dedup(line: str) -> str:
line = line.strip()
if not line:
return line
# case
line = line.lower()
# numbers
line = DIGIT_RE.sub("0", line)
line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line)
return line
## START OF MIMIR CODE
def normalize_text(line):
normalized_line = unicodedata.normalize('NFKC', line).lower()
# Add a trailing dot if the line does not end with a punctuation mark
normalized_line = normalized_line.rstrip()
if normalized_line and normalized_line[-1] not in PUNCTS:
normalized_line += '.'
# Replace newline characters with spaces (if any remain)
# normalized_line = re.sub(r'\r\n|\r|\n', ' ', normalized_line)
normalized_line = normalize(normalized_line, accent=False, case=True, numbers=True, punct=1)
return normalized_line
def normalize_file(input_file, output_file, cutoff=None):
with (open(output_file, 'w', encoding='utf-8') as f,
open(input_file, 'r', encoding='utf-8') as lines):
for line_count, line in tqdm(enumerate(lines), desc="Processing"):
f.write(normalize_text(line) + "\n")
if cutoff and line_count >= cutoff:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Normalize text file line by line, ensure trailing punctuation, replace newlines with spaces, and show progress.')
parser.add_argument('input_file', type=str, help='Input file path')
parser.add_argument('output_file', type=str, help='Output file path')
parser.add_argument('--cutoff', required=False, type=int, help='Max number of lines to process')
args = parser.parse_args()
normalize_file(args.input_file, args.output_file, args.cutoff)