File size: 4,077 Bytes
71ad94c
6d1e318
0e04945
 
df369fa
0e04945
71ad94c
0e04945
 
 
 
 
 
 
71ad94c
6d1e318
 
 
 
 
 
 
 
 
 
 
 
 
71ad94c
0e04945
79868fd
0e04945
 
 
 
 
 
 
 
 
 
 
 
 
 
6d1e318
 
0e04945
6d1e318
0e04945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0702769
0e04945
 
af923d2
0e04945
 
 
 
af923d2
 
0e04945
 
 
 
 
 
 
 
 
 
 
 
af923d2
 
0e04945
3134ca6
af923d2
0e04945
 
79868fd
3134ca6
af923d2
 
 
0e04945
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import nltk
nltk.download('punkt_tab')
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import torch

# Load IndicTrans2 model
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

def split_text_into_batches(text, max_tokens_per_batch):
    sentences = nltk.sent_tokenize(text)  # Tokenize text into sentences
    batches = []
    current_batch = ""
    for sentence in sentences:
        if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch:  # Add 1 for space
            current_batch += sentence + " "  # Add sentence to current batch
        else:
            batches.append(current_batch.strip())  # Add current batch to batches list
            current_batch = sentence + " "  # Start a new batch with the current sentence
    if current_batch:
        batches.append(current_batch.strip())  # Add the last batch
    return batches

def run_translation(file_uploader, input_text, source_language, target_language):
    if file_uploader is not None:
        with open(file_uploader.name, "r", encoding="utf-8") as file:
            input_text = file.read()

    # Language mapping
    lang_code_map = {
        "Hindi": "hin_Deva",
        "Punjabi": "pan_Guru",
        "English": "eng_Latn",
    }

    src_lang = lang_code_map[source_language]
    tgt_lang = lang_code_map[target_language]

    max_tokens_per_batch = 256
    batches = split_text_into_batches(input_text, max_tokens_per_batch)
    translated_text = ""

    for batch in batches:
        batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
        inputs = tokenizer(
            batch_preprocessed,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer.as_target_tokenizer():
            decoded_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
        translated_text += " ".join(translations) + " "

    output = translated_text.strip()
    _output_name = "result.txt"
    with open(_output_name, "w", encoding="utf-8") as out_file:
        out_file.write(output)

    return output, _output_name

# Define Gradio UI
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            file_uploader = gr.File(label="Upload a text file (Optional)")
            input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...")
            source_language = gr.Dropdown(
                label="Source language",
                choices=["Hindi", "Punjabi", "English"],
                value="Hindi",
            )
            target_language = gr.Dropdown(
                label="Target language",
                choices=["Hindi", "Punjabi", "English"],
                value="English",
            )
            btn = gr.Button("Translate")
        with gr.Column():
            output_text = gr.Textbox(label="Translated text", lines=5)
            output_file = gr.File(label="Translated text file")

    btn.click(
        fn=run_translation,
        inputs=[file_uploader, input_text, source_language, target_language],
        outputs=[output_text, output_file],
    )

if __name__ == "__main__":
    demo.launch(debug=True)