APERTUSMM / src /streamlit_app.py
MMOON's picture
Update src/streamlit_app.py
3116c2c verified
raw
history blame
5.69 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
# --- CONFIGURATION DE LA PAGE ---
st.set_page_config(
page_title="Chat Avancé avec Apertus",
page_icon="🚀",
layout="wide",
initial_sidebar_state="expanded",
)
# --- STYLES CSS PERSONNALISÉS (Optionnel) ---
st.markdown("""
<style>
.stSpinner > div > div {
border-top-color: #f63366;
}
.stChatMessage {
background-color: #f0f2f6;
border-radius: 10px;
padding: 15px;
margin-bottom: 10px;
}
</style>
""", unsafe_allow_html=True)
# --- BARRE LATÉRALE DE CONFIGURATION ---
with st.sidebar:
st.title("🚀 Paramètres")
st.markdown("Configurez l'assistant et le modèle de langage.")
# --- Authentification Hugging Face ---
st.subheader("Authentification Hugging Face")
hf_token = st.text_input("Votre Token Hugging Face (hf_...)", type="password")
if st.button("Se Connecter"):
if hf_token:
try:
login(token=hf_token)
st.success("Connecté à Hugging Face Hub !")
st.session_state.hf_logged_in = True
except Exception as e:
st.error(f"Échec de la connexion : {e}")
else:
st.warning("Veuillez entrer un token Hugging Face.")
# --- Sélection du Modèle ---
st.subheader("Sélection du Modèle")
model_options = {
"Apertus 8B (Rapide)": "swiss-ai/Apertus-8B-Instruct-2509",
"Apertus 70B (Puissant)": "swiss-ai/Apertus-70B-2509"
}
selected_model_name = st.selectbox("Choisissez un modèle :", options=list(model_options.keys()))
model_id = model_options[selected_model_name]
st.caption(f"ID du modèle : `{model_id}`")
# --- Paramètres de Génération ---
st.subheader("Paramètres de Génération")
temperature = st.slider("Température", min_value=0.1, max_value=1.5, value=0.7, step=0.05,
help="Plus la valeur est élevée, plus la réponse est créative et aléatoire.")
max_new_tokens = st.slider("Tokens Max", min_value=64, max_value=1024, value=256, step=64,
help="Longueur maximale de la réponse générée.")
top_p = st.slider("Top-p (Nucleus Sampling)", min_value=0.1, max_value=1.0, value=0.95, step=0.05,
help="Contrôle la diversité en sélectionnant les mots les plus probables dont la somme des probabilités dépasse ce seuil.")
# --- Bouton pour effacer l'historique ---
if st.button("🗑️ Effacer l'historique"):
st.session_state.messages = []
st.experimental_rerun()
# --- CHARGEMENT DU MODÈLE (MIS EN CACHE) ---
@st.cache_resource(show_spinner=False)
def load_model(model_identifier):
"""Charge le tokenizer et le modèle avec quantification 4-bit."""
with st.spinner(f"Chargement du modèle '{model_identifier}'... Cela peut prendre un moment. ⏳"):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_identifier)
model = AutoModelForCausalLM.from_pretrained(
model_identifier,
quantization_config=bnb_config,
device_map="auto",
)
return tokenizer, model
# Charge le modèle sélectionné
try:
tokenizer, model = load_model(model_id)
except Exception as e:
st.error(f"Impossible de charger le modèle. Assurez-vous d'être connecté si le modèle est privé. Erreur : {e}")
st.stop()
# --- INTERFACE DE CHAT PRINCIPALE ---
st.title("🤖 Chat avec Apertus")
st.caption(f"Vous discutez actuellement avec **{selected_model_name}**.")
# Initialisation de l'historique du chat
if "messages" not in st.session_state:
st.session_state.messages = []
# Affichage des messages de l'historique
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Zone de saisie utilisateur
if prompt := st.chat_input("Posez votre question à Apertus..."):
# Ajout et affichage du message utilisateur
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# --- GÉNÉRATION DE LA RÉPONSE ---
with st.chat_message("assistant"):
response_placeholder = st.empty()
with st.spinner("Réflexion en cours... 🤔"):
# Préparation des entrées pour le modèle
# Nous ne formaterons plus le prompt, le modèle instruct est déjà finetuné pour ça.
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Génération de la réponse
outputs = model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id
)
# Décodage et nettoyage de la réponse
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Nettoyage pour retirer la question initiale de la réponse
cleaned_response = response_text.replace(prompt, "").strip()
response_placeholder.markdown(cleaned_response)
# Ajout de la réponse de l'assistant à l'historique
st.session_state.messages.append({"role": "assistant", "content": cleaned_response})