File size: 2,647 Bytes
bc5a3a6
5a7f978
 
bc5a3a6
 
5a7f978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5a3a6
5a7f978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a21e33
 
5a7f978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5a3a6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM
from huggingface_hub import InferenceClient

# Define tokenizer
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़")
char_vocab = special_tokens + nepali_chars

char2id = {char: idx for idx, char in enumerate(char_vocab)}
id2char = {idx: char for char, idx in char2id.items()}

class CharTokenizer:
    def __init__(self, char2id, id2char):
        self.char2id = char2id
        self.id2char = id2char

    def encode(self, text):
        return [self.char2id.get(char, self.char2id["<unk>"]) for char in text]

    def decode(self, tokens):
        return "".join([self.id2char.get(token, "<unk>") for token in tokens])

    def decodex(self, tokens):
        decoded_string = ""
        for i, token in enumerate(tokens):
            char = self.id2char.get(token, "<unk>")
            if char == "<unk>":
                if i == 0 or i == len(tokens) - 1 or self.id2char.get(tokens[i - 1], "<unk>") == "<unk>":
                    decoded_string += ""
                else:
                    decoded_string += " "
            elif char == "<pad>":
                pass
            else:
                decoded_string += char
        return decoded_string

# Initialize tokenizer
tokenizer = CharTokenizer(char2id, id2char)

# Load T5 model
model_name = "bashyaldhiraj2067/t5_char_nepali"
# model_name = "bashyaldhiraj2067/attention_epoch_2_xpu_64_copymechanism_nepali_GEC_new_21"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def correct_text(input_text, max_length=256):
    input_ids = tokenizer.encode(input_text)
    input_tensor = torch.tensor([input_ids])
    
    with torch.no_grad():
        outputs = model.generate(
            input_tensor,
            max_length=max_length,
            return_dict_in_generate=True
        )
    
    generated_tokens = outputs.sequences[0].tolist()
    return tokenizer.decodex(generated_tokens)

# Gradio interface
demo = gr.Interface(
    fn=correct_text,
    inputs=[gr.Textbox(label="Enter Nepali Text"), gr.Slider(50, 256, step=10, label="Max Length")],
    outputs=gr.Textbox(label="Corrected Text"),
    title="Nepali Text Correction",
    description="Enter text with errors and get corrected output using a T5 model trained on Nepali text.",
)

if __name__ == "__main__":
    demo.launch()