Spaces:
Build error
Build error
optimize code
Browse files
app.py
CHANGED
@@ -2,36 +2,25 @@ from typing import List, Optional, Tuple
|
|
2 |
from queue import Empty, Queue
|
3 |
from threading import Thread
|
4 |
from bot.web_scrapping.crawler_and_indexer import content_crawler_and_index
|
5 |
-
from bot.web_scrapping.searchable_index import SearchableIndex
|
6 |
from bot.utils.callbacks import QueueCallback
|
7 |
from bot.utils.constanst import set_api_key
|
8 |
from bot.utils.show_log import logger
|
|
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
from langchain.prompts import HumanMessagePromptTemplate
|
11 |
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
12 |
import gradio as gr
|
13 |
|
14 |
set_api_key()
|
15 |
-
MODELS_NAMES = ["gpt-3.5-turbo"]
|
16 |
-
DEFAULT_TEMPERATURE = 0.7
|
17 |
-
|
18 |
-
ChatHistory = List[str]
|
19 |
-
|
20 |
-
default_system_prompt = 'Put your prompt here'
|
21 |
-
default_system_format = 'txt'
|
22 |
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}")
|
23 |
|
24 |
|
25 |
-
def
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def bot_learning(urls, file_formats, chat_mode=False):
|
30 |
-
index = content_crawler_and_index(url=str(urls), file_format=file_formats)
|
31 |
if chat_mode:
|
32 |
return index
|
33 |
else:
|
34 |
-
fb =
|
35 |
return fb
|
36 |
|
37 |
|
@@ -54,10 +43,10 @@ def chat_start(
|
|
54 |
job_done = object()
|
55 |
messages.append(HumanMessage(content=f':{message}'))
|
56 |
chatbot_messages.append((message, ""))
|
57 |
-
index = bot_learning(urls='NO_URL', file_formats='txt', chat_mode=True)
|
58 |
|
59 |
def query_retrieval():
|
60 |
-
response =
|
61 |
chatbot_message = AIMessage(content=response)
|
62 |
messages.append(chatbot_message)
|
63 |
queue.put(job_done)
|
@@ -105,90 +94,96 @@ def on_apply_settings_button_click(
|
|
105 |
return chat, *on_clear_button_click(system_prompt)
|
106 |
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
)
|
118 |
-
|
119 |
-
chatbot = gr.Chatbot()
|
120 |
-
with gr.Column():
|
121 |
-
message = gr.Textbox(label="Type some message")
|
122 |
-
message.submit(
|
123 |
-
chat_start,
|
124 |
-
[chat, message, chatbot, messages],
|
125 |
-
[chat, message, chatbot, messages],
|
126 |
-
queue=True,
|
127 |
-
)
|
128 |
-
message_button = gr.Button("Submit", variant="primary")
|
129 |
-
message_button.click(
|
130 |
-
chat_start,
|
131 |
-
[chat, message, chatbot, messages],
|
132 |
-
[chat, message, chatbot, messages],
|
133 |
)
|
134 |
-
|
135 |
-
|
136 |
-
url = gr.Textbox(label="URL to Documents")
|
137 |
-
file_format = gr.Textbox(label="Set your file format:", placeholder='Example: pdf, txt')
|
138 |
-
url.submit(
|
139 |
-
bot_learning,
|
140 |
-
[url, file_format],
|
141 |
-
[learning_status]
|
142 |
-
)
|
143 |
-
training_button = gr.Button("Training", variant="primary")
|
144 |
-
training_button.click(
|
145 |
-
bot_learning,
|
146 |
-
[url, file_format],
|
147 |
-
[learning_status]
|
148 |
-
)
|
149 |
-
with gr.Row():
|
150 |
with gr.Column():
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
[
|
155 |
-
[message, chatbot, messages],
|
156 |
-
queue=
|
157 |
)
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
169 |
)
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
[
|
174 |
-
[
|
175 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
)
|
186 |
-
system_prompt_button = gr.Button("Set")
|
187 |
-
system_prompt_button.click(
|
188 |
-
on_apply_settings_button_click,
|
189 |
-
[system_prompt, model_name, temperature],
|
190 |
-
[chat, message, chatbot, messages],
|
191 |
-
)
|
192 |
|
193 |
-
demo
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from queue import Empty, Queue
|
3 |
from threading import Thread
|
4 |
from bot.web_scrapping.crawler_and_indexer import content_crawler_and_index
|
|
|
5 |
from bot.utils.callbacks import QueueCallback
|
6 |
from bot.utils.constanst import set_api_key
|
7 |
from bot.utils.show_log import logger
|
8 |
+
from bot.web_scrapping.default import *
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
from langchain.prompts import HumanMessagePromptTemplate
|
11 |
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
12 |
import gradio as gr
|
13 |
|
14 |
set_api_key()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}")
|
16 |
|
17 |
|
18 |
+
def bot_learning(urls, file_formats, llm, prompt, chat_mode=False):
|
19 |
+
index = content_crawler_and_index(url=str(urls), llm=llm, prompt=prompt, file_format=file_formats)
|
|
|
|
|
|
|
|
|
20 |
if chat_mode:
|
21 |
return index
|
22 |
else:
|
23 |
+
fb = 'Training Completed'
|
24 |
return fb
|
25 |
|
26 |
|
|
|
43 |
job_done = object()
|
44 |
messages.append(HumanMessage(content=f':{message}'))
|
45 |
chatbot_messages.append((message, ""))
|
46 |
+
index = bot_learning(urls='NO_URL', file_formats='txt', llm=chat, prompt=message, chat_mode=True)
|
47 |
|
48 |
def query_retrieval():
|
49 |
+
response = index.query()
|
50 |
chatbot_message = AIMessage(content=response)
|
51 |
messages.append(chatbot_message)
|
52 |
queue.put(job_done)
|
|
|
94 |
return chat, *on_clear_button_click(system_prompt)
|
95 |
|
96 |
|
97 |
+
def main():
|
98 |
+
with gr.Blocks() as demo:
|
99 |
+
system_prompt = gr.State(default_system_prompt)
|
100 |
+
messages = gr.State([SystemMessage(content=default_system_prompt)])
|
101 |
+
chat = gr.State(None)
|
102 |
|
103 |
+
with gr.Column(elem_id="col_container"):
|
104 |
+
gr.Markdown("# Welcome to OWN-GPT! 🤖")
|
105 |
+
gr.Markdown(
|
106 |
+
"Demo Chat Bot Platform"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
)
|
108 |
+
|
109 |
+
chatbot = gr.Chatbot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
with gr.Column():
|
111 |
+
message = gr.Textbox(label="Type some message")
|
112 |
+
message.submit(
|
113 |
+
chat_start,
|
114 |
+
[chat, message, chatbot, messages],
|
115 |
+
[chat, message, chatbot, messages],
|
116 |
+
queue=True,
|
117 |
)
|
118 |
+
message_button = gr.Button("Submit", variant="primary")
|
119 |
+
message_button.click(
|
120 |
+
chat_start,
|
121 |
+
[chat, message, chatbot, messages],
|
122 |
+
[chat, message, chatbot, messages],
|
123 |
)
|
124 |
+
with gr.Column():
|
125 |
+
learning_status = gr.Textbox(label='Training Status')
|
126 |
+
url = gr.Textbox(label="URL to Documents")
|
127 |
+
file_format = gr.Textbox(label="Set your file format:", placeholder='Example: pdf, txt')
|
128 |
+
url.submit(
|
129 |
+
bot_learning,
|
130 |
+
[url, file_format, chat, message],
|
131 |
+
[learning_status]
|
132 |
)
|
133 |
+
training_button = gr.Button("Training", variant="primary")
|
134 |
+
training_button.click(
|
135 |
+
bot_learning,
|
136 |
+
[url, file_format, chat, message],
|
137 |
+
[learning_status]
|
138 |
)
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
clear_button = gr.Button("Clear")
|
142 |
+
clear_button.click(
|
143 |
+
on_clear_button_click,
|
144 |
+
[system_prompt],
|
145 |
+
[message, chatbot, messages],
|
146 |
+
queue=False,
|
147 |
+
)
|
148 |
+
with gr.Accordion("Settings", open=False):
|
149 |
+
model_name = gr.Dropdown(
|
150 |
+
choices=MODELS_NAMES, value=MODELS_NAMES[0], label="model"
|
151 |
+
)
|
152 |
+
temperature = gr.Slider(
|
153 |
+
minimum=0.0,
|
154 |
+
maximum=1.0,
|
155 |
+
value=0.7,
|
156 |
+
step=0.1,
|
157 |
+
label="temperature",
|
158 |
+
interactive=True,
|
159 |
+
)
|
160 |
+
apply_settings_button = gr.Button("Apply")
|
161 |
+
apply_settings_button.click(
|
162 |
+
on_apply_settings_button_click,
|
163 |
+
[system_prompt, model_name, temperature],
|
164 |
+
[chat, message, chatbot, messages],
|
165 |
+
)
|
166 |
|
167 |
+
with gr.Column():
|
168 |
+
system_prompt_area = gr.TextArea(
|
169 |
+
default_system_prompt, lines=4, label="prompt", interactive=True
|
170 |
+
)
|
171 |
+
system_prompt_area.input(
|
172 |
+
system_prompt_handler,
|
173 |
+
inputs=[system_prompt_area],
|
174 |
+
outputs=[system_prompt],
|
175 |
+
)
|
176 |
+
system_prompt_button = gr.Button("Set")
|
177 |
+
system_prompt_button.click(
|
178 |
+
on_apply_settings_button_click,
|
179 |
+
[system_prompt, model_name, temperature],
|
180 |
+
[chat, message, chatbot, messages],
|
181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
return demo
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == '__main__':
|
187 |
+
demo = main()
|
188 |
+
demo.queue()
|
189 |
+
demo.launch()
|
bot/web_scrapping/crawler_and_indexer.py
CHANGED
@@ -7,66 +7,69 @@ import pandas as pd
|
|
7 |
import requests
|
8 |
import os
|
9 |
|
10 |
-
set_api_key(api_key='sk-
|
11 |
|
12 |
|
13 |
def save_content_to_file(url=None, text=None, output_folder=None, file_format=None):
|
14 |
file_path = os.path.join(output_folder, f"combined_content.{file_format}")
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
logger.info(f"Content appended to {file_path}")
|
27 |
-
elif file_format == 'xml':
|
28 |
-
xml_content = ''.join([f'<item>{t.text}</item>' for t in text])
|
29 |
-
with open(f"{file_path}", "a", encoding="utf-8") as file:
|
30 |
-
file.write(xml_content)
|
31 |
logger.info(f"Content appended to {file_path}")
|
32 |
else:
|
33 |
logger.warning("Invalid file format. Supported formats: txt, pdf, csv, xml")
|
|
|
34 |
return file_path
|
35 |
|
36 |
|
37 |
-
def
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
responses = requests.get(url)
|
41 |
-
|
42 |
-
if responses.status_code == 200:
|
43 |
-
# Create output folder if it doesn't exist
|
44 |
-
if not os.path.exists(output_folder):
|
45 |
-
os.makedirs(output_folder)
|
46 |
-
# Parse the HTML content using BeautifulSoup
|
47 |
-
soup = BeautifulSoup(responses.text, "html.parser")
|
48 |
-
text = soup.find_all(['h2', 'p', 'i', 'ul'])
|
49 |
-
if text:
|
50 |
-
# Save content based on the specified file format
|
51 |
-
file_path = save_content_to_file(text=text, output_folder=output_folder, file_format=file_format)
|
52 |
-
|
53 |
-
# Create or update the index
|
54 |
-
index = SearchableIndex.embed_index(url, file_path)
|
55 |
-
if os.path.isfile(file_path):
|
56 |
-
os.remove(file_path)
|
57 |
-
return index
|
58 |
-
else:
|
59 |
-
file_path = save_content_to_file(url=url, output_folder=output_folder, file_format=file_format)
|
60 |
-
index = SearchableIndex.embed_index(url, file_path)
|
61 |
-
if os.path.isfile(file_path):
|
62 |
-
os.remove(file_path)
|
63 |
-
return index
|
64 |
-
|
65 |
-
else:
|
66 |
logger.warning("Failed to retrieve content from the URL.")
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
if __name__ == '__main__':
|
|
|
7 |
import requests
|
8 |
import os
|
9 |
|
10 |
+
set_api_key(api_key='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')
|
11 |
|
12 |
|
13 |
def save_content_to_file(url=None, text=None, output_folder=None, file_format=None):
|
14 |
file_path = os.path.join(output_folder, f"combined_content.{file_format}")
|
15 |
+
|
16 |
+
write_functions = {
|
17 |
+
'txt': lambda: write_text(file_path, text),
|
18 |
+
'pdf': lambda: write_pdf(url, file_path),
|
19 |
+
'csv': lambda: write_csv(file_path, text),
|
20 |
+
'xml': lambda: write_xml(file_path, text)
|
21 |
+
}
|
22 |
+
|
23 |
+
write_function = write_functions.get(file_format)
|
24 |
+
if write_function:
|
25 |
+
write_function()
|
|
|
|
|
|
|
|
|
|
|
26 |
logger.info(f"Content appended to {file_path}")
|
27 |
else:
|
28 |
logger.warning("Invalid file format. Supported formats: txt, pdf, csv, xml")
|
29 |
+
|
30 |
return file_path
|
31 |
|
32 |
|
33 |
+
def write_text(file_path, text):
|
34 |
+
with open(file_path, "a", encoding="utf-8") as file:
|
35 |
+
for t in text:
|
36 |
+
file.write(f'{t.text}\n')
|
37 |
+
|
38 |
+
|
39 |
+
def write_pdf(url, file_path):
|
40 |
+
request.urlretrieve(url, file_path)
|
41 |
+
|
42 |
+
|
43 |
+
def write_csv(file_path, text):
|
44 |
+
df = pd.DataFrame({'Content': [t.text for t in text]})
|
45 |
+
df.to_csv(file_path, mode='a', index=False, header=False)
|
46 |
+
|
47 |
+
|
48 |
+
def write_xml(file_path, text):
|
49 |
+
xml_content = ''.join([f'<item>{t.text}</item>' for t in text])
|
50 |
+
with open(file_path, "a", encoding="utf-8") as file:
|
51 |
+
file.write(xml_content)
|
52 |
+
|
53 |
+
|
54 |
+
def content_crawler_and_index(url, llm, prompt, file_format='txt', output_folder='learning_documents'):
|
55 |
+
if url == 'NO_URL':
|
56 |
+
file_path = output_folder
|
57 |
+
else:
|
58 |
responses = requests.get(url)
|
59 |
+
if responses.status_code != 200:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
logger.warning("Failed to retrieve content from the URL.")
|
61 |
+
return None
|
62 |
+
if not os.path.exists(output_folder):
|
63 |
+
os.makedirs(output_folder)
|
64 |
+
soup = BeautifulSoup(responses.text, "html.parser")
|
65 |
+
text = soup.find_all(['h2', 'p', 'i', 'ul'])
|
66 |
+
file_path = save_content_to_file(text=text, url=url, output_folder=output_folder, file_format=file_format)
|
67 |
+
|
68 |
+
index = SearchableIndex.embed_index(url=url, path=file_path, llm=llm, prompt=prompt)
|
69 |
+
if url != 'NO_URL' and os.path.isfile(file_path):
|
70 |
+
os.remove(file_path)
|
71 |
+
|
72 |
+
return index
|
73 |
|
74 |
|
75 |
if __name__ == '__main__':
|
bot/web_scrapping/default.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
MODELS_NAMES = ["gpt-3.5-turbo"]
|
3 |
+
DEFAULT_TEMPERATURE = 0.7
|
4 |
+
ChatHistory = List[str]
|
5 |
+
default_system_prompt = 'Put your prompt here'
|
6 |
+
default_system_format = 'txt'
|
bot/web_scrapping/searchable_index.py
CHANGED
@@ -16,125 +16,98 @@ import os
|
|
16 |
import queue
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class SearchableIndex:
|
20 |
def __init__(self, path):
|
21 |
self.path = path
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return doc_list
|
44 |
|
45 |
-
def get_xml_splits(self, target_col, sheet_name):
|
46 |
-
df = pd.read_excel(io=self.path,
|
47 |
-
engine='openpyxl',
|
48 |
-
sheet_name=sheet_name)
|
49 |
-
|
50 |
-
df_loader = DataFrameLoader(df,
|
51 |
-
page_content_column=target_col)
|
52 |
-
|
53 |
-
excel_docs = df_loader.load()
|
54 |
-
|
55 |
-
return excel_docs
|
56 |
-
|
57 |
-
def get_csv_splits(self):
|
58 |
-
csv_loader = CSVLoader(self.path)
|
59 |
-
csv_docs = csv_loader.load()
|
60 |
-
return csv_docs
|
61 |
-
|
62 |
@classmethod
|
63 |
def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
|
64 |
if os.path.exists(index_store):
|
65 |
local_db = FAISS.load_local(index_store, embeddings)
|
66 |
local_db.merge_from(faiss_db)
|
67 |
-
logger.info("Merge index completed")
|
68 |
local_db.save_local(index_store)
|
69 |
-
|
70 |
else:
|
71 |
faiss_db.save_local(folder_path=index_store)
|
72 |
logger.info("New store created and loaded...")
|
73 |
local_db = FAISS.load_local(index_store, embeddings)
|
74 |
-
|
75 |
|
76 |
@classmethod
|
77 |
-
def check_and_load_index(cls, index_files, embeddings, logger,
|
78 |
if index_files:
|
79 |
local_db = FAISS.load_local(index_files[0], embeddings)
|
80 |
-
file_to_remove = os.path.join(path, 'combined_content.txt')
|
81 |
-
if os.path.exists(file_to_remove):
|
82 |
-
os.remove(file_to_remove)
|
83 |
else:
|
84 |
raise logger.warning("Index store does not exist")
|
85 |
result_queue.put(local_db) # Put the result in the queue
|
86 |
|
87 |
@classmethod
|
88 |
-
def embed_index(cls, url, path, target_col=None, sheet_name=None):
|
89 |
embeddings = OpenAIEmbeddings()
|
90 |
|
91 |
-
def process_docs(queues, extension):
|
92 |
-
nonlocal doc_list
|
93 |
-
instance = cls(path)
|
94 |
-
if extension == ".txt":
|
95 |
-
doc_list = instance.get_text_splits()
|
96 |
-
elif extension == ".pdf":
|
97 |
-
doc_list = instance.get_pdf_splits()
|
98 |
-
elif extension == ".xml":
|
99 |
-
doc_list = instance.get_xml_splits(target_col, sheet_name)
|
100 |
-
elif extension == ".csv":
|
101 |
-
doc_list = instance.get_csv_splits()
|
102 |
-
else:
|
103 |
-
doc_list = None
|
104 |
-
queues.put(doc_list)
|
105 |
-
|
106 |
if url != 'NO_URL' and path:
|
107 |
-
|
108 |
-
data_queue = queue.Queue()
|
109 |
-
thread = threading.Thread(target=process_docs, args=(data_queue, file_extension))
|
110 |
-
thread.start()
|
111 |
-
doc_list = data_queue.get()
|
112 |
-
if not doc_list:
|
113 |
-
raise ValueError("Unsupported file format")
|
114 |
-
|
115 |
faiss_db = FAISS.from_texts(doc_list, embeddings)
|
116 |
index_store = os.path.splitext(path)[0] + "_index"
|
117 |
local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
|
118 |
-
return
|
119 |
elif url == 'NO_URL' and path:
|
120 |
index_files = glob.glob(os.path.join(path, '*_index'))
|
121 |
|
122 |
result_queue = queue.Queue() # Create a queue to store the result
|
123 |
|
124 |
thread = threading.Thread(target=cls.check_and_load_index,
|
125 |
-
args=(index_files, embeddings, logger,
|
126 |
thread.start()
|
127 |
local_db = result_queue.get() # Retrieve the result from the queue
|
128 |
-
return local_db
|
129 |
-
|
130 |
-
@classmethod
|
131 |
-
def query(cls, question: str, llm, index):
|
132 |
-
"""Query the vectorstore."""
|
133 |
-
llm = llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
|
134 |
-
chain = RetrievalQA.from_chain_type(
|
135 |
-
llm, retriever=index.as_retriever()
|
136 |
-
)
|
137 |
-
return chain.run(question)
|
138 |
|
139 |
|
140 |
if __name__ == '__main__':
|
|
|
16 |
import queue
|
17 |
|
18 |
|
19 |
+
class Query:
|
20 |
+
def __init__(self, question, llm, index):
|
21 |
+
self.question = question
|
22 |
+
self.llm = llm
|
23 |
+
self.index = index
|
24 |
+
|
25 |
+
def query(self):
|
26 |
+
"""Query the vectorstore."""
|
27 |
+
llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
|
28 |
+
chain = RetrievalQA.from_chain_type(
|
29 |
+
llm, retriever=self.index.as_retriever()
|
30 |
+
)
|
31 |
+
return chain.run(self.question)
|
32 |
+
|
33 |
+
|
34 |
class SearchableIndex:
|
35 |
def __init__(self, path):
|
36 |
self.path = path
|
37 |
|
38 |
+
@classmethod
|
39 |
+
def get_splits(cls, path, target_col=None, sheet_name=None):
|
40 |
+
extension = os.path.splitext(path)[1].lower()
|
41 |
+
doc_list = None
|
42 |
+
if extension == ".txt":
|
43 |
+
with open(path, 'r') as txt:
|
44 |
+
data = txt.read()
|
45 |
+
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
|
46 |
+
chunk_overlap=0,
|
47 |
+
length_function=len)
|
48 |
+
doc_list = text_split.split_text(data)
|
49 |
+
elif extension == ".pdf":
|
50 |
+
loader = PyPDFLoader(path)
|
51 |
+
pages = loader.load_and_split()
|
52 |
+
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
|
53 |
+
chunk_overlap=0,
|
54 |
+
length_function=len)
|
55 |
+
doc_list = []
|
56 |
+
for pg in pages:
|
57 |
+
pg_splits = text_split.split_text(pg.page_content)
|
58 |
+
doc_list.extend(pg_splits)
|
59 |
+
elif extension == ".xml":
|
60 |
+
df = pd.read_excel(io=path, engine='openpyxl', sheet_name=sheet_name)
|
61 |
+
df_loader = DataFrameLoader(df, page_content_column=target_col)
|
62 |
+
doc_list = df_loader.load()
|
63 |
+
elif extension == ".csv":
|
64 |
+
csv_loader = CSVLoader(path)
|
65 |
+
doc_list = csv_loader.load()
|
66 |
+
if doc_list is None:
|
67 |
+
raise ValueError("Unsupported file format")
|
68 |
return doc_list
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
@classmethod
|
71 |
def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
|
72 |
if os.path.exists(index_store):
|
73 |
local_db = FAISS.load_local(index_store, embeddings)
|
74 |
local_db.merge_from(faiss_db)
|
|
|
75 |
local_db.save_local(index_store)
|
76 |
+
logger.info("Merge index completed")
|
77 |
else:
|
78 |
faiss_db.save_local(folder_path=index_store)
|
79 |
logger.info("New store created and loaded...")
|
80 |
local_db = FAISS.load_local(index_store, embeddings)
|
81 |
+
return local_db
|
82 |
|
83 |
@classmethod
|
84 |
+
def check_and_load_index(cls, index_files, embeddings, logger, result_queue):
|
85 |
if index_files:
|
86 |
local_db = FAISS.load_local(index_files[0], embeddings)
|
|
|
|
|
|
|
87 |
else:
|
88 |
raise logger.warning("Index store does not exist")
|
89 |
result_queue.put(local_db) # Put the result in the queue
|
90 |
|
91 |
@classmethod
|
92 |
+
def embed_index(cls, url, path, llm, prompt, target_col=None, sheet_name=None):
|
93 |
embeddings = OpenAIEmbeddings()
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if url != 'NO_URL' and path:
|
96 |
+
doc_list = cls.get_splits(path, target_col, sheet_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
faiss_db = FAISS.from_texts(doc_list, embeddings)
|
98 |
index_store = os.path.splitext(path)[0] + "_index"
|
99 |
local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
|
100 |
+
return Query(prompt, llm, local_db)
|
101 |
elif url == 'NO_URL' and path:
|
102 |
index_files = glob.glob(os.path.join(path, '*_index'))
|
103 |
|
104 |
result_queue = queue.Queue() # Create a queue to store the result
|
105 |
|
106 |
thread = threading.Thread(target=cls.check_and_load_index,
|
107 |
+
args=(index_files, embeddings, logger, result_queue))
|
108 |
thread.start()
|
109 |
local_db = result_queue.get() # Retrieve the result from the queue
|
110 |
+
return Query(prompt, llm, local_db)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
if __name__ == '__main__':
|
learning_documents/combined_content_index/index.faiss
ADDED
Binary file (651 kB). View file
|
|
learning_documents/combined_content_index/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7246a7c708997772e992539fa7cef62d0e33a4a77a03f6483be6a108106a7c1c
|
3 |
+
size 100825
|