Spaces:
Runtime error
Runtime error
from os import truncate | |
from sacremoses import MosesPunctNormalizer | |
from sacremoses import MosesTokenizer | |
from sacremoses import MosesDetokenizer | |
from subword_nmt.apply_bpe import BPE, read_vocabulary | |
import codecs | |
from tqdm import tqdm | |
from indicnlp.tokenize import indic_tokenize | |
from indicnlp.tokenize import indic_detokenize | |
from indicnlp.normalize import indic_normalize | |
from indicnlp.transliterate import unicode_transliterate | |
from mosestokenizer import MosesSentenceSplitter | |
from indicnlp.tokenize import sentence_tokenize | |
from inference.custom_interactive import Translator | |
INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] | |
def split_sentences(paragraph, language): | |
if language == "en": | |
with MosesSentenceSplitter(language) as splitter: | |
return splitter([paragraph]) | |
elif language in INDIC: | |
return sentence_tokenize.sentence_split(paragraph, lang=language) | |
def add_token(sent, tag_infos): | |
"""add special tokens specified by tag_infos to each element in list | |
tag_infos: list of tuples (tag_type,tag) | |
each tag_info results in a token of the form: __{tag_type}__{tag}__ | |
""" | |
tokens = [] | |
for tag_type, tag in tag_infos: | |
token = "__" + tag_type + "__" + tag + "__" | |
tokens.append(token) | |
return " ".join(tokens) + " " + sent | |
def apply_lang_tags(sents, src_lang, tgt_lang): | |
tagged_sents = [] | |
for sent in sents: | |
tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)]) | |
tagged_sents.append(tagged_sent) | |
return tagged_sents | |
def truncate_long_sentences(sents): | |
MAX_SEQ_LEN = 200 | |
new_sents = [] | |
for sent in sents: | |
words = sent.split() | |
num_words = len(words) | |
if num_words > MAX_SEQ_LEN: | |
print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:]) | |
sent = " ".join(words[:MAX_SEQ_LEN]) | |
print( | |
f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit" | |
) | |
new_sents.append(sent) | |
return new_sents | |
class Model: | |
def __init__(self, expdir): | |
self.expdir = expdir | |
self.en_tok = MosesTokenizer(lang="en") | |
self.en_normalizer = MosesPunctNormalizer() | |
self.en_detok = MosesDetokenizer(lang="en") | |
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator() | |
print("Initializing vocab and bpe") | |
self.vocabulary = read_vocabulary( | |
codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5 | |
) | |
self.bpe = BPE( | |
codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"), | |
-1, | |
"@@", | |
self.vocabulary, | |
None, | |
) | |
print("Initializing model for translation") | |
# initialize the model | |
self.translator = Translator( | |
f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100 | |
) | |
# translate a batch of sentences from src_lang to tgt_lang | |
def batch_translate(self, batch, src_lang, tgt_lang): | |
assert isinstance(batch, list) | |
preprocessed_sents = self.preprocess(batch, lang=src_lang) | |
bpe_sents = self.apply_bpe(preprocessed_sents) | |
tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang) | |
tagged_sents = truncate_long_sentences(tagged_sents) | |
translations = self.translator.translate(tagged_sents) | |
postprocessed_sents = self.postprocess(translations, tgt_lang) | |
return postprocessed_sents | |
# translate a paragraph from src_lang to tgt_lang | |
def translate_paragraph(self, paragraph, src_lang, tgt_lang): | |
assert isinstance(paragraph, str) | |
sents = split_sentences(paragraph, src_lang) | |
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang) | |
translated_paragraph = " ".join(postprocessed_sents) | |
return translated_paragraph | |
def preprocess_sent(self, sent, normalizer, lang): | |
if lang == "en": | |
return " ".join( | |
self.en_tok.tokenize( | |
self.en_normalizer.normalize(sent.strip()), escape=False | |
) | |
) | |
else: | |
# line = indic_detokenize.trivial_detokenize(line.strip(), lang) | |
return unicode_transliterate.UnicodeIndicTransliterator.transliterate( | |
" ".join( | |
indic_tokenize.trivial_tokenize( | |
normalizer.normalize(sent.strip()), lang | |
) | |
), | |
lang, | |
"hi", | |
).replace(" ् ", "्") | |
def preprocess(self, sents, lang): | |
""" | |
Normalize, tokenize and script convert(for Indic) | |
return number of sentences input file | |
""" | |
if lang == "en": | |
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")( | |
# delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines) | |
# ) | |
processed_sents = [ | |
self.preprocess_sent(line, None, lang) for line in tqdm(sents) | |
] | |
else: | |
normfactory = indic_normalize.IndicNormalizerFactory() | |
normalizer = normfactory.get_normalizer(lang) | |
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")( | |
# delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines) | |
# ) | |
processed_sents = [ | |
self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents) | |
] | |
return processed_sents | |
def postprocess(self, sents, lang, common_lang="hi"): | |
""" | |
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize. | |
infname: fairseq log file | |
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT' | |
input_size: expected number of output sentences | |
lang: language | |
""" | |
postprocessed_sents = [] | |
if lang == "en": | |
for sent in sents: | |
# outfile.write(en_detok.detokenize(sent.split(" ")) + "\n") | |
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" "))) | |
else: | |
for sent in sents: | |
outstr = indic_detokenize.trivial_detokenize( | |
self.xliterator.transliterate(sent, common_lang, lang), lang | |
) | |
# outfile.write(outstr + "\n") | |
postprocessed_sents.append(outstr) | |
return postprocessed_sents | |
def apply_bpe(self, sents): | |
return [self.bpe.process_line(sent) for sent in sents] | |