Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM | |
from huggingface_hub import InferenceClient | |
# Define tokenizer | |
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"] | |
nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़") | |
char_vocab = special_tokens + nepali_chars | |
char2id = {char: idx for idx, char in enumerate(char_vocab)} | |
id2char = {idx: char for char, idx in char2id.items()} | |
class CharTokenizer: | |
def __init__(self, char2id, id2char): | |
self.char2id = char2id | |
self.id2char = id2char | |
def encode(self, text): | |
return [self.char2id.get(char, self.char2id["<unk>"]) for char in text] | |
def decode(self, tokens): | |
return "".join([self.id2char.get(token, "<unk>") for token in tokens]) | |
def decodex(self, tokens): | |
decoded_string = "" | |
for i, token in enumerate(tokens): | |
char = self.id2char.get(token, "<unk>") | |
if char == "<unk>": | |
if i == 0 or i == len(tokens) - 1 or self.id2char.get(tokens[i - 1], "<unk>") == "<unk>": | |
decoded_string += "" | |
else: | |
decoded_string += " " | |
elif char == "<pad>": | |
pass | |
else: | |
decoded_string += char | |
return decoded_string | |
# Initialize tokenizer | |
tokenizer = CharTokenizer(char2id, id2char) | |
# Load T5 model | |
model_name = "bashyaldhiraj2067/t5_char_nepali" | |
# model_name = "bashyaldhiraj2067/attention_epoch_2_xpu_64_copymechanism_nepali_GEC_new_21" | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def correct_text(input_text, max_length=256): | |
input_ids = tokenizer.encode(input_text) | |
input_tensor = torch.tensor([input_ids]) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_tensor, | |
max_length=max_length, | |
return_dict_in_generate=True | |
) | |
generated_tokens = outputs.sequences[0].tolist() | |
return tokenizer.decodex(generated_tokens) | |
# Gradio interface | |
demo = gr.Interface( | |
fn=correct_text, | |
inputs=[gr.Textbox(label="Enter Nepali Text"), gr.Slider(50, 256, step=10, label="Max Length")], | |
outputs=gr.Textbox(label="Corrected Text"), | |
title="Nepali Text Correction", | |
description="Enter text with errors and get corrected output using a T5 model trained on Nepali text.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |