Spaces:
Running
Running
| import torch | |
| import re | |
| from collections.abc import Iterator | |
| from converter import Converter | |
| from sacremoses import MosesPunctNormalizer | |
| from transformers import AutoModelForSeq2SeqLM, NllbTokenizer | |
| import time | |
| import gradio as gr | |
| code_mapping = { | |
| "Russian": "rus_Cyrl", | |
| "English": "eng_Latn", | |
| "Shughni": "shu_Cyrl" | |
| } | |
| source_languages = list(code_mapping.keys())[::-1] | |
| target_languages = code_mapping.keys() | |
| punct_normalizer = MosesPunctNormalizer(lang="en") | |
| converter = Converter(dest="cyr", settings="auto", lang="sgh") | |
| converter_latn = Converter(dest="lat", settings="auto", lang="sgh") | |
| start_time = time.time() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1").to(device).eval() | |
| tokenizer = NllbTokenizer.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1", src_lang='rus_Cyrl') | |
| model2 = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device).eval() | |
| tokenizer2 = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang='rus_Cyrl') | |
| load_time = time.time() - start_time | |
| print(f"Model loaded in {load_time:.2f} seconds") | |
| def sh_ru_translate(paragraphs, tgt_code): | |
| translated_paragraphs = [] | |
| for paragraph in paragraphs: | |
| splitter = re.compile(r'(?<=[.!?])\s+') | |
| sentences = splitter.split(paragraph) | |
| translated_sentences = [] | |
| for sentence in sentences: | |
| input_tokens = ( | |
| tokenizer(sentence, return_tensors="pt") | |
| .input_ids[0] | |
| .cpu() | |
| .numpy() | |
| .tolist() | |
| ) | |
| translated_chunk = model.generate( | |
| input_ids=torch.tensor([input_tokens]).to(device), | |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), | |
| max_length=len(input_tokens) + 50, | |
| num_return_sequences=1, | |
| num_beams=5, | |
| no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams | |
| renormalize_logits=True, # recompute token probabilities after banning the repetitions | |
| ) | |
| translated_chunk = tokenizer.decode( | |
| translated_chunk[0], skip_special_tokens=True | |
| ) | |
| translated_sentences.append(translated_chunk) | |
| translated_paragraph = " ".join(translated_sentences) | |
| translated_paragraphs.append(translated_paragraph) | |
| return "\n".join(translated_paragraphs) | |
| def ru_en_translate(paragraphs, tgt_code): | |
| translated_paragraphs = [] | |
| for paragraph in paragraphs: | |
| splitter = re.compile(r'(?<=[.!?])\s+') | |
| sentences = splitter.split(paragraph) | |
| translated_sentences = [] | |
| for sentence in sentences: | |
| input_tokens = ( | |
| tokenizer2(sentence, return_tensors="pt") | |
| .input_ids[0] | |
| .cpu() | |
| .numpy() | |
| .tolist() | |
| ) | |
| translated_chunk = model2.generate( | |
| input_ids=torch.tensor([input_tokens]).to(device), | |
| forced_bos_token_id=tokenizer2.convert_tokens_to_ids(tgt_code), | |
| max_length=len(input_tokens) + 50, | |
| num_return_sequences=1, | |
| num_beams=5, | |
| no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams | |
| renormalize_logits=True, # recompute token probabilities after banning the repetitions | |
| ) | |
| translated_chunk = tokenizer2.decode( | |
| translated_chunk[0], skip_special_tokens=True | |
| ) | |
| translated_sentences.append(translated_chunk) | |
| translated_paragraph = " ".join(translated_sentences) | |
| translated_paragraphs.append(translated_paragraph) | |
| return "\n".join(translated_paragraphs) | |
| def translate(text: str, src_lang: str, tgt_lang: str, show_latin: bool=False): | |
| src_code = code_mapping[src_lang] | |
| tgt_code = code_mapping[tgt_lang] | |
| tokenizer.src_lang = src_code | |
| tokenizer.tgt_lang = tgt_code | |
| # normalizing the punctuation first | |
| text = punct_normalizer.normalize(text) | |
| if src_lang == 'Shughni': | |
| text = converter.convert(text).text | |
| paragraphs = text.split("\n") | |
| if tgt_lang == 'English' and src_lang == 'Shughni': | |
| result_ru = sh_ru_translate(paragraphs, 'rus_Cyrl') | |
| result = ru_en_translate(result_ru.split('\n'), 'eng_Latn') | |
| elif src_lang == 'English' and tgt_lang == 'Shughni': | |
| result_ru = ru_en_translate(paragraphs, 'rus_Cyrl') | |
| result = sh_ru_translate(result_ru.split('\n'), 'shu_Cyrl') | |
| elif src_lang == 'English' and tgt_lang == 'Russian': | |
| result = ru_en_translate(paragraphs, 'rus_Cyrl') | |
| elif tgt_lang == 'English' and src_lang == 'Russian': | |
| result = ru_en_translate(paragraphs, 'eng_Latn') | |
| else: | |
| result = sh_ru_translate(paragraphs, tgt_code) | |
| if show_latin and tgt_lang == "Shughni": | |
| result = converter_latn.convert(result).text | |
| return result | |
| def swap_langs(src, tgt): | |
| return tgt, src | |
| description = """ | |
| <div style="text-align: center;"> | |
| <h1>Shughni - Russian - English Translator</h1> | |
| </div> | |
| """ | |
| disclaimer = """ | |
| ## Disclaimer | |
| This is a demo of a work-in-progress translation project. Translation from and into English is implemented through Russian with a stand alone facebook/nllb-200-distilled-600M model. | |
| """ | |
| examples_inputs = [["wāδ-ен сат wи-рд-ен лу̊д: потх̌о лу̊д иди, ту бойад шич тар wи хез сāwи.","Shughni","Russian"], | |
| ["Карамшо прочитал интересную книгу.", "Russian", "Shughni"], | |
| ["Азими ху х̌āр тарк чӯд.", "Shughni", "Russian"], | |
| ["Азим вышел из дома.", "Russian", "Shughni"], | |
| ["Асал х̌увдқати алалаш сут.", "Shughni", "Russian"], | |
| ["Лола спела красивую песню.", "Russian", "Shughni"], | |
| ["аҷойиб китоб х̌êйдум.", "Shughni", "English"], | |
| ["You read an interesting book.", "English", "Shughni"], | |
| ] | |
| css = """ | |
| #swap-button { | |
| height: 90px !important; | |
| width: 10px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| src_lang = gr.Dropdown(label="Source Language", choices=source_languages) | |
| swap_button = gr.Button("⇄", elem_id="swap-button") | |
| target_lang = gr.Dropdown(label="Target Language", choices=target_languages) | |
| swap_button.click( | |
| fn=swap_langs, | |
| inputs=[src_lang, target_lang], | |
| outputs=[src_lang, target_lang], | |
| ) | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Input Text", lines=6) | |
| with gr.Row(): | |
| show_latin = gr.Checkbox(label="Show Shughni Output in Latin", value=False) | |
| with gr.Row(): | |
| btn = gr.Button("Translate text") | |
| with gr.Row(): | |
| output = gr.Textbox(label="Output Text", lines=6) | |
| btn.click( | |
| translate, | |
| inputs=[input_text, src_lang, target_lang, show_latin], | |
| outputs=output, | |
| ) | |
| examples = gr.Examples( | |
| examples=examples_inputs, | |
| inputs=[input_text, src_lang, target_lang], | |
| fn=translate, | |
| outputs=output, | |
| cache_examples=True | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(disclaimer) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |