transw / app.py
jojonocode's picture
Update app.py
4ffaa7c verified
import torch
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# --------------------------------------------------
# Chargement du modèle NLLB
# --------------------------------------------------
MODEL_NAME = "facebook/nllb-200-distilled-1.3B"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Chargement du modèle {MODEL_NAME} sur {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
# --------------------------------------------------
# Dictionnaire de langues
# --------------------------------------------------
LANGUAGES = {
"Français": "fra_Latn",
"Ewe": "ewe_Latn",
"Fon": "fon_Latn",
"Anglais": "eng_Latn",
"Espagnol": "spa_Latn",
"Allemand": "deu_Latn",
"Swahili": "swh_Latn",
"Lingala": "lin_Latn",
"Portugais": "por_Latn"
}
# --------------------------------------------------
# Fonction de traduction
# --------------------------------------------------
def translate(text, src_lang, tgt_lang="Ewe"):
if not text.strip():
return "⚠️ Veuillez entrer un texte à traduire."
try:
# Configuration des langues (Source dynamique, Cible forcée sur Ewe)
src_code = LANGUAGES.get(src_lang, "fra_Latn")
tgt_code = LANGUAGES.get("Ewe", "ewe_Latn")
# Tokenization avec spécification précise de la langue source
# Note: Passer src_lang au tokenizer est crucial pour NLLB-200
inputs = tokenizer(text, return_tensors="pt", padding=True, src_lang=src_code).to(device)
# Génération avec paramètres optimisés pour éviter les répétitions
with torch.no_grad():
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
max_length=512,
num_beams=5,
no_repeat_ngram_size=3,
repetition_penalty=1.5, # Augmenté pour éviter "etudiant etudiant"
early_stopping=True,
length_penalty=1.0
)
# Décodage
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
# Sécurité pour les sorties vides ou identiques à l'entrée
if not result.strip() or result.strip().lower() == text.strip().lower():
# Si le modèle échoue, on tente une génération plus simple sans pénalités agressives
with torch.no_grad():
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
max_length=512,
num_beams=2
)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
if not result.strip():
return "⚠️ Traduction impossible pour ce texte."
return result
except Exception as e:
return f"❌ Erreur : {str(e)}"
# --------------------------------------------------
# Interface Gradio
# --------------------------------------------------
with gr.Blocks(title="🌍 Traduction EWE") as demo:
gr.Markdown(
"""
<div style="text-align: center;">
<h1>� Yawotrad NLLB Translator</h1>
<p style="font-size: 18px;">
Traduction haute performance vers l'<b>Ewe</b>, le <b>Fon</b> et plus encore.
</p>
</div>
"""
)
with gr.Row():
src_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Français", label="Langue source 🌍")
# Masqué car la cible est toujours l'Ewe
tgt_lang = gr.Textbox(value="Ewe", label="Langue cible (Fixée)", interactive=False)
with gr.Row():
text_input = gr.Textbox(placeholder="Entre ton texte ici...", lines=6, label="Texte à traduire")
text_output = gr.Textbox(placeholder="Résultat de la traduction...", lines=6, label="Traduction")
translate_btn = gr.Button("🔁 Traduire")
translate_btn.click(translate, [text_input, src_lang, tgt_lang], text_output)
gr.Markdown(
"""
---
<div style="text-align: center; font-size: 14px;">
🧠 Propulsé par <a href="https://strivenew.com" target="_blank">Strive y-200 (1.1B)</a><br>
<b>Strive</b>
</div>
"""
)
demo.launch()