kusht55's picture
Create app.py
c781d20 verified
raw
history blame
2.85 kB
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import gradio as gr
# Define the model and tokenizer
model_name = "ai4bharat/indictrans2-indic-indic-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
# Define the language codes
LANGUAGES = {
"Assamese (asm_Beng)": "asm_Beng",
"Kashmiri (kas_Arab)": "kas_Arab",
"Punjabi (pan_Guru)": "pan_Guru",
"Bengali (ben_Beng)": "ben_Beng",
"Kashmiri (kas_Deva)": "kas_Deva",
"Sanskrit (san_Deva)": "san_Deva",
"Bodo (brx_Deva)": "brx_Deva",
"Maithili (mai_Deva)": "mai_Deva",
"Santali (sat_Olck)": "sat_Olck",
"Dogri (doi_Deva)": "doi_Deva",
"Malayalam (mal_Mlym)": "mal_Mlym",
"Sindhi (snd_Arab)": "snd_Arab",
"English (eng_Latn)": "eng_Latn",
"Marathi (mar_Deva)": "mar_Deva",
"Sindhi (snd_Deva)": "snd_Deva",
"Konkani (gom_Deva)": "gom_Deva",
"Manipuri (mni_Beng)": "mni_Beng",
"Tamil (tam_Taml)": "tam_Taml",
"Gujarati (guj_Gujr)": "guj_Gujr",
"Manipuri (mni_Mtei)": "mni_Mtei",
"Telugu (tel_Telu)": "tel_Telu",
"Hindi (hin_Deva)": "hin_Deva",
"Nepali (npi_Deva)": "npi_Deva",
"Urdu (urd_Arab)": "urd_Arab",
"Kannada (kan_Knda)": "kan_Knda",
"Odia (ory_Orya)": "ory_Orya",
}
# Define the translation function
def translate(text, src_lang, tgt_lang):
batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return generated_text
# Create a Gradio interface
with gr.Blocks() as demo:
gr.Markdown("### Indic Translations")
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate")
src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys()))
tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys()))
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translation", interactive=False)
@translate_button.click
def on_translate(text, src_lang, tgt_lang):
translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
translation_output.value = translation
demo.launch()