File size: 4,964 Bytes
e96a87f
 
 
 
7ec6cbc
e96a87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b2961b
 
a0a0875
34dfc89
 
3a4239c
10e4994
 
90c70f7
e96a87f
 
 
 
 
 
 
21a09fc
e96a87f
 
 
 
 
 
 
 
 
 
7ec6cbc
e96a87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b0b40d
e96a87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a09fc
e96a87f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
from transformers import pipeline
import torch
import logging
import spaces
from typing import Literal, Tuple

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Automatically detect the available device (CUDA, MPS, or CPU)
if torch.cuda.is_available():
    device = "cuda"
    logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
    device = "mps"
    logger.info("Using MPS for inference.")
else:
    device = "cpu"
    logger.info("Using CPU for inference.")

# Load the translation pipeline with the specified model and detected device
#model_checkpoint = "oza75/bm-nllb-1.3B"
#revision = "f55e4e7193eb95483b084248e5e4a91e28eb7dc2"
# model_checkpoint = "oza75/bm-nllb-600-asr"
# model_checkpoint = "oza75/bm-nllb-600-02"
model_checkpoint = "oza75/bm-nllb-1.3B-03"
revision = None
# model_checkpoint = "oza75/nllb-600M-mt-french-bambara"
# revision=None
translator = pipeline("translation", model=model_checkpoint, revision=revision, device=device, max_length=512)
logger.info("Translation pipeline initialized successfully.")

# Define the languages supported
SOURCE_LANG_OPTIONS = {
    "French": "fra_Latn",
    "English": "eng_Latn",
    "Bambara": "bam_Latn",
    # "Bambara With Error": "bam_Error"
}

TARGET_LANG_OPTIONS = {
    "French": "fra_Latn",
    "English": "eng_Latn",
    "Bambara": "bam_Latn"
}


# Define the translation function with typing
@spaces.GPU()
def translate_text(text: str, source_lang: str, target_lang: str) -> str:
    """
    Translate the input text from the source language to the target language using the NLLB model.

    Args:
        text (str): The text to be translated.
        source_lang (str): The source language code (e.g., "fra_Latn", "bam_Error").
        target_lang (str): The target language code (e.g., "eng_Latn", "bam_Latn").

    Returns:
        str: The translated text.
    """
    source_lang, target_lang = SOURCE_LANG_OPTIONS[source_lang], TARGET_LANG_OPTIONS[target_lang]
    logger.info(f"Translating text from {source_lang} to {target_lang}.")
    try:
        # Perform translation using the Hugging Face pipeline
        result = translator(text, src_lang=source_lang, tgt_lang=target_lang, num_beams=2)
        translated_text = result[0]['translation_text']
        logger.info("Translation successful.")
        return translated_text
    except Exception as e:
        logger.error(f"Translation failed: {e}")
        return "An error occurred during translation."


# Define the Gradio interface
def build_interface():
    """
    Builds the Gradio interface for translating text between supported languages.

    Returns:
        gr.Interface: The Gradio interface object.
    """
    # Define Gradio input and output components
    text_input = gr.Textbox(lines=5, label="Text to Translate", placeholder="Enter text here...")
    source_lang_input = gr.Dropdown(choices=list(SOURCE_LANG_OPTIONS.keys()), value="French", label="Source Language")
    target_lang_input = gr.Dropdown(choices=list(TARGET_LANG_OPTIONS.keys()), value="Bambara", label="Target Language")
    output_text = gr.Textbox(label="Translated Text")

    # Define the Gradio interface with the translation function
    return gr.Interface(
        fn=translate_text,
        inputs=[text_input, source_lang_input, target_lang_input],
        outputs=output_text,
        title="Bambara NLLB Translation",
        description=(
            "This application uses the NLLB model to translate text between French, English, and Bambara. "
            "The source and target languages should be chosen from the dropdown options. If you encounter "
            "any issues, please check your inputs."
        ),
        examples=[
            ["Le Burkina Faso, « patrie des (personnes) intègres » ou « patrie de l'intégrité », anciennement république de Haute-Volta, est un pays d'Afrique de l'Ouest. Sans accès à la mer, il est entouré de six pays : le Niger à l'est-nord-est, le Bénin à l'est-sud-est, le Togo au sud-est, le Ghana au sud, la Côte d'Ivoire au sud-ouest et le Mali au nord-ouest.", "French", "Bambara"],
            ["Thomas Sankara, né le 21 décembre 1949 à Yako (Haute-Volta) et mort assassiné le 15 octobre 1987 à Ouagadougou (Burkina Faso), est un homme d'État voltaïque, chef de l’État de la république de 'Haute-Volta', rebaptisée Burkina Faso, de 1983 à 1987.", "French", "Bambara"],
            ["Good morning", "English", "Bambara"],
            ["- Ɔridinatɛri ye minɛn ye min bɛ se ka porogaramu - A bɛ se ka kunnafoniw mara - A bɛ se ka kunnafoniw sɔrɔ - A bɛ se ka kunnafoniw baara", "Bambara", "French"],
        ]
    )


# Run the Gradio application
if __name__ == "__main__":
    logger.info("Starting the Gradio interface for the Bambara NLLB model.")
    interface = build_interface()
    interface.launch()
    logger.info("Gradio interface running.")