import gradio as gr import torch import numpy as np import fasttext import os import urllib import huggingface_hub from transformers import MBartForConditionalGeneration, MBart50Tokenizer MODEL_URL_MYV_MUL = 'slone/mbart-large-51-myv-mul-v1' MODEL_URL_MUL_MYV = 'slone/mbart-large-51-mul-myv-v1' MODEL_URL_LANGID = 'https://huggingface.co/slone/fastText-LID-323/resolve/main/lid.323.ftz' MODEL_PATH_LANGID = 'lid.323.ftz' HF_TOKEN = os.getenv('HF_TOKEN') hf_writer = gr.HuggingFaceDatasetSaver( hf_token=HF_TOKEN, dataset_name="myv-translation-2022-demo-flags-v2", organization="slone", private=True, ) lang_to_code = { 'Эрзянь | Erzya': 'myv_XX', 'Русский | Рузонь | Russian': 'ru_RU', 'Suomi | Суоминь | Finnish': 'fi_FI', 'Deutsch | Немецень | German': 'de_DE', 'Español | Испанонь | Spanish': 'es_XX', 'English | Англань ': 'en_XX', 'हिन्दी | Хинди | Hindi': 'hi_IN', '漢語 | Китаень | Chinese': 'zh_CN', 'Türkçe | Турконь | Turkish': 'tr_TR', 'Українська | Украинань | Ukrainian': 'uk_UA', 'Français | Французонь | French': 'fr_XX', 'العربية | Арабонь | Arabic': 'ar_AR', } def fix_tokenizer(tokenizer, extra_lang='myv_XX'): """Add a new language id to a MBART 50 tokenizer (because it is not serialized) and shift the mask token id.""" old_len = len(tokenizer) - int(extra_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[extra_lang] = old_len-1 tokenizer.id_to_lang_code[old_len-1] = extra_lang tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} if extra_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(extra_lang) tokenizer.added_tokens_encoder = {} def translate( text, model, tokenizer, src='ru_RU', trg='myv_XX', max_length='auto', num_beams=3, repetition_penalty=5.0, train_mode=False, n_out=None, **kwargs ): tokenizer.src_lang = src encoded = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) if max_length == 'auto': max_length = int(32 + 1.5 * encoded.input_ids.shape[1]) if train_mode: model.train() else: model.eval() generated_tokens = model.generate( **encoded.to(model.device), forced_bos_token_id=tokenizer.lang_code_to_id[trg], max_length=max_length, num_beams=num_beams, repetition_penalty=repetition_penalty, # early_stopping=True, num_return_sequences=n_out or 1, **kwargs ) out = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) if isinstance(text, str) and n_out is None: return out[0] return out def translate_rerank( text, model, tokenizer, src='ru_RU', trg='myv_XX', max_length='auto', num_beams=3, repetition_penalty=5.0, train_mode=False, n=5, diversity_penalty=3.0, lang='myv', max_score=0.3, order_penalty=0.01, verbose=False, **kwargs ): texts = translate( text, model, tokenizer, src, trg, max_length=max_length, train_mode=train_mode, repetition_penalty=repetition_penalty, num_beams=n, num_beam_groups=n, diversity_penalty=diversity_penalty, n_out=n, **kwargs ) scores = [get_mean_lang_score(t, lang=lang, max_score=max_score) for t in texts] pen_scores = scores - order_penalty * np.arange(n) if verbose: print(texts) print(scores) print(pen_scores) return texts[np.argmax(pen_scores)] def get_mean_lang_score(text, lang='myv', k=300, max_score=0.3): words = text.split() + [text] res = [] for langs, scores in zip(*langid_model.predict(words, k=k)): d = dict(zip([l[9:] for l in langs], scores)) score = min(d.get(lang, 0), max_score) / max_score res.append(score) # print(res) return np.mean(res) def translate_wrapper(text, src, trg, correct=None): src = lang_to_code.get(src) trg = lang_to_code.get(trg) if src == trg: return 'Please choose two different languages' if src == 'myv_XX': model = model_myv_mul elif trg == 'myv_XX': model = model_mul_myv else: return 'Please translate to or from Erzya' print(text, src, trg) fn = translate_rerank if trg == 'myv_XX' else translate result = fn( text=text, model=model, tokenizer=tokenizer, src=src, trg=trg, ) return result article = """ Те эрзянь кельсэ автоматической васенце ютавтыця. Тонавкстнэ улить – [сёрмадовкссо](https://arxiv.org/abs/2209.09368). Это первый автоматический переводчик для эрзянского языка. Подробности – в [статье](https://arxiv.org/abs/2209.09368). Пожалуйста, оставляйте своё мнение о качестве переводов с помощью кнопок с эмодзи! This is the first automatic translator for the Erzya language. The details are in the [paper](https://arxiv.org/abs/2209.09368). Please leave your feedback about the quality of translations using the buttons with emojis. The code and models for translation can be found in the repository: https://github.com/slone-nlp/myv-nmt. """ fix_instruction = 'Если перевод модели неправильный, впишите сюда правильный текст, снова нажмите "Исполнить", и затем "bad 🙁". ' \ 'Тогда к нам в базу попадёт пометка, что перевод был неверным, и его исправление.' interface = gr.Interface( translate_wrapper, [ gr.Textbox(label="Text / текст", lines=2, placeholder='text to translate / текст ютавтозь'), gr.Dropdown(list(lang_to_code.keys()), type="value", label='source language / васенце кель', value=list(lang_to_code.keys())[0]), gr.Dropdown(list(lang_to_code.keys()), type="value", label='target language / эрявикс кель', value=list(lang_to_code.keys())[1]), gr.Textbox(label="Correct translation", lines=2, placeholder=fix_instruction), ], "text", allow_flagging="manual", flagging_options=["good 🙂", "50/50 😐", "bad 🙁"], flagging_callback=hf_writer, title='Эрзянь ютавтыця | Эрзянский переводчик | Erzya translator', article=article, ) if __name__ == '__main__': model_mul_myv = MBartForConditionalGeneration.from_pretrained(MODEL_URL_MUL_MYV) model_myv_mul = MBartForConditionalGeneration.from_pretrained(MODEL_URL_MYV_MUL) if torch.cuda.is_available(): model_mul_myv.cuda() model_myv_mul.cuda() tokenizer = MBart50Tokenizer.from_pretrained(MODEL_URL_MYV_MUL) fix_tokenizer(tokenizer) if not os.path.exists(MODEL_PATH_LANGID): print('downloading LID model...') urllib.request.urlretrieve(MODEL_URL_LANGID, MODEL_PATH_LANGID) langid_model = fasttext.load_model(MODEL_PATH_LANGID) interface.launch()