import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM from huggingface_hub import InferenceClient # Define tokenizer special_tokens = ["", "", "", ""] nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़") char_vocab = special_tokens + nepali_chars char2id = {char: idx for idx, char in enumerate(char_vocab)} id2char = {idx: char for char, idx in char2id.items()} class CharTokenizer: def __init__(self, char2id, id2char): self.char2id = char2id self.id2char = id2char def encode(self, text): return [self.char2id.get(char, self.char2id[""]) for char in text] def decode(self, tokens): return "".join([self.id2char.get(token, "") for token in tokens]) def decodex(self, tokens): decoded_string = "" for i, token in enumerate(tokens): char = self.id2char.get(token, "") if char == "": if i == 0 or i == len(tokens) - 1 or self.id2char.get(tokens[i - 1], "") == "": decoded_string += "" else: decoded_string += " " elif char == "": pass else: decoded_string += char return decoded_string # Initialize tokenizer tokenizer = CharTokenizer(char2id, id2char) # Load T5 model model_name = "bashyaldhiraj2067/t5_char_nepali" # model_name = "bashyaldhiraj2067/attention_epoch_2_xpu_64_copymechanism_nepali_GEC_new_21" model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def correct_text(input_text, max_length=256): input_ids = tokenizer.encode(input_text) input_tensor = torch.tensor([input_ids]) with torch.no_grad(): outputs = model.generate( input_tensor, max_length=max_length, return_dict_in_generate=True ) generated_tokens = outputs.sequences[0].tolist() return tokenizer.decodex(generated_tokens) # Gradio interface demo = gr.Interface( fn=correct_text, inputs=[gr.Textbox(label="Enter Nepali Text"), gr.Slider(50, 256, step=10, label="Max Length")], outputs=gr.Textbox(label="Corrected Text"), title="Nepali Text Correction", description="Enter text with errors and get corrected output using a T5 model trained on Nepali text.", ) if __name__ == "__main__": demo.launch()