Spaces:
Runtime error
Runtime error
import gradio as gr | |
import time | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from flores200_codes import flores_codes | |
def load_models(): | |
# build model and tokenizer | |
model_name_dict = { | |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M", | |
"nllb-distilled-1.3B": "facebook/nllb-200-distilled-1.3B", | |
# "nllb-1.3B": "facebook/nllb-200-1.3B", | |
# "nllb-3.3B": "facebook/nllb-200-3.3B", | |
} | |
model_dict = {} | |
for call_name, real_name in model_name_dict.items(): | |
print("\tLoading model: %s" % call_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
tokenizer = AutoTokenizer.from_pretrained(real_name) | |
model_dict[call_name + "_model"] = model | |
model_dict[call_name + "_tokenizer"] = tokenizer | |
return model_dict | |
def translation(model_name, source, target, text): | |
start_time = time.time() | |
source = flores_codes[source] | |
target = flores_codes[target] | |
model = model_dict[model_name + "_model"] | |
tokenizer = model_dict[model_name + "_tokenizer"] | |
translator = pipeline( | |
"translation", | |
model=model, | |
tokenizer=tokenizer, | |
src_lang=source, | |
tgt_lang=target, | |
) | |
# sentence-wise translation | |
sentences = text.split("\n") | |
translated_sentences = [] | |
for sentence in sentences: | |
translated_sentence = translator(sentence, max_length=400)[0][ | |
"translation_text" | |
] | |
translated_sentences.append(translated_sentence) | |
output = "\n".join(translated_sentences) | |
end_time = time.time() | |
# output = translator(text, max_length=400) | |
# full_output = output | |
# output = output[0]["translation_text"] | |
result = { | |
"inference_time": end_time - start_time, | |
"source": source, | |
"target": target, | |
"result": output, | |
# "full_output": full_output, | |
} | |
return result, output | |
if __name__ == "__main__": | |
print("\tinit models") | |
global model_dict | |
model_dict = load_models() | |
# define gradio demo | |
lang_codes = list(flores_codes.keys()) | |
inputs = [ | |
gr.inputs.Radio( | |
[ | |
"nllb-distilled-600M", | |
"nllb-distilled-1.3B", | |
# "nllb-1.3B", | |
# "nllb-3.3B" | |
], | |
label="NLLB Model", | |
default="nllb-distilled-1.3B", | |
), | |
gr.inputs.Dropdown(lang_codes, default="Najdi Arabic", label="Source"), | |
gr.inputs.Dropdown(lang_codes, default="English", label="Target"), | |
gr.inputs.Textbox(lines=5, label="Input text"), | |
] | |
outputs = [ | |
gr.outputs.JSON(label="Metadata"), | |
gr.outputs.Textbox(label="Output text"), | |
] | |
title = "NLLB (No Language Left Behind) demo" | |
demo_status = "Demo is running on CPU" | |
description = f"Using NLLB model, details: https://github.com/facebookresearch/fairseq/tree/nllb.\n{demo_status}" | |
examples = [ | |
["nllb-distilled-1.3B", "Najdi Arabic", "English", "جلست اطفال"], | |
[ | |
"nllb-distilled-600M", | |
"Najdi Arabic", | |
"English", | |
"شد للبيع طابقين مع شرع له نظيف حق غمارتين", | |
], | |
] | |
gr.Interface( | |
translation, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
examples=examples, | |
examples_per_page=50, | |
).launch() | |