Omar Solano commited on
Commit
5a84661
Β·
1 Parent(s): e9ec472

refactor code

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
- app_file: scripts/gradio-ui.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
+ app_file: scripts/main.py
9
  pinned: false
10
  ---
11
 
requirements.txt CHANGED
@@ -200,6 +200,7 @@ pydantic-core==2.20.1
200
  pydispatcher==2.0.7
201
  pydub==0.25.1
202
  pygments==2.18.0
 
203
  pyopenssl==24.2.1
204
  pyparsing==3.1.2
205
  pypdf==4.3.1
 
200
  pydispatcher==2.0.7
201
  pydub==0.25.1
202
  pygments==2.18.0
203
+ pymongo==4.8.0
204
  pyopenssl==24.2.1
205
  pyparsing==3.1.2
206
  pypdf==4.3.1
scripts/custom_retriever.py CHANGED
@@ -1,4 +1,3 @@
1
- import logging
2
  import time
3
  from typing import List
4
 
@@ -8,9 +7,6 @@ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
8
  from llama_index.core.schema import NodeWithScore, TextNode
9
  from llama_index.postprocessor.cohere_rerank import CohereRerank
10
 
11
- logger = logging.getLogger(__name__)
12
- logging.basicConfig(level=logging.INFO)
13
-
14
 
15
  class CustomRetriever(BaseRetriever):
16
  """Custom retriever that performs both semantic search and hybrid search."""
 
 
1
  import time
2
  from typing import List
3
 
 
7
  from llama_index.core.schema import NodeWithScore, TextNode
8
  from llama_index.postprocessor.cohere_rerank import CohereRerank
9
 
 
 
 
10
 
11
  class CustomRetriever(BaseRetriever):
12
  """Custom retriever that performs both semantic search and hybrid search."""
scripts/gradio-ui.py DELETED
@@ -1,409 +0,0 @@
1
- import logging
2
- import os
3
- import pickle
4
-
5
- import chromadb
6
- import gradio as gr
7
- import logfire
8
- from custom_retriever import CustomRetriever
9
- from dotenv import load_dotenv
10
- from llama_index.agent.openai import OpenAIAgent
11
- from llama_index.core import VectorStoreIndex
12
- from llama_index.core.llms import MessageRole
13
- from llama_index.core.memory import ChatMemoryBuffer, ChatSummaryMemoryBuffer
14
- from llama_index.core.node_parser import SentenceSplitter
15
- from llama_index.core.retrievers import VectorIndexRetriever
16
- from llama_index.core.tools import RetrieverTool, ToolMetadata
17
- from llama_index.embeddings.openai import OpenAIEmbedding
18
- from llama_index.llms.openai import OpenAI
19
- from llama_index.vector_stores.chroma import ChromaVectorStore
20
- from tutor_prompts import system_message_openai_agent
21
-
22
- # from utils import init_mongo_db
23
-
24
- load_dotenv()
25
-
26
- logger = logging.getLogger(__name__)
27
- logging.basicConfig(level=logging.INFO)
28
- logging.getLogger("httpx").setLevel(logging.WARNING)
29
- logfire.configure()
30
-
31
-
32
- CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
33
- MONGODB_URI = os.getenv("MONGODB_URI")
34
-
35
-
36
- if not os.path.exists("data/chroma-db-transformers"):
37
- # Download the vector database from the Hugging Face Hub if it doesn't exist locally
38
- # https://huggingface.co/datasets/towardsai-buster/ai-tutor-db/tree/main
39
- logfire.warn(
40
- f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
41
- )
42
- from huggingface_hub import snapshot_download
43
-
44
- snapshot_download(
45
- repo_id="towardsai-buster/ai-tutor-vector-db",
46
- local_dir="data",
47
- repo_type="dataset",
48
- )
49
- logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
50
-
51
- AVAILABLE_SOURCES_UI = [
52
- "HF Transformers",
53
- "PEFT",
54
- "TRL",
55
- "LlamaIndex Docs",
56
- "Towards AI Blog",
57
- # "Wikipedia",
58
- # "OpenAI Docs",
59
- # "LangChain Docs",
60
- "RAG Course",
61
- ]
62
-
63
- AVAILABLE_SOURCES = [
64
- "HF_Transformers",
65
- "PEFT",
66
- "TRL",
67
- "LlamaIndex",
68
- "towards_ai_blog",
69
- # "wikipedia",
70
- # "openai_docs",
71
- # "langchain_docs",
72
- "rag_course",
73
- ]
74
-
75
-
76
- # mongo_db = (
77
- # init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
78
- # if MONGODB_URI
79
- # else logfire.warn("No mongodb uri found, you will not be able to save data.")
80
- # )
81
-
82
- DB_PATH = os.getenv("DB_PATH", "data/chroma-db-transformers")
83
- DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-transformers")
84
-
85
- db2 = chromadb.PersistentClient(path=DB_PATH)
86
- chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
87
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
88
-
89
- index = VectorStoreIndex.from_vector_store(
90
- vector_store=vector_store,
91
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
92
- transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
93
- show_progress=True,
94
- use_async=True,
95
- )
96
- vector_retriever = VectorIndexRetriever(
97
- index=index,
98
- similarity_top_k=10,
99
- use_async=True,
100
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
101
- )
102
- with open(f"{DB_PATH}/document_dict_tf.pkl", "rb") as f:
103
- document_dict = pickle.load(f)
104
-
105
- custom_retriever_tf = CustomRetriever(vector_retriever, document_dict)
106
-
107
- DB_PATH = os.getenv("DB_PATH", "data/chroma-db-peft")
108
- DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-peft")
109
-
110
- db2 = chromadb.PersistentClient(path=DB_PATH)
111
- chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
112
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
113
-
114
- index = VectorStoreIndex.from_vector_store(
115
- vector_store=vector_store,
116
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
117
- transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
118
- show_progress=True,
119
- use_async=True,
120
- )
121
- vector_retriever = VectorIndexRetriever(
122
- index=index,
123
- similarity_top_k=10,
124
- use_async=True,
125
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
126
- )
127
- with open(f"{DB_PATH}/document_dict_peft.pkl", "rb") as f:
128
- document_dict = pickle.load(f)
129
-
130
- custom_retriever_peft = CustomRetriever(vector_retriever, document_dict)
131
-
132
- DB_PATH = os.getenv("DB_PATH", f"data/chroma-db-trl")
133
- DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-trl")
134
-
135
- db2 = chromadb.PersistentClient(path=DB_PATH)
136
- chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
137
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
138
-
139
- index = VectorStoreIndex.from_vector_store(
140
- vector_store=vector_store,
141
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
142
- transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
143
- show_progress=True,
144
- use_async=True,
145
- )
146
- vector_retriever = VectorIndexRetriever(
147
- index=index,
148
- similarity_top_k=10,
149
- use_async=True,
150
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
151
- )
152
- with open(f"{DB_PATH}/document_dict_trl.pkl", "rb") as f:
153
- document_dict = pickle.load(f)
154
-
155
- custom_retriever_trl = CustomRetriever(vector_retriever, document_dict)
156
-
157
- DB_PATH = os.getenv("DB_PATH", "data/chroma-db-llama-index")
158
- DB_COLLECTION = os.getenv("DB_NAME", "chroma-db-llama-index")
159
-
160
- db2 = chromadb.PersistentClient(path=DB_PATH)
161
- chroma_collection = db2.get_or_create_collection(DB_COLLECTION)
162
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
163
-
164
- index = VectorStoreIndex.from_vector_store(
165
- vector_store=vector_store,
166
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
167
- transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
168
- show_progress=True,
169
- use_async=True,
170
- )
171
- vector_retriever = VectorIndexRetriever(
172
- index=index,
173
- similarity_top_k=10,
174
- use_async=True,
175
- embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
176
- )
177
- with open(f"{DB_PATH}/document_dict_llamaindex.pkl", "rb") as f:
178
- document_dict = pickle.load(f)
179
-
180
- custom_retriever_llamaindex = CustomRetriever(vector_retriever, document_dict)
181
-
182
-
183
- def format_sources(completion) -> str:
184
- if len(completion.sources) == 0:
185
- return ""
186
-
187
- display_source_to_ui = {
188
- src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
189
- }
190
-
191
- documents_answer_template: str = (
192
- "πŸ“ Here are the sources I used to answer your question:\n{documents}"
193
- )
194
- document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
195
-
196
- all_documents = []
197
- for source in completion.sources:
198
- for src in source.raw_output:
199
- document = document_template.format(
200
- title=src.metadata["title"],
201
- score=src.score,
202
- source=display_source_to_ui.get(
203
- src.metadata["source"], src.metadata["source"]
204
- ),
205
- url=src.metadata["url"],
206
- )
207
- all_documents.append(document)
208
-
209
- documents = "\n".join(all_documents)
210
-
211
- return documents_answer_template.format(documents=documents)
212
-
213
-
214
- def add_sources(answer_str, completion):
215
- if completion is None:
216
- yield answer_str
217
-
218
- formatted_sources = format_sources(completion)
219
- if formatted_sources == "":
220
- yield answer_str
221
-
222
- if formatted_sources != "":
223
- answer_str += "\n\n" + formatted_sources
224
-
225
- yield answer_str
226
-
227
-
228
- def generate_completion(
229
- query,
230
- history,
231
- sources,
232
- model,
233
- memory,
234
- ):
235
-
236
- with logfire.span("Running query"):
237
- logfire.info(f"query: {query}")
238
- logfire.info(f"model: {model}")
239
- logfire.info(f"sources: {sources}")
240
-
241
- chat_list = memory.get()
242
-
243
- if len(chat_list) != 0:
244
- user_index = [
245
- i for i, msg in enumerate(chat_list) if msg.role == MessageRole.USER
246
- ]
247
- if len(user_index) > len(history):
248
- user_index_to_remove = user_index[len(history)]
249
- chat_list = chat_list[:user_index_to_remove]
250
- memory.set(chat_list)
251
-
252
- logfire.info(f"chat_history: {len(memory.get())} {memory.get()}")
253
- logfire.info(f"gradio_history: {len(history)} {history}")
254
-
255
- # # # TODO: change source UI name to actual source name
256
- # filters = MetadataFilters(
257
- # filters=[
258
- # # MetadataFilter(key="source", value="HF_Transformers"),
259
- # # MetadataFilter(key="source", value="towards_ai_blog"),
260
- # MetadataFilter(
261
- # key="source", operator=FilterOperator.EQ, value="HF_Transformers"
262
- # ),
263
- # ],
264
- # # condition=FilterCondition.OR,
265
- # )
266
- # vector_retriever = VectorIndexRetriever(
267
- # # filters=filters,
268
- # index=index,
269
- # similarity_top_k=10,
270
- # use_async=True,
271
- # embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
272
- # )
273
- # custom_retriever = CustomRetriever(vector_retriever, document_dict)
274
-
275
- llm = OpenAI(temperature=1, model=model, max_tokens=None)
276
- client = llm._get_client()
277
- logfire.instrument_openai(client)
278
-
279
- query_engine_tools = [
280
- RetrieverTool(
281
- retriever=custom_retriever_tf,
282
- metadata=ToolMetadata(
283
- name="Transformers_information",
284
- description="""Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch general information on topics such as language models theory (transformer architectures), tips on prompting, models, quantization, etc.""",
285
- ),
286
- ),
287
- RetrieverTool(
288
- retriever=custom_retriever_peft,
289
- metadata=ToolMetadata(
290
- name="PEFT_information",
291
- description=" Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc."
292
- "",
293
- ),
294
- ),
295
- RetrieverTool(
296
- retriever=custom_retriever_trl,
297
- metadata=ToolMetadata(
298
- name="TRL_information",
299
- description="""Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
300
- ),
301
- ),
302
- RetrieverTool(
303
- retriever=custom_retriever_llamaindex,
304
- metadata=ToolMetadata(
305
- name="LlamaIndex_information",
306
- description="""Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of the LlamaIndex framework, includes info about fine-tuning embedding models, building chatbots, and agents with llms, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
307
- ),
308
- ),
309
- ]
310
-
311
- agent = OpenAIAgent.from_tools(
312
- llm=llm,
313
- memory=memory,
314
- tools=query_engine_tools, # type: ignore
315
- system_prompt=system_message_openai_agent,
316
- )
317
-
318
- completion = agent.stream_chat(query)
319
-
320
- answer_str = ""
321
- for token in completion.response_gen:
322
- answer_str += token
323
- yield answer_str
324
-
325
- for answer_str in add_sources(answer_str, completion):
326
- yield answer_str
327
-
328
-
329
- def vote(data: gr.LikeData):
330
- collection = "liked_data-test"
331
- if data.liked:
332
- print("You upvoted this response: " + data.value["value"])
333
- else:
334
- print("You downvoted this response: " + data.value["value"])
335
-
336
- # completion_json["liked"] = like_data.liked
337
- # logger.info(f"User reported {like_data.liked=}")
338
-
339
- # try:
340
- # cfg.mongo_db[collection].insert_one(completion_json)
341
- # except:
342
- # logger.info("Something went wrong logging")
343
-
344
-
345
- # def save_completion(completion: Completion, history):
346
- # collection = "completion_data-hf"
347
-
348
- # # Convert completion to JSON and ignore certain columns
349
- # completion_json = completion.to_json(
350
- # columns_to_ignore=["embedding", "similarity", "similarity_to_answer"]
351
- # )
352
-
353
- # # Add the current date and time to the JSON
354
- # completion_json["timestamp"] = datetime.utcnow().isoformat()
355
- # completion_json["history"] = history
356
- # completion_json["history_len"] = len(history)
357
-
358
- # try:
359
- # cfg.mongo_db[collection].insert_one(completion_json)
360
- # logger.info("Completion saved to db")
361
- # except Exception as e:
362
- # logger.info(f"Something went wrong logging completion to db: {e}")
363
-
364
-
365
- accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
366
- sources = gr.CheckboxGroup(
367
- AVAILABLE_SOURCES_UI, label="Sources", value=["HF Transformers", "PEFT", "TRL", "LlamaIndex Docs"], interactive=False # type: ignore
368
- )
369
- model = gr.Dropdown(
370
- [
371
- "gemini-1.5-pro",
372
- "gemini-1.5-flash",
373
- "gpt-4o-mini",
374
- "gpt-4o",
375
- ],
376
- label="Model",
377
- value="gpt-4o-mini",
378
- interactive=False,
379
- )
380
-
381
- with gr.Blocks(
382
- fill_height=True,
383
- title="Towards AI πŸ€–",
384
- analytics_enabled=True,
385
- ) as demo:
386
-
387
- memory = gr.State(
388
- ChatSummaryMemoryBuffer.from_defaults(
389
- token_limit=120000,
390
- )
391
- )
392
- chatbot = gr.Chatbot(
393
- scale=1,
394
- placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",
395
- show_label=False,
396
- likeable=True,
397
- show_copy_button=True,
398
- )
399
- chatbot.like(vote, None, None)
400
- gr.ChatInterface(
401
- fn=generate_completion,
402
- chatbot=chatbot,
403
- additional_inputs=[sources, model, memory],
404
- additional_inputs_accordion=accordion,
405
- )
406
-
407
- if __name__ == "__main__":
408
- demo.queue(default_concurrency_limit=CONCURRENCY_COUNT)
409
- demo.launch(debug=False, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/main.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import logfire
3
+ from llama_index.agent.openai import OpenAIAgent
4
+ from llama_index.core.llms import MessageRole
5
+ from llama_index.core.memory import ChatSummaryMemoryBuffer
6
+ from llama_index.core.tools import RetrieverTool, ToolMetadata
7
+ from llama_index.llms.openai import OpenAI
8
+ from prompts import system_message_openai_agent
9
+ from setup import (
10
+ AVAILABLE_SOURCES,
11
+ AVAILABLE_SOURCES_UI,
12
+ CONCURRENCY_COUNT,
13
+ custom_retriever_llamaindex,
14
+ custom_retriever_peft,
15
+ custom_retriever_tf,
16
+ custom_retriever_trl,
17
+ )
18
+
19
+
20
+ def update_query_engine_tools(selected_sources):
21
+ tools = []
22
+ source_mapping = {
23
+ "HF Transformers": (
24
+ custom_retriever_tf,
25
+ "Transformers_information",
26
+ """Useful for general questions asking about the artificial intelligence (AI) field. Employ this tool to fetch general information on topics such as language models theory (transformer architectures), tips on prompting, models, quantization, etc.""",
27
+ ),
28
+ "PEFT": (
29
+ custom_retriever_peft,
30
+ "PEFT_information",
31
+ """Useful for questions asking about efficient LLM fine-tuning. Employ this tool to fetch information on topics such as LoRA, QLoRA, etc.""",
32
+ ),
33
+ "TRL": (
34
+ custom_retriever_trl,
35
+ "TRL_information",
36
+ """Useful for questions asking about fine-tuning LLMs with reinforcement learning (RLHF). Includes information about the Supervised Fine-tuning step (SFT), Reward Modeling step (RM), and the Proximal Policy Optimization (PPO) step.""",
37
+ ),
38
+ "LlamaIndex Docs": (
39
+ custom_retriever_llamaindex,
40
+ "LlamaIndex_information",
41
+ """Useful for questions asking about retrieval augmented generation (RAG) with LLMs and embedding models. It is the documentation of the LlamaIndex framework, includes info about fine-tuning embedding models, building chatbots, and agents with llms, using vector databases, embeddings, information retrieval with cosine similarity or bm25, etc.""",
42
+ ),
43
+ }
44
+
45
+ for source in selected_sources:
46
+ if source in source_mapping:
47
+ retriever, name, description = source_mapping[source]
48
+ tools.append(
49
+ RetrieverTool(
50
+ retriever=retriever,
51
+ metadata=ToolMetadata(
52
+ name=name,
53
+ description=description,
54
+ ),
55
+ )
56
+ )
57
+
58
+ return tools
59
+
60
+
61
+ def generate_completion(
62
+ query,
63
+ history,
64
+ sources,
65
+ model,
66
+ memory,
67
+ ):
68
+ with logfire.span("Running query"):
69
+ logfire.info(f"query: {query}")
70
+ logfire.info(f"model: {model}")
71
+ logfire.info(f"sources: {sources}")
72
+
73
+ chat_list = memory.get()
74
+
75
+ if len(chat_list) != 0:
76
+ user_index = [
77
+ i for i, msg in enumerate(chat_list) if msg.role == MessageRole.USER
78
+ ]
79
+ if len(user_index) > len(history):
80
+ user_index_to_remove = user_index[len(history)]
81
+ chat_list = chat_list[:user_index_to_remove]
82
+ memory.set(chat_list)
83
+
84
+ logfire.info(f"chat_history: {len(memory.get())} {memory.get()}")
85
+ logfire.info(f"gradio_history: {len(history)} {history}")
86
+
87
+ llm = OpenAI(temperature=1, model=model, max_tokens=None)
88
+ client = llm._get_client()
89
+ logfire.instrument_openai(client)
90
+
91
+ query_engine_tools = update_query_engine_tools(sources)
92
+
93
+ agent = OpenAIAgent.from_tools(
94
+ llm=llm,
95
+ memory=memory,
96
+ tools=query_engine_tools,
97
+ system_prompt=system_message_openai_agent,
98
+ )
99
+
100
+ completion = agent.stream_chat(query)
101
+
102
+ answer_str = ""
103
+ for token in completion.response_gen:
104
+ answer_str += token
105
+ yield answer_str
106
+
107
+ for answer_str in add_sources(answer_str, completion):
108
+ yield answer_str
109
+
110
+
111
+ def add_sources(answer_str, completion):
112
+ if completion is None:
113
+ yield answer_str
114
+
115
+ formatted_sources = format_sources(completion)
116
+ if formatted_sources == "":
117
+ yield answer_str
118
+
119
+ if formatted_sources != "":
120
+ answer_str += "\n\n" + formatted_sources
121
+
122
+ yield answer_str
123
+
124
+
125
+ def format_sources(completion) -> str:
126
+ if len(completion.sources) == 0:
127
+ return ""
128
+
129
+ display_source_to_ui = {
130
+ src: ui for src, ui in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI)
131
+ }
132
+
133
+ documents_answer_template: str = (
134
+ "πŸ“ Here are the sources I used to answer your question:\n{documents}"
135
+ )
136
+ document_template: str = "[πŸ”— {source}: {title}]({url}), relevance: {score:2.2f}"
137
+
138
+ all_documents = []
139
+ for source in completion.sources:
140
+ for src in source.raw_output:
141
+ document = document_template.format(
142
+ title=src.metadata["title"],
143
+ score=src.score,
144
+ source=display_source_to_ui.get(
145
+ src.metadata["source"], src.metadata["source"]
146
+ ),
147
+ url=src.metadata["url"],
148
+ )
149
+ all_documents.append(document)
150
+
151
+ documents = "\n".join(all_documents)
152
+
153
+ return documents_answer_template.format(documents=documents)
154
+
155
+
156
+ def save_completion(completion, history):
157
+ pass
158
+
159
+
160
+ def vote(data: gr.LikeData):
161
+ pass
162
+
163
+
164
+ accordion = gr.Accordion(label="Customize Sources (Click to expand)", open=False)
165
+ sources = gr.CheckboxGroup(
166
+ AVAILABLE_SOURCES_UI,
167
+ label="Sources",
168
+ value=["HF Transformers", "PEFT", "TRL", "LlamaIndex Docs"],
169
+ interactive=True,
170
+ )
171
+ model = gr.Dropdown(
172
+ [
173
+ "gpt-4o-mini",
174
+ "gpt-4o",
175
+ ],
176
+ label="Model",
177
+ value="gpt-4o-mini",
178
+ interactive=False,
179
+ )
180
+
181
+ with gr.Blocks(
182
+ fill_height=True,
183
+ title="Towards AI πŸ€–",
184
+ analytics_enabled=True,
185
+ ) as demo:
186
+
187
+ memory = gr.State(
188
+ ChatSummaryMemoryBuffer.from_defaults(
189
+ token_limit=120000,
190
+ )
191
+ )
192
+ chatbot = gr.Chatbot(
193
+ scale=1,
194
+ placeholder="<strong>Towards AI πŸ€–: A Question-Answering Bot for anything AI-related</strong><br>",
195
+ show_label=False,
196
+ likeable=True,
197
+ show_copy_button=True,
198
+ )
199
+ chatbot.like(vote, None, None)
200
+ gr.ChatInterface(
201
+ fn=generate_completion,
202
+ chatbot=chatbot,
203
+ additional_inputs=[sources, model, memory],
204
+ additional_inputs_accordion=accordion,
205
+ )
206
+
207
+ if __name__ == "__main__":
208
+ demo.queue(default_concurrency_limit=CONCURRENCY_COUNT)
209
+ demo.launch(debug=False, share=False)
scripts/prompts.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs. Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback. Questions should be understood in this context.
2
+
3
+ Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
4
+
5
+ Use the available tools to gather insights pertinent to the field of AI. Always use two tools at the same time. These tools accept a string (a user query rewritten as a statement) and return informative content regarding the domain of AI.
6
+ e.g:
7
+ User question: 'How can I fine-tune an LLM?'
8
+ Input to the tool: 'Fine-tuning an LLM'
9
+
10
+ User question: How can quantize an LLM?
11
+ Input to the tool: 'Quantization for LLMs'
12
+
13
+ User question: 'Teach me how to build an AI agent"'
14
+ Input to the tool: 'Building an AI Agent'
15
+
16
+ Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
17
+
18
+ Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
19
+
20
+ When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
21
+
22
+ Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
23
+
24
+ Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
25
+
26
+ At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
27
+
28
+ Do not refer to the documentation directly, but use the information provided within it to answer questions.
29
+
30
+ If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
31
+
32
+ Make sure to format your answers in Markdown format, including code blocks and snippets.
33
+ """
scripts/setup.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+
5
+ import chromadb
6
+ import logfire
7
+ from custom_retriever import CustomRetriever
8
+ from dotenv import load_dotenv
9
+ from llama_index.core import VectorStoreIndex
10
+ from llama_index.core.node_parser import SentenceSplitter
11
+ from llama_index.core.retrievers import VectorIndexRetriever
12
+ from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.vector_stores.chroma import ChromaVectorStore
14
+
15
+ # from utils import init_mongo_db
16
+
17
+ load_dotenv()
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(level=logging.INFO)
21
+ logging.getLogger("httpx").setLevel(logging.WARNING)
22
+ logfire.configure()
23
+
24
+
25
+ if not os.path.exists("data/chroma-db-transformers"):
26
+ # Download the vector database from the Hugging Face Hub if it doesn't exist locally
27
+ # https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main
28
+ logfire.warn(
29
+ f"Vector database does not exist at 'data/chroma-db-transformers', downloading from Hugging Face Hub"
30
+ )
31
+ from huggingface_hub import snapshot_download
32
+
33
+ snapshot_download(
34
+ repo_id="towardsai-buster/ai-tutor-vector-db",
35
+ local_dir="data",
36
+ repo_type="dataset",
37
+ )
38
+ logfire.info(f"Downloaded vector database to 'data/chroma-db-transformers'")
39
+
40
+
41
+ def setup_database(db_collection, dict_file_name):
42
+ db = chromadb.PersistentClient(path=f"data/{db_collection}")
43
+ chroma_collection = db.get_or_create_collection(db_collection)
44
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
45
+
46
+ index = VectorStoreIndex.from_vector_store(
47
+ vector_store=vector_store,
48
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
49
+ transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=400)],
50
+ show_progress=True,
51
+ use_async=True,
52
+ )
53
+ vector_retriever = VectorIndexRetriever(
54
+ index=index,
55
+ similarity_top_k=10,
56
+ use_async=True,
57
+ embed_model=OpenAIEmbedding(model="text-embedding-3-large", mode="similarity"),
58
+ )
59
+ with open(f"data/{db_collection}/{dict_file_name}", "rb") as f:
60
+ document_dict = pickle.load(f)
61
+
62
+ return CustomRetriever(vector_retriever, document_dict)
63
+
64
+
65
+ # Setup retrievers
66
+ custom_retriever_tf = setup_database(
67
+ "chroma-db-transformers",
68
+ "document_dict_tf.pkl",
69
+ )
70
+ custom_retriever_peft = setup_database("chroma-db-peft", "document_dict_peft.pkl")
71
+ custom_retriever_trl = setup_database("chroma-db-trl", "document_dict_trl.pkl")
72
+ custom_retriever_llamaindex = setup_database(
73
+ "chroma-db-llama-index",
74
+ "document_dict_llamaindex.pkl",
75
+ )
76
+
77
+ # Constants
78
+ CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64))
79
+ MONGODB_URI = os.getenv("MONGODB_URI")
80
+
81
+ AVAILABLE_SOURCES_UI = [
82
+ "HF Transformers",
83
+ "PEFT",
84
+ "TRL",
85
+ "LlamaIndex Docs",
86
+ "Towards AI Blog",
87
+ "RAG Course",
88
+ ]
89
+
90
+ AVAILABLE_SOURCES = [
91
+ "HF_Transformers",
92
+ "PEFT",
93
+ "TRL",
94
+ "LlamaIndex",
95
+ "towards_ai_blog",
96
+ "rag_course",
97
+ ]
98
+
99
+ # mongo_db = (
100
+ # init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster")
101
+ # if MONGODB_URI
102
+ # else logfire.warn("No mongodb uri found, you will not be able to save data.")
103
+ # )
104
+
105
+ __all__ = [
106
+ "custom_retriever_tf",
107
+ "custom_retriever_peft",
108
+ "custom_retriever_trl",
109
+ "custom_retriever_llamaindex",
110
+ "CONCURRENCY_COUNT",
111
+ "MONGODB_URI",
112
+ "AVAILABLE_SOURCES_UI",
113
+ "AVAILABLE_SOURCES",
114
+ ]
scripts/tutor_prompts.py DELETED
@@ -1,117 +0,0 @@
1
- from llama_index.core import ChatPromptTemplate
2
- from llama_index.core.llms import ChatMessage, MessageRole
3
- from pydantic import BaseModel, Field
4
-
5
- default_user_prompt = (
6
- "Context information is below.\n"
7
- "---------------------\n"
8
- "{context_str}\n"
9
- "---------------------\n"
10
- "Given the context information and not prior knowledge, "
11
- "answer the question: {query_str}\n"
12
- )
13
-
14
- system_prompt = (
15
- "You are an AI teacher, answering questions from students of an applied artificial intelligence course on Large Language Models (LLMs or LLM). "
16
- "Your answers are aimed to teach students, so they should be complete, clear, and easy to understand. "
17
- "Topics covered include training models, fine-tuning models, giving 'memory' to LLMs, prompting, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, Llama-Index, LLMs interact with tool use, AI agents, reinforcement learning with human feedback. Understand the questions with this context. "
18
- "You are provided information in Hugging Face's documentation and a RAG course. "
19
- "Only some information might be relevant to the question, so ignore the irrelevant part and use the relevant part to answer the question. "
20
- "Formulate your answer with the information given to you below. DO NOT use additional information, even if you know the answer. "
21
- "If the answer is somewhere in the documentation below, answer the question, depending on the question and the variety of relevant information in the documentation, give complete and helpful answers. "
22
- "If code is provided in the information, share it with the students. It's important to provide complete code blocks. "
23
- "Here is the information you can use, the order is not important: \n\n"
24
- "---------------------\n"
25
- "{context_str}\n"
26
- "---------------------\n\n"
27
- "REMEMBER:\n"
28
- "You are an AI teacher, answering questions from students of an applied artificial intelligence course on Large Language Models (LLMs or llm). Topics covered include training models, fine tuning models, giving memory to LLMs, prompting, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, making LLMs interact with tool use, AI agents, reinforcement learning with human feedback. Questions should be understood with this context. "
29
- "Your answers are aimed to teach students, so they should be complete, clear, and easy to understand. "
30
- "You are provided information found in Hugging Face's documentation and a RAG course. "
31
- "Here are the rules you must follow: \n"
32
- "* Only respond with information inside the documentation. DO NOT provide additional information, even if you know the answer. "
33
- "* If the answer is in the documentation, answer the question (depending on the questions and the variety of relevant information in the documentation. Your answer needs to give a clear and complete explanation as if you were a teacher. "
34
- "* Do not refer to the documentation directly, but use the information provided within it to answer questions. "
35
- "* Do not reference any links, urls or hyperlinks in your answers.\n "
36
- "* If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute it.\n "
37
- "* Make sure to format your answers in Markdown format, including code block and snippets.\n "
38
- "Now answer the following question: \n"
39
- )
40
-
41
- chat_text_qa_msgs: list[ChatMessage] = [
42
- ChatMessage(role=MessageRole.SYSTEM, content=system_prompt),
43
- ChatMessage(
44
- role=MessageRole.USER,
45
- content="{query_str}",
46
- ),
47
- ]
48
-
49
- TEXT_QA_TEMPLATE = ChatPromptTemplate(chat_text_qa_msgs)
50
-
51
-
52
- system_message_validation = """- You are a witty AI teacher, helpfully answering questions from students studying the field of applied artificial intelligence.
53
- - Your job is to determine whether user's question is valid or not. Users will not always submit a question either.
54
- - Users will ask all sorts of questions, and some might be tangentially related to artificial intelligence (AI), machine learning (ML), natural language processing (NLP), computer vision (CV) or generative AI.
55
- - Users can ask how to build LLM-powered apps, with LangChain, LlamaIndex, Deep Lake, Chroma DB among other technologies including OpenAI, RAG and more.
56
- - As long as a question is somewhat related to the topic of AI, ML, NLP, RAG, data and techniques used in AI like vector embeddings, memories, embeddings, tokenization, encoding, databases, RAG (Retrieval-Augmented Generation), Langchain, LlamaIndex, LLMs (Large Language Models), Preprocessing techniques, Document loading, Chunking, Indexing of document segments, Embedding models, Chains, Memory modules, Vector stores, Chat models, Sequential chains, Information Retrieval, Data connectors, LlamaHub, Node objects, Query engines, Fine-tuning, Activeloop’s Deep Memory, Prompt engineering, Synthetic training dataset, Inference, Recall rates, Query construction, Query expansion, Query transformation, Re-ranking, Cohere Reranker, Recursive retrieval, Small-to-big retrieval, Hybrid searches, Hit Rate, Mean Reciprocal Rank (MRR), GPT-4, Agents, OpenGPTs, Zero-shot ReAct, Conversational Agent, OpenAI Assistants API, Hugging Face Inference API, Code Interpreter, Knowledge Retrieval, Function Calling, Whisper, Dall-E 3, GPT-4 Vision, Unstructured, Deep Lake, FaithfulnessEvaluator, RAGAS, LangSmith, LangChain Hub, LangServe, REST API, respond 'true'. If a question is on a different subject or unrelated, respond 'false'.
57
- - Make sure the question is a valid question.
58
-
59
- Here is a list of acronyms and concepts related to Artificial Intelligence AI that are valid. The following terms can be Uppercase or Lowercase:
60
- You are case insensitive.
61
- 'TQL', 'Deep Memory', 'LLM', 'Llama', 'llamaindex', 'llama-index', 'lang chain', 'langchain', 'llama index', 'GPT', 'NLP', 'RLHF', 'RLAIF', 'Mistral', 'SFT', 'Cohere', 'NanoGPT', 'ReAct', 'LoRA', 'QLoRA', 'LMMOps', 'Alpaca', 'Flan', 'Weights and Biases', 'W&B', 'IDEFICS', 'Flamingo', 'LLaVA', 'BLIP', 'Falcon', 'Gemini'
62
-
63
- """
64
-
65
-
66
- class QueryValidation(BaseModel):
67
- """
68
- Validate the user query. Use the guidelines given to you.
69
- """
70
-
71
- user_query: str = Field(
72
- description="The user query to validate.",
73
- )
74
- chain_of_thought: str = Field(
75
- description="Is the user query valid given the above guidelines? Think step-by-step. Write down your reasoning here.",
76
- )
77
- is_valid: bool = Field(
78
- description="Based on the previous reasoning, answer with True if the query is related to AI. Answer False otherwise.",
79
- )
80
- reason: str = Field(
81
- description="Explain why the query was valid or not. What are the keywords that make it valid or invalid?",
82
- )
83
-
84
-
85
- system_message_openai_agent = """You are an AI teacher, answering questions from students of an applied AI course on Large Language Models (LLMs or llm) and Retrieval Augmented Generation (RAG) for LLMs. Topics covered include training models, fine-tuning models, giving memory to LLMs, prompting tips, hallucinations and bias, vector databases, transformer architectures, embeddings, RAG frameworks, Langchain, LlamaIndex, making LLMs interact with tools, AI agents, reinforcement learning with human feedback. Questions should be understood in this context.
86
-
87
- Your answers are aimed to teach students, so they should be complete, clear, and easy to understand.
88
-
89
- Use the available tools to gather insights pertinent to the field of AI. Always use two tools at the same time. These tools accept a string (a user query rewritten as a statement) and return informative content regarding the domain of AI.
90
- e.g:
91
- User question: 'How can I fine-tune an LLM?'
92
- Input to the tool: 'Fine-tuning an LLM'
93
-
94
- User question: How can quantize an LLM?
95
- Input to the tool: 'Quantization for LLMs'
96
-
97
- User question: 'Teach me how to build an AI agent"'
98
- Input to the tool: 'Building an AI Agent'
99
-
100
- Only some information returned by the tools might be relevant to the question, so ignore the irrelevant part and answer the question with what you have.
101
-
102
- Your responses are exclusively based on the output provided by the tools. Refrain from incorporating information not directly obtained from the tool's responses.
103
-
104
- When the conversation deepens or shifts focus within a topic, adapt your input to the tools to reflect these nuances. This means if a user requests further elaboration on a specific aspect of a previously discussed topic, you should reformulate your input to the tool to capture this new angle or more profound layer of inquiry.
105
-
106
- Provide comprehensive answers, ideally structured in multiple paragraphs, drawing from the tool's variety of relevant details. The depth and breadth of your responses should align with the scope and specificity of the information retrieved.
107
-
108
- Should the tools repository lack information on the queried topic, politely inform the user that the question transcends the bounds of your current knowledge base, citing the absence of relevant content in the tool's documentation.
109
-
110
- At the end of your answers, always invite the students to ask deeper questions about the topic if they have any. Make sure reformulate the question to the tool to capture this new angle or more profound layer of inquiry.
111
-
112
- Do not refer to the documentation directly, but use the information provided within it to answer questions.
113
-
114
- If code is provided in the information, share it with the students. It's important to provide complete code blocks so they can execute the code when they copy and paste them.
115
-
116
- Make sure to format your answers in Markdown format, including code blocks and snippets.
117
- """