Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
from transformers import NllbTokenizer, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, Accelerator | |
import torch | |
import gradio as gr | |
def fix_tokenizer(tokenizer, new_lang='tok_Latn'): | |
""" Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} | |
tokenizer.added_tokens_decoder = {} | |
accelerator = Accelerator(cpu=True) | |
model_load_name = 'RedDev/nllb-deu-tok-v1' | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_load_name).cuda() | |
tokenizer = NllbTokenizer.from_pretrained(model_load_name) | |
fix_tokenizer(tokenizer) | |
def translate(text, src_lang='deu_Latn', tgt_lang='tok_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs): | |
tokenizer.src_lang = src_lang | |
tokenizer.tgt_lang = tgt_lang | |
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) | |
result = model.generate( | |
**inputs.to(model.device), | |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), | |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), | |
num_beams=num_beams, | |
**kwargs | |
) | |
return tokenizer.batch_decode(result, skip_special_tokens=True) | |
LANG_CODES = { | |
"Deutsch":"deu_Latn", | |
"toki pona":"tok_Latn" | |
} | |
if __name__ == '__main__': | |
print('\tinit models') | |
# define gradio demo | |
lang_codes = list(LANG_CODES.keys()) | |
#inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'), | |
inputs = [gr.inputs.Dropdown(lang_codes, default='Deutsch', label='Source'), | |
gr.inputs.Dropdown(lang_codes, default='toki pona', label='Target'), | |
gr.inputs.Textbox(lines=10, label="Input text"), | |
] | |
outputs = gr.outputs.JSON() | |
title = "NLLB Deutsch toki pona" | |
description = "Details: https://github.com/facebookresearch/fairseq/tree/nllb." | |
examples = [ | |
['Deutsch', 'toki pona', 'Er stand auf und sah zum Fenster hinaus. Da war es hell und er war glücklich. Aber es war Nacht also half er seinen Großeltern beim Einkauf, danach gibt es Eiscreme und er überwachte den garten.'] | |
] | |
gr.Interface(translate, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
examples=examples, | |
examples_per_page=50, | |
).launch() | |