sabazo commited on
Commit
c49aa7c
1 Parent(s): e55546b

added gradio UI + google, arxiv and wikipedia tools with HF hub

Browse files
Files changed (1) hide show
  1. Copy_of_mixtral_react_agent.ipynb +441 -0
Copy_of_mixtral_react_agent.ipynb ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "authorship_tag": "ABX9TyM4ysnzp2PKemzKgh131B0g",
8
+ "include_colab_link": true
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ }
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {
22
+ "id": "view-in-github",
23
+ "colab_type": "text"
24
+ },
25
+ "source": [
26
+ "<a href=\"https://colab.research.google.com/github/almutareb/InnovationPathfinderAI/blob/main/Copy_of_mixtral_react_agent.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "ydmVy2pS_hwU"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "!pip install -qU langchain_community\n",
38
+ "!pip install -qU langchain\n",
39
+ "!pip install -qU google-search-results\n",
40
+ "!pip install -qU langchainhub\n",
41
+ "!pip install -qU text_generation\n",
42
+ "!pip install -qU arxiv\n",
43
+ "!pip install -qU wikipedia\n",
44
+ "!pip install -qU gradio==3.48.0\n",
45
+ "!pip install -qU youtube_search\n",
46
+ "!pip install -qU sentence_transformers\n",
47
+ "!pip install -qU hromadb"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "source": [
53
+ "import os\n",
54
+ "from google.colab import userdata\n",
55
+ "os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = userdata.get('HUGGINGFACEHUB_API_TOKEN')\n",
56
+ "#os.environ[\"SERPAPI_API_KEY\"] = userdata.get('SERPAPI_API_KEY')\n",
57
+ "os.environ[\"GOOGLE_CSE_ID\"] = userdata.get('GOOGLE_CSE_ID')\n",
58
+ "os.environ[\"GOOGLE_API_KEY\"] = userdata.get('GOOGLE_API_KEY')\n",
59
+ "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
60
+ "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n",
61
+ "os.environ[\"LANGCHAIN_API_KEY\"] = userdata.get('LANGCHAIN_API_KEY')\n",
62
+ "os.environ[\"LANGCHAIN_PROJECT\"] = \"arxiv_ollama_agent\""
63
+ ],
64
+ "metadata": {
65
+ "id": "JYt3cFVnQiPe"
66
+ },
67
+ "execution_count": null,
68
+ "outputs": []
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "source": [
73
+ "from langchain.tools import WikipediaQueryRun\n",
74
+ "from langchain_community.utilities import WikipediaAPIWrapper\n",
75
+ "\n",
76
+ "from langchain.tools import Tool\n",
77
+ "from langchain_community.utilities import GoogleSearchAPIWrapper"
78
+ ],
79
+ "metadata": {
80
+ "id": "bpb1dYzBZsRR"
81
+ },
82
+ "execution_count": null,
83
+ "outputs": []
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "source": [
88
+ "api_wrapper = WikipediaAPIWrapper()\n",
89
+ "wikipedia = WikipediaQueryRun(api_wrapper=api_wrapper)"
90
+ ],
91
+ "metadata": {
92
+ "id": "_NAkY8FkMHcx"
93
+ },
94
+ "execution_count": null,
95
+ "outputs": []
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "source": [
100
+ "wikipedia.run(\"large language model\")"
101
+ ],
102
+ "metadata": {
103
+ "id": "ADu6renzI3bi"
104
+ },
105
+ "execution_count": null,
106
+ "outputs": []
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "source": [
111
+ "websearch = GoogleSearchAPIWrapper()\n",
112
+ "\n",
113
+ "def top5_results(query):\n",
114
+ " return websearch.results(query, 5)\n",
115
+ "\n",
116
+ "google_search = Tool(\n",
117
+ " name=\"google_search\",\n",
118
+ " description=\"Search Google for recent results.\",\n",
119
+ " #func=top5_results,\n",
120
+ " func=websearch.run,\n",
121
+ ")"
122
+ ],
123
+ "metadata": {
124
+ "id": "QtWQgcDpblGx"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "google_search.run(\"large language model\")"
133
+ ],
134
+ "metadata": {
135
+ "id": "IVAbbQ04ZE9M"
136
+ },
137
+ "execution_count": null,
138
+ "outputs": []
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "source": [
143
+ "wikipedia.args"
144
+ ],
145
+ "metadata": {
146
+ "id": "Cv2z8MFNJ3sD"
147
+ },
148
+ "execution_count": null,
149
+ "outputs": []
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "source": [
154
+ "# HF libraries\n",
155
+ "from langchain.llms import HuggingFaceHub\n",
156
+ "\n",
157
+ "# Load the model from the Hugging Face Hub\n",
158
+ "model_id = HuggingFaceHub(repo_id=\"mistralai/Mixtral-8x7B-Instruct-v0.1\", model_kwargs={\n",
159
+ " \"temperature\":0.1,\n",
160
+ " \"max_new_tokens\":1024,\n",
161
+ " \"repetition_penalty\":1.2,\n",
162
+ " \"return_full_text\":False\n",
163
+ " })"
164
+ ],
165
+ "metadata": {
166
+ "id": "JHO0Hr5phBLH"
167
+ },
168
+ "execution_count": null,
169
+ "outputs": []
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "source": [
174
+ "from langchain import hub\n",
175
+ "from langchain.agents import AgentExecutor, create_react_agent, load_tools\n",
176
+ "from langchain.tools.render import render_text_description\n",
177
+ "from langchain.tools.retriever import create_retriever_tool\n",
178
+ "from langchain.retrievers import ArxivRetriever\n",
179
+ "from langchain.agents.format_scratchpad import format_log_to_str\n",
180
+ "from langchain.agents.output_parsers import (\n",
181
+ " ReActJsonSingleInputOutputParser,\n",
182
+ ")\n",
183
+ "from langchain.tools import YouTubeSearchTool\n",
184
+ "\n",
185
+ "from langchain_community.chat_message_histories import ChatMessageHistory\n",
186
+ "from langchain_core.runnables.history import RunnableWithMessageHistory\n",
187
+ "\n",
188
+ "message_history = ChatMessageHistory()\n",
189
+ "\n",
190
+ "arxiv_retriever = ArxivRetriever()\n",
191
+ "\n",
192
+ "arxiv_search = create_retriever_tool(\n",
193
+ " arxiv_retriever,\n",
194
+ " \"arxiv_database\",\n",
195
+ " \"Search arxiv database for scientific research papers and studies\",\n",
196
+ ")\n",
197
+ "\n",
198
+ "youtube_search = YouTubeSearchTool()\n",
199
+ "\n",
200
+ "tools = [arxiv_search, wikipedia, google_search]\n",
201
+ "\n",
202
+ "#prompt = hub.pull(\"hwchase17/react\")\n",
203
+ "prompt = hub.pull(\"hwchase17/react-json\")\n",
204
+ "prompt = prompt.partial(\n",
205
+ " tools=render_text_description(tools),\n",
206
+ " tool_names=\", \".join([t.name for t in tools]),\n",
207
+ ")\n",
208
+ "chat_model_with_stop = model_id.bind(stop=[\"\\nObservation\"])\n",
209
+ "agent = (\n",
210
+ " {\n",
211
+ " \"input\": lambda x: x[\"input\"],\n",
212
+ " \"agent_scratchpad\": lambda x: format_log_to_str(x[\"intermediate_steps\"]),\n",
213
+ " }\n",
214
+ " | prompt\n",
215
+ " | chat_model_with_stop\n",
216
+ "# | model_id\n",
217
+ " | ReActJsonSingleInputOutputParser()\n",
218
+ ")\n",
219
+ "\n",
220
+ "#agent = create_react_agent(model_id, tools, prompt)\n",
221
+ "agent_executor = AgentExecutor(\n",
222
+ " agent=agent,\n",
223
+ " tools=tools,\n",
224
+ " verbose=True,\n",
225
+ " max_iterations=10, # cap number of iterations\n",
226
+ " #max_execution_time=60, # timout at 60 sec\n",
227
+ " return_intermediate_steps=True,\n",
228
+ " handle_parsing_errors=True,\n",
229
+ " )\n",
230
+ "\n",
231
+ "def stream_output(query):\n",
232
+ " for chunk in agent_executor.stream({\"input\": query}):\n",
233
+ " # Agent Action\n",
234
+ " if \"actions\" in chunk:\n",
235
+ " for action in chunk[\"actions\"]:\n",
236
+ " print(\n",
237
+ " f\"Calling Tool ```{action.tool}``` with input ```{action.tool_input}```\"\n",
238
+ " )\n",
239
+ " # Observation\n",
240
+ " elif \"steps\" in chunk:\n",
241
+ " for step in chunk[\"steps\"]:\n",
242
+ " print(f\"Got result: ```{step.observation}```\")\n",
243
+ "\n",
244
+ "# Chat memory not working yet\n",
245
+ "agent_with_chat_history = RunnableWithMessageHistory(\n",
246
+ " agent_executor,\n",
247
+ " lambda session_id: message_history,\n",
248
+ " input_message_key=\"input\",\n",
249
+ " history_messages_key=\"chat_history\",\n",
250
+ ")"
251
+ ],
252
+ "metadata": {
253
+ "id": "D4Gj_dZtgzci"
254
+ },
255
+ "execution_count": null,
256
+ "outputs": []
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "source": [
261
+ "stream_output(\"what is corrective retrieval augmeneted generation\")"
262
+ ],
263
+ "metadata": {
264
+ "id": "ItAD-n6BnTc6"
265
+ },
266
+ "execution_count": null,
267
+ "outputs": []
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "source": [
272
+ "## Youtube search tool, not used yet\n",
273
+ "import ast\n",
274
+ "def you_four(query):\n",
275
+ " fquery = query+',4'\n",
276
+ " videos_str = youtube_search.run(fquery)\n",
277
+ "# video_list.replace('watch?v=','embed/')\n",
278
+ "# video_list = [word.replace('watch?v=','embed/') for word in video_list]\n",
279
+ " video_list = convert_urls(videos_str)\n",
280
+ "\n",
281
+ " return video_list\n",
282
+ "\n",
283
+ "def convert_urls(urls):\n",
284
+ " # Convert the string representation of the list into an actual list\n",
285
+ " urls = ast.literal_eval(urls)\n",
286
+ " #urls = [ for url in urls]\n",
287
+ " iframes = []\n",
288
+ " for url in urls:\n",
289
+ " embed_url = url.replace('watch?v=','embed/')\n",
290
+ " iframe = f'<iframe width=\"560\" height=\"315\" src=\"{embed_url}\" frameborder=\"0\" allowfullscreen></iframe>'\n",
291
+ " iframes.append(iframe)\n",
292
+ " return iframes"
293
+ ],
294
+ "metadata": {
295
+ "id": "3LzQEqeTzH0L"
296
+ },
297
+ "execution_count": null,
298
+ "outputs": []
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "source": [
303
+ "list_d=you_four(\"air taxi\")\n",
304
+ "list_d"
305
+ ],
306
+ "metadata": {
307
+ "id": "T7I6eIh318rU"
308
+ },
309
+ "execution_count": null,
310
+ "outputs": []
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "source": [
315
+ "agent_with_chat_history.invoke(\n",
316
+ " {\"input\": \"hi! I'm bob\"},\n",
317
+ " # This is needed because in most real world scenarios, a session id is needed\n",
318
+ " # It isn't really used here because we are using a simple in memory ChatMessageHistory\n",
319
+ " config={\"configurable\": {\"session_id\": \"<foo>\"}},\n",
320
+ ")"
321
+ ],
322
+ "metadata": {
323
+ "id": "7MxiaD6qffZG"
324
+ },
325
+ "execution_count": null,
326
+ "outputs": []
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "source": [
331
+ "agent_with_chat_history.invoke(\n",
332
+ " {\"input\": \"what's my name?\"},\n",
333
+ " # This is needed because in most real world scenarios, a session id is needed\n",
334
+ " # It isn't really used here because we are using a simple in memory ChatMessageHistory\n",
335
+ " config={\"configurable\": {\"session_id\": \"<foo>\"}},\n",
336
+ ")"
337
+ ],
338
+ "metadata": {
339
+ "id": "5cjo4j2nfkbQ"
340
+ },
341
+ "execution_count": null,
342
+ "outputs": []
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "source": [
347
+ "return_txt= agent_executor.invoke(\n",
348
+ " {\n",
349
+ " \"input\": \"how could a concept for an airtaxi fleet management look like?\",\n",
350
+ " }\n",
351
+ ")"
352
+ ],
353
+ "metadata": {
354
+ "id": "-q81PaZijPvO"
355
+ },
356
+ "execution_count": null,
357
+ "outputs": []
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "source": [
362
+ "agent_executor.invoke(\n",
363
+ " {\n",
364
+ " \"input\": \"What's the latest paper on corrective retrieval augmeneted generation?\"\n",
365
+ " }\n",
366
+ ")"
367
+ ],
368
+ "metadata": {
369
+ "id": "GCAOXXdPJ_wL"
370
+ },
371
+ "execution_count": null,
372
+ "outputs": []
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "source": [
377
+ "\n",
378
+ "import gradio as gr\n",
379
+ "def add_text(history, text):\n",
380
+ " history = history + [(text, None)]\n",
381
+ " return history, \"\"\n",
382
+ "\n",
383
+ "def bot(history):\n",
384
+ " response = infer(history[-1][0], history)\n",
385
+ " history[-1][1] = response['output']\n",
386
+ " return history\n",
387
+ "\n",
388
+ "def infer(question, history):\n",
389
+ " query = question\n",
390
+ " result = agent_executor.invoke(\n",
391
+ " {\n",
392
+ " \"input\": question,\n",
393
+ " }\n",
394
+ " )\n",
395
+ " return result\n",
396
+ "\n",
397
+ "def you_frame(question):\n",
398
+ " iframes=you_four(question)\n",
399
+ " return '\\n'.join(iframes)\n",
400
+ "\n",
401
+ "def vote(data: gr.LikeData):\n",
402
+ " if data.liked:\n",
403
+ " print(\"You upvoted this response: \" + data.value)\n",
404
+ " else:\n",
405
+ " print(\"You downvoted this response: \" + data.value)\n",
406
+ "\n",
407
+ "css=\"\"\"\n",
408
+ "#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}\n",
409
+ "\"\"\"\n",
410
+ "\n",
411
+ "title = \"\"\"\n",
412
+ "<div style=\"text-align: center;max-width: 700px;\">\n",
413
+ " <p>Hello Dave, how can I help today?<br />\n",
414
+ "</div>\n",
415
+ "\"\"\"\n",
416
+ "\n",
417
+ "with gr.Blocks(theme=gr.themes.Soft()) as demo:\n",
418
+ " with gr.Tab(\"Google|Wikipedia|Arxiv\"):\n",
419
+ " with gr.Column(elem_id=\"col-container\"):\n",
420
+ " gr.HTML(title)\n",
421
+ " with gr.Row():\n",
422
+ " question = gr.Textbox(label=\"Question\", placeholder=\"Type your question and hit Enter \")\n",
423
+ " chatbot = gr.Chatbot([], elem_id=\"chatbot\")\n",
424
+ " chatbot.like(vote, None, None)\n",
425
+ " clear = gr.Button(\"Clear\")\n",
426
+ " question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(\n",
427
+ " bot, chatbot, chatbot\n",
428
+ " )\n",
429
+ " clear.click(lambda: None, None, chatbot, queue=False)\n",
430
+ "\n",
431
+ "demo.queue()\n",
432
+ "demo.launch(debug=True)"
433
+ ],
434
+ "metadata": {
435
+ "id": "J7xy7c2LcEbe"
436
+ },
437
+ "execution_count": null,
438
+ "outputs": []
439
+ }
440
+ ]
441
+ }