Spaces:
Sleeping
Sleeping
File size: 8,551 Bytes
156c1af ecc74f6 5e8b427 03947a1 5e8b427 38a1518 8653a09 ecc74f6 8653a09 38a1518 5e8b427 38a1518 5e8b427 38a1518 5e8b427 dc90d7f e8d52de dc90d7f e8d52de dc90d7f 5e8b427 156c1af fdb7505 156c1af fdb7505 5efeaa8 fdb7505 5efeaa8 fdb7505 8553e6c 5efeaa8 fdb7505 38a1518 9884ff8 baeaaf2 156c1af baeaaf2 9884ff8 dc90d7f 940d242 eccd0f1 940d242 156c1af 9884ff8 fdb7505 9884ff8 78b32f3 9884ff8 78b32f3 9884ff8 e8d52de 9884ff8 c7793f6 9884ff8 99fd60b 1d7b03b 1f71c11 1d7b03b e8d52de 38a1518 e8d52de 1d7b03b e8d52de 38a1518 e8d52de 38a1518 e8d52de 1d7b03b e8d52de 38a1518 e8d52de 1d7b03b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
import streamlit as st
from langchain_chroma import Chroma
from utils.load_embeddings import get_local_embeddings
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer
import torch
@st.cache_resource
def load_local_model(model_id):
if model_id == "plguillou/t5-base-fr-sum-cnndm":
tokenizer = T5Tokenizer.from_pretrained(model_id)
else:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return tokenizer, model
# Pré-chargement des deux modèles
flan_tokenizer, flan_model = load_local_model("google/flan-t5-small")
plg_tokenizer, plg_model = load_local_model("plguillou/t5-base-fr-sum-cnndm")
def generate_response(prompt, tokenizer, model):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Si le tag [RESPONSE] n'est pas généré, on affiche tout
if "[RESPONSE]" in text:
text = text.split("[RESPONSE]", 1)[-1].strip()
# Si la réponse est vide, on affiche la sortie brute
if not text.strip():
return tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
st.set_page_config(page_title="Assistant Juridique IA", layout="wide")
st.title("📚 Assistant Juridique avec IA")
st.write("Posez une question juridique.")
# Réorganisation de la sidebar : paramètres avancés en haut
st.sidebar.header("🔧 Paramètres avancés")
max_docs = st.sidebar.slider(
"Nombre maximal de documents à utiliser",
min_value=1,
max_value=20,
value=5,
step=1
)
similarity_threshold = st.sidebar.slider(
"Seuil de pertinence (%)",
min_value=0,
max_value=200,
value=90,
step=5
)
# Choix multi-bases avec checkbox
st.sidebar.markdown("**Bases de documents à interroger :**")
base_options = [
("Archive mails", "archive_mail", "archives_mails"),
("Textes de loi", "textes_loi", "textes_loi"),
("Jurisprudence", "jurisprudence", "jurisprudence")
]
selected_bases = [
key for label, key, _ in base_options if st.sidebar.checkbox(label, value=True)
]
# Vérification qu'au moins une base est sélectionnée
if not selected_bases:
st.sidebar.warning("⚠️ Veuillez sélectionner au moins une base de documents pour continuer.")
st.stop()
# Affichage des modèles utilisés (en bas de la sidebar)
st.sidebar.markdown("---")
st.sidebar.markdown("🧠 **Modèle d'embedding :** `paraphrase-multilingual-mpnet-base-v2`")
st.sidebar.markdown("🗂️ **Base vectorielle :** `Chroma`")
st.sidebar.markdown("💬 **Modèle LLM :** `google/flan-t5-small` (text-generation, multilingue, open source)")
# Saisie de l'utilisateur et personnalisation du prompt en même temps
col1, col2 = st.columns([2, 3])
with col1:
user_input = st.text_area("✉️ Votre question :", height=200, key="user_question")
with col2:
user_prompt_intro = st.text_area(
"Début du prompt (modifiable)",
value="Vous êtes un assistant juridique spécialisé en droit français.\nVotre tâche est de proposer une réponse synthétique et argumentée à la question suivante, en vous appuyant uniquement sur les extraits de documents fournis, classés par pertinence. Indiquez clairement si la réponse est incertaine ou partielle. Répondez en français.",
height=120,
key="prompt_intro"
)
# Bouton d'envoi de la question
if st.button("📤 Envoyer") and user_input.strip():
user_input = st.session_state["user_question"]
user_prompt_intro = st.session_state["prompt_intro"]
def distance_to_percent(score, max_dist=10.0):
score = max(0, min(score, max_dist))
return round((1 - score / max_dist) * 100)
with st.spinner("Recherche des documents pertinents..."):
embeddings = get_local_embeddings()
db_path = os.path.abspath("./db")
db = Chroma(persist_directory=db_path, embedding_function=embeddings)
retriever = db.as_retriever(search_kwargs={"k": max_docs})
docs_and_scores = [
(doc, score)
for doc, score in retriever.vectorstore.similarity_search_with_score(user_input, k=30)
if doc.metadata.get("source") in selected_bases
][:max_docs]
docs_scores_pertinences = [
(doc, score, distance_to_percent(score, max_dist=10.0))
for doc, score in docs_and_scores
]
max_dist = 10.0
distance_seuil = max_dist * (1 - similarity_threshold / 100)
filtered_docs = [
(doc, score, pertinence)
for doc, score, pertinence in docs_scores_pertinences
if pertinence >= similarity_threshold
]
# Affichage des documents pertinents (dropdown fermé par défaut)
st.subheader("📎 Documents pertinents trouvés")
if not filtered_docs:
# Calcul de la meilleure pertinence trouvée
best_pertinence = max((p for _, _, p in docs_scores_pertinences), default=None)
st.warning("❗ Aucun document suffisamment pertinent trouvé pour cette question.")
st.info("L'assistant ne peut pas formuler de réponse fiable sans documents de référence.")
if best_pertinence is not None:
st.info(f"💡 Astuce : La meilleure pertinence trouvée est {best_pertinence}%. Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
else:
st.info("💡 Astuce : Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
st.stop()
else:
for idx, (doc, score, pertinence) in enumerate(filtered_docs, 1):
titre = os.path.basename(doc.metadata.get("ref", doc.metadata.get("source", "inconnu.txt")))
with st.expander(f"📄 Document {idx} — {titre} (🔍 Pertinence : {pertinence}%)", expanded=False):
st.markdown(
f"""
<div style='white-space: pre-wrap; word-wrap: break-word; overflow-x: hidden; background-color: #f9f9f9; padding: 1em; border-radius: 8px; border: 1px solid #ddd;'>
{doc.page_content}
</div>
""",
unsafe_allow_html=True
)
# Préparation du contexte documentaire (doit être défini avant les prompts)
context_text = "\n\n".join([
f"<doc pertinence={score:.2f}>\n{doc.page_content.strip()}\n</doc>"
for doc, score, pertinence in filtered_docs
])
# Construction du prompt à partir de la personnalisation utilisateur
prompt_flan = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
prompt_plg = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
# Génération des deux réponses en colonnes, d'abord le modèle le plus rapide (flan-t5-small)
col1, col2 = st.columns(2)
output_flan = None
output_plg = None
with col1:
with st.spinner("Génération de la réponse (flan-t5-small)..."):
try:
output_flan = generate_response(prompt_flan, flan_tokenizer, flan_model)
except Exception as e:
st.error(f"Erreur génération flan-t5-small : {e}")
st.subheader("Réponse (flan-t5-small)")
if output_flan:
st.write(output_flan)
else:
st.info("Aucune réponse générée par flan-t5-small.")
with col2:
with st.spinner("Génération de la réponse (t5-base-fr-sum-cnndm)..."):
try:
output_plg = generate_response(prompt_plg, plg_tokenizer, plg_model)
except Exception as e:
st.error(f"Erreur génération t5-base-fr-sum-cnndm : {e}")
st.subheader("Réponse (t5-base-fr-sum-cnndm)")
if output_plg:
st.write(output_plg)
else:
st.info("Aucune réponse générée par t5-base-fr-sum-cnndm.") |