import os import pandas as pd from transformers import NllbTokenizer, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, Accelerator import torch import gradio as gr def fix_tokenizer(tokenizer, new_lang='tok_Latn'): """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[new_lang] = old_len-1 tokenizer.id_to_lang_code[old_len-1] = new_lang # always move "mask" to the last position tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} if new_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(new_lang) # clear the added token encoder; otherwise a new token may end up there by mistake tokenizer.added_tokens_encoder = {} tokenizer.added_tokens_decoder = {} accelerator = Accelerator(cpu=True) model_load_name = 'RedDev/nllb-deu-tok-v1' model = AutoModelForSeq2SeqLM.from_pretrained(model_load_name).cuda() tokenizer = NllbTokenizer.from_pretrained(model_load_name) fix_tokenizer(tokenizer) def translate(text, src_lang='deu_Latn', tgt_lang='tok_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs): tokenizer.src_lang = src_lang tokenizer.tgt_lang = tgt_lang inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) result = model.generate( **inputs.to(model.device), forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), max_new_tokens=int(a + b * inputs.input_ids.shape[1]), num_beams=num_beams, **kwargs ) return tokenizer.batch_decode(result, skip_special_tokens=True) LANG_CODES = { "Deutsch":"deu_Latn", "toki pona":"tok_Latn" } if __name__ == '__main__': print('\tinit models') # define gradio demo lang_codes = list(LANG_CODES.keys()) #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'), inputs = [gr.inputs.Dropdown(lang_codes, default='Deutsch', label='Source'), gr.inputs.Dropdown(lang_codes, default='toki pona', label='Target'), gr.inputs.Textbox(lines=10, label="Input text"), ] outputs = gr.outputs.JSON() title = "NLLB Deutsch toki pona" description = "Details: https://github.com/facebookresearch/fairseq/tree/nllb." examples = [ ['Deutsch', 'toki pona', 'Er stand auf und sah zum Fenster hinaus. Da war es hell und er war glücklich. Aber es war Nacht also half er seinen Großeltern beim Einkauf, danach gibt es Eiscreme und er überwachte den garten.'] ] gr.Interface(translate, inputs, outputs, title=title, description=description, examples=examples, examples_per_page=50, ).launch()