Hussain Shaikh
final commit added required files
7edceed
raw
history blame
No virus
6.83 kB
from os import truncate
from sacremoses import MosesPunctNormalizer
from sacremoses import MosesTokenizer
from sacremoses import MosesDetokenizer
from subword_nmt.apply_bpe import BPE, read_vocabulary
import codecs
from tqdm import tqdm
from indicnlp.tokenize import indic_tokenize
from indicnlp.tokenize import indic_detokenize
from indicnlp.normalize import indic_normalize
from indicnlp.transliterate import unicode_transliterate
from mosestokenizer import MosesSentenceSplitter
from indicnlp.tokenize import sentence_tokenize
from inference.custom_interactive import Translator
INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
def split_sentences(paragraph, language):
if language == "en":
with MosesSentenceSplitter(language) as splitter:
return splitter([paragraph])
elif language in INDIC:
return sentence_tokenize.sentence_split(paragraph, lang=language)
def add_token(sent, tag_infos):
"""add special tokens specified by tag_infos to each element in list
tag_infos: list of tuples (tag_type,tag)
each tag_info results in a token of the form: __{tag_type}__{tag}__
"""
tokens = []
for tag_type, tag in tag_infos:
token = "__" + tag_type + "__" + tag + "__"
tokens.append(token)
return " ".join(tokens) + " " + sent
def apply_lang_tags(sents, src_lang, tgt_lang):
tagged_sents = []
for sent in sents:
tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)])
tagged_sents.append(tagged_sent)
return tagged_sents
def truncate_long_sentences(sents):
MAX_SEQ_LEN = 200
new_sents = []
for sent in sents:
words = sent.split()
num_words = len(words)
if num_words > MAX_SEQ_LEN:
print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:])
sent = " ".join(words[:MAX_SEQ_LEN])
print(
f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit"
)
new_sents.append(sent)
return new_sents
class Model:
def __init__(self, expdir):
self.expdir = expdir
self.en_tok = MosesTokenizer(lang="en")
self.en_normalizer = MosesPunctNormalizer()
self.en_detok = MosesDetokenizer(lang="en")
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
print("Initializing vocab and bpe")
self.vocabulary = read_vocabulary(
codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5
)
self.bpe = BPE(
codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"),
-1,
"@@",
self.vocabulary,
None,
)
print("Initializing model for translation")
# initialize the model
self.translator = Translator(
f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100
)
# translate a batch of sentences from src_lang to tgt_lang
def batch_translate(self, batch, src_lang, tgt_lang):
assert isinstance(batch, list)
preprocessed_sents = self.preprocess(batch, lang=src_lang)
bpe_sents = self.apply_bpe(preprocessed_sents)
tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang)
tagged_sents = truncate_long_sentences(tagged_sents)
translations = self.translator.translate(tagged_sents)
postprocessed_sents = self.postprocess(translations, tgt_lang)
return postprocessed_sents
# translate a paragraph from src_lang to tgt_lang
def translate_paragraph(self, paragraph, src_lang, tgt_lang):
assert isinstance(paragraph, str)
sents = split_sentences(paragraph, src_lang)
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
translated_paragraph = " ".join(postprocessed_sents)
return translated_paragraph
def preprocess_sent(self, sent, normalizer, lang):
if lang == "en":
return " ".join(
self.en_tok.tokenize(
self.en_normalizer.normalize(sent.strip()), escape=False
)
)
else:
# line = indic_detokenize.trivial_detokenize(line.strip(), lang)
return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
" ".join(
indic_tokenize.trivial_tokenize(
normalizer.normalize(sent.strip()), lang
)
),
lang,
"hi",
).replace(" ् ", "्")
def preprocess(self, sents, lang):
"""
Normalize, tokenize and script convert(for Indic)
return number of sentences input file
"""
if lang == "en":
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
# delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines)
# )
processed_sents = [
self.preprocess_sent(line, None, lang) for line in tqdm(sents)
]
else:
normfactory = indic_normalize.IndicNormalizerFactory()
normalizer = normfactory.get_normalizer(lang)
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
# delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines)
# )
processed_sents = [
self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents)
]
return processed_sents
def postprocess(self, sents, lang, common_lang="hi"):
"""
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
infname: fairseq log file
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
input_size: expected number of output sentences
lang: language
"""
postprocessed_sents = []
if lang == "en":
for sent in sents:
# outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
else:
for sent in sents:
outstr = indic_detokenize.trivial_detokenize(
self.xliterator.transliterate(sent, common_lang, lang), lang
)
# outfile.write(outstr + "\n")
postprocessed_sents.append(outstr)
return postprocessed_sents
def apply_bpe(self, sents):
return [self.bpe.process_line(sent) for sent in sents]