eliot-hub commited on
Commit
f55a67c
·
1 Parent(s): 9599b7c

agent + gr blocks

Browse files
Files changed (4) hide show
  1. .gitignore +3 -1
  2. README.md +1 -1
  3. app.py +77 -162
  4. tools.py +126 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  .env
2
  hf_to_chroma_ds
3
- __pycache__
 
 
 
1
  .env
2
  hf_to_chroma_ds
3
+ __pycache__
4
+ app_archive.py
5
+ test_app.ipynb
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📚
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  startup_duration_timeout: 1h
 
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
  startup_duration_timeout: 1h
app.py CHANGED
@@ -1,166 +1,81 @@
1
- import os
2
- from dotenv import load_dotenv
3
-
4
  import gradio as gr
5
-
6
- from langchain_chroma import Chroma
7
- from langchain.prompts import ChatPromptTemplate
8
- from langchain.chains import create_retrieval_chain, create_history_aware_retriever
9
- from langchain.chains.combine_documents import create_stuff_documents_chain
10
- from langchain_core.prompts import MessagesPlaceholder
11
- from langchain_community.chat_message_histories import ChatMessageHistory
12
- from langchain_core.runnables.history import RunnableWithMessageHistory
13
- from langchain_core.documents import Document
14
- from langchain_core.retrievers import BaseRetriever
15
- from langchain_core.callbacks import CallbackManagerForRetrieverRun
16
- from langchain_core.vectorstores import VectorStoreRetriever
17
- from langchain_openai import ChatOpenAI
18
- from langchain.callbacks.tracers import ConsoleCallbackHandler
19
- from langchain_huggingface import HuggingFaceEmbeddings
20
-
21
- from datasets import load_dataset
22
- import chromadb
23
- from typing import List
24
- from mixedbread_ai.client import MixedbreadAI
25
- from tqdm import tqdm
26
-
27
- # Global params
28
- CHROMA_PATH = "chromadb_mem10_mxbai_800_complete"
29
- MODEL_EMB = "mxbai-embed-large"
30
- MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
31
- LLM_NAME = "gpt-4o-mini"
32
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
33
- MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
34
- HF_TOKEN = os.environ.get("HF_TOKEN")
35
- HF_API_KEY = os.environ.get("HF_API_KEY")
36
-
37
- # MixedbreadAI Client
38
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
39
- mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
40
- model_emb = "mixedbread-ai/mxbai-embed-large-v1"
41
-
42
- # Set up ChromaDB
43
- memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
44
- batched_ds = memoires_ds.batch(batch_size=41000)
45
- client = chromadb.Client()
46
- collection = client.get_or_create_collection(name="embeddings_mxbai")
47
-
48
- for batch in tqdm(batched_ds, desc="Processing dataset batches"):
49
- collection.add(
50
- ids=batch["id"],
51
- metadatas=batch["metadata"],
52
- documents=batch["document"],
53
- embeddings=batch["embedding"],
54
  )
55
- print(f"Collection complete: {collection.count()}")
56
-
57
- db = Chroma(
58
- client=client,
59
- collection_name=f"embeddings_mxbai",
60
- embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
61
- )
62
-
63
-
64
- # Reranker class
65
- class Reranker(BaseRetriever):
66
- retriever: VectorStoreRetriever
67
- # model: CrossEncoder
68
- k: int
69
-
70
- def _get_relevant_documents(
71
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
72
- ) -> List[Document]:
73
- docs = self.retriever.invoke(query)
74
- results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
75
- return [Document(page_content=res.input) for res in results.data]
76
-
77
- # Set up reranker + LLM
78
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
79
- reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
80
- llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, )
81
-
82
- # Set up the contextualize question prompt
83
- contextualize_q_system_prompt = (
84
- "Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur "
85
- "qui peut faire référence à un contexte dans l'historique du chat, "
86
- "formuler une question autonome qui peut être comprise "
87
- "sans l'historique du chat. Ne répondez PAS à la question, "
88
- "juste la reformuler si nécessaire et sinon la renvoyer telle quelle."
89
- )
90
-
91
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
92
- [
93
- ("system", contextualize_q_system_prompt),
94
- MessagesPlaceholder("chat_history"),
95
- ("human", "{input}"),
96
- ]
97
- )
98
-
99
- # Create the history-aware retriever
100
- history_aware_retriever = create_history_aware_retriever(
101
- llm, reranker, contextualize_q_prompt
102
- )
103
-
104
- # Set up the QA prompt
105
- system_prompt = (
106
- "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
107
- "Si tu ne connais pas la réponse, dis que tu ne sais pas."
108
- )
109
- qa_prompt = ChatPromptTemplate.from_messages(
110
- [
111
- ("system", system_prompt),
112
- MessagesPlaceholder("chat_history"),
113
- ("human", "{input}"),
114
- ]
115
- )
116
-
117
- # Create the question-answer chain
118
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
119
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
120
-
121
- # Set up the conversation history
122
- store = {}
123
-
124
- def get_session_history(session_id: str) -> ChatMessageHistory:
125
- if session_id not in store:
126
- store[session_id] = ChatMessageHistory()
127
- return store[session_id]
128
-
129
- conversational_rag_chain = RunnableWithMessageHistory(
130
- rag_chain,
131
- get_session_history,
132
- input_messages_key="input",
133
- history_messages_key="chat_history",
134
- output_messages_key="answer",
135
- )
136
-
137
- # Gradio interface
138
- def chatbot(message, history):
139
- session_id = "gradio_session"
140
- response = conversational_rag_chain.invoke(
141
- {"input": message},
142
- config={
143
- "configurable": {"session_id": session_id},
144
- "callbacks": [ConsoleCallbackHandler()]
145
- },
146
- )["answer"]
147
- return response
148
-
149
- iface = gr.ChatInterface(
150
- chatbot,
151
- title="Dataltist Chatbot",
152
- description="Posez vos questions sur l'assurance",
153
- textbox=gr.Textbox(placeholder="Qu'est-ce que l'assurance multirisque habitation ?", container=False, scale=9),
154
- theme="soft",
155
- # examples=[
156
- # "Qu'est-ce que l'assurance multirisque habitation ?",
157
- # "Qu'est-ce que la garantie DTA ?",
158
- # ],
159
- retry_btn=None,
160
- undo_btn=None,
161
- submit_btn=gr.Button(value="Envoyer", icon="./send_icon.png", variant="primary"),
162
- clear_btn="Effacer la conversation",
163
- )
164
 
165
  if __name__ == "__main__":
166
- iface.launch() # share=True
 
1
+ import time
 
 
2
  import gradio as gr
3
+ from tools import create_agent
4
+ from langchain_core.messages import RemoveMessage
5
+ from langchain_core.messages import trim_messages
6
+ # from toolkits import create_agent
7
+ # from langchain.schema import AIMessage, HumanMessage, SystemMessage
8
+
9
+ AGENT = create_agent()
10
+ theme = gr.themes.Default(primary_hue="red", secondary_hue="red")
11
+
12
+ def filter_msg(msg_list:list, keep_n:int) -> list:
13
+ """Keep only last keep_n messages from chat history. Preserves structure user msg -> tool msg -> ai msg"""
14
+ msg = trim_messages(
15
+ msg_list,
16
+ strategy="last",
17
+ token_counter=len,
18
+ max_tokens=keep_n,
19
+ start_on="human",
20
+ end_on=("tool", "ai"),
21
+ include_system=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
+ return [m.id for m in msg]
24
+
25
+ def agent_response(query, config, keep_n=10):
26
+ messages = AGENT.get_state(config).values.get("messages", [])
27
+
28
+ if len(messages) > keep_n:
29
+ keep_msg_ids = filter_msg(messages, keep_n)
30
+ AGENT.update_state(config, {"messages": [RemoveMessage(id=m.id) for m in messages if m.id not in keep_msg_ids]})
31
+ print("msg removed")
32
+
33
+ # Generate answer
34
+ answer = AGENT.invoke({"messages":query}, config=config)
35
+ return answer["messages"][-1].content
36
+
37
+
38
+ js_func = """
39
+ function refresh() {
40
+ const url = new URL(window.location);
41
+
42
+ if (url.searchParams.get('__theme') != 'light') {
43
+ url.searchParams.set('__theme', 'light');
44
+ window.location.href = url.href;
45
+ }
46
+ }
47
+ """
48
+
49
+
50
+ def delete_agent():
51
+ print("del agent")
52
+ global AGENT
53
+ AGENT = create_agent()
54
+ # print(AGENT.get_state(config).values.get("messages"), "\n\n")
55
+
56
+ with gr.Blocks(theme=theme, js=js_func, title="Dataltist", fill_height=True) as iface:
57
+ gr.Markdown("# Dataltist Chatbot 🚀")
58
+ chatbot = gr.Chatbot(show_copy_button=True, show_share_button=False, type="messages", scale=1)
59
+ msg = gr.Textbox(lines=1, show_label=False, placeholder="Posez vos questions sur l'assurance") # submit_btn=True
60
+ clear = gr.ClearButton([msg, chatbot], value="Effacer 🗑")
61
+ config = {"configurable": {"thread_id": "1"}}
62
+
63
+
64
+ def user(user_message, history: list):
65
+ return "", history + [{"role": "user", "content": user_message}]
66
+
67
+ def bot(history: list):
68
+ bot_message = agent_response(history[-1]["content"], config) #AGENT.invoke({"messages":history[-1]["content"]}, config=config)
69
+ history.append({"role": "assistant", "content": ""})
70
+ for character in bot_message:
71
+ history[-1]['content'] += character
72
+ # time.sleep(0.005)
73
+ yield history
74
+
75
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
76
+ bot, chatbot, chatbot
77
+ )
78
+ iface.unload(delete_agent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if __name__ == "__main__":
81
+ iface.launch() # share=True # auth=("admin", "admin")
tools.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.tools import TavilySearchResults
2
+
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
5
+ from langchain_core.vectorstores import VectorStoreRetriever
6
+ from langgraph.prebuilt import create_react_agent
7
+ from langchain_core.documents import Document
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.checkpoint.memory import MemorySaver
10
+ from mixedbread_ai.client import MixedbreadAI
11
+ from langchain.chains import create_retrieval_chain
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
+ from langchain.prompts import ChatPromptTemplate
14
+ from dotenv import load_dotenv
15
+ import os
16
+ from langchain_chroma import Chroma
17
+ import chromadb
18
+ from typing import List
19
+ from datasets import load_dataset
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+
22
+
23
+ load_dotenv()
24
+ # Global params
25
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
26
+ MODEL_EMB = "mxbai-embed-large"
27
+ MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1"
28
+ LLM_NAME = "gpt-4o-mini"
29
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
30
+ MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY")
31
+ HF_TOKEN = os.environ.get("HF_TOKEN")
32
+ HF_API_KEY = os.environ.get("HF_API_KEY")
33
+
34
+ # MixedbreadAI Client
35
+ mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY)
36
+ model_emb = "mixedbread-ai/mxbai-embed-large-v1"
37
+
38
+ # # Set up ChromaDB
39
+ memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True)
40
+ batched_ds = memoires_ds.batch(batch_size=41000)
41
+ client = chromadb.Client()
42
+ collection = client.get_or_create_collection(name="embeddings_mxbai")
43
+
44
+
45
+
46
+ llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, temperature=0)
47
+
48
+
49
+
50
+ def init_rag_tool():
51
+ """Init tools to allow an LLM to query the documents"""
52
+ # client = chromadb.PersistentClient(path=CHROMA_PATH)
53
+ db = Chroma(
54
+ client=client,
55
+ collection_name=f"embeddings_mxbai",
56
+ embedding_function = HuggingFaceEmbeddings(model_name=model_emb)
57
+ )
58
+ # Reranker class
59
+ class Reranker(BaseRetriever):
60
+ retriever: VectorStoreRetriever
61
+ # model: CrossEncoder
62
+ k: int
63
+
64
+ def _get_relevant_documents(
65
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
66
+ ) -> List[Document]:
67
+ docs = self.retriever.invoke(query)
68
+ results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k)
69
+ return [Document(page_content=res.input) for res in results.data]
70
+
71
+ # Set up reranker + LLM
72
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25})
73
+ reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4)
74
+ llm = ChatOpenAI(model=LLM_NAME, verbose=True)
75
+
76
+
77
+ system_prompt = (
78
+ "Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}"
79
+ "Si tu ne connais pas la réponse, dis que tu ne sais pas."
80
+ )
81
+
82
+ prompt = ChatPromptTemplate.from_messages(
83
+ [
84
+ ("system", system_prompt),
85
+ ("human", "{input}"),
86
+ ]
87
+ )
88
+
89
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
90
+ rag_chain = create_retrieval_chain(reranker, question_answer_chain)
91
+
92
+ rag_tool = rag_chain.as_tool(
93
+ name="RAG_search",
94
+ description="Recherche d'information dans les mémoires d'actuariat",
95
+ arg_types={"input": str},
96
+ )
97
+ return rag_tool
98
+
99
+
100
+ def init_websearch_tool():
101
+ web_search_tool = TavilySearchResults(
102
+ name="Web_search",
103
+ max_results=5,
104
+ description="Recherche d'informations sur le web",
105
+ search_depth="advanced",
106
+ include_answer=True,
107
+ include_raw_content=True,
108
+ include_images=False,
109
+ verbose=False,
110
+ )
111
+ return web_search_tool
112
+
113
+
114
+ def create_agent():
115
+ rag_tool = init_rag_tool()
116
+ web_search_tool = init_websearch_tool()
117
+ memory = MemorySaver()
118
+ llm_4o = ChatOpenAI(model="gpt-4o-mini", api_key=OPENAI_API_KEY, verbose=True, temperature=0, streaming=True)
119
+ tools = [rag_tool, web_search_tool]
120
+ system_message = """
121
+ Tu es un assistant dont la fonction est de répondre à des questions à propos de l'assurance et de l'actuariat.
122
+ Utilise les outils RAG_search ou Web_search pour répondre aux questions de l'utilisateur.
123
+ """ # Dans la réponse finale, sépare les informations de l'outil RAG et de l'outil Web.
124
+
125
+ react_agent = create_react_agent(llm_4o, tools, state_modifier=system_message, checkpointer=memory, debug=False)
126
+ return react_agent