Novokshanov's picture
Added sgh output in Latin
5531d29
import torch
import re
from collections.abc import Iterator
from converter import Converter
from sacremoses import MosesPunctNormalizer
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
import time
import gradio as gr
code_mapping = {
"Russian": "rus_Cyrl",
"English": "eng_Latn",
"Shughni": "shu_Cyrl"
}
source_languages = list(code_mapping.keys())[::-1]
target_languages = code_mapping.keys()
punct_normalizer = MosesPunctNormalizer(lang="en")
converter = Converter(dest="cyr", settings="auto", lang="sgh")
converter_latn = Converter(dest="lat", settings="auto", lang="sgh")
start_time = time.time()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1").to(device).eval()
tokenizer = NllbTokenizer.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1", src_lang='rus_Cyrl')
model2 = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device).eval()
tokenizer2 = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang='rus_Cyrl')
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds")
def sh_ru_translate(paragraphs, tgt_code):
translated_paragraphs = []
for paragraph in paragraphs:
splitter = re.compile(r'(?<=[.!?])\s+')
sentences = splitter.split(paragraph)
translated_sentences = []
for sentence in sentences:
input_tokens = (
tokenizer(sentence, return_tensors="pt")
.input_ids[0]
.cpu()
.numpy()
.tolist()
)
translated_chunk = model.generate(
input_ids=torch.tensor([input_tokens]).to(device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
max_length=len(input_tokens) + 50,
num_return_sequences=1,
num_beams=5,
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
renormalize_logits=True, # recompute token probabilities after banning the repetitions
)
translated_chunk = tokenizer.decode(
translated_chunk[0], skip_special_tokens=True
)
translated_sentences.append(translated_chunk)
translated_paragraph = " ".join(translated_sentences)
translated_paragraphs.append(translated_paragraph)
return "\n".join(translated_paragraphs)
def ru_en_translate(paragraphs, tgt_code):
translated_paragraphs = []
for paragraph in paragraphs:
splitter = re.compile(r'(?<=[.!?])\s+')
sentences = splitter.split(paragraph)
translated_sentences = []
for sentence in sentences:
input_tokens = (
tokenizer2(sentence, return_tensors="pt")
.input_ids[0]
.cpu()
.numpy()
.tolist()
)
translated_chunk = model2.generate(
input_ids=torch.tensor([input_tokens]).to(device),
forced_bos_token_id=tokenizer2.convert_tokens_to_ids(tgt_code),
max_length=len(input_tokens) + 50,
num_return_sequences=1,
num_beams=5,
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
renormalize_logits=True, # recompute token probabilities after banning the repetitions
)
translated_chunk = tokenizer2.decode(
translated_chunk[0], skip_special_tokens=True
)
translated_sentences.append(translated_chunk)
translated_paragraph = " ".join(translated_sentences)
translated_paragraphs.append(translated_paragraph)
return "\n".join(translated_paragraphs)
def translate(text: str, src_lang: str, tgt_lang: str, show_latin: bool=False):
src_code = code_mapping[src_lang]
tgt_code = code_mapping[tgt_lang]
tokenizer.src_lang = src_code
tokenizer.tgt_lang = tgt_code
# normalizing the punctuation first
text = punct_normalizer.normalize(text)
if src_lang == 'Shughni':
text = converter.convert(text).text
paragraphs = text.split("\n")
if tgt_lang == 'English' and src_lang == 'Shughni':
result_ru = sh_ru_translate(paragraphs, 'rus_Cyrl')
result = ru_en_translate(result_ru.split('\n'), 'eng_Latn')
elif src_lang == 'English' and tgt_lang == 'Shughni':
result_ru = ru_en_translate(paragraphs, 'rus_Cyrl')
result = sh_ru_translate(result_ru.split('\n'), 'shu_Cyrl')
elif src_lang == 'English' and tgt_lang == 'Russian':
result = ru_en_translate(paragraphs, 'rus_Cyrl')
elif tgt_lang == 'English' and src_lang == 'Russian':
result = ru_en_translate(paragraphs, 'eng_Latn')
else:
result = sh_ru_translate(paragraphs, tgt_code)
if show_latin and tgt_lang == "Shughni":
result = converter_latn.convert(result).text
return result
def swap_langs(src, tgt):
return tgt, src
description = """
<div style="text-align: center;">
<h1>Shughni - Russian - English Translator</h1>
</div>
"""
disclaimer = """
## Disclaimer
This is a demo of a work-in-progress translation project. Translation from and into English is implemented through Russian with a stand alone facebook/nllb-200-distilled-600M model.
"""
examples_inputs = [["wāδ-ен сат wи-рд-ен лу̊д: потх̌о лу̊д иди, ту бойад шич тар wи хез сāwи.","Shughni","Russian"],
["Карамшо прочитал интересную книгу.", "Russian", "Shughni"],
["Азими ху х̌āр тарк чӯд.", "Shughni", "Russian"],
["Азим вышел из дома.", "Russian", "Shughni"],
["Асал х̌увдқати алалаш сут.", "Shughni", "Russian"],
["Лола спела красивую песню.", "Russian", "Shughni"],
["аҷойиб китоб х̌êйдум.", "Shughni", "English"],
["You read an interesting book.", "English", "Shughni"],
]
css = """
#swap-button {
height: 90px !important;
width: 10px !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description)
with gr.Row():
src_lang = gr.Dropdown(label="Source Language", choices=source_languages)
swap_button = gr.Button("⇄", elem_id="swap-button")
target_lang = gr.Dropdown(label="Target Language", choices=target_languages)
swap_button.click(
fn=swap_langs,
inputs=[src_lang, target_lang],
outputs=[src_lang, target_lang],
)
with gr.Row():
input_text = gr.Textbox(label="Input Text", lines=6)
with gr.Row():
show_latin = gr.Checkbox(label="Show Shughni Output in Latin", value=False)
with gr.Row():
btn = gr.Button("Translate text")
with gr.Row():
output = gr.Textbox(label="Output Text", lines=6)
btn.click(
translate,
inputs=[input_text, src_lang, target_lang, show_latin],
outputs=output,
)
examples = gr.Examples(
examples=examples_inputs,
inputs=[input_text, src_lang, target_lang],
fn=translate,
outputs=output,
cache_examples=True
)
with gr.Row():
gr.Markdown(disclaimer)
if __name__ == "__main__":
demo.queue(max_size=20).launch()