import torch import re from collections.abc import Iterator from converter import Converter from sacremoses import MosesPunctNormalizer from transformers import AutoModelForSeq2SeqLM, NllbTokenizer import time import gradio as gr code_mapping = { "Russian": "rus_Cyrl", "English": "eng_Latn", "Shughni": "shu_Cyrl" } source_languages = list(code_mapping.keys())[::-1] target_languages = code_mapping.keys() punct_normalizer = MosesPunctNormalizer(lang="en") converter = Converter(dest="cyr", settings="auto", lang="sgh") converter_latn = Converter(dest="lat", settings="auto", lang="sgh") start_time = time.time() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForSeq2SeqLM.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1").to(device).eval() tokenizer = NllbTokenizer.from_pretrained("Novokshanov/nllb-200-distilled-600M-Shughni-v1", src_lang='rus_Cyrl') model2 = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device).eval() tokenizer2 = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang='rus_Cyrl') load_time = time.time() - start_time print(f"Model loaded in {load_time:.2f} seconds") def sh_ru_translate(paragraphs, tgt_code): translated_paragraphs = [] for paragraph in paragraphs: splitter = re.compile(r'(?<=[.!?])\s+') sentences = splitter.split(paragraph) translated_sentences = [] for sentence in sentences: input_tokens = ( tokenizer(sentence, return_tensors="pt") .input_ids[0] .cpu() .numpy() .tolist() ) translated_chunk = model.generate( input_ids=torch.tensor([input_tokens]).to(device), forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), max_length=len(input_tokens) + 50, num_return_sequences=1, num_beams=5, no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams renormalize_logits=True, # recompute token probabilities after banning the repetitions ) translated_chunk = tokenizer.decode( translated_chunk[0], skip_special_tokens=True ) translated_sentences.append(translated_chunk) translated_paragraph = " ".join(translated_sentences) translated_paragraphs.append(translated_paragraph) return "\n".join(translated_paragraphs) def ru_en_translate(paragraphs, tgt_code): translated_paragraphs = [] for paragraph in paragraphs: splitter = re.compile(r'(?<=[.!?])\s+') sentences = splitter.split(paragraph) translated_sentences = [] for sentence in sentences: input_tokens = ( tokenizer2(sentence, return_tensors="pt") .input_ids[0] .cpu() .numpy() .tolist() ) translated_chunk = model2.generate( input_ids=torch.tensor([input_tokens]).to(device), forced_bos_token_id=tokenizer2.convert_tokens_to_ids(tgt_code), max_length=len(input_tokens) + 50, num_return_sequences=1, num_beams=5, no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams renormalize_logits=True, # recompute token probabilities after banning the repetitions ) translated_chunk = tokenizer2.decode( translated_chunk[0], skip_special_tokens=True ) translated_sentences.append(translated_chunk) translated_paragraph = " ".join(translated_sentences) translated_paragraphs.append(translated_paragraph) return "\n".join(translated_paragraphs) def translate(text: str, src_lang: str, tgt_lang: str, show_latin: bool=False): src_code = code_mapping[src_lang] tgt_code = code_mapping[tgt_lang] tokenizer.src_lang = src_code tokenizer.tgt_lang = tgt_code # normalizing the punctuation first text = punct_normalizer.normalize(text) if src_lang == 'Shughni': text = converter.convert(text).text paragraphs = text.split("\n") if tgt_lang == 'English' and src_lang == 'Shughni': result_ru = sh_ru_translate(paragraphs, 'rus_Cyrl') result = ru_en_translate(result_ru.split('\n'), 'eng_Latn') elif src_lang == 'English' and tgt_lang == 'Shughni': result_ru = ru_en_translate(paragraphs, 'rus_Cyrl') result = sh_ru_translate(result_ru.split('\n'), 'shu_Cyrl') elif src_lang == 'English' and tgt_lang == 'Russian': result = ru_en_translate(paragraphs, 'rus_Cyrl') elif tgt_lang == 'English' and src_lang == 'Russian': result = ru_en_translate(paragraphs, 'eng_Latn') else: result = sh_ru_translate(paragraphs, tgt_code) if show_latin and tgt_lang == "Shughni": result = converter_latn.convert(result).text return result def swap_langs(src, tgt): return tgt, src description = """