XThomasBU commited on
Commit
f51bb92
β€’
1 Parent(s): 48f8268

init commit

Browse files
.chainlit/translations/en-US.json DELETED
@@ -1,231 +0,0 @@
1
- {
2
- "components": {
3
- "atoms": {
4
- "buttons": {
5
- "userButton": {
6
- "menu": {
7
- "settings": "Settings",
8
- "settingsKey": "S",
9
- "APIKeys": "API Keys",
10
- "logout": "Logout"
11
- }
12
- }
13
- }
14
- },
15
- "molecules": {
16
- "newChatButton": {
17
- "newChat": "New Chat"
18
- },
19
- "tasklist": {
20
- "TaskList": {
21
- "title": "\ud83d\uddd2\ufe0f Task List",
22
- "loading": "Loading...",
23
- "error": "An error occured"
24
- }
25
- },
26
- "attachments": {
27
- "cancelUpload": "Cancel upload",
28
- "removeAttachment": "Remove attachment"
29
- },
30
- "newChatDialog": {
31
- "createNewChat": "Create new chat?",
32
- "clearChat": "This will clear the current messages and start a new chat.",
33
- "cancel": "Cancel",
34
- "confirm": "Confirm"
35
- },
36
- "settingsModal": {
37
- "settings": "Settings",
38
- "expandMessages": "Expand Messages",
39
- "hideChainOfThought": "Hide Chain of Thought",
40
- "darkMode": "Dark Mode"
41
- },
42
- "detailsButton": {
43
- "using": "Using",
44
- "running": "Running",
45
- "took_one": "Took {{count}} step",
46
- "took_other": "Took {{count}} steps"
47
- },
48
- "auth": {
49
- "authLogin": {
50
- "title": "Login to access the app.",
51
- "form": {
52
- "email": "Email address",
53
- "password": "Password",
54
- "noAccount": "Don't have an account?",
55
- "alreadyHaveAccount": "Already have an account?",
56
- "signup": "Sign Up",
57
- "signin": "Sign In",
58
- "or": "OR",
59
- "continue": "Continue",
60
- "forgotPassword": "Forgot password?",
61
- "passwordMustContain": "Your password must contain:",
62
- "emailRequired": "email is a required field",
63
- "passwordRequired": "password is a required field"
64
- },
65
- "error": {
66
- "default": "Unable to sign in.",
67
- "signin": "Try signing in with a different account.",
68
- "oauthsignin": "Try signing in with a different account.",
69
- "redirect_uri_mismatch": "The redirect URI is not matching the oauth app configuration.",
70
- "oauthcallbackerror": "Try signing in with a different account.",
71
- "oauthcreateaccount": "Try signing in with a different account.",
72
- "emailcreateaccount": "Try signing in with a different account.",
73
- "callback": "Try signing in with a different account.",
74
- "oauthaccountnotlinked": "To confirm your identity, sign in with the same account you used originally.",
75
- "emailsignin": "The e-mail could not be sent.",
76
- "emailverify": "Please verify your email, a new email has been sent.",
77
- "credentialssignin": "Sign in failed. Check the details you provided are correct.",
78
- "sessionrequired": "Please sign in to access this page."
79
- }
80
- },
81
- "authVerifyEmail": {
82
- "almostThere": "You're almost there! We've sent an email to ",
83
- "verifyEmailLink": "Please click on the link in that email to complete your signup.",
84
- "didNotReceive": "Can't find the email?",
85
- "resendEmail": "Resend email",
86
- "goBack": "Go Back",
87
- "emailSent": "Email sent successfully.",
88
- "verifyEmail": "Verify your email address"
89
- },
90
- "providerButton": {
91
- "continue": "Continue with {{provider}}",
92
- "signup": "Sign up with {{provider}}"
93
- },
94
- "authResetPassword": {
95
- "newPasswordRequired": "New password is a required field",
96
- "passwordsMustMatch": "Passwords must match",
97
- "confirmPasswordRequired": "Confirm password is a required field",
98
- "newPassword": "New password",
99
- "confirmPassword": "Confirm password",
100
- "resetPassword": "Reset Password"
101
- },
102
- "authForgotPassword": {
103
- "email": "Email address",
104
- "emailRequired": "email is a required field",
105
- "emailSent": "Please check the email address {{email}} for instructions to reset your password.",
106
- "enterEmail": "Enter your email address and we will send you instructions to reset your password.",
107
- "resendEmail": "Resend email",
108
- "continue": "Continue",
109
- "goBack": "Go Back"
110
- }
111
- }
112
- },
113
- "organisms": {
114
- "chat": {
115
- "history": {
116
- "index": {
117
- "showHistory": "Show history",
118
- "lastInputs": "Last Inputs",
119
- "noInputs": "Such empty...",
120
- "loading": "Loading..."
121
- }
122
- },
123
- "inputBox": {
124
- "input": {
125
- "placeholder": "Type your message here..."
126
- },
127
- "speechButton": {
128
- "start": "Start recording",
129
- "stop": "Stop recording"
130
- },
131
- "SubmitButton": {
132
- "sendMessage": "Send message",
133
- "stopTask": "Stop Task"
134
- },
135
- "UploadButton": {
136
- "attachFiles": "Attach files"
137
- },
138
- "waterMark": {
139
- "text": "Built with"
140
- }
141
- },
142
- "Messages": {
143
- "index": {
144
- "running": "Running",
145
- "executedSuccessfully": "executed successfully",
146
- "failed": "failed",
147
- "feedbackUpdated": "Feedback updated",
148
- "updating": "Updating"
149
- }
150
- },
151
- "dropScreen": {
152
- "dropYourFilesHere": "Drop your files here"
153
- },
154
- "index": {
155
- "failedToUpload": "Failed to upload",
156
- "cancelledUploadOf": "Cancelled upload of",
157
- "couldNotReachServer": "Could not reach the server",
158
- "continuingChat": "Continuing previous chat"
159
- },
160
- "settings": {
161
- "settingsPanel": "Settings panel",
162
- "reset": "Reset",
163
- "cancel": "Cancel",
164
- "confirm": "Confirm"
165
- }
166
- },
167
- "threadHistory": {
168
- "sidebar": {
169
- "filters": {
170
- "FeedbackSelect": {
171
- "feedbackAll": "Feedback: All",
172
- "feedbackPositive": "Feedback: Positive",
173
- "feedbackNegative": "Feedback: Negative"
174
- },
175
- "SearchBar": {
176
- "search": "Search"
177
- }
178
- },
179
- "DeleteThreadButton": {
180
- "confirmMessage": "This will delete the thread as well as it's messages and elements.",
181
- "cancel": "Cancel",
182
- "confirm": "Confirm",
183
- "deletingChat": "Deleting chat",
184
- "chatDeleted": "Chat deleted"
185
- },
186
- "index": {
187
- "pastChats": "Past Chats"
188
- },
189
- "ThreadList": {
190
- "empty": "Empty...",
191
- "today": "Today",
192
- "yesterday": "Yesterday",
193
- "previous7days": "Previous 7 days",
194
- "previous30days": "Previous 30 days"
195
- },
196
- "TriggerButton": {
197
- "closeSidebar": "Close sidebar",
198
- "openSidebar": "Open sidebar"
199
- }
200
- },
201
- "Thread": {
202
- "backToChat": "Go back to chat",
203
- "chatCreatedOn": "This chat was created on"
204
- }
205
- },
206
- "header": {
207
- "chat": "Chat",
208
- "readme": "Readme"
209
- }
210
- }
211
- },
212
- "hooks": {
213
- "useLLMProviders": {
214
- "failedToFetchProviders": "Failed to fetch providers:"
215
- }
216
- },
217
- "pages": {
218
- "Design": {},
219
- "Env": {
220
- "savedSuccessfully": "Saved successfully",
221
- "requiredApiKeys": "Required API Keys",
222
- "requiredApiKeysInfo": "To use this app, the following API keys are required. The keys are stored on your device's local storage."
223
- },
224
- "Page": {
225
- "notPartOfProject": "You are not part of this project."
226
- },
227
- "ResumeButton": {
228
- "resumeChat": "Resume Chat"
229
- }
230
- }
231
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.chainlit/translations/pt-BR.json DELETED
@@ -1,155 +0,0 @@
1
- {
2
- "components": {
3
- "atoms": {
4
- "buttons": {
5
- "userButton": {
6
- "menu": {
7
- "settings": "Configura\u00e7\u00f5es",
8
- "settingsKey": "S",
9
- "APIKeys": "Chaves de API",
10
- "logout": "Sair"
11
- }
12
- }
13
- }
14
- },
15
- "molecules": {
16
- "newChatButton": {
17
- "newChat": "Nova Conversa"
18
- },
19
- "tasklist": {
20
- "TaskList": {
21
- "title": "\ud83d\uddd2\ufe0f Lista de Tarefas",
22
- "loading": "Carregando...",
23
- "error": "Ocorreu um erro"
24
- }
25
- },
26
- "attachments": {
27
- "cancelUpload": "Cancelar envio",
28
- "removeAttachment": "Remover anexo"
29
- },
30
- "newChatDialog": {
31
- "createNewChat": "Criar novo chat?",
32
- "clearChat": "Isso limpar\u00e1 as mensagens atuais e iniciar\u00e1 uma nova conversa.",
33
- "cancel": "Cancelar",
34
- "confirm": "Confirmar"
35
- },
36
- "settingsModal": {
37
- "expandMessages": "Expandir Mensagens",
38
- "hideChainOfThought": "Esconder Sequ\u00eancia de Pensamento",
39
- "darkMode": "Modo Escuro"
40
- }
41
- },
42
- "organisms": {
43
- "chat": {
44
- "history": {
45
- "index": {
46
- "lastInputs": "\u00daltimas Entradas",
47
- "noInputs": "Vazio...",
48
- "loading": "Carregando..."
49
- }
50
- },
51
- "inputBox": {
52
- "input": {
53
- "placeholder": "Digite sua mensagem aqui..."
54
- },
55
- "speechButton": {
56
- "start": "Iniciar grava\u00e7\u00e3o",
57
- "stop": "Parar grava\u00e7\u00e3o"
58
- },
59
- "SubmitButton": {
60
- "sendMessage": "Enviar mensagem",
61
- "stopTask": "Parar Tarefa"
62
- },
63
- "UploadButton": {
64
- "attachFiles": "Anexar arquivos"
65
- },
66
- "waterMark": {
67
- "text": "Constru\u00eddo com"
68
- }
69
- },
70
- "Messages": {
71
- "index": {
72
- "running": "Executando",
73
- "executedSuccessfully": "executado com sucesso",
74
- "failed": "falhou",
75
- "feedbackUpdated": "Feedback atualizado",
76
- "updating": "Atualizando"
77
- }
78
- },
79
- "dropScreen": {
80
- "dropYourFilesHere": "Solte seus arquivos aqui"
81
- },
82
- "index": {
83
- "failedToUpload": "Falha ao enviar",
84
- "cancelledUploadOf": "Envio cancelado de",
85
- "couldNotReachServer": "N\u00e3o foi poss\u00edvel conectar ao servidor",
86
- "continuingChat": "Continuando o chat anterior"
87
- },
88
- "settings": {
89
- "settingsPanel": "Painel de Configura\u00e7\u00f5es",
90
- "reset": "Redefinir",
91
- "cancel": "Cancelar",
92
- "confirm": "Confirmar"
93
- }
94
- },
95
- "threadHistory": {
96
- "sidebar": {
97
- "filters": {
98
- "FeedbackSelect": {
99
- "feedbackAll": "Feedback: Todos",
100
- "feedbackPositive": "Feedback: Positivo",
101
- "feedbackNegative": "Feedback: Negativo"
102
- },
103
- "SearchBar": {
104
- "search": "Buscar"
105
- }
106
- },
107
- "DeleteThreadButton": {
108
- "confirmMessage": "Isso deletar\u00e1 a conversa, assim como suas mensagens e elementos.",
109
- "cancel": "Cancelar",
110
- "confirm": "Confirmar",
111
- "deletingChat": "Deletando conversa",
112
- "chatDeleted": "Conversa deletada"
113
- },
114
- "index": {
115
- "pastChats": "Conversas Anteriores"
116
- },
117
- "ThreadList": {
118
- "empty": "Vazio..."
119
- },
120
- "TriggerButton": {
121
- "closeSidebar": "Fechar barra lateral",
122
- "openSidebar": "Abrir barra lateral"
123
- }
124
- },
125
- "Thread": {
126
- "backToChat": "Voltar para a conversa",
127
- "chatCreatedOn": "Esta conversa foi criada em"
128
- }
129
- },
130
- "header": {
131
- "chat": "Conversa",
132
- "readme": "Leia-me"
133
- }
134
- },
135
- "hooks": {
136
- "useLLMProviders": {
137
- "failedToFetchProviders": "Falha ao buscar provedores:"
138
- }
139
- },
140
- "pages": {
141
- "Design": {},
142
- "Env": {
143
- "savedSuccessfully": "Salvo com sucesso",
144
- "requiredApiKeys": "Chaves de API necess\u00e1rias",
145
- "requiredApiKeysInfo": "Para usar este aplicativo, as seguintes chaves de API s\u00e3o necess\u00e1rias. As chaves s\u00e3o armazenadas localmente em seu dispositivo."
146
- },
147
- "Page": {
148
- "notPartOfProject": "Voc\u00ea n\u00e3o faz parte deste projeto."
149
- },
150
- "ResumeButton": {
151
- "resumeChat": "Continuar Conversa"
152
- }
153
- }
154
- }
155
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -160,4 +160,10 @@ cython_debug/
160
  #.idea/
161
 
162
  # log files
163
- *.log
 
 
 
 
 
 
 
160
  #.idea/
161
 
162
  # log files
163
+ *.log
164
+
165
+ .ragatouille/*
166
+ */__pycache__/*
167
+ */.chainlit/translations/*
168
+ storage/logs/*
169
+ vectorstores/*
{.chainlit β†’ code/.chainlit}/config.toml RENAMED
File without changes
code/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import *
chainlit.md β†’ code/chainlit.md RENAMED
File without changes
code/main.py CHANGED
@@ -1,9 +1,8 @@
1
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
  from langchain import PromptTemplate
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
6
- from langchain.llms import CTransformers
7
  import chainlit as cl
8
  from langchain_community.chat_models import ChatOpenAI
9
  from langchain_community.embeddings import OpenAIEmbeddings
@@ -11,13 +10,22 @@ import yaml
11
  import logging
12
  from dotenv import load_dotenv
13
 
14
- from modules.llm_tutor import LLMTutor
15
- from modules.constants import *
16
- from modules.helpers import get_sources
17
 
 
 
 
18
 
 
 
 
 
 
 
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.INFO)
 
21
 
22
  # Console Handler
23
  console_handler = logging.StreamHandler()
@@ -26,13 +34,6 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
29
- # File Handler
30
- log_file_path = "log_file.log" # Change this to your desired log file path
31
- file_handler = logging.FileHandler(log_file_path)
32
- file_handler.setLevel(logging.INFO)
33
- file_handler.setFormatter(formatter)
34
- logger.addHandler(file_handler)
35
-
36
 
37
  # Adding option to select the chat profile
38
  @cl.set_chat_profiles
@@ -66,12 +67,26 @@ def rename(orig_author: str):
66
  # chainlit code
67
  @cl.on_chat_start
68
  async def start():
69
- with open("code/config.yml", "r") as f:
70
  config = yaml.safe_load(f)
71
- print(config)
72
- logger.info("Config file loaded")
73
- logger.info(f"Config: {config}")
74
- logger.info("Creating llm_tutor instance")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  chat_profile = cl.user_session.get("chat_profile")
77
  if chat_profile is not None:
@@ -93,8 +108,7 @@ async def start():
93
  llm_tutor = LLMTutor(config, logger=logger)
94
 
95
  chain = llm_tutor.qa_bot()
96
- model = config["llm_params"]["local_llm_params"]["model"]
97
- msg = cl.Message(content=f"Starting the bot {model}...")
98
  await msg.send()
99
  msg.content = opening_message
100
  await msg.update()
@@ -104,24 +118,17 @@ async def start():
104
 
105
  @cl.on_message
106
  async def main(message):
 
107
  user = cl.user_session.get("user")
108
  chain = cl.user_session.get("chain")
109
- # cb = cl.AsyncLangchainCallbackHandler(
110
- # stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
111
- # )
112
- # cb.answer_reached = True
113
- # res=await chain.acall(message, callbacks=[cb])
114
- res = await chain.acall(message.content)
115
- print(f"response: {res}")
116
  try:
117
  answer = res["answer"]
118
  except:
119
  answer = res["result"]
120
- print(f"answer: {answer}")
121
-
122
- logger.info(f"Question: {res['question']}")
123
- logger.info(f"History: {res['chat_history']}")
124
- logger.info(f"Answer: {answer}\n")
125
 
126
  answer_with_sources, source_elements = get_sources(res, answer)
127
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
  from langchain import PromptTemplate
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
 
6
  import chainlit as cl
7
  from langchain_community.chat_models import ChatOpenAI
8
  from langchain_community.embeddings import OpenAIEmbeddings
 
10
  import logging
11
  from dotenv import load_dotenv
12
 
13
+ import os
14
+ import sys
 
15
 
16
+ # Add the 'code' directory to the Python path
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(current_dir)
19
 
20
+ from modules.chat.llm_tutor import LLMTutor
21
+ from modules.config.constants import *
22
+ from modules.chat.helpers import get_sources
23
+
24
+
25
+ global logger
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(logging.INFO)
28
+ logger.propagate = False
29
 
30
  # Console Handler
31
  console_handler = logging.StreamHandler()
 
34
  console_handler.setFormatter(formatter)
35
  logger.addHandler(console_handler)
36
 
 
 
 
 
 
 
 
37
 
38
  # Adding option to select the chat profile
39
  @cl.set_chat_profiles
 
67
  # chainlit code
68
  @cl.on_chat_start
69
  async def start():
70
+ with open("modules/config/config.yml", "r") as f:
71
  config = yaml.safe_load(f)
72
+
73
+ # Ensure log directory exists
74
+ log_directory = config["log_dir"]
75
+ if not os.path.exists(log_directory):
76
+ os.makedirs(log_directory)
77
+
78
+ # File Handler
79
+ log_file_path = (
80
+ f"{log_directory}/tutor.log" # Change this to your desired log file path
81
+ )
82
+ file_handler = logging.FileHandler(log_file_path, mode="w")
83
+ file_handler.setLevel(logging.INFO)
84
+ file_handler.setFormatter(formatter)
85
+ logger.addHandler(file_handler)
86
+
87
+ logger.info("Config file loaded")
88
+ logger.info(f"Config: {config}")
89
+ logger.info("Creating llm_tutor instance")
90
 
91
  chat_profile = cl.user_session.get("chat_profile")
92
  if chat_profile is not None:
 
108
  llm_tutor = LLMTutor(config, logger=logger)
109
 
110
  chain = llm_tutor.qa_bot()
111
+ msg = cl.Message(content=f"Starting the bot {chat_profile}...")
 
112
  await msg.send()
113
  msg.content = opening_message
114
  await msg.update()
 
118
 
119
  @cl.on_message
120
  async def main(message):
121
+ global logger
122
  user = cl.user_session.get("user")
123
  chain = cl.user_session.get("chain")
124
+ cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
125
+ cb.answer_reached = True
126
+ res = await chain.acall(message.content, callbacks=[cb])
127
+ # res = await chain.acall(message.content)
 
 
 
128
  try:
129
  answer = res["answer"]
130
  except:
131
  answer = res["result"]
 
 
 
 
 
132
 
133
  answer_with_sources, source_elements = get_sources(res, answer)
134
 
code/modules/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import vectorstore
2
+ from . import dataloader
code/modules/chat/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .llm_tutor import LLMTutor
2
+ from .chat_model_loader import ChatModelLoader
code/modules/{chat_model_loader.py β†’ chat/chat_model_loader.py} RENAMED
@@ -1,8 +1,7 @@
1
  from langchain_community.chat_models import ChatOpenAI
2
- from langchain.llms import CTransformers
3
- from langchain.llms.huggingface_pipeline import HuggingFacePipeline
4
  from transformers import AutoTokenizer, TextStreamer
5
- from langchain.llms import LlamaCpp
6
  import torch
7
  import transformers
8
  import os
 
1
  from langchain_community.chat_models import ChatOpenAI
2
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
 
3
  from transformers import AutoTokenizer, TextStreamer
4
+ from langchain_community.llms import LlamaCpp
5
  import torch
6
  import transformers
7
  import os
code/modules/{helpers.py β†’ chat/helpers.py} RENAMED
@@ -1,176 +1,6 @@
1
- import requests
2
- from bs4 import BeautifulSoup
3
- from tqdm import tqdm
4
  import chainlit as cl
5
- from langchain import PromptTemplate
6
- import requests
7
- from bs4 import BeautifulSoup
8
- from urllib.parse import urlparse, urljoin, urldefrag
9
- import asyncio
10
- import aiohttp
11
- from aiohttp import ClientSession
12
- from typing import Dict, Any, List
13
-
14
- try:
15
- from modules.constants import *
16
- except:
17
- from constants import *
18
-
19
- """
20
- Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
21
- """
22
-
23
-
24
- class WebpageCrawler:
25
- def __init__(self):
26
- self.dict_href_links = {}
27
-
28
- async def fetch(self, session: ClientSession, url: str) -> str:
29
- async with session.get(url) as response:
30
- try:
31
- return await response.text()
32
- except UnicodeDecodeError:
33
- return await response.text(encoding="latin1")
34
-
35
- def url_exists(self, url: str) -> bool:
36
- try:
37
- response = requests.head(url)
38
- return response.status_code == 200
39
- except requests.ConnectionError:
40
- return False
41
-
42
- async def get_links(self, session: ClientSession, website_link: str, base_url: str):
43
- html_data = await self.fetch(session, website_link)
44
- soup = BeautifulSoup(html_data, "html.parser")
45
- list_links = []
46
- for link in soup.find_all("a", href=True):
47
- href = link["href"].strip()
48
- full_url = urljoin(base_url, href)
49
- normalized_url = self.normalize_url(full_url) # sections removed
50
- if (
51
- normalized_url not in self.dict_href_links
52
- and self.is_child_url(normalized_url, base_url)
53
- and self.url_exists(normalized_url)
54
- ):
55
- self.dict_href_links[normalized_url] = None
56
- list_links.append(normalized_url)
57
-
58
- return list_links
59
-
60
- async def get_subpage_links(
61
- self, session: ClientSession, urls: list, base_url: str
62
- ):
63
- tasks = [self.get_links(session, url, base_url) for url in urls]
64
- results = await asyncio.gather(*tasks)
65
- all_links = [link for sublist in results for link in sublist]
66
- return all_links
67
-
68
- async def get_all_pages(self, url: str, base_url: str):
69
- async with aiohttp.ClientSession() as session:
70
- dict_links = {url: "Not-checked"}
71
- counter = None
72
- while counter != 0:
73
- unchecked_links = [
74
- link
75
- for link, status in dict_links.items()
76
- if status == "Not-checked"
77
- ]
78
- if not unchecked_links:
79
- break
80
- new_links = await self.get_subpage_links(
81
- session, unchecked_links, base_url
82
- )
83
- for link in unchecked_links:
84
- dict_links[link] = "Checked"
85
- print(f"Checked: {link}")
86
- dict_links.update(
87
- {
88
- link: "Not-checked"
89
- for link in new_links
90
- if link not in dict_links
91
- }
92
- )
93
- counter = len(
94
- [
95
- status
96
- for status in dict_links.values()
97
- if status == "Not-checked"
98
- ]
99
- )
100
-
101
- checked_urls = [
102
- url for url, status in dict_links.items() if status == "Checked"
103
- ]
104
- return checked_urls
105
-
106
- def is_webpage(self, url: str) -> bool:
107
- try:
108
- response = requests.head(url, allow_redirects=True)
109
- content_type = response.headers.get("Content-Type", "").lower()
110
- return "text/html" in content_type
111
- except requests.RequestException:
112
- return False
113
-
114
- def clean_url_list(self, urls):
115
- files, webpages = [], []
116
-
117
- for url in urls:
118
- if self.is_webpage(url):
119
- webpages.append(url)
120
- else:
121
- files.append(url)
122
-
123
- return files, webpages
124
-
125
- def is_child_url(self, url, base_url):
126
- return url.startswith(base_url)
127
-
128
- def normalize_url(self, url: str):
129
- # Strip the fragment identifier
130
- defragged_url, _ = urldefrag(url)
131
- return defragged_url
132
-
133
-
134
- def get_urls_from_file(file_path: str):
135
- """
136
- Function to get urls from a file
137
- """
138
- with open(file_path, "r") as f:
139
- urls = f.readlines()
140
- urls = [url.strip() for url in urls]
141
- return urls
142
-
143
-
144
- def get_base_url(url):
145
- parsed_url = urlparse(url)
146
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
147
- return base_url
148
-
149
-
150
- def get_prompt(config):
151
- if config["llm_params"]["use_history"]:
152
- if config["llm_params"]["llm_loader"] == "local_llm":
153
- custom_prompt_template = tinyllama_prompt_template_with_history
154
- elif config["llm_params"]["llm_loader"] == "openai":
155
- custom_prompt_template = openai_prompt_template_with_history
156
- # else:
157
- # custom_prompt_template = tinyllama_prompt_template_with_history # default
158
- prompt = PromptTemplate(
159
- template=custom_prompt_template,
160
- input_variables=["context", "chat_history", "question"],
161
- )
162
- else:
163
- if config["llm_params"]["llm_loader"] == "local_llm":
164
- custom_prompt_template = tinyllama_prompt_template
165
- elif config["llm_params"]["llm_loader"] == "openai":
166
- custom_prompt_template = openai_prompt_template
167
- # else:
168
- # custom_prompt_template = tinyllama_prompt_template
169
- prompt = PromptTemplate(
170
- template=custom_prompt_template,
171
- input_variables=["context", "question"],
172
- )
173
- return prompt
174
 
175
 
176
  def get_sources(res, answer):
@@ -248,90 +78,27 @@ def get_sources(res, answer):
248
  return full_answer, source_elements
249
 
250
 
251
- def get_metadata(lectures_url, schedule_url):
252
- """
253
- Function to get the lecture metadata from the lectures and schedule URLs.
254
- """
255
- lecture_metadata = {}
256
-
257
- # Get the main lectures page content
258
- r_lectures = requests.get(lectures_url)
259
- soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
260
-
261
- # Get the main schedule page content
262
- r_schedule = requests.get(schedule_url)
263
- soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
264
-
265
- # Find all lecture blocks
266
- lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
267
-
268
- # Create a mapping from slides link to date
269
- date_mapping = {}
270
- schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
271
- for row in schedule_rows:
272
- try:
273
- date = (
274
- row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
275
- )
276
- description_div = row.find("div", {"data-label": "Description"})
277
- slides_link_tag = description_div.find("a", title="Download slides")
278
- slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
279
- slides_link = (
280
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
281
- )
282
- if slides_link:
283
- date_mapping[slides_link] = date
284
- except Exception as e:
285
- print(f"Error processing schedule row: {e}")
286
- continue
287
-
288
- for block in lecture_blocks:
289
- try:
290
- # Extract the lecture title
291
- title = block.find("span", style="font-weight: bold;").text.strip()
292
-
293
- # Extract the TL;DR
294
- tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
295
-
296
- # Extract the link to the slides
297
- slides_link_tag = block.find("a", title="Download slides")
298
- slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
299
- slides_link = (
300
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
301
- )
302
-
303
- # Extract the link to the lecture recording
304
- recording_link_tag = block.find("a", title="Download lecture recording")
305
- recording_link = (
306
- recording_link_tag["href"].strip() if recording_link_tag else None
307
- )
308
-
309
- # Extract suggested readings or summary if available
310
- suggested_readings_tag = block.find("p", text="Suggested Readings:")
311
- if suggested_readings_tag:
312
- suggested_readings = suggested_readings_tag.find_next_sibling("ul")
313
- if suggested_readings:
314
- suggested_readings = suggested_readings.get_text(
315
- separator="\n"
316
- ).strip()
317
- else:
318
- suggested_readings = "No specific readings provided."
319
- else:
320
- suggested_readings = "No specific readings provided."
321
-
322
- # Get the date from the schedule
323
- date = date_mapping.get(slides_link, "No date available")
324
-
325
- # Add to the dictionary
326
- lecture_metadata[slides_link] = {
327
- "date": date,
328
- "tldr": tldr,
329
- "title": title,
330
- "lecture_recording": recording_link,
331
- "suggested_readings": suggested_readings,
332
- }
333
- except Exception as e:
334
- print(f"Error processing block: {e}")
335
- continue
336
-
337
- return lecture_metadata
 
1
+ from modules.config.constants import *
 
 
2
  import chainlit as cl
3
+ from langchain_core.prompts import PromptTemplate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def get_sources(res, answer):
 
78
  return full_answer, source_elements
79
 
80
 
81
+ def get_prompt(config):
82
+ if config["llm_params"]["use_history"]:
83
+ if config["llm_params"]["llm_loader"] == "local_llm":
84
+ custom_prompt_template = tinyllama_prompt_template_with_history
85
+ elif config["llm_params"]["llm_loader"] == "openai":
86
+ custom_prompt_template = openai_prompt_template_with_history
87
+ # else:
88
+ # custom_prompt_template = tinyllama_prompt_template_with_history # default
89
+ prompt = PromptTemplate(
90
+ template=custom_prompt_template,
91
+ input_variables=["context", "chat_history", "question"],
92
+ )
93
+ else:
94
+ if config["llm_params"]["llm_loader"] == "local_llm":
95
+ custom_prompt_template = tinyllama_prompt_template
96
+ elif config["llm_params"]["llm_loader"] == "openai":
97
+ custom_prompt_template = openai_prompt_template
98
+ # else:
99
+ # custom_prompt_template = tinyllama_prompt_template
100
+ prompt = PromptTemplate(
101
+ template=custom_prompt_template,
102
+ input_variables=["context", "question"],
103
+ )
104
+ return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/{llm_tutor.py β†’ chat/llm_tutor.py} RENAMED
@@ -1,24 +1,52 @@
1
- from langchain import PromptTemplate
2
- from langchain.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.chat_models import ChatOpenAI
4
- from langchain_community.embeddings import OpenAIEmbeddings
5
- from langchain.vectorstores import FAISS
6
  from langchain.chains import RetrievalQA, ConversationalRetrievalChain
7
- from langchain.llms import CTransformers
8
- from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory
 
 
9
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
  import os
11
- from modules.constants import *
12
- from modules.helpers import get_prompt
13
- from modules.chat_model_loader import ChatModelLoader
14
- from modules.vector_db import VectorDB, VectorDBScore
15
- from typing import Dict, Any, Optional
 
 
 
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
  import inspect
18
  from langchain.chains.conversational_retrieval.base import _get_chat_history
 
 
 
 
 
 
 
19
 
20
 
21
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  async def _acall(
23
  self,
24
  inputs: Dict[str, Any],
@@ -26,13 +54,31 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
26
  ) -> Dict[str, Any]:
27
  _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
28
  question = inputs["question"]
29
- get_chat_history = self.get_chat_history or _get_chat_history
30
  chat_history_str = get_chat_history(inputs["chat_history"])
31
- print(f"chat_history_str: {chat_history_str}")
32
  if chat_history_str:
33
- callbacks = _run_manager.get_child()
34
- new_question = await self.question_generator.arun(
35
- question=question, chat_history=chat_history_str, callbacks=callbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
  else:
38
  new_question = question
@@ -56,27 +102,24 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
56
  # Prepare the final prompt with metadata
57
  context = "\n\n".join(
58
  [
59
- f"Document content: {doc.page_content}\nMetadata: {doc.metadata}"
60
- for doc in docs
61
  ]
62
  )
63
- final_prompt = f"""
64
- You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Use the following pieces of information to answer the user's question.
65
- If you don't know the answer, just say that you don't knowβ€”don't try to make up an answer.
66
- Use the chat history to answer the question only if it's relevant; otherwise, ignore it. The context for the answer will be under "Document context:".
67
- Use the metadata from each document to guide the user to the correct sources.
68
- The context is ordered by relevance to the question. Give more weight to the most relevant documents.
69
- Talk in a friendly and personalized manner, similar to how you would speak to a friend who needs help. Make the conversation engaging and avoid sounding repetitive or robotic.
70
-
71
- Chat History:
72
- {chat_history_str}
73
-
74
- Context:
75
- {context}
76
-
77
- Question: {new_question}
78
- AI Tutor:
79
- """
80
 
81
  new_inputs["input"] = final_prompt
82
  new_inputs["question"] = final_prompt
@@ -98,8 +141,9 @@ class LLMTutor:
98
  def __init__(self, config, logger=None):
99
  self.config = config
100
  self.llm = self.load_llm()
101
- self.vector_db = VectorDB(config, logger=logger)
102
- if self.config["embedding_options"]["embedd_files"]:
 
103
  self.vector_db.create_database()
104
  self.vector_db.save_database()
105
 
@@ -114,24 +158,20 @@ class LLMTutor:
114
 
115
  # Retrieval QA Chain
116
  def retrieval_qa_chain(self, llm, prompt, db):
117
- if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
118
- retriever = VectorDBScore(
119
- vectorstore=db,
120
- # search_type="similarity_score_threshold",
121
- # search_kwargs={
122
- # "score_threshold": self.config["embedding_options"][
123
- # "score_threshold"
124
- # ],
125
- # "k": self.config["embedding_options"]["search_top_k"],
126
- # },
127
- )
128
- elif self.config["embedding_options"]["db_option"] == "RAGatouille":
129
  retriever = db.as_langchain_retriever(
130
- k=self.config["embedding_options"]["search_top_k"]
131
  )
 
132
  if self.config["llm_params"]["use_history"]:
133
- memory = ConversationSummaryBufferMemory(
134
- llm = llm,
135
  k=self.config["llm_params"]["memory_window"],
136
  memory_key="chat_history",
137
  return_messages=True,
@@ -145,6 +185,7 @@ class LLMTutor:
145
  return_source_documents=True,
146
  memory=memory,
147
  combine_docs_chain_kwargs={"prompt": prompt},
 
148
  )
149
  else:
150
  qa_chain = RetrievalQA.from_chain_type(
@@ -166,7 +207,9 @@ class LLMTutor:
166
  def qa_bot(self):
167
  db = self.vector_db.load_database()
168
  qa_prompt = self.set_custom_prompt()
169
- qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
 
 
170
 
171
  return qa
172
 
 
 
 
 
 
 
1
  from langchain.chains import RetrievalQA, ConversationalRetrievalChain
2
+ from langchain.memory import (
3
+ ConversationBufferWindowMemory,
4
+ ConversationSummaryBufferMemory,
5
+ )
6
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
7
  import os
8
+ from modules.config.constants import *
9
+ from modules.chat.helpers import get_prompt
10
+ from modules.chat.chat_model_loader import ChatModelLoader
11
+ from modules.vectorstore.store_manager import VectorStoreManager
12
+
13
+ from modules.retriever import FaissRetriever, ChromaRetriever
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
  import inspect
18
  from langchain.chains.conversational_retrieval.base import _get_chat_history
19
+ from langchain_core.messages import BaseMessage
20
+
21
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
22
+
23
+ from langchain_core.output_parsers import StrOutputParser
24
+ from langchain_core.prompts import ChatPromptTemplate
25
+ from langchain_community.chat_models import ChatOpenAI
26
 
27
 
28
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
+
30
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
+ buffer = ""
33
+ for dialogue_turn in chat_history:
34
+ if isinstance(dialogue_turn, BaseMessage):
35
+ role_prefix = _ROLE_MAP.get(
36
+ dialogue_turn.type, f"{dialogue_turn.type}: "
37
+ )
38
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
39
+ elif isinstance(dialogue_turn, tuple):
40
+ human = "Student: " + dialogue_turn[0]
41
+ ai = "AI Tutor: " + dialogue_turn[1]
42
+ buffer += "\n" + "\n".join([human, ai])
43
+ else:
44
+ raise ValueError(
45
+ f"Unsupported chat history format: {type(dialogue_turn)}."
46
+ f" Full chat history: {chat_history} "
47
+ )
48
+ return buffer
49
+
50
  async def _acall(
51
  self,
52
  inputs: Dict[str, Any],
 
54
  ) -> Dict[str, Any]:
55
  _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
56
  question = inputs["question"]
57
+ get_chat_history = self._get_chat_history
58
  chat_history_str = get_chat_history(inputs["chat_history"])
 
59
  if chat_history_str:
60
+ # callbacks = _run_manager.get_child()
61
+ # new_question = await self.question_generator.arun(
62
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
63
+ # )
64
+ system = (
65
+ "You are an AI Tutor helping a student. Your task is to rephrase the student's question to provide more context from their chat history (only if relevant), ensuring the rephrased question still reflects the student's point of view. "
66
+ "The rephrased question should incorporate relevant details from the chat history to make it clearer and more specific. It should also expand upon the original question to provide more context on only what the student provided."
67
+ "Always end the rephrased question with the original question in parentheses for reference. "
68
+ "Do not change the meaning of the question, and keep the tone and perspective as if it were asked by the student. "
69
+ "Here is the chat history for context: \n{chat_history_str}\n"
70
+ "Now, rephrase the following question: '{question}'"
71
+ )
72
+ prompt = ChatPromptTemplate.from_messages(
73
+ [
74
+ ("system", system),
75
+ ("human", "{question}, {chat_history_str}"),
76
+ ]
77
+ )
78
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
79
+ step_back = prompt | llm | StrOutputParser()
80
+ new_question = step_back.invoke(
81
+ {"question": question, "chat_history_str": chat_history_str}
82
  )
83
  else:
84
  new_question = question
 
102
  # Prepare the final prompt with metadata
103
  context = "\n\n".join(
104
  [
105
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source']}))"
106
+ for idx, doc in enumerate(docs)
107
  ]
108
  )
109
+ final_prompt = (
110
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. "
111
+ "Use the following pieces of information to answer the user's question. "
112
+ "If you don't know the answer, try your best, but don't try to make up an answer. Keep the flow of the conversation going. "
113
+ "Use the chat history just as a gist to answer the question only if it's relevant; otherwise, ignore it. Do not repeat responses in the history. Use the context as a guide to construct your answer. The context for the answer will be under 'Document context:'. Remember, the conext may include text not directly related to the question."
114
+ "Make sure to use the source_file field in metadata from each document to provide links to the user to the correct sources. "
115
+ "The context is ordered by relevance to the question. "
116
+ "Talk in a friendly and personalized manner, similar to how you would speak to a friend who needs help. Make the conversation engaging and avoid sounding repetitive or robotic.\n\n"
117
+ f"Chat History:\n{chat_history_str}\n\n"
118
+ f"Context:\n{context}\n\n"
119
+ f"Student: {new_question}\n"
120
+ "Anwer the student's question in a friendly, concise, and engaging manner.\n"
121
+ "AI Tutor:"
122
+ )
 
 
 
123
 
124
  new_inputs["input"] = final_prompt
125
  new_inputs["question"] = final_prompt
 
141
  def __init__(self, config, logger=None):
142
  self.config = config
143
  self.llm = self.load_llm()
144
+ self.logger = logger
145
+ self.vector_db = VectorStoreManager(config, logger=self.logger)
146
+ if self.config["vectorstore"]["embedd_files"]:
147
  self.vector_db.create_database()
148
  self.vector_db.save_database()
149
 
 
158
 
159
  # Retrieval QA Chain
160
  def retrieval_qa_chain(self, llm, prompt, db):
161
+
162
+ if self.config["vectorstore"]["db_option"] == "FAISS":
163
+ retriever = FaissRetriever().return_retriever(db, self.config)
164
+
165
+ elif self.config["vectorstore"]["db_option"] == "Chroma":
166
+ retriever = ChromaRetriever().return_retriever(db, self.config)
167
+
168
+ elif self.config["vectorstore"]["db_option"] == "RAGatouille":
 
 
 
 
169
  retriever = db.as_langchain_retriever(
170
+ k=self.config["vectorstore"]["search_top_k"]
171
  )
172
+
173
  if self.config["llm_params"]["use_history"]:
174
+ memory = ConversationBufferWindowMemory(
 
175
  k=self.config["llm_params"]["memory_window"],
176
  memory_key="chat_history",
177
  return_messages=True,
 
185
  return_source_documents=True,
186
  memory=memory,
187
  combine_docs_chain_kwargs={"prompt": prompt},
188
+ response_if_no_docs_found="No context found",
189
  )
190
  else:
191
  qa_chain = RetrievalQA.from_chain_type(
 
207
  def qa_bot(self):
208
  db = self.vector_db.load_database()
209
  qa_prompt = self.set_custom_prompt()
210
+ qa = self.retrieval_qa_chain(
211
+ self.llm, qa_prompt, db
212
+ ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
213
 
214
  return qa
215
 
code/modules/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .constants import *
code/{config.yml β†’ modules/config/config.yml} RENAMED
@@ -1,23 +1,32 @@
1
- embedding_options:
 
 
 
 
2
  embedd_files: False # bool
3
- data_path: 'storage/data' # str
4
- url_file_path: 'storage/data/urls.txt' # str
5
- expand_urls: True # bool
6
- db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
7
- db_path : 'vectorstores' # str
8
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
9
  search_top_k : 3 # int
10
  score_threshold : 0.2 # float
 
 
 
 
 
 
 
 
11
  llm_params:
12
  use_history: True # bool
13
  memory_window: 3 # int
14
  llm_loader: 'openai' # str [local_llm, openai]
15
  openai_params:
16
  model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
17
- local_llm_params:
18
- model: "storage/models/llama-2-7b-chat.Q4_0.gguf"
19
- model_type: "llama"
20
- temperature: 0.2
21
  splitter_options:
22
  use_splitter: True # bool
23
  split_by_token : True # bool
 
1
+ log_dir: '../storage/logs' # str
2
+ log_chunk_dir: '../storage/logs/chunks' # str
3
+ device: 'cpu' # str [cuda, cpu]
4
+
5
+ vectorstore:
6
  embedd_files: False # bool
7
+ data_path: '../storage/data' # str
8
+ url_file_path: '../storage/data/urls.txt' # str
9
+ expand_urls: False # bool
10
+ db_option : 'Chroma' # str [FAISS, Chroma, RAGatouille]
11
+ db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
14
  score_threshold : 0.2 # float
15
+
16
+ faiss_params: # Not used as of now
17
+ index_path: '../vectorstores/faiss.index' # str
18
+ index_type: 'Flat' # str [Flat, HNSW, IVF]
19
+ index_dimension: 384 # int
20
+ index_nlist: 100 # int
21
+ index_nprobe: 10 # int
22
+
23
  llm_params:
24
  use_history: True # bool
25
  memory_window: 3 # int
26
  llm_loader: 'openai' # str [local_llm, openai]
27
  openai_params:
28
  model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
29
+
 
 
 
30
  splitter_options:
31
  use_splitter: True # bool
32
  split_by_token : True # bool
code/modules/{constants.py β†’ config/constants.py} RENAMED
File without changes
code/modules/dataloader/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .webpage_crawler import WebpageCrawler
2
+ from .data_loader import DataLoader
code/modules/{data_loader.py β†’ dataloader/data_loader.py} RENAMED
@@ -16,15 +16,12 @@ import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from ragatouille import RAGPretrainedModel
18
  from langchain.chains import LLMChain
19
- from langchain.llms import OpenAI
20
  from langchain import PromptTemplate
 
 
21
 
22
- try:
23
- from modules.helpers import get_metadata
24
- except:
25
- from helpers import get_metadata
26
-
27
- logger = logging.getLogger(__name__)
28
 
29
 
30
  class PDFReader:
@@ -40,8 +37,9 @@ class PDFReader:
40
 
41
 
42
  class FileReader:
43
- def __init__(self):
44
  self.pdf_reader = PDFReader()
 
45
 
46
  def extract_text_from_pdf(self, pdf_path):
47
  text = ""
@@ -61,7 +59,7 @@ class FileReader:
61
  temp_file_path = temp_file.name
62
  return temp_file_path
63
  else:
64
- print("Failed to download PDF from URL:", pdf_url)
65
  return None
66
 
67
  def read_pdf(self, temp_file_path: str):
@@ -99,13 +97,18 @@ class FileReader:
99
  if response.status_code == 200:
100
  return [Document(page_content=response.text)]
101
  else:
102
- print("Failed to fetch .tex file from URL:", tex_url)
103
  return None
104
 
105
 
106
  class ChunkProcessor:
107
- def __init__(self, config):
108
  self.config = config
 
 
 
 
 
109
 
110
  if config["splitter_options"]["use_splitter"]:
111
  if config["splitter_options"]["split_by_token"]:
@@ -124,7 +127,7 @@ class ChunkProcessor:
124
  )
125
  else:
126
  self.splitter = None
127
- logger.info("ChunkProcessor instance created")
128
 
129
  def remove_delimiters(self, document_chunks: list):
130
  for chunk in document_chunks:
@@ -139,7 +142,6 @@ class ChunkProcessor:
139
  del document_chunks[0]
140
  for _ in range(end):
141
  document_chunks.pop()
142
- logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
143
  return document_chunks
144
 
145
  def process_chunks(
@@ -172,122 +174,184 @@ class ChunkProcessor:
172
 
173
  return document_chunks
174
 
175
- def get_chunks(self, file_reader, uploaded_files, weblinks):
176
- self.document_chunks_full = []
177
- self.parent_document_names = []
178
- self.child_document_names = []
179
- self.documents = []
180
- self.document_metadata = []
181
-
182
  addl_metadata = get_metadata(
183
  "https://dl4ds.github.io/sp2024/lectures/",
184
  "https://dl4ds.github.io/sp2024/schedule/",
185
  ) # For any additional metadata
186
 
187
- for file_index, file_path in enumerate(uploaded_files):
188
- file_name = os.path.basename(file_path)
189
- if file_name not in self.parent_document_names:
190
- file_type = file_name.split(".")[-1].lower()
191
-
192
- # try:
193
- if file_type == "pdf":
194
- documents = file_reader.read_pdf(file_path)
195
- elif file_type == "txt":
196
- documents = file_reader.read_txt(file_path)
197
- elif file_type == "docx":
198
- documents = file_reader.read_docx(file_path)
199
- elif file_type == "srt":
200
- documents = file_reader.read_srt(file_path)
201
- elif file_type == "tex":
202
- documents = file_reader.read_tex_from_url(file_path)
203
- else:
204
- logger.warning(f"Unsupported file type: {file_type}")
205
- continue
206
-
207
- for doc in documents:
208
- page_num = doc.metadata.get("page", 0)
209
- self.documents.append(doc.page_content)
210
- self.document_metadata.append(
211
- {"source": file_path, "page": page_num}
212
- )
213
- metadata = addl_metadata.get(file_path, {})
214
- self.document_metadata[-1].update(metadata)
215
-
216
- self.child_document_names.append(f"{file_name}_{page_num}")
217
-
218
- self.parent_document_names.append(file_name)
219
- if self.config["embedding_options"]["db_option"] not in [
220
- "RAGatouille"
221
- ]:
222
- document_chunks = self.process_chunks(
223
- self.documents[-1],
224
- file_type,
225
- source=file_path,
226
- page=page_num,
227
- metadata=metadata,
228
- )
229
- self.document_chunks_full.extend(document_chunks)
230
-
231
- # except Exception as e:
232
- # logger.error(f"Error processing file {file_name}: {str(e)}")
233
-
234
- self.process_weblinks(file_reader, weblinks)
235
-
236
- logger.info(
237
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
238
  )
239
- return (
240
- self.document_chunks_full,
241
- self.child_document_names,
242
- self.documents,
243
- self.document_metadata,
244
- )
245
 
246
- def process_weblinks(self, file_reader, weblinks):
247
- if weblinks[0] != "":
248
- logger.info(f"Splitting weblinks: total of {len(weblinks)}")
249
-
250
- for link_index, link in enumerate(weblinks):
251
- if link not in self.parent_document_names:
252
- try:
253
- logger.info(f"\tSplitting link {link_index+1} : {link}")
254
- if "youtube" in link:
255
- documents = file_reader.read_youtube_transcript(link)
256
- else:
257
- documents = file_reader.read_html(link)
258
-
259
- for doc in documents:
260
- page_num = doc.metadata.get("page", 0)
261
- self.documents.append(doc.page_content)
262
- self.document_metadata.append(
263
- {"source": link, "page": page_num}
264
- )
265
- self.child_document_names.append(f"{link}")
266
-
267
- self.parent_document_names.append(link)
268
- if self.config["embedding_options"]["db_option"] not in [
269
- "RAGatouille"
270
- ]:
271
- document_chunks = self.process_chunks(
272
- self.documents[-1],
273
- "txt",
274
- source=link,
275
- page=0,
276
- metadata={"source_type": "webpage"},
277
- )
278
- self.document_chunks_full.extend(document_chunks)
279
- except Exception as e:
280
- logger.error(
281
- f"Error splitting link {link_index+1} : {link}: {str(e)}"
282
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
 
285
  class DataLoader:
286
- def __init__(self, config):
287
- self.file_reader = FileReader()
288
- self.chunk_processor = ChunkProcessor(config)
289
 
290
  def get_chunks(self, uploaded_files, weblinks):
291
- return self.chunk_processor.get_chunks(
292
  self.file_reader, uploaded_files, weblinks
293
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from ragatouille import RAGPretrainedModel
18
  from langchain.chains import LLMChain
19
+ from langchain_community.llms import OpenAI
20
  from langchain import PromptTemplate
21
+ import json
22
+ from concurrent.futures import ThreadPoolExecutor
23
 
24
+ from modules.dataloader.helpers import get_metadata
 
 
 
 
 
25
 
26
 
27
  class PDFReader:
 
37
 
38
 
39
  class FileReader:
40
+ def __init__(self, logger):
41
  self.pdf_reader = PDFReader()
42
+ self.logger = logger
43
 
44
  def extract_text_from_pdf(self, pdf_path):
45
  text = ""
 
59
  temp_file_path = temp_file.name
60
  return temp_file_path
61
  else:
62
+ self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
63
  return None
64
 
65
  def read_pdf(self, temp_file_path: str):
 
97
  if response.status_code == 200:
98
  return [Document(page_content=response.text)]
99
  else:
100
+ self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
101
  return None
102
 
103
 
104
  class ChunkProcessor:
105
+ def __init__(self, config, logger):
106
  self.config = config
107
+ self.logger = logger
108
+
109
+ self.document_data = {}
110
+ self.document_metadata = {}
111
+ self.document_chunks_full = []
112
 
113
  if config["splitter_options"]["use_splitter"]:
114
  if config["splitter_options"]["split_by_token"]:
 
127
  )
128
  else:
129
  self.splitter = None
130
+ self.logger.info("ChunkProcessor instance created")
131
 
132
  def remove_delimiters(self, document_chunks: list):
133
  for chunk in document_chunks:
 
142
  del document_chunks[0]
143
  for _ in range(end):
144
  document_chunks.pop()
 
145
  return document_chunks
146
 
147
  def process_chunks(
 
174
 
175
  return document_chunks
176
 
177
+ def chunk_docs(self, file_reader, uploaded_files, weblinks):
 
 
 
 
 
 
178
  addl_metadata = get_metadata(
179
  "https://dl4ds.github.io/sp2024/lectures/",
180
  "https://dl4ds.github.io/sp2024/schedule/",
181
  ) # For any additional metadata
182
 
183
+ with ThreadPoolExecutor() as executor:
184
+ executor.map(
185
+ self.process_file,
186
+ uploaded_files,
187
+ range(len(uploaded_files)),
188
+ [file_reader] * len(uploaded_files),
189
+ [addl_metadata] * len(uploaded_files),
190
+ )
191
+ executor.map(
192
+ self.process_weblink,
193
+ weblinks,
194
+ range(len(weblinks)),
195
+ [file_reader] * len(weblinks),
196
+ [addl_metadata] * len(weblinks),
197
+ )
198
+
199
+ document_names = [
200
+ f"{file_name}_{page_num}"
201
+ for file_name, pages in self.document_data.items()
202
+ for page_num in pages.keys()
203
+ ]
204
+ documents = [
205
+ page for doc in self.document_data.values() for page in doc.values()
206
+ ]
207
+ document_metadata = [
208
+ page for doc in self.document_metadata.values() for page in doc.values()
209
+ ]
210
+
211
+ self.save_document_data()
212
+
213
+ self.logger.info(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
215
  )
 
 
 
 
 
 
216
 
217
+ return self.document_chunks_full, document_names, documents, document_metadata
218
+
219
+ def process_documents(
220
+ self, documents, file_path, file_type, metadata_source, addl_metadata
221
+ ):
222
+ file_data = {}
223
+ file_metadata = {}
224
+
225
+ for doc in documents:
226
+ if len(doc.page_content) <= 400:
227
+ continue
228
+
229
+ page_num = doc.metadata.get("page", 0)
230
+ file_data[page_num] = doc.page_content
231
+ metadata = (
232
+ addl_metadata.get(file_path, {})
233
+ if metadata_source == "file"
234
+ else {"source": file_path, "page": page_num}
235
+ )
236
+ file_metadata[page_num] = metadata
237
+
238
+ if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
239
+ document_chunks = self.process_chunks(
240
+ doc.page_content,
241
+ file_type,
242
+ source=file_path,
243
+ page=page_num,
244
+ metadata=metadata,
245
+ )
246
+ self.document_chunks_full.extend(document_chunks)
247
+
248
+ self.document_data[file_path] = file_data
249
+ self.document_metadata[file_path] = file_metadata
250
+
251
+ def process_file(self, file_path, file_index, file_reader, addl_metadata):
252
+ file_name = os.path.basename(file_path)
253
+ if file_name in self.document_data:
254
+ return
255
+
256
+ file_type = file_name.split(".")[-1].lower()
257
+ self.logger.info(f"Reading file {file_index + 1}: {file_path}")
258
+
259
+ read_methods = {
260
+ "pdf": file_reader.read_pdf,
261
+ "txt": file_reader.read_txt,
262
+ "docx": file_reader.read_docx,
263
+ "srt": file_reader.read_srt,
264
+ "tex": file_reader.read_tex_from_url,
265
+ }
266
+ if file_type not in read_methods:
267
+ self.logger.warning(f"Unsupported file type: {file_type}")
268
+ return
269
+
270
+ try:
271
+ documents = read_methods[file_type](file_path)
272
+ self.process_documents(
273
+ documents, file_path, file_type, "file", addl_metadata
274
+ )
275
+ except Exception as e:
276
+ self.logger.error(f"Error processing file {file_name}: {str(e)}")
277
+
278
+ def process_weblink(self, link, link_index, file_reader, addl_metadata):
279
+ if link in self.document_data:
280
+ return
281
+
282
+ self.logger.info(f"Reading link {link_index + 1} : {link}")
283
+
284
+ try:
285
+ if "youtube" in link:
286
+ documents = file_reader.read_youtube_transcript(link)
287
+ else:
288
+ documents = file_reader.read_html(link)
289
+
290
+ self.process_documents(documents, link, "txt", "link", addl_metadata)
291
+ except Exception as e:
292
+ self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
293
+
294
+ def save_document_data(self):
295
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
296
+ os.makedirs(f"{self.config['log_chunk_dir']}/docs")
297
+ self.logger.info(
298
+ f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
299
+ )
300
+ self.logger.info(
301
+ f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
302
+ )
303
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
304
+ os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
305
+ self.logger.info(
306
+ f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
307
+ )
308
+ self.logger.info(
309
+ f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
310
+ )
311
+ with open(
312
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
313
+ ) as json_file:
314
+ json.dump(self.document_data, json_file, indent=4)
315
+ with open(
316
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
317
+ ) as json_file:
318
+ json.dump(self.document_metadata, json_file, indent=4)
319
+
320
+ def load_document_data(self):
321
+ with open(
322
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
323
+ ) as json_file:
324
+ self.document_data = json.load(json_file)
325
+ with open(
326
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
327
+ ) as json_file:
328
+ self.document_metadata = json.load(json_file)
329
 
330
 
331
  class DataLoader:
332
+ def __init__(self, config, logger=None):
333
+ self.file_reader = FileReader(logger=logger)
334
+ self.chunk_processor = ChunkProcessor(config, logger=logger)
335
 
336
  def get_chunks(self, uploaded_files, weblinks):
337
+ return self.chunk_processor.chunk_docs(
338
  self.file_reader, uploaded_files, weblinks
339
  )
340
+
341
+
342
+ if __name__ == "__main__":
343
+ import yaml
344
+
345
+ logger = logging.getLogger(__name__)
346
+ logger.setLevel(logging.INFO)
347
+
348
+ with open("../code/config.yml", "r") as f:
349
+ config = yaml.safe_load(f)
350
+
351
+ data_loader = DataLoader(config, logger=logger)
352
+ document_chunks, document_names, documents, document_metadata = (
353
+ data_loader.get_chunks(
354
+ [],
355
+ ["https://dl4ds.github.io/sp2024/"],
356
+ )
357
+ )
code/modules/dataloader/helpers.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from tqdm import tqdm
4
+
5
+
6
+ def get_urls_from_file(file_path: str):
7
+ """
8
+ Function to get urls from a file
9
+ """
10
+ with open(file_path, "r") as f:
11
+ urls = f.readlines()
12
+ urls = [url.strip() for url in urls]
13
+ return urls
14
+
15
+
16
+ def get_base_url(url):
17
+ parsed_url = urlparse(url)
18
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
19
+ return base_url
20
+
21
+
22
+ def get_metadata(lectures_url, schedule_url):
23
+ """
24
+ Function to get the lecture metadata from the lectures and schedule URLs.
25
+ """
26
+ lecture_metadata = {}
27
+
28
+ # Get the main lectures page content
29
+ r_lectures = requests.get(lectures_url)
30
+ soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
31
+
32
+ # Get the main schedule page content
33
+ r_schedule = requests.get(schedule_url)
34
+ soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
35
+
36
+ # Find all lecture blocks
37
+ lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
38
+
39
+ # Create a mapping from slides link to date
40
+ date_mapping = {}
41
+ schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
42
+ for row in schedule_rows:
43
+ try:
44
+ date = (
45
+ row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
46
+ )
47
+ description_div = row.find("div", {"data-label": "Description"})
48
+ slides_link_tag = description_div.find("a", title="Download slides")
49
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
50
+ slides_link = (
51
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
52
+ )
53
+ if slides_link:
54
+ date_mapping[slides_link] = date
55
+ except Exception as e:
56
+ print(f"Error processing schedule row: {e}")
57
+ continue
58
+
59
+ for block in lecture_blocks:
60
+ try:
61
+ # Extract the lecture title
62
+ title = block.find("span", style="font-weight: bold;").text.strip()
63
+
64
+ # Extract the TL;DR
65
+ tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
66
+
67
+ # Extract the link to the slides
68
+ slides_link_tag = block.find("a", title="Download slides")
69
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
70
+ slides_link = (
71
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
72
+ )
73
+
74
+ # Extract the link to the lecture recording
75
+ recording_link_tag = block.find("a", title="Download lecture recording")
76
+ recording_link = (
77
+ recording_link_tag["href"].strip() if recording_link_tag else None
78
+ )
79
+
80
+ # Extract suggested readings or summary if available
81
+ suggested_readings_tag = block.find("p", text="Suggested Readings:")
82
+ if suggested_readings_tag:
83
+ suggested_readings = suggested_readings_tag.find_next_sibling("ul")
84
+ if suggested_readings:
85
+ suggested_readings = suggested_readings.get_text(
86
+ separator="\n"
87
+ ).strip()
88
+ else:
89
+ suggested_readings = "No specific readings provided."
90
+ else:
91
+ suggested_readings = "No specific readings provided."
92
+
93
+ # Get the date from the schedule
94
+ date = date_mapping.get(slides_link, "No date available")
95
+
96
+ # Add to the dictionary
97
+ lecture_metadata[slides_link] = {
98
+ "date": date,
99
+ "tldr": tldr,
100
+ "title": title,
101
+ "lecture_recording": recording_link,
102
+ "suggested_readings": suggested_readings,
103
+ }
104
+ except Exception as e:
105
+ print(f"Error processing block: {e}")
106
+ continue
107
+
108
+ return lecture_metadata
code/modules/dataloader/webpage_crawler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ from aiohttp import ClientSession
3
+ import asyncio
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from urllib.parse import urlparse, urljoin, urldefrag
7
+
8
+ class WebpageCrawler:
9
+ def __init__(self):
10
+ self.dict_href_links = {}
11
+
12
+ async def fetch(self, session: ClientSession, url: str) -> str:
13
+ async with session.get(url) as response:
14
+ try:
15
+ return await response.text()
16
+ except UnicodeDecodeError:
17
+ return await response.text(encoding="latin1")
18
+
19
+ def url_exists(self, url: str) -> bool:
20
+ try:
21
+ response = requests.head(url)
22
+ return response.status_code == 200
23
+ except requests.ConnectionError:
24
+ return False
25
+
26
+ async def get_links(self, session: ClientSession, website_link: str, base_url: str):
27
+ html_data = await self.fetch(session, website_link)
28
+ soup = BeautifulSoup(html_data, "html.parser")
29
+ list_links = []
30
+ for link in soup.find_all("a", href=True):
31
+ href = link["href"].strip()
32
+ full_url = urljoin(base_url, href)
33
+ normalized_url = self.normalize_url(full_url) # sections removed
34
+ if (
35
+ normalized_url not in self.dict_href_links
36
+ and self.is_child_url(normalized_url, base_url)
37
+ and self.url_exists(normalized_url)
38
+ ):
39
+ self.dict_href_links[normalized_url] = None
40
+ list_links.append(normalized_url)
41
+
42
+ return list_links
43
+
44
+ async def get_subpage_links(
45
+ self, session: ClientSession, urls: list, base_url: str
46
+ ):
47
+ tasks = [self.get_links(session, url, base_url) for url in urls]
48
+ results = await asyncio.gather(*tasks)
49
+ all_links = [link for sublist in results for link in sublist]
50
+ return all_links
51
+
52
+ async def get_all_pages(self, url: str, base_url: str):
53
+ async with aiohttp.ClientSession() as session:
54
+ dict_links = {url: "Not-checked"}
55
+ counter = None
56
+ while counter != 0:
57
+ unchecked_links = [
58
+ link
59
+ for link, status in dict_links.items()
60
+ if status == "Not-checked"
61
+ ]
62
+ if not unchecked_links:
63
+ break
64
+ new_links = await self.get_subpage_links(
65
+ session, unchecked_links, base_url
66
+ )
67
+ for link in unchecked_links:
68
+ dict_links[link] = "Checked"
69
+ print(f"Checked: {link}")
70
+ dict_links.update(
71
+ {
72
+ link: "Not-checked"
73
+ for link in new_links
74
+ if link not in dict_links
75
+ }
76
+ )
77
+ counter = len(
78
+ [
79
+ status
80
+ for status in dict_links.values()
81
+ if status == "Not-checked"
82
+ ]
83
+ )
84
+
85
+ checked_urls = [
86
+ url for url, status in dict_links.items() if status == "Checked"
87
+ ]
88
+ return checked_urls
89
+
90
+ def is_webpage(self, url: str) -> bool:
91
+ try:
92
+ response = requests.head(url, allow_redirects=True)
93
+ content_type = response.headers.get("Content-Type", "").lower()
94
+ return "text/html" in content_type
95
+ except requests.RequestException:
96
+ return False
97
+
98
+ def clean_url_list(self, urls):
99
+ files, webpages = [], []
100
+
101
+ for url in urls:
102
+ if self.is_webpage(url):
103
+ webpages.append(url)
104
+ else:
105
+ files.append(url)
106
+
107
+ return files, webpages
108
+
109
+ def is_child_url(self, url, base_url):
110
+ return url.startswith(base_url)
111
+
112
+ def normalize_url(self, url: str):
113
+ # Strip the fragment identifier
114
+ defragged_url, _ = urldefrag(url)
115
+ return defragged_url
code/modules/retriever/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .faiss_retriever import FaissRetriever
2
+ from .chroma_retriever import ChromaRetriever
code/modules/retriever/base.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class BaseRetriever:
2
+ def __init__(self, config):
3
+ self.config = config
4
+
5
+ def return_retriever(self):
6
+ raise NotImplementedError
code/modules/retriever/chroma_retriever.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class ChromaRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+
24
+ return retriever
code/modules/retriever/faiss_retriever.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class FaissRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+ return retriever
code/modules/retriever/helpers.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema.vectorstore import VectorStoreRetriever
2
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
3
+ from langchain.schema.document import Document
4
+ from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
5
+ from typing import List
6
+
7
+
8
+ class VectorStoreRetrieverScore(VectorStoreRetriever):
9
+
10
+ # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
11
+ def _get_relevant_documents(
12
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
13
+ ) -> List[Document]:
14
+ docs_and_similarities = (
15
+ self.vectorstore.similarity_search_with_relevance_scores(
16
+ query, **self.search_kwargs
17
+ )
18
+ )
19
+ # Make the score part of the document metadata
20
+ for doc, similarity in docs_and_similarities:
21
+ doc.metadata["score"] = similarity
22
+
23
+ docs = [doc for doc, _ in docs_and_similarities]
24
+ return docs
25
+
26
+ async def _aget_relevant_documents(
27
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
28
+ ) -> List[Document]:
29
+ docs_and_similarities = (
30
+ self.vectorstore.similarity_search_with_relevance_scores(
31
+ query, **self.search_kwargs
32
+ )
33
+ )
34
+ # Make the score part of the document metadata
35
+ for doc, similarity in docs_and_similarities:
36
+ doc.metadata["score"] = similarity
37
+
38
+ docs = [doc for doc, _ in docs_and_similarities]
39
+ return docs
code/modules/vectorstore/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import VectorStoreBase
2
+ from .faiss import FAISS
code/modules/vectorstore/base.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class VectorStoreBase:
2
+ def __init__(self, config):
3
+ self.config = config
4
+
5
+ def _init_vector_db(self):
6
+ raise NotImplementedError
7
+
8
+ def create_database(self, database_name):
9
+ raise NotImplementedError
10
+
11
+ def load_database(self, database_name):
12
+ raise NotImplementedError
13
+
14
+ def as_retriever(self):
15
+ raise NotImplementedError
16
+
17
+ def __str__(self):
18
+ return self.__class__.__name__
code/modules/vectorstore/chroma.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class ChromaVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.chroma = Chroma()
13
+
14
+ def create_database(self, document_chunks, embedding_model):
15
+ self.vectorstore = self.chroma.from_documents(
16
+ documents=document_chunks,
17
+ embedding=embedding_model,
18
+ persist_directory=os.path.join(
19
+ self.config["vectorstore"]["db_path"],
20
+ "db_"
21
+ + self.config["vectorstore"]["db_option"]
22
+ + "_"
23
+ + self.config["vectorstore"]["model"],
24
+ ),
25
+ )
26
+
27
+ def load_database(self, embedding_model):
28
+ self.vectorstore = Chroma(
29
+ persist_directory=os.path.join(
30
+ self.config["vectorstore"]["db_path"],
31
+ "db_"
32
+ + self.config["vectorstore"]["db_option"]
33
+ + "_"
34
+ + self.config["vectorstore"]["model"],
35
+ ),
36
+ embedding_function=embedding_model,
37
+ )
38
+ return self.vectorstore
39
+
40
+ def as_retriever(self):
41
+ return self.vectorstore.as_retriever()
code/modules/{embedding_model_loader.py β†’ vectorstore/embedding_model_loader.py} RENAMED
@@ -2,10 +2,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.embeddings import LlamaCppEmbeddings
4
 
5
- try:
6
- from modules.constants import *
7
- except:
8
- from constants import *
9
  import os
10
 
11
 
@@ -14,19 +11,19 @@ class EmbeddingModelLoader:
14
  self.config = config
15
 
16
  def load_embedding_model(self):
17
- if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
18
  embedding_model = OpenAIEmbeddings(
19
  deployment="SL-document_embedder",
20
- model=self.config["embedding_options"]["model"],
21
  show_progress_bar=True,
22
  openai_api_key=OPENAI_API_KEY,
23
  disallowed_special=(),
24
  )
25
  else:
26
  embedding_model = HuggingFaceEmbeddings(
27
- model_name=self.config["embedding_options"]["model"],
28
  model_kwargs={
29
- "device": "cpu",
30
  "token": f"{HUGGINGFACE_TOKEN}",
31
  "trust_remote_code": True,
32
  },
 
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.embeddings import LlamaCppEmbeddings
4
 
5
+ from modules.config.constants import *
 
 
 
6
  import os
7
 
8
 
 
11
  self.config = config
12
 
13
  def load_embedding_model(self):
14
+ if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
15
  embedding_model = OpenAIEmbeddings(
16
  deployment="SL-document_embedder",
17
+ model=self.config["vectorestore"]["model"],
18
  show_progress_bar=True,
19
  openai_api_key=OPENAI_API_KEY,
20
  disallowed_special=(),
21
  )
22
  else:
23
  embedding_model = HuggingFaceEmbeddings(
24
+ model_name=self.config["vectorstore"]["model"],
25
  model_kwargs={
26
+ "device": f"{self.config['device']}",
27
  "token": f"{HUGGINGFACE_TOKEN}",
28
  "trust_remote_code": True,
29
  },
code/modules/vectorstore/faiss.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class FaissVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.faiss = FAISS(
13
+ embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
14
+ )
15
+
16
+ def create_database(self, document_chunks, embedding_model):
17
+ self.vectorstore = self.faiss.from_documents(
18
+ documents=document_chunks, embedding=embedding_model
19
+ )
20
+ self.vectorstore.save_local(
21
+ os.path.join(
22
+ self.config["vectorstore"]["db_path"],
23
+ "db_"
24
+ + self.config["vectorstore"]["db_option"]
25
+ + "_"
26
+ + self.config["vectorstore"]["model"],
27
+ )
28
+ )
29
+
30
+ def load_database(self, embedding_model):
31
+ self.vectorstore = self.faiss.load_local(
32
+ os.path.join(
33
+ self.config["vectorstore"]["db_path"],
34
+ "db_"
35
+ + self.config["vectorstore"]["db_option"]
36
+ + "_"
37
+ + self.config["vectorstore"]["model"],
38
+ ),
39
+ embedding_model,
40
+ allow_dangerous_deserialization=True,
41
+ )
42
+ return self.vectorstore
43
+
44
+ def as_retriever(self):
45
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/helpers.py ADDED
File without changes
code/modules/{vector_db.py β†’ vectorstore/store_manager.py} RENAMED
@@ -1,72 +1,27 @@
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
- import yaml
4
- from langchain_community.vectorstores import FAISS, Chroma
5
- from langchain.schema.vectorstore import VectorStoreRetriever
6
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
- from langchain.schema.document import Document
8
- from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
9
- from ragatouille import RAGPretrainedModel
10
-
11
- try:
12
- from modules.embedding_model_loader import EmbeddingModelLoader
13
- from modules.data_loader import DataLoader
14
- from modules.constants import *
15
- from modules.helpers import *
16
- except:
17
- from embedding_model_loader import EmbeddingModelLoader
18
- from data_loader import DataLoader
19
- from constants import *
20
- from helpers import *
21
-
22
- from typing import List
23
-
24
-
25
- class VectorDBScore(VectorStoreRetriever):
26
-
27
- # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
28
- def _get_relevant_documents(
29
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
30
- ) -> List[Document]:
31
- docs_and_similarities = (
32
- self.vectorstore.similarity_search_with_relevance_scores(
33
- query, **self.search_kwargs
34
- )
35
- )
36
- # Make the score part of the document metadata
37
- for doc, similarity in docs_and_similarities:
38
- doc.metadata["score"] = similarity
39
-
40
- docs = [doc for doc, _ in docs_and_similarities]
41
- return docs
42
-
43
- async def _aget_relevant_documents(
44
- self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
45
- ) -> List[Document]:
46
- docs_and_similarities = (
47
- self.vectorstore.similarity_search_with_relevance_scores(
48
- query, **self.search_kwargs
49
- )
50
- )
51
- # Make the score part of the document metadata
52
- for doc, similarity in docs_and_similarities:
53
- doc.metadata["score"] = similarity
54
-
55
- docs = [doc for doc, _ in docs_and_similarities]
56
- return docs
57
 
58
 
59
- class VectorDB:
60
  def __init__(self, config, logger=None):
61
  self.config = config
62
- self.db_option = config["embedding_options"]["db_option"]
63
  self.document_names = None
64
- self.webpage_crawler = WebpageCrawler()
65
 
66
  # Set up logging to both console and a file
67
  if logger is None:
68
  self.logger = logging.getLogger(__name__)
69
  self.logger.setLevel(logging.INFO)
 
70
 
71
  # Console Handler
72
  console_handler = logging.StreamHandler()
@@ -75,8 +30,13 @@ class VectorDB:
75
  console_handler.setFormatter(formatter)
76
  self.logger.addHandler(console_handler)
77
 
 
 
 
 
 
78
  # File Handler
79
- log_file_path = "vector_db.log" # Change this to your desired log file path
80
  file_handler = logging.FileHandler(log_file_path, mode="w")
81
  file_handler.setLevel(logging.INFO)
82
  file_handler.setFormatter(formatter)
@@ -84,16 +44,18 @@ class VectorDB:
84
  else:
85
  self.logger = logger
86
 
 
 
87
  self.logger.info("VectorDB instance instantiated")
88
 
89
  def load_files(self):
90
- files = os.listdir(self.config["embedding_options"]["data_path"])
91
  files = [
92
- os.path.join(self.config["embedding_options"]["data_path"], file)
93
  for file in files
94
  ]
95
- urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"])
96
- if self.config["embedding_options"]["expand_urls"]:
97
  all_urls = []
98
  for url in urls:
99
  loop = asyncio.get_event_loop()
@@ -109,8 +71,9 @@ class VectorDB:
109
 
110
  def create_embedding_model(self):
111
  self.logger.info("Creating embedding function")
112
- self.embedding_model_loader = EmbeddingModelLoader(self.config)
113
- self.embedding_model = self.embedding_model_loader.load_embedding_model()
 
114
 
115
  def initialize_database(
116
  self,
@@ -120,107 +83,153 @@ class VectorDB:
120
  document_metadata: list,
121
  ):
122
  if self.db_option in ["FAISS", "Chroma"]:
123
- self.create_embedding_model()
124
- # Track token usage
125
  self.logger.info("Initializing vector_db")
126
  self.logger.info("\tUsing {} as db_option".format(self.db_option))
127
  if self.db_option == "FAISS":
128
- self.vector_db = FAISS.from_documents(
129
- documents=document_chunks, embedding=self.embedding_model
130
- )
131
  elif self.db_option == "Chroma":
132
- self.vector_db = Chroma.from_documents(
133
- documents=document_chunks,
134
- embedding=self.embedding_model,
135
- persist_directory=os.path.join(
136
- self.config["embedding_options"]["db_path"],
137
- "db_"
138
- + self.config["embedding_options"]["db_option"]
139
- + "_"
140
- + self.config["embedding_options"]["model"],
141
- ),
142
- )
143
  elif self.db_option == "RAGatouille":
144
  self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
145
- index_path = self.RAG.index(
146
- index_name="new_idx",
147
- collection=documents,
148
- document_ids=document_names,
149
- document_metadatas=document_metadata,
150
- )
151
- self.logger.info("Completed initializing vector_db")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def create_database(self):
154
- data_loader = DataLoader(self.config)
 
155
  self.logger.info("Loading data")
156
  files, urls = self.load_files()
157
  files, webpages = self.webpage_crawler.clean_url_list(urls)
158
- if "storage/data/urls.txt" in files:
159
- files.remove("storage/data/urls.txt")
 
 
160
  document_chunks, document_names, documents, document_metadata = (
161
  data_loader.get_chunks(files, webpages)
162
  )
 
 
 
 
163
  self.logger.info("Completed loading data")
164
  self.initialize_database(
165
  document_chunks, document_names, documents, document_metadata
166
  )
 
 
 
 
 
167
 
168
- def save_database(self):
169
- if self.db_option == "FAISS":
170
- self.vector_db.save_local(
171
- os.path.join(
172
- self.config["embedding_options"]["db_path"],
173
- "db_"
174
- + self.config["embedding_options"]["db_option"]
175
- + "_"
176
- + self.config["embedding_options"]["model"],
177
- )
178
- )
179
- elif self.db_option == "Chroma":
180
- # db is saved in the persist directory during initialization
181
- pass
182
- elif self.db_option == "RAGatouille":
183
- # index is saved during initialization
184
- pass
185
- self.logger.info("Saved database")
 
 
 
 
 
186
 
187
  def load_database(self):
188
- self.create_embedding_model()
 
 
189
  if self.db_option == "FAISS":
190
- self.vector_db = FAISS.load_local(
191
- os.path.join(
192
- self.config["embedding_options"]["db_path"],
193
- "db_"
194
- + self.config["embedding_options"]["db_option"]
195
- + "_"
196
- + self.config["embedding_options"]["model"],
197
- ),
198
- self.embedding_model,
199
- allow_dangerous_deserialization=True,
200
- )
201
  elif self.db_option == "Chroma":
202
- self.vector_db = Chroma(
203
- persist_directory=os.path.join(
204
- self.config["embedding_options"]["db_path"],
205
- "db_"
206
- + self.config["embedding_options"]["db_option"]
207
- + "_"
208
- + self.config["embedding_options"]["model"],
209
- ),
210
- embedding_function=self.embedding_model,
211
- )
212
- elif self.db_option == "RAGatouille":
213
- self.vector_db = RAGPretrainedModel.from_index(
214
- ".ragatouille/colbert/indexes/new_idx"
215
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  self.logger.info("Loaded database")
217
- return self.vector_db
218
 
219
 
220
  if __name__ == "__main__":
221
- with open("code/config.yml", "r") as f:
 
 
222
  config = yaml.safe_load(f)
223
  print(config)
224
- vector_db = VectorDB(config)
 
225
  vector_db.create_database()
226
- vector_db.save_database()
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.faiss import FaissVectorStore
2
+ from modules.vectorstore.chroma import ChromaVectorStore
3
+ from modules.vectorstore.helpers import *
4
+ from modules.dataloader.webpage_crawler import WebpageCrawler
5
+ from modules.dataloader.data_loader import DataLoader
6
+ from modules.dataloader.helpers import *
7
+ from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
8
  import logging
9
  import os
10
+ import time
11
+ import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
+ class VectorStoreManager:
15
  def __init__(self, config, logger=None):
16
  self.config = config
17
+ self.db_option = config["vectorstore"]["db_option"]
18
  self.document_names = None
 
19
 
20
  # Set up logging to both console and a file
21
  if logger is None:
22
  self.logger = logging.getLogger(__name__)
23
  self.logger.setLevel(logging.INFO)
24
+ self.logger.propagate = False
25
 
26
  # Console Handler
27
  console_handler = logging.StreamHandler()
 
30
  console_handler.setFormatter(formatter)
31
  self.logger.addHandler(console_handler)
32
 
33
+ # Ensure log directory exists
34
+ log_directory = self.config["log_dir"]
35
+ if not os.path.exists(log_directory):
36
+ os.makedirs(log_directory)
37
+
38
  # File Handler
39
+ log_file_path = f"{log_directory}/vector_db.log" # Change this to your desired log file path
40
  file_handler = logging.FileHandler(log_file_path, mode="w")
41
  file_handler.setLevel(logging.INFO)
42
  file_handler.setFormatter(formatter)
 
44
  else:
45
  self.logger = logger
46
 
47
+ self.webpage_crawler = WebpageCrawler()
48
+
49
  self.logger.info("VectorDB instance instantiated")
50
 
51
  def load_files(self):
52
+ files = os.listdir(self.config["vectorstore"]["data_path"])
53
  files = [
54
+ os.path.join(self.config["vectorstore"]["data_path"], file)
55
  for file in files
56
  ]
57
+ urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
58
+ if self.config["vectorstore"]["expand_urls"]:
59
  all_urls = []
60
  for url in urls:
61
  loop = asyncio.get_event_loop()
 
71
 
72
  def create_embedding_model(self):
73
  self.logger.info("Creating embedding function")
74
+ embedding_model_loader = EmbeddingModelLoader(self.config)
75
+ embedding_model = embedding_model_loader.load_embedding_model()
76
+ return embedding_model
77
 
78
  def initialize_database(
79
  self,
 
83
  document_metadata: list,
84
  ):
85
  if self.db_option in ["FAISS", "Chroma"]:
86
+ self.embedding_model = self.create_embedding_model()
87
+
88
  self.logger.info("Initializing vector_db")
89
  self.logger.info("\tUsing {} as db_option".format(self.db_option))
90
  if self.db_option == "FAISS":
91
+ self.vector_db = FaissVectorStore(self.config)
92
+ self.vector_db.create_database(document_chunks, self.embedding_model)
 
93
  elif self.db_option == "Chroma":
94
+ self.vector_db = ChromaVectorStore(self.config)
95
+ self.vector_db.create_database(document_chunks, self.embedding_model)
 
 
 
 
 
 
 
 
 
96
  elif self.db_option == "RAGatouille":
97
  self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
98
+ # index_path = self.RAG.index(
99
+ # index_name="new_idx",
100
+ # collection=documents,
101
+ # document_ids=document_names,
102
+ # document_metadatas=document_metadata,
103
+ # )
104
+ batch_size = 32
105
+ for i in range(0, len(documents), batch_size):
106
+ if i == 0:
107
+ self.RAG.index(
108
+ index_name="new_idx",
109
+ collection=documents[i : i + batch_size],
110
+ document_ids=document_names[i : i + batch_size],
111
+ document_metadatas=document_metadata[i : i + batch_size],
112
+ )
113
+ else:
114
+ self.RAG = RAGPretrainedModel.from_index(
115
+ ".ragatouille/colbert/indexes/new_idx"
116
+ )
117
+ self.RAG.add_to_index(
118
+ new_collection=documents[i : i + batch_size],
119
+ new_document_ids=document_names[i : i + batch_size],
120
+ new_document_metadatas=document_metadata[i : i + batch_size],
121
+ )
122
 
123
  def create_database(self):
124
+ start_time = time.time() # Start time for creating database
125
+ data_loader = DataLoader(self.config, self.logger)
126
  self.logger.info("Loading data")
127
  files, urls = self.load_files()
128
  files, webpages = self.webpage_crawler.clean_url_list(urls)
129
+ self.logger.info(f"Number of files: {len(files)}")
130
+ self.logger.info(f"Number of webpages: {len(webpages)}")
131
+ if f"{self.config['vectorstore']['url_file_path']}" in files:
132
+ files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
133
  document_chunks, document_names, documents, document_metadata = (
134
  data_loader.get_chunks(files, webpages)
135
  )
136
+ num_documents = len(document_chunks)
137
+ self.logger.info(f"Number of documents in the DB: {num_documents}")
138
+ metadata_keys = list(document_metadata[0].keys())
139
+ self.logger.info(f"Metadata keys: {metadata_keys}")
140
  self.logger.info("Completed loading data")
141
  self.initialize_database(
142
  document_chunks, document_names, documents, document_metadata
143
  )
144
+ end_time = time.time() # End time for creating database
145
+ self.logger.info("Created database")
146
+ self.logger.info(
147
+ f"Time taken to create database: {end_time - start_time} seconds"
148
+ )
149
 
150
+ # def save_database(self):
151
+ # start_time = time.time() # Start time for saving database
152
+ # if self.db_option == "FAISS":
153
+ # self.vector_db.save_local(
154
+ # os.path.join(
155
+ # self.config["vectorstore"]["db_path"],
156
+ # "db_"
157
+ # + self.config["vectorstore"]["db_option"]
158
+ # + "_"
159
+ # + self.config["vectorstore"]["model"],
160
+ # )
161
+ # )
162
+ # elif self.db_option == "Chroma":
163
+ # # db is saved in the persist directory during initialization
164
+ # pass
165
+ # elif self.db_option == "RAGatouille":
166
+ # # index is saved during initialization
167
+ # pass
168
+ # self.logger.info("Saved database")
169
+ # end_time = time.time() # End time for saving database
170
+ # self.logger.info(
171
+ # f"Time taken to save database: {end_time - start_time} seconds"
172
+ # )
173
 
174
  def load_database(self):
175
+ start_time = time.time() # Start time for loading database
176
+ if self.db_option in ["FAISS", "Chroma"]:
177
+ self.embedding_model = self.create_embedding_model()
178
  if self.db_option == "FAISS":
179
+ self.vector_db = FaissVectorStore(self.config)
180
+ self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
 
 
 
 
 
 
 
 
 
181
  elif self.db_option == "Chroma":
182
+ self.vector_db = ChromaVectorStore(self.config)
183
+ self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
184
+ # if self.db_option == "FAISS":
185
+ # self.vector_db = FAISS.load_local(
186
+ # os.path.join(
187
+ # self.config["vectorstore"]["db_path"],
188
+ # "db_"
189
+ # + self.config["vectorstore"]["db_option"]
190
+ # + "_"
191
+ # + self.config["vectorstore"]["model"],
192
+ # ),
193
+ # self.embedding_model,
194
+ # allow_dangerous_deserialization=True,
195
+ # )
196
+ # elif self.db_option == "Chroma":
197
+ # self.vector_db = Chroma(
198
+ # persist_directory=os.path.join(
199
+ # self.config["embedding_options"]["db_path"],
200
+ # "db_"
201
+ # + self.config["embedding_options"]["db_option"]
202
+ # + "_"
203
+ # + self.config["embedding_options"]["model"],
204
+ # ),
205
+ # embedding_function=self.embedding_model,
206
+ # )
207
+ # elif self.db_option == "RAGatouille":
208
+ # self.vector_db = RAGPretrainedModel.from_index(
209
+ # ".ragatouille/colbert/indexes/new_idx"
210
+ # )
211
+ end_time = time.time() # End time for loading database
212
+ self.logger.info(
213
+ f"Time taken to load database: {end_time - start_time} seconds"
214
+ )
215
  self.logger.info("Loaded database")
216
+ return self.loaded_vector_db
217
 
218
 
219
  if __name__ == "__main__":
220
+ import yaml
221
+
222
+ with open("modules/config/config.yml", "r") as f:
223
  config = yaml.safe_load(f)
224
  print(config)
225
+ print(f"Trying to create database with config: {config}")
226
+ vector_db = VectorStoreManager(config)
227
  vector_db.create_database()
228
+ print("Created database")
229
+
230
+ print(f"Trying to load the database")
231
+ vector_db = VectorStoreManager(config)
232
+ vector_db.load_database()
233
+ print("Loaded database")
234
+
235
+ print(f"View the logs at {config['log_dir']}/vector_db.log")