Spaces:
Runtime error
Runtime error
File size: 6,833 Bytes
e8aeaf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from os import truncate
from sacremoses import MosesPunctNormalizer
from sacremoses import MosesTokenizer
from sacremoses import MosesDetokenizer
from subword_nmt.apply_bpe import BPE, read_vocabulary
import codecs
from tqdm import tqdm
from indicnlp.tokenize import indic_tokenize
from indicnlp.tokenize import indic_detokenize
from indicnlp.normalize import indic_normalize
from indicnlp.transliterate import unicode_transliterate
from mosestokenizer import MosesSentenceSplitter
from indicnlp.tokenize import sentence_tokenize
from inference.custom_interactive import Translator
INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]
def split_sentences(paragraph, language):
if language == "en":
with MosesSentenceSplitter(language) as splitter:
return splitter([paragraph])
elif language in INDIC:
return sentence_tokenize.sentence_split(paragraph, lang=language)
def add_token(sent, tag_infos):
"""add special tokens specified by tag_infos to each element in list
tag_infos: list of tuples (tag_type,tag)
each tag_info results in a token of the form: __{tag_type}__{tag}__
"""
tokens = []
for tag_type, tag in tag_infos:
token = "__" + tag_type + "__" + tag + "__"
tokens.append(token)
return " ".join(tokens) + " " + sent
def apply_lang_tags(sents, src_lang, tgt_lang):
tagged_sents = []
for sent in sents:
tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)])
tagged_sents.append(tagged_sent)
return tagged_sents
def truncate_long_sentences(sents):
MAX_SEQ_LEN = 200
new_sents = []
for sent in sents:
words = sent.split()
num_words = len(words)
if num_words > MAX_SEQ_LEN:
print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:])
sent = " ".join(words[:MAX_SEQ_LEN])
print(
f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit"
)
new_sents.append(sent)
return new_sents
class Model:
def __init__(self, expdir):
self.expdir = expdir
self.en_tok = MosesTokenizer(lang="en")
self.en_normalizer = MosesPunctNormalizer()
self.en_detok = MosesDetokenizer(lang="en")
self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
print("Initializing vocab and bpe")
self.vocabulary = read_vocabulary(
codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5
)
self.bpe = BPE(
codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"),
-1,
"@@",
self.vocabulary,
None,
)
print("Initializing model for translation")
# initialize the model
self.translator = Translator(
f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100
)
# translate a batch of sentences from src_lang to tgt_lang
def batch_translate(self, batch, src_lang, tgt_lang):
assert isinstance(batch, list)
preprocessed_sents = self.preprocess(batch, lang=src_lang)
bpe_sents = self.apply_bpe(preprocessed_sents)
tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang)
tagged_sents = truncate_long_sentences(tagged_sents)
translations = self.translator.translate(tagged_sents)
postprocessed_sents = self.postprocess(translations, tgt_lang)
return postprocessed_sents
# translate a paragraph from src_lang to tgt_lang
def translate_paragraph(self, paragraph, src_lang, tgt_lang):
assert isinstance(paragraph, str)
sents = split_sentences(paragraph, src_lang)
postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
translated_paragraph = " ".join(postprocessed_sents)
return translated_paragraph
def preprocess_sent(self, sent, normalizer, lang):
if lang == "en":
return " ".join(
self.en_tok.tokenize(
self.en_normalizer.normalize(sent.strip()), escape=False
)
)
else:
# line = indic_detokenize.trivial_detokenize(line.strip(), lang)
return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
" ".join(
indic_tokenize.trivial_tokenize(
normalizer.normalize(sent.strip()), lang
)
),
lang,
"hi",
).replace(" ् ", "्")
def preprocess(self, sents, lang):
"""
Normalize, tokenize and script convert(for Indic)
return number of sentences input file
"""
if lang == "en":
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
# delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines)
# )
processed_sents = [
self.preprocess_sent(line, None, lang) for line in tqdm(sents)
]
else:
normfactory = indic_normalize.IndicNormalizerFactory()
normalizer = normfactory.get_normalizer(lang)
# processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
# delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines)
# )
processed_sents = [
self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents)
]
return processed_sents
def postprocess(self, sents, lang, common_lang="hi"):
"""
parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.
infname: fairseq log file
outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
input_size: expected number of output sentences
lang: language
"""
postprocessed_sents = []
if lang == "en":
for sent in sents:
# outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
else:
for sent in sents:
outstr = indic_detokenize.trivial_detokenize(
self.xliterator.transliterate(sent, common_lang, lang), lang
)
# outfile.write(outstr + "\n")
postprocessed_sents.append(outstr)
return postprocessed_sents
def apply_bpe(self, sents):
return [self.bpe.process_line(sent) for sent in sents]
|