Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
import streamlit as st | |
from langchain_chroma import Chroma | |
from utils.load_embeddings import get_local_embeddings | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer | |
import torch | |
def load_local_model(model_id): | |
if model_id == "plguillou/t5-base-fr-sum-cnndm": | |
tokenizer = T5Tokenizer.from_pretrained(model_id) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" | |
) | |
return tokenizer, model | |
# Pré-chargement des deux modèles | |
flan_tokenizer, flan_model = load_local_model("google/flan-t5-small") | |
plg_tokenizer, plg_model = load_local_model("plguillou/t5-base-fr-sum-cnndm") | |
def generate_response(prompt, tokenizer, model): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Si le tag [RESPONSE] n'est pas généré, on affiche tout | |
if "[RESPONSE]" in text: | |
text = text.split("[RESPONSE]", 1)[-1].strip() | |
# Si la réponse est vide, on affiche la sortie brute | |
if not text.strip(): | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return text | |
st.set_page_config(page_title="Assistant Juridique IA", layout="wide") | |
st.title("📚 Assistant Juridique avec IA") | |
st.write("Posez une question juridique.") | |
# Réorganisation de la sidebar : paramètres avancés en haut | |
st.sidebar.header("🔧 Paramètres avancés") | |
max_docs = st.sidebar.slider( | |
"Nombre maximal de documents à utiliser", | |
min_value=1, | |
max_value=20, | |
value=5, | |
step=1 | |
) | |
similarity_threshold = st.sidebar.slider( | |
"Seuil de pertinence (%)", | |
min_value=0, | |
max_value=200, | |
value=90, | |
step=5 | |
) | |
# Choix multi-bases avec checkbox | |
st.sidebar.markdown("**Bases de documents à interroger :**") | |
base_options = [ | |
("Archive mails", "archive_mail", "archives_mails"), | |
("Textes de loi", "textes_loi", "textes_loi"), | |
("Jurisprudence", "jurisprudence", "jurisprudence") | |
] | |
selected_bases = [ | |
key for label, key, _ in base_options if st.sidebar.checkbox(label, value=True) | |
] | |
# Vérification qu'au moins une base est sélectionnée | |
if not selected_bases: | |
st.sidebar.warning("⚠️ Veuillez sélectionner au moins une base de documents pour continuer.") | |
st.stop() | |
# Affichage des modèles utilisés (en bas de la sidebar) | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("🧠 **Modèle d'embedding :** `paraphrase-multilingual-mpnet-base-v2`") | |
st.sidebar.markdown("🗂️ **Base vectorielle :** `Chroma`") | |
st.sidebar.markdown("💬 **Modèle LLM :** `google/flan-t5-small` (text-generation, multilingue, open source)") | |
# Saisie de l'utilisateur et personnalisation du prompt en même temps | |
col1, col2 = st.columns([2, 3]) | |
with col1: | |
user_input = st.text_area("✉️ Votre question :", height=200, key="user_question") | |
with col2: | |
user_prompt_intro = st.text_area( | |
"Début du prompt (modifiable)", | |
value="Vous êtes un assistant juridique spécialisé en droit français.\nVotre tâche est de proposer une réponse synthétique et argumentée à la question suivante, en vous appuyant uniquement sur les extraits de documents fournis, classés par pertinence. Indiquez clairement si la réponse est incertaine ou partielle. Répondez en français.", | |
height=120, | |
key="prompt_intro" | |
) | |
# Bouton d'envoi de la question | |
if st.button("📤 Envoyer") and user_input.strip(): | |
user_input = st.session_state["user_question"] | |
user_prompt_intro = st.session_state["prompt_intro"] | |
def distance_to_percent(score, max_dist=10.0): | |
score = max(0, min(score, max_dist)) | |
return round((1 - score / max_dist) * 100) | |
with st.spinner("Recherche des documents pertinents..."): | |
embeddings = get_local_embeddings() | |
db_path = os.path.abspath("./db") | |
db = Chroma(persist_directory=db_path, embedding_function=embeddings) | |
retriever = db.as_retriever(search_kwargs={"k": max_docs}) | |
docs_and_scores = [ | |
(doc, score) | |
for doc, score in retriever.vectorstore.similarity_search_with_score(user_input, k=30) | |
if doc.metadata.get("source") in selected_bases | |
][:max_docs] | |
docs_scores_pertinences = [ | |
(doc, score, distance_to_percent(score, max_dist=10.0)) | |
for doc, score in docs_and_scores | |
] | |
max_dist = 10.0 | |
distance_seuil = max_dist * (1 - similarity_threshold / 100) | |
filtered_docs = [ | |
(doc, score, pertinence) | |
for doc, score, pertinence in docs_scores_pertinences | |
if pertinence >= similarity_threshold | |
] | |
# Affichage des documents pertinents (dropdown fermé par défaut) | |
st.subheader("📎 Documents pertinents trouvés") | |
if not filtered_docs: | |
# Calcul de la meilleure pertinence trouvée | |
best_pertinence = max((p for _, _, p in docs_scores_pertinences), default=None) | |
st.warning("❗ Aucun document suffisamment pertinent trouvé pour cette question.") | |
st.info("L'assistant ne peut pas formuler de réponse fiable sans documents de référence.") | |
if best_pertinence is not None: | |
st.info(f"💡 Astuce : La meilleure pertinence trouvée est {best_pertinence}%. Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.") | |
else: | |
st.info("💡 Astuce : Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.") | |
st.stop() | |
else: | |
for idx, (doc, score, pertinence) in enumerate(filtered_docs, 1): | |
titre = os.path.basename(doc.metadata.get("ref", doc.metadata.get("source", "inconnu.txt"))) | |
with st.expander(f"📄 Document {idx} — {titre} (🔍 Pertinence : {pertinence}%)", expanded=False): | |
st.markdown( | |
f""" | |
<div style='white-space: pre-wrap; word-wrap: break-word; overflow-x: hidden; background-color: #f9f9f9; padding: 1em; border-radius: 8px; border: 1px solid #ddd;'> | |
{doc.page_content} | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Préparation du contexte documentaire (doit être défini avant les prompts) | |
context_text = "\n\n".join([ | |
f"<doc pertinence={score:.2f}>\n{doc.page_content.strip()}\n</doc>" | |
for doc, score, pertinence in filtered_docs | |
]) | |
# Construction du prompt à partir de la personnalisation utilisateur | |
prompt_flan = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n""" | |
prompt_plg = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n""" | |
# Génération des deux réponses en colonnes, d'abord le modèle le plus rapide (flan-t5-small) | |
col1, col2 = st.columns(2) | |
output_flan = None | |
output_plg = None | |
with col1: | |
with st.spinner("Génération de la réponse (flan-t5-small)..."): | |
try: | |
output_flan = generate_response(prompt_flan, flan_tokenizer, flan_model) | |
except Exception as e: | |
st.error(f"Erreur génération flan-t5-small : {e}") | |
st.subheader("Réponse (flan-t5-small)") | |
if output_flan: | |
st.write(output_flan) | |
else: | |
st.info("Aucune réponse générée par flan-t5-small.") | |
with col2: | |
with st.spinner("Génération de la réponse (t5-base-fr-sum-cnndm)..."): | |
try: | |
output_plg = generate_response(prompt_plg, plg_tokenizer, plg_model) | |
except Exception as e: | |
st.error(f"Erreur génération t5-base-fr-sum-cnndm : {e}") | |
st.subheader("Réponse (t5-base-fr-sum-cnndm)") | |
if output_plg: | |
st.write(output_plg) | |
else: | |
st.info("Aucune réponse générée par t5-base-fr-sum-cnndm.") |