Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from IndicTransToolkit import IndicProcessor | |
import gradio as gr | |
# Define the model and tokenizer | |
model_name = "ai4bharat/indictrans2-indic-indic-1B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
ip = IndicProcessor(inference=True) | |
# Define the language codes | |
LANGUAGES = { | |
"Assamese (asm_Beng)": "asm_Beng", | |
"Kashmiri (kas_Arab)": "kas_Arab", | |
"Punjabi (pan_Guru)": "pan_Guru", | |
"Bengali (ben_Beng)": "ben_Beng", | |
"Kashmiri (kas_Deva)": "kas_Deva", | |
"Sanskrit (san_Deva)": "san_Deva", | |
"Bodo (brx_Deva)": "brx_Deva", | |
"Maithili (mai_Deva)": "mai_Deva", | |
"Santali (sat_Olck)": "sat_Olck", | |
"Dogri (doi_Deva)": "doi_Deva", | |
"Malayalam (mal_Mlym)": "mal_Mlym", | |
"Sindhi (snd_Arab)": "snd_Arab", | |
"English (eng_Latn)": "eng_Latn", | |
"Marathi (mar_Deva)": "mar_Deva", | |
"Sindhi (snd_Deva)": "snd_Deva", | |
"Konkani (gom_Deva)": "gom_Deva", | |
"Manipuri (mni_Beng)": "mni_Beng", | |
"Tamil (tam_Taml)": "tam_Taml", | |
"Gujarati (guj_Gujr)": "guj_Gujr", | |
"Manipuri (mni_Mtei)": "mni_Mtei", | |
"Telugu (tel_Telu)": "tel_Telu", | |
"Hindi (hin_Deva)": "hin_Deva", | |
"Nepali (npi_Deva)": "npi_Deva", | |
"Urdu (urd_Arab)": "urd_Arab", | |
"Kannada (kan_Knda)": "kan_Knda", | |
"Odia (ory_Orya)": "ory_Orya", | |
} | |
# Define the translation function | |
def translate(text, src_lang, tgt_lang): | |
batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with tokenizer.as_target_tokenizer(): | |
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
return generated_text | |
# Create a Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("### Indic Translations") | |
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate") | |
src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys())) | |
tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys())) | |
translate_button = gr.Button("Translate") | |
translation_output = gr.Textbox(label="Translation", interactive=False) | |
def on_translate(text, src_lang, tgt_lang): | |
translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang]) | |
translation_output.value = translation | |
demo.launch() | |