nllb / app.py
davanstrien's picture
davanstrien HF staff
Refactor translate function to split input text into smaller chunks
15ccfd9
raw
history blame
2.17 kB
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores import code_mapping
import platform
device = "cpu" if platform.system() == "Darwin" else "cuda"
MODEL_NAME = "facebook/nllb-200-distilled-600M"
code_mapping = dict(sorted(code_mapping.items(), key=lambda item: item[1]))
flores_codes = list(code_mapping.keys())
def load_model():
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer
model, tokenizer = load_model()
@spaces.GPU
def _translate(text: str, src_lang: str, tgt_lang: str):
source = code_mapping[src_lang]
target = code_mapping[tgt_lang]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source,
tgt_lang=target,
device=device,
)
output = translator(text, max_length=400)
return output[0]["translation_text"]
def translate(text: str, src_lang: str, tgt_lang: str):
# split the input text into smaller chunks
# split first on newlines
outputs = ""
chunks = text.split("\n")
for chunk in chunks:
# run the translation on each chunk
output = _translate(chunk, src_lang, tgt_lang)
outputs += output + "\n"
return outputs
description = """
No Language Left Behind (NLLB) is a series of open-source models aiming to provide high-quality translations between 200 language."""
with gr.Blocks() as demo:
gr.Markdown("# No Language Left Behind (NLLB) Translation Demo")
gr.Markdown(description)
with gr.Row():
src_lang = gr.Dropdown(label="Source Language", choices=flores_codes)
target_lang = gr.Dropdown(label="Target Language", choices=flores_codes)
with gr.Row():
input_text = gr.Textbox(label="Input Text", lines=6)
with gr.Row():
btn = gr.Button("Translate text")
with gr.Row():
output = gr.Textbox(label="Output Text", lines=6)
btn.click(
translate,
inputs=[input_text, src_lang, target_lang],
outputs=output,
)
demo.launch()