File size: 2,852 Bytes
c781d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import gradio as gr

# Define the model and tokenizer
model_name = "ai4bharat/indictrans2-indic-indic-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)

# Define the language codes
LANGUAGES = {
    "Assamese (asm_Beng)": "asm_Beng",
    "Kashmiri (kas_Arab)": "kas_Arab",
    "Punjabi (pan_Guru)": "pan_Guru",
    "Bengali (ben_Beng)": "ben_Beng",
    "Kashmiri (kas_Deva)": "kas_Deva",
    "Sanskrit (san_Deva)": "san_Deva",
    "Bodo (brx_Deva)": "brx_Deva",
    "Maithili (mai_Deva)": "mai_Deva",
    "Santali (sat_Olck)": "sat_Olck",
    "Dogri (doi_Deva)": "doi_Deva",
    "Malayalam (mal_Mlym)": "mal_Mlym",
    "Sindhi (snd_Arab)": "snd_Arab",
    "English (eng_Latn)": "eng_Latn",
    "Marathi (mar_Deva)": "mar_Deva",
    "Sindhi (snd_Deva)": "snd_Deva",
    "Konkani (gom_Deva)": "gom_Deva",
    "Manipuri (mni_Beng)": "mni_Beng",
    "Tamil (tam_Taml)": "tam_Taml",
    "Gujarati (guj_Gujr)": "guj_Gujr",
    "Manipuri (mni_Mtei)": "mni_Mtei",
    "Telugu (tel_Telu)": "tel_Telu",
    "Hindi (hin_Deva)": "hin_Deva",
    "Nepali (npi_Deva)": "npi_Deva",
    "Urdu (urd_Arab)": "urd_Arab",
    "Kannada (kan_Knda)": "kan_Knda",
    "Odia (ory_Orya)": "ory_Orya",
}

# Define the translation function
def translate(text, src_lang, tgt_lang):
    batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        generated_tokens = model.generate(
            **inputs,
            use_cache=True,
            min_length=0,
            max_length=256,
            num_beams=5,
            num_return_sequences=1,
        )
    with tokenizer.as_target_tokenizer():
        generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    return generated_text

# Create a Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("### Indic Translations")
    input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate")
    src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys()))
    tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys()))
    translate_button = gr.Button("Translate")
    translation_output = gr.Textbox(label="Translation", interactive=False)

    @translate_button.click
    def on_translate(text, src_lang, tgt_lang):
        translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
        translation_output.value = translation

demo.launch()