File size: 5,440 Bytes
6321ecb
94788eb
991572d
2920e24
d78dcaa
f7b57e5
 
6321ecb
991572d
 
94788eb
991572d
94788eb
2920e24
 
 
 
 
 
991572d
f7b57e5
 
 
 
 
 
 
 
 
 
 
 
991572d
 
 
 
 
 
 
 
 
 
 
 
 
 
d78dcaa
2920e24
991572d
 
 
2920e24
 
5790c07
2920e24
991572d
 
94788eb
991572d
 
 
f7b57e5
 
 
991572d
 
 
 
2920e24
991572d
 
 
f7b57e5
991572d
2920e24
5790c07
f7b57e5
991572d
d78dcaa
 
f7b57e5
 
 
 
 
 
 
 
 
 
 
 
 
991572d
 
 
 
 
 
 
 
5790c07
991572d
 
 
 
 
 
 
 
94788eb
 
991572d
 
f7b57e5
 
991572d
d78dcaa
991572d
 
 
f7b57e5
 
 
 
 
 
 
 
 
f184c54
f7b57e5
 
 
 
d78dcaa
991572d
d78dcaa
f7b57e5
 
991572d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
from huggingface_hub import login, InferenceClient
import os
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import umap
import pandas as pd

HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

login(token=HF_TOKEN)
client = InferenceClient(token=HF_TOKEN)


embeddings = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1")

db_code = FAISS.load_local("faiss_code_education",
        embeddings,
        allow_dangerous_deserialization=True)

reducer = umap.UMAP()
index = db_code.index
ntotal = min(index.ntotal, 4998)
embeds = index.reconstruct_n(0, ntotal)
umap_embeds = reducer.fit_transform(embeds)

articles_df = pd.DataFrame({
    "x" : umap_embeds[:,0],
    "y" : umap_embeds[:,1],
    "type" : [ "Source" ] * len(umap_embeds),
})

system_prompt = """Tu es un assistant juridique spécialisé dans le Code de l'éducation français. 
Ta mission est d'aider les utilisateurs à comprendre la législation en répondant à leurs questions.

Voici comment tu dois procéder :

1. **Analyse de la question:** Lis attentivement la question de l'utilisateur.
2. **Identification des articles pertinents:** Examine les 10 articles de loi fournis et sélectionne ceux qui sont les plus pertinents pour répondre à la question.
3. **Formulation de la réponse:** Rédige une réponse claire et concise en français, en utilisant les informations des articles sélectionnés. Cite explicitement les articles que tu utilises (par exemple, "Selon l'article L311-3...").
4. **Structure de la réponse:** Si ta réponse s'appuie sur plusieurs articles, structure-la de manière logique, en séparant les informations provenant de chaque article.
5. **Cas ambigus:** 
* Si la question est trop vague, demande des précisions à l'utilisateur.
* Si plusieurs articles pourraient s'appliquer, présente les différentes interprétations possibles."""


def query_rag(query, model, system_prompt):
    docs = db_code.similarity_search(query, 10)

    article_dict = {}
    context_list = []
    for doc in docs:
        article = doc.metadata
        context_list.append(' > '.join(article['chemin'])+'\n'+article['texte']+'\n---\n')
        article_dict[article['article']] = article

    user = 'Question de l\'utilisateur : ' + query + '\nContexte législatif :\n' + '\n'.join(context_list)

    messages = [ { "role" : "system", "content" : system_prompt } ]
    messages.append( { "role" : "user", "content" : user } )

    if "factice" in model:
        return user, article_dict

    chat_completion = client.chat_completion(
        messages=messages,
        model=model,
        max_tokens=1024)

    return chat_completion.choices[0].message.content, article_dict

def create_context_response(response, article_dict):
    context = '\n'
    for i, article in enumerate(article_dict):
        art = article_dict[article]
        context += '* **' + ' > '.join(art['chemin']) + '** : '+ art['texte'].replace('\n', '\n    ')+'\n'
    return context

def chat_interface(query, model, system_prompt):
    response, article_dict = query_rag(query, model, system_prompt)
    context = create_context_response(response, article_dict)
    return response, context

def update_plot(query):
    query_embed = embeddings.embed_documents([query])[0]
    query_umap_embed = reducer.transform([query_embed])
    
    data = {
        "x": umap_embeds[:, 0].tolist() + [query_umap_embed[0, 0]],
        "y": umap_embeds[:, 1].tolist() + [query_umap_embed[0, 1]],
        "type": ["Source"] * len(umap_embeds) + ["Requête"]
    }
    return pd.DataFrame(data)

with gr.Blocks(title="Assistant Juridique pour le Code de l'éducation (Beta)") as demo:
    gr.Markdown(
        """
        ## Posez vos questions sur le Code de l'éducation
        
        **Créé par Marc de Falco**

        **Avertissement :** Les informations fournies sont données à titre indicatif et ne constituent pas un avis juridique. Les serveurs étant publics, veuillez ne pas inclure de données sensibles.
        """
    )

    query_box = gr.Textbox(label="Votre question")

    model = gr.Dropdown(
        label="Modèle de langage",
        choices=[
            "meta-llama/Meta-Llama-3-70B-Instruct",
            "meta-llama/Meta-Llama-3-8B-Instruct",
            "HuggingFaceH4/zephyr-7b-beta",
            "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            "mistralai/Mixtral-8x22B-v0.1",
            "factice: question+contexte"
            ],
        value="meta-llama/Meta-Llama-3-70B-Instruct")

    submit_button = gr.Button("Envoyer")

    with gr.Tab(label="Réponse"):
        response_box = gr.Markdown()
    with gr.Tab(label="Sources"):
        sources_box = gr.Markdown()
    with gr.Tab(label="Visualisation"):
        scatter_plot = gr.ScatterPlot(articles_df,
                x = "x", y = "y",
                color="type",
                label="Visualisation des embeddings",
                width=500,
                height=500)
    with gr.Tab(label="Paramètres"):
        system_box = gr.Textbox(label="Invite systeme", value=system_prompt,
                                lines=20)

    submit_button.click(chat_interface, 
                inputs=[query_box, model, system_box], 
                outputs=[response_box, sources_box])
    submit_button.click(update_plot, inputs=[query_box], outputs=[scatter_plot])

demo.launch()