Spaces:
Sleeping
Sleeping
File size: 4,077 Bytes
71ad94c 6d1e318 0e04945 df369fa 0e04945 71ad94c 0e04945 71ad94c 6d1e318 71ad94c 0e04945 79868fd 0e04945 6d1e318 0e04945 6d1e318 0e04945 0702769 0e04945 af923d2 0e04945 af923d2 0e04945 af923d2 0e04945 3134ca6 af923d2 0e04945 79868fd 3134ca6 af923d2 0e04945 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import nltk
nltk.download('punkt_tab')
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import torch
# Load IndicTrans2 model
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
def split_text_into_batches(text, max_tokens_per_batch):
sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
batches = []
current_batch = ""
for sentence in sentences:
if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space
current_batch += sentence + " " # Add sentence to current batch
else:
batches.append(current_batch.strip()) # Add current batch to batches list
current_batch = sentence + " " # Start a new batch with the current sentence
if current_batch:
batches.append(current_batch.strip()) # Add the last batch
return batches
def run_translation(file_uploader, input_text, source_language, target_language):
if file_uploader is not None:
with open(file_uploader.name, "r", encoding="utf-8") as file:
input_text = file.read()
# Language mapping
lang_code_map = {
"Hindi": "hin_Deva",
"Punjabi": "pan_Guru",
"English": "eng_Latn",
}
src_lang = lang_code_map[source_language]
tgt_lang = lang_code_map[target_language]
max_tokens_per_batch = 256
batches = split_text_into_batches(input_text, max_tokens_per_batch)
translated_text = ""
for batch in batches:
batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer(
batch_preprocessed,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
decoded_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
translated_text += " ".join(translations) + " "
output = translated_text.strip()
_output_name = "result.txt"
with open(_output_name, "w", encoding="utf-8") as out_file:
out_file.write(output)
return output, _output_name
# Define Gradio UI
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
file_uploader = gr.File(label="Upload a text file (Optional)")
input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...")
source_language = gr.Dropdown(
label="Source language",
choices=["Hindi", "Punjabi", "English"],
value="Hindi",
)
target_language = gr.Dropdown(
label="Target language",
choices=["Hindi", "Punjabi", "English"],
value="English",
)
btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(label="Translated text", lines=5)
output_file = gr.File(label="Translated text file")
btn.click(
fn=run_translation,
inputs=[file_uploader, input_text, source_language, target_language],
outputs=[output_text, output_file],
)
if __name__ == "__main__":
demo.launch(debug=True)
|