File size: 3,316 Bytes
3a0d18e
5a33d64
3a0d18e
 
8913a87
 
 
3a0d18e
aa405d8
3a0d18e
 
 
aa405d8
 
 
 
 
 
 
 
 
 
 
 
 
 
68799de
aa405d8
68799de
aa405d8
 
 
5a33d64
bb6e5e6
 
3a0d18e
 
 
 
 
 
 
68799de
3a0d18e
bb6e5e6
33c47a5
3a0d18e
 
 
bb6e5e6
3a0d18e
68799de
3a0d18e
 
 
7386d6d
3a0d18e
 
aa405d8
3a0d18e
 
 
 
 
 
9247161
 
3a0d18e
 
 
 
 
 
aa405d8
3a0d18e
68799de
 
3a0d18e
25f26e5
aa405d8
3a0d18e
786a7c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch


#code

LANG_CODES = {
    "Bambara": "yo",
    "Zarma": "yo",
}

models = {
    "Bambara": ("Mamadou2727/m2m100_418M-correction", "facebook/m2m100_418M"),
    "Zarma": ("Mamadou2727/m2m100_418M-correction-zarma", "facebook/m2m100_418M")
}

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def load_model(language):
    model_name, tokenizer_name = models[language]
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    return model, tokenizer

def correct(text, language, candidates):
    """
    Correct the text in the selected language
    """
    model, tokenizer = load_model(language)
    model.to(device)
    src = LANG_CODES[language]

    tokenizer.src_lang = src
    tokenizer.tgt_lang = src

    ins = tokenizer(text, return_tensors='pt').to(device)

    gen_args = {
        'return_dict_in_generate': True,
        'output_scores': True,
        'output_hidden_states': True,
        'length_penalty': 0.0,  # don't encourage longer or shorter output
        'num_return_sequences': candidates,
        'num_beams': candidates,
        'forced_bos_token_id': tokenizer.lang_code_to_id[src]
    }

    outs = model.generate(**{**ins, **gen_args})
    output = tokenizer.batch_decode(outs.sequences, skip_special_tokens=True)

    return '\n'.join(output)  # Join the suggestions with new lines and return as a single string

with gr.Blocks() as app:
    markdown = r"""
     # Hasegnan, The First Zarma Spell Checker        
        <img src="https://cdn-uploads.huggingface.co/production/uploads/63cc1d4bf488db9bb3c6449e/AtOKLAaL5kt0VhRsxE0vf.png" width="500" height="300">
        
        This is a beta version of the Zarma Spell Checker and Inlcude Bambara spellchecking as well.
            
        ## Intended Uses & Limitations        
        
        This model is intended for academic research and practical applications in machine translation. It can be used to translate French text to Zarma and vice versa. Users should note that the model's performance may vary based on the complexity and context of the input text.
            
        ## Authors:
        The project, **FERIJI**, was curated by **Elysabhete Ibrahim Amadou**, **Habibatou Abdoulaye Alfari**, **Adwoa Bremang**, **Dennis Owusu**, **Mamadou K. KEITA** and **Dr Christopher Homan**, with the aim to enhance linguistic studies for Zarma.
        
    """

    with gr.Row():
        gr.Markdown(markdown)
        with gr.Column():
            input_text = gr.components.Textbox(lines=7, label="Input Text", value="")
            language = gr.Dropdown(label="Language", choices=["Bambara", "Zarma"], value="Bambara")
            return_seqs = gr.Slider(label="Number of return sequences", value=1, minimum=1, maximum=12, step=1)
            correction_suggestions = gr.Textbox(lines=7, label="Correction Suggestions")
            final_output = gr.Textbox(lines=7, label="Final Output", placeholder="Copy your preferred correction here...")

            translate_btn = gr.Button("Corrige")
            translate_btn.click(correct, inputs=[input_text, language, return_seqs], outputs=correction_suggestions)

app.launch(share=True)