import torch import gradio as gr import time import asyncio from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from flores200_codes import flores_codes # Load models and tokenizers once during initialization async def load_models(): model_name_dict = { "nllb-distilled-600M": "facebook/nllb-200-distilled-600M", } model_dict = {} for call_name, real_name in model_name_dict.items(): print("\tLoading model:", call_name) model = await asyncio.to_thread(AutoModelForSeq2SeqLM.from_pretrained, real_name) tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, real_name) model_dict[call_name] = { "model": model, "tokenizer": tokenizer, } return model_dict # Translate text using preloaded models and tokenizers def translate_text(source_lang, target_lang, input_text, model_dict): model_name = "nllb-distilled-600M" start_time = time.time() source_code = flores_codes[source_lang] target_code = flores_codes[target_lang] if model_name in model_dict: model = model_dict[model_name]["model"] tokenizer = model_dict[model_name]["tokenizer"] translator = pipeline( "translation", model=model, tokenizer=tokenizer, src_lang=source_code, tgt_lang=target_code, ) translated_output = translator(input_text, max_length=400) end_time = time.time() translated_result = { "inference_time": end_time - start_time, "source": source_lang, "target": target_lang, "result": translated_output[0]["translation_text"], } return translated_result else: raise KeyError(f"Model '{model_name}' not found in model_dict") async def main(): print("\tInitializing models") # Load models and tokenizers model_dict = await load_models() lang_codes = list(flores_codes.keys()) inputs = [ gr.inputs.Dropdown(lang_codes, default="English", label="Source Language"), gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target Language"), gr.inputs.Textbox(lines=5, label="Input Text"), ] outputs = gr.outputs.JSON() title = "Masterful Translator" app_description = ( "This is a beta version of the Masterful Translator that utilizes pre-trained language models for translation." ) examples = [["English", "Nepali", "Hello, how are you?"]] gr.Interface( fn=translate_text, inputs=inputs, outputs=outputs, title=title, description=app_description, examples=examples, examples_per_page=50, ).launch() if __name__ == "__main__": asyncio.run(main())