DinoFrog's picture
Update app.py
a56dc0f verified
raw
history blame contribute delete
10.6 kB
import streamlit as st
from transformers import pipeline
from langdetect import detect
from huggingface_hub import InferenceClient
import pandas as pd
import os
import asyncio
import nltk
from nltk.tokenize import sent_tokenize
import torch
# Téléchargement de punkt_tab avec gestion d'erreur
try:
nltk.download('punkt_tab', download_dir='/usr/local/share/nltk_data')
except Exception as e:
st.error(f"Erreur lors du téléchargement de punkt_tab : {str(e)}. Veuillez vérifier votre connexion réseau et les permissions du répertoire /usr/local/share/nltk_data.")
st.stop()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
st.error("Erreur : Clé API Hugging Face (HF_TOKEN) manquante. Veuillez configurer HF_TOKEN dans les variables d'environnement.")
st.stop()
# Fonctions pour charger les modèles avec st.cache_resource
@st.cache_resource
def load_classifier():
return pipeline("sentiment-analysis", model="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", device="cpu", map_location="cpu")
@st.cache_resource
def load_translator_to_en():
return pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en", device="cpu", map_location="cpu")
@st.cache_resource
def load_translator_to_fr():
return pipeline("translation", model="Helsinki-NLP/opus-mt-en-fr", device="cpu", map_location="cpu")
# Charger les modèles une seule fois
classifier = load_classifier()
translator_to_en = load_translator_to_en()
translator_to_fr = load_translator_to_fr()
# Fonction pour appeler l'API Zephyr avec des paramètres ajustés
async def call_zephyr_api(prompt, mode, hf_token=HF_TOKEN):
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token)
try:
if mode == "Rapide":
max_new_tokens = 50
temperature = 0.3
elif mode == "Équilibré":
max_new_tokens = 100
temperature = 0.5
else: # Précis
max_new_tokens = 150
temperature = 0.7
response = await asyncio.to_thread(client.text_generation, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
return response
except Exception as e:
st.error(f"❌ Erreur d'appel API Hugging Face : {str(e)}")
return None
# Traduction en français avec Helsinki-NLP
def safe_translate_to_fr(text, max_length=512):
try:
sentences = sent_tokenize(text)
translated_sentences = []
for sentence in sentences:
translated = translator_to_fr(sentence, max_length=max_length)[0]['translation_text']
translated_sentences.append(translated)
return " ".join(translated_sentences)
except Exception as e:
return f"Erreur de traduction : {str(e)}"
# Fonction pour suggérer le meilleur modèle
def suggest_model(text):
word_count = len(text.split())
if word_count < 50:
return "Rapide"
elif word_count <= 200:
return "Équilibré"
else:
return "Précis"
# Fonction pour créer une jauge de sentiment
def create_sentiment_gauge(sentiment, score):
score_percentage = score * 100
color = "#A9A9A9"
if sentiment.lower() == "positive":
color = "#2E8B57"
elif sentiment.lower() == "negative":
color = "#DC143C"
html = f"""
<div style='width: 100%; max-width: 300px; margin: 10px 0;'>
<div style='background-color: #D3D3D3; border-radius: 5px; height: 20px; position: relative;'>
<div style='background-color: {color}; width: {score_percentage}%; height: 100%; border-radius: 5px;'></div>
<span style='position: absolute; top: 0; left: 50%; transform: translateX(-50%); font-weight: bold;'>{score_percentage:.1f}%</span>
</div>
<div style='text-align: center; margin-top: 5px;'>Sentiment : {sentiment}</div>
</div>
"""
return html
# Fonction d'analyse
async def full_analysis(text, mode, detail_mode, history):
if not text:
st.warning("Entrez une phrase.")
return None, text, None, None, history, None, "Aucune analyse effectuée."
# Initialisation de la barre de progression
progress_bar = st.progress(0)
status_text = st.empty()
status_text.write("Analyse en cours... (Étape 1 : Détection de la langue)")
# Étape 1 : Détection de la langue
try:
lang = detect(text)
except:
lang = "unknown"
progress_bar.progress(25)
if lang != "en":
text_en = translator_to_en(text, max_length=512)[0]['translation_text']
else:
text_en = text
# Étape 2 : Analyse du sentiment
status_text.write("Analyse en cours... (Étape 2 : Analyse du sentiment)")
result = classifier(text_en)
result = result[0]
sentiment_output = f"Sentiment prédictif : {result['label']} (Score: {result['score']:.2f})"
sentiment_gauge = create_sentiment_gauge(result['label'], result['score'])
progress_bar.progress(50)
# Étape 3 : Explication IA
status_text.write("Analyse en cours... (Étape 3 : Explication IA)")
explanation_prompt = f"""<|system|>
You are a professional financial analyst AI with expertise in economic forecasting.
</s>
<|user|>
Given the following question about a potential economic event: "{text}"
The predicted sentiment for this event is: {result['label'].lower()}.
Assume the event happens. Explain why this event would likely have a {result['label'].lower()} economic impact.
</s>
<|assistant|>"""
explanation_en = await call_zephyr_api(explanation_prompt, mode)
if explanation_en is None:
return None, text, None, None, history, None, "Erreur lors de la génération de l'explication."
progress_bar.progress(75)
# Étape 4 : Traduction en français
status_text.write("Analyse en cours... (Étape 4 : Traduction en français)")
explanation_fr = safe_translate_to_fr(explanation_en)
progress_bar.progress(100)
# Mise à jour de l'historique
history.append({
"Texte": text,
"Sentiment": result['label'],
"Score": f"{result['score']:.2f}",
"Explication_EN": explanation_en,
"Explication_FR": explanation_fr
})
status_text.write("✅ Analyse terminée.")
return sentiment_output, text, explanation_en, explanation_fr, history, sentiment_gauge, "Analyse terminée."
# Historique CSV
def download_history(history):
if not history:
return None
df = pd.DataFrame(history)
file_path = "/tmp/analysis_history.csv"
df.to_csv(file_path, index=False)
return file_path
# Interface Streamlit
def main():
# CSS personnalisé pour Streamlit
st.markdown("""
<style>
body {
background: linear-gradient(135deg, #0A1D37 0%, #1A3C34 100%);
font-family: 'Inter', sans-serif;
color: #E0E0E0;
}
.stApp {
background: linear-gradient(135deg, #0A1D37 0%, #1A3C34 100%);
color: #E0E0E0;
}
.stTextArea textarea, .stSelectbox select {
background: #2A4A43 !important;
border: 1px solid #FFD700 !important;
border-radius: 12px !important;
padding: 10px !important;
color: #E0E0E0 !important;
box-shadow: 0px 4px 12px rgba(255, 215, 0, 0.4);
}
.stButton button {
background: linear-gradient(90deg, #FFD700, #D4AF37);
color: #0A1D37;
font-weight: bold;
border: none;
border-radius: 8px;
padding: 12px 24px;
transition: transform 0.2s;
}
.stButton button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 12px rgba(255, 215, 0, 0.5);
}
h1, h2, h3 {
color: #FFD700;
}
</style>
""", unsafe_allow_html=True)
st.title("📈 Analyse Financière Premium avec IA")
st.markdown("**Posez une question économique.** L'IA analyse et explique l'impact.")
# Gestion de l'état de l'historique avec session_state
if 'history' not in st.session_state:
st.session_state.history = []
if 'count' not in st.session_state:
st.session_state.count = 0
# Layout avec colonnes
col1, col2 = st.columns([2, 1])
with col1:
input_text = st.text_area("Votre question économique", height=150)
with col2:
mode_selector = st.selectbox("Mode de réponse", ["Rapide", "Équilibré", "Précis"], index=1)
detail_mode_selector = st.selectbox("Niveau de détail", ["Normal", "Expert"], index=0)
# Suggestion automatique du mode
if input_text:
suggested_mode = suggest_model(input_text)
st.session_state.mode_selector = suggested_mode
mode_selector = suggested_mode
# Boutons
col_analyze, col_download = st.columns([1, 1])
with col_analyze:
analyze_btn = st.button("Analyser")
with col_download:
download_btn = st.button("Télécharger l'historique")
# Résultats
if analyze_btn and input_text:
# Exécuter l'analyse
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(full_analysis(input_text, mode_selector, detail_mode_selector, st.session_state.history))
loop.close()
if result:
sentiment_output, displayed_prompt, explanation_en, explanation_fr, st.session_state.history, sentiment_gauge, progress_message = result
st.session_state.count += 1
# Affichage des résultats
st.subheader("Résultats")
col_results1, col_results2 = st.columns([1, 2])
with col_results1:
st.text_area("Sentiment prédictif", sentiment_output, height=100, disabled=True)
st.markdown(sentiment_gauge, unsafe_allow_html=True)
st.text_area("Votre question", displayed_prompt, height=100, disabled=True)
st.text_area("Progression", progress_message, height=70, disabled=True)
with col_results2:
st.text_area("Explication en anglais", explanation_en, height=200, disabled=True)
st.text_area("Explication en français", explanation_fr, height=200, disabled=True)
# Téléchargement de l'historique
if download_btn:
history_file = download_history(st.session_state.history)
if history_file:
with open(history_file, "rb") as file:
st.download_button("Télécharger CSV", file, file_name="analysis_history.csv")
else:
st.warning("Aucun historique à télécharger.")
if __name__ == "__main__":
main()