AssistantJuridique2 / src /streamlit_app.py
Guillaumedbx's picture
refactor
a42113e
import os
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
import streamlit as st
from langchain_chroma import Chroma
from utils.load_embeddings import get_local_embeddings
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer
import torch
@st.cache_resource
def load_local_model(model_id):
if model_id == "plguillou/t5-base-fr-sum-cnndm":
tokenizer = T5Tokenizer.from_pretrained(model_id)
else:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return tokenizer, model
# Pré-chargement des deux modèles
flan_tokenizer, flan_model = load_local_model("google/flan-t5-small")
plg_tokenizer, plg_model = load_local_model("plguillou/t5-base-fr-sum-cnndm")
def generate_response(prompt, tokenizer, model):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Si le tag [RESPONSE] n'est pas généré, on affiche tout
if "[RESPONSE]" in text:
text = text.split("[RESPONSE]", 1)[-1].strip()
# Si la réponse est vide, on affiche la sortie brute
if not text.strip():
return tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
st.set_page_config(page_title="Assistant Juridique IA", layout="wide")
st.title("📚 Assistant Juridique avec IA")
st.write("Posez une question juridique.")
# Réorganisation de la sidebar : paramètres avancés en haut
st.sidebar.header("🔧 Paramètres avancés")
max_docs = st.sidebar.slider(
"Nombre maximal de documents à utiliser",
min_value=1,
max_value=20,
value=5,
step=1
)
similarity_threshold = st.sidebar.slider(
"Seuil de pertinence (%)",
min_value=0,
max_value=200,
value=90,
step=5
)
# Choix multi-bases avec checkbox
st.sidebar.markdown("**Bases de documents à interroger :**")
base_options = [
("Archive mails", "archive_mail", "archives_mails"),
("Textes de loi", "textes_loi", "textes_loi"),
("Jurisprudence", "jurisprudence", "jurisprudence")
]
selected_bases = [
key for label, key, _ in base_options if st.sidebar.checkbox(label, value=True)
]
# Vérification qu'au moins une base est sélectionnée
if not selected_bases:
st.sidebar.warning("⚠️ Veuillez sélectionner au moins une base de documents pour continuer.")
st.stop()
# Affichage des modèles utilisés (en bas de la sidebar)
st.sidebar.markdown("---")
st.sidebar.markdown("🧠 **Modèle d'embedding :** `paraphrase-multilingual-mpnet-base-v2`")
st.sidebar.markdown("🗂️ **Base vectorielle :** `Chroma`")
st.sidebar.markdown("💬 **Modèle LLM :** `google/flan-t5-small` (text-generation, multilingue, open source)")
# Saisie de l'utilisateur et personnalisation du prompt en même temps
col1, col2 = st.columns([2, 3])
with col1:
user_input = st.text_area("✉️ Votre question :", height=200, key="user_question")
with col2:
user_prompt_intro = st.text_area(
"Début du prompt (modifiable)",
value="Vous êtes un assistant juridique spécialisé en droit français.\nVotre tâche est de proposer une réponse synthétique et argumentée à la question suivante, en vous appuyant uniquement sur les extraits de documents fournis, classés par pertinence. Indiquez clairement si la réponse est incertaine ou partielle. Répondez en français.",
height=120,
key="prompt_intro"
)
# Bouton d'envoi de la question
if st.button("📤 Envoyer") and user_input.strip():
user_input = st.session_state["user_question"]
user_prompt_intro = st.session_state["prompt_intro"]
def distance_to_percent(score, max_dist=10.0):
score = max(0, min(score, max_dist))
return round((1 - score / max_dist) * 100)
with st.spinner("Recherche des documents pertinents..."):
embeddings = get_local_embeddings()
db_path = os.path.abspath("./db")
db = Chroma(persist_directory=db_path, embedding_function=embeddings)
retriever = db.as_retriever(search_kwargs={"k": max_docs})
docs_and_scores = [
(doc, score)
for doc, score in retriever.vectorstore.similarity_search_with_score(user_input, k=30)
if doc.metadata.get("source") in selected_bases
][:max_docs]
docs_scores_pertinences = [
(doc, score, distance_to_percent(score, max_dist=10.0))
for doc, score in docs_and_scores
]
max_dist = 10.0
distance_seuil = max_dist * (1 - similarity_threshold / 100)
filtered_docs = [
(doc, score, pertinence)
for doc, score, pertinence in docs_scores_pertinences
if pertinence >= similarity_threshold
]
# Affichage des documents pertinents (dropdown fermé par défaut)
st.subheader("📎 Documents pertinents trouvés")
if not filtered_docs:
# Calcul de la meilleure pertinence trouvée
best_pertinence = max((p for _, _, p in docs_scores_pertinences), default=None)
st.warning("❗ Aucun document suffisamment pertinent trouvé pour cette question.")
st.info("L'assistant ne peut pas formuler de réponse fiable sans documents de référence.")
if best_pertinence is not None:
st.info(f"💡 Astuce : La meilleure pertinence trouvée est {best_pertinence}%. Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
else:
st.info("💡 Astuce : Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
st.stop()
else:
for idx, (doc, score, pertinence) in enumerate(filtered_docs, 1):
titre = os.path.basename(doc.metadata.get("ref", doc.metadata.get("source", "inconnu.txt")))
with st.expander(f"📄 Document {idx}{titre} (🔍 Pertinence : {pertinence}%)", expanded=False):
st.markdown(
f"""
<div style='white-space: pre-wrap; word-wrap: break-word; overflow-x: hidden; background-color: #f9f9f9; padding: 1em; border-radius: 8px; border: 1px solid #ddd;'>
{doc.page_content}
</div>
""",
unsafe_allow_html=True
)
# Préparation du contexte documentaire (doit être défini avant les prompts)
context_text = "\n\n".join([
f"<doc pertinence={score:.2f}>\n{doc.page_content.strip()}\n</doc>"
for doc, score, pertinence in filtered_docs
])
# Construction du prompt à partir de la personnalisation utilisateur
prompt_flan = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
prompt_plg = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
# Génération des deux réponses en colonnes, d'abord le modèle le plus rapide (flan-t5-small)
col1, col2 = st.columns(2)
output_flan = None
output_plg = None
with col1:
with st.spinner("Génération de la réponse (flan-t5-small)..."):
try:
output_flan = generate_response(prompt_flan, flan_tokenizer, flan_model)
except Exception as e:
st.error(f"Erreur génération flan-t5-small : {e}")
st.subheader("Réponse (flan-t5-small)")
if output_flan:
st.write(output_flan)
else:
st.info("Aucune réponse générée par flan-t5-small.")
with col2:
with st.spinner("Génération de la réponse (t5-base-fr-sum-cnndm)..."):
try:
output_plg = generate_response(prompt_plg, plg_tokenizer, plg_model)
except Exception as e:
st.error(f"Erreur génération t5-base-fr-sum-cnndm : {e}")
st.subheader("Réponse (t5-base-fr-sum-cnndm)")
if output_plg:
st.write(output_plg)
else:
st.info("Aucune réponse générée par t5-base-fr-sum-cnndm.")