RedDev's picture
Update app.py
c602df2
import os
import pandas as pd
from transformers import NllbTokenizer, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, Accelerator
import torch
import gradio as gr
def fix_tokenizer(tokenizer, new_lang='tok_Latn'):
""" Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
tokenizer.lang_code_to_id[new_lang] = old_len-1
tokenizer.id_to_lang_code[old_len-1] = new_lang
# always move "mask" to the last position
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
if new_lang not in tokenizer._additional_special_tokens:
tokenizer._additional_special_tokens.append(new_lang)
# clear the added token encoder; otherwise a new token may end up there by mistake
tokenizer.added_tokens_encoder = {}
tokenizer.added_tokens_decoder = {}
accelerator = Accelerator(cpu=True)
model_load_name = 'RedDev/nllb-deu-tok-v1'
model = AutoModelForSeq2SeqLM.from_pretrained(model_load_name).cuda()
tokenizer = NllbTokenizer.from_pretrained(model_load_name)
fix_tokenizer(tokenizer)
def translate(text, src_lang='deu_Latn', tgt_lang='tok_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs):
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length)
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
num_beams=num_beams,
**kwargs
)
return tokenizer.batch_decode(result, skip_special_tokens=True)
LANG_CODES = {
"Deutsch":"deu_Latn",
"toki pona":"tok_Latn"
}
if __name__ == '__main__':
print('\tinit models')
# define gradio demo
lang_codes = list(LANG_CODES.keys())
#inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
inputs = [gr.inputs.Dropdown(lang_codes, default='Deutsch', label='Source'),
gr.inputs.Dropdown(lang_codes, default='toki pona', label='Target'),
gr.inputs.Textbox(lines=10, label="Input text"),
]
outputs = gr.outputs.JSON()
title = "NLLB Deutsch toki pona"
description = "Details: https://github.com/facebookresearch/fairseq/tree/nllb."
examples = [
['Deutsch', 'toki pona', 'Er stand auf und sah zum Fenster hinaus. Da war es hell und er war glücklich. Aber es war Nacht also half er seinen Großeltern beim Einkauf, danach gibt es Eiscreme und er überwachte den garten.']
]
gr.Interface(translate,
inputs,
outputs,
title=title,
description=description,
examples=examples,
examples_per_page=50,
).launch()