File size: 8,551 Bytes
156c1af
 
 
 
 
 
 
 
 
ecc74f6
5e8b427
03947a1
 
5e8b427
38a1518
8653a09
ecc74f6
8653a09
 
38a1518
5e8b427
 
 
 
 
 
38a1518
 
 
5e8b427
38a1518
 
5e8b427
 
 
 
 
 
 
 
 
dc90d7f
e8d52de
dc90d7f
 
e8d52de
 
 
dc90d7f
5e8b427
156c1af
 
 
 
fdb7505
156c1af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb7505
 
 
5efeaa8
 
 
fdb7505
 
5efeaa8
fdb7505
 
8553e6c
 
 
 
5efeaa8
fdb7505
 
 
 
38a1518
9884ff8
baeaaf2
 
 
 
 
 
 
 
 
 
 
156c1af
 
 
baeaaf2
 
9884ff8
 
 
 
dc90d7f
940d242
eccd0f1
940d242
156c1af
9884ff8
 
 
fdb7505
9884ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b32f3
 
9884ff8
 
78b32f3
 
 
 
9884ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d52de
9884ff8
c7793f6
 
9884ff8
99fd60b
1d7b03b
 
 
1f71c11
1d7b03b
e8d52de
 
38a1518
 
e8d52de
1d7b03b
e8d52de
 
38a1518
e8d52de
 
 
 
38a1518
 
e8d52de
1d7b03b
e8d52de
 
38a1518
e8d52de
 
 
1d7b03b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
from dotenv import load_dotenv
load_dotenv()
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

import streamlit as st
from langchain_chroma import Chroma
from utils.load_embeddings import get_local_embeddings

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer
import torch


@st.cache_resource
def load_local_model(model_id):
    if model_id == "plguillou/t5-base-fr-sum-cnndm":
        tokenizer = T5Tokenizer.from_pretrained(model_id)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    return tokenizer, model

# Pré-chargement des deux modèles
flan_tokenizer, flan_model = load_local_model("google/flan-t5-small")
plg_tokenizer, plg_model = load_local_model("plguillou/t5-base-fr-sum-cnndm")


def generate_response(prompt, tokenizer, model):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Si le tag [RESPONSE] n'est pas généré, on affiche tout
    if "[RESPONSE]" in text:
        text = text.split("[RESPONSE]", 1)[-1].strip()
    # Si la réponse est vide, on affiche la sortie brute
    if not text.strip():
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text

st.set_page_config(page_title="Assistant Juridique IA", layout="wide")
st.title("📚 Assistant Juridique avec IA")
st.write("Posez une question juridique.")

# Réorganisation de la sidebar : paramètres avancés en haut
st.sidebar.header("🔧 Paramètres avancés")
max_docs = st.sidebar.slider(
    "Nombre maximal de documents à utiliser",
    min_value=1,
    max_value=20,
    value=5,
    step=1
)
similarity_threshold = st.sidebar.slider(
    "Seuil de pertinence (%)",
    min_value=0,
    max_value=200,
    value=90,
    step=5
)

# Choix multi-bases avec checkbox
st.sidebar.markdown("**Bases de documents à interroger :**")
base_options = [
    ("Archive mails", "archive_mail", "archives_mails"),
    ("Textes de loi", "textes_loi", "textes_loi"),
    ("Jurisprudence", "jurisprudence", "jurisprudence")
]
selected_bases = [
    key for label, key, _ in base_options if st.sidebar.checkbox(label, value=True)
]

# Vérification qu'au moins une base est sélectionnée
if not selected_bases:
    st.sidebar.warning("⚠️ Veuillez sélectionner au moins une base de documents pour continuer.")
    st.stop()  

# Affichage des modèles utilisés (en bas de la sidebar)
st.sidebar.markdown("---")
st.sidebar.markdown("🧠 **Modèle d'embedding :** `paraphrase-multilingual-mpnet-base-v2`")
st.sidebar.markdown("🗂️ **Base vectorielle :** `Chroma`")
st.sidebar.markdown("💬 **Modèle LLM :** `google/flan-t5-small` (text-generation, multilingue, open source)")

# Saisie de l'utilisateur et personnalisation du prompt en même temps
col1, col2 = st.columns([2, 3])
with col1:
    user_input = st.text_area("✉️ Votre question :", height=200, key="user_question")
with col2:
    user_prompt_intro = st.text_area(
        "Début du prompt (modifiable)",
        value="Vous êtes un assistant juridique spécialisé en droit français.\nVotre tâche est de proposer une réponse synthétique et argumentée à la question suivante, en vous appuyant uniquement sur les extraits de documents fournis, classés par pertinence. Indiquez clairement si la réponse est incertaine ou partielle. Répondez en français.",
        height=120,
        key="prompt_intro"
    )

# Bouton d'envoi de la question
if st.button("📤 Envoyer") and user_input.strip():
    user_input = st.session_state["user_question"]
    user_prompt_intro = st.session_state["prompt_intro"]
    def distance_to_percent(score, max_dist=10.0):
        score = max(0, min(score, max_dist))
        return round((1 - score / max_dist) * 100)

    with st.spinner("Recherche des documents pertinents..."):
        embeddings = get_local_embeddings()
        db_path = os.path.abspath("./db")
        db = Chroma(persist_directory=db_path, embedding_function=embeddings)
        retriever = db.as_retriever(search_kwargs={"k": max_docs})
        docs_and_scores = [
            (doc, score)
            for doc, score in retriever.vectorstore.similarity_search_with_score(user_input, k=30)
            if doc.metadata.get("source") in selected_bases
        ][:max_docs]
        docs_scores_pertinences = [
            (doc, score, distance_to_percent(score, max_dist=10.0))
            for doc, score in docs_and_scores
        ]
        max_dist = 10.0
        distance_seuil = max_dist * (1 - similarity_threshold / 100)
        filtered_docs = [
            (doc, score, pertinence)
            for doc, score, pertinence in docs_scores_pertinences
            if pertinence >= similarity_threshold
        ]

    # Affichage des documents pertinents (dropdown fermé par défaut)
    st.subheader("📎 Documents pertinents trouvés")
    if not filtered_docs:
        # Calcul de la meilleure pertinence trouvée
        best_pertinence = max((p for _, _, p in docs_scores_pertinences), default=None)
        st.warning("❗ Aucun document suffisamment pertinent trouvé pour cette question.")
        st.info("L'assistant ne peut pas formuler de réponse fiable sans documents de référence.")
        if best_pertinence is not None:
            st.info(f"💡 Astuce : La meilleure pertinence trouvée est {best_pertinence}%. Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
        else:
            st.info("💡 Astuce : Essayez de baisser le seuil de pertinence dans les paramètres avancés pour augmenter vos chances de trouver des documents pertinents.")
        st.stop()
    else:
        for idx, (doc, score, pertinence) in enumerate(filtered_docs, 1):
            titre = os.path.basename(doc.metadata.get("ref", doc.metadata.get("source", "inconnu.txt")))
            with st.expander(f"📄 Document {idx}{titre} (🔍 Pertinence : {pertinence}%)", expanded=False):
                st.markdown(
                    f"""
                    <div style='white-space: pre-wrap; word-wrap: break-word; overflow-x: hidden; background-color: #f9f9f9; padding: 1em; border-radius: 8px; border: 1px solid #ddd;'>
                        {doc.page_content}
                    </div>
                    """,
                    unsafe_allow_html=True
                )

    # Préparation du contexte documentaire (doit être défini avant les prompts)
    context_text = "\n\n".join([
        f"<doc pertinence={score:.2f}>\n{doc.page_content.strip()}\n</doc>"
        for doc, score, pertinence in filtered_docs
    ])
    
    # Construction du prompt à partir de la personnalisation utilisateur
    prompt_flan = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
    prompt_plg = f"""{user_prompt_intro}\n\nQuestion : {user_input}\n\nContexte documentaire :\n{context_text}\n"""
    # Génération des deux réponses en colonnes, d'abord le modèle le plus rapide (flan-t5-small)
    col1, col2 = st.columns(2)
    output_flan = None
    output_plg = None
    with col1:
        with st.spinner("Génération de la réponse (flan-t5-small)..."):
            try:
                output_flan = generate_response(prompt_flan, flan_tokenizer, flan_model)
            except Exception as e:
                st.error(f"Erreur génération flan-t5-small : {e}")
        st.subheader("Réponse (flan-t5-small)")
        if output_flan:
            st.write(output_flan)
        else:
            st.info("Aucune réponse générée par flan-t5-small.")
    with col2:
        with st.spinner("Génération de la réponse (t5-base-fr-sum-cnndm)..."):
            try:
                output_plg = generate_response(prompt_plg, plg_tokenizer, plg_model)
            except Exception as e:
                st.error(f"Erreur génération t5-base-fr-sum-cnndm : {e}")
        st.subheader("Réponse (t5-base-fr-sum-cnndm)")
        if output_plg:
            st.write(output_plg)
        else:
            st.info("Aucune réponse générée par t5-base-fr-sum-cnndm.")