import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import torch LANGS = ["kin_Latn","eng_Latn"] TASK = "translation" # CKPT = "DigitalUmuganda/Finetuned-NLLB" # MODELS = ["facebook/nllb-200-distilled-600M","DigitalUmuganda/Finetuned-NLLB"] # model = AutoModelForSeq2SeqLM.from_pretrained(CKPT) # tokenizer = AutoTokenizer.from_pretrained(CKPT) device = 0 if torch.cuda.is_available() else -1 #general_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_general_en_kin") education_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_education_en_kin") #tourism_model = AutoModelForSeq2SeqLM.from_pretrained("mbazaNLP/Nllb_finetuned_tourism_en_kin") #MODELS = {"General model":general_model_model,"Education model":education_model,"Tourism model":tourism_model} #MODELS = {"Education model":education_model,"Tourism model":tourism_model} tokenizer = AutoTokenizer.from_pretrained("mbazaNLP/Nllb_finetuned_general_en_kin") # def translate(text, src_lang, tgt_lang, max_length=400): TASK = "translation" device = 0 if torch.cuda.is_available() else -1 def translate(text, source_lang, target_lang, max_length=400): """ Translate text from source language to target language """ # src_lang = choose_language(source_lang) # tgt_lang= choose_language(target_lang) # if src_lang==None: # return "Error: the source langage is incorrect" # elif tgt_lang==None: # return "Error: the target language is incorrect" translation_pipeline = pipeline(TASK, model=education_model, tokenizer=tokenizer, src_lang=source_lang, tgt_lang=target_lang, max_length=max_length, device=device) result = translation_pipeline(text) return result[0]['translation_text'] gradio_ui= gr.Interface( fn=translate, title="NLLB-Education EN-KIN Translation Demo", inputs= [ gr.components.Textbox(label="Text"), gr.components.Dropdown(label="Source Language", choices=LANGS), gr.components.Dropdown(label="Target Language", choices=LANGS), # gr.components.Slider(8, 400, value=400, step=8, label="Max Length") ], outputs=gr.outputs.Textbox(label="Translated text") ) gradio_ui.launch()