import spaces import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from flores import code_mapping import platform device = "cpu" if platform.system() == "Darwin" else "cuda" MODEL_NAME = "facebook/nllb-200-distilled-600M" code_mapping = dict(sorted(code_mapping.items(), key=lambda item: item[1])) flores_codes = list(code_mapping.keys()) def load_model(): model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) return model, tokenizer model, tokenizer = load_model() @spaces.GPU def _translate(text: str, src_lang: str, tgt_lang: str): source = code_mapping[src_lang] target = code_mapping[tgt_lang] translator = pipeline( "translation", model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target, device=device, ) output = translator(text, max_length=400) return output[0]["translation_text"] def translate(text: str, src_lang: str, tgt_lang: str): # split the input text into smaller chunks # split first on newlines outputs = "" paragraph_chunks = text.split("\n") for chunk in paragraph_chunks: # check if the chunk is too long if len(chunk) > 500: # split on full stops sentence_chunks = chunk.split(".") for sentence in sentence_chunks: outputs += f"{_translate(sentence, src_lang, tgt_lang)}. " else: outputs += _translate(chunk, src_lang, tgt_lang) + "\n\n" return outputs description = """ No Language Left Behind (NLLB) is a series of open-source models aiming to provide high-quality translations between 200 language.""" with gr.Blocks() as demo: gr.Markdown("# No Language Left Behind (NLLB) Translation Demo") gr.Markdown(description) with gr.Row(): src_lang = gr.Dropdown(label="Source Language", choices=flores_codes) target_lang = gr.Dropdown(label="Target Language", choices=flores_codes) with gr.Row(): input_text = gr.Textbox(label="Input Text", lines=6) with gr.Row(): btn = gr.Button("Translate text") with gr.Row(): output = gr.Textbox(label="Output Text", lines=6) btn.click( translate, inputs=[input_text, src_lang, target_lang], outputs=output, ) demo.launch()