elyxlz
commited on
Commit
·
0c7add2
1
Parent(s):
be8dc83
initial commit
Browse files- .gitattributes +1 -33
- .gitignore +2 -0
- app.py +104 -0
- config/conf_0.1.yaml +52 -0
- data/store.pkl +3 -0
- ingest.py +78 -0
- modules.py +75 -0
- requirements.txt +7 -0
- tools.py +76 -0
.gitattributes
CHANGED
@@ -1,34 +1,2 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.pkl* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import gradio as gr
|
4 |
+
import argparse
|
5 |
+
import datetime
|
6 |
+
import pickle
|
7 |
+
#import whisper
|
8 |
+
import dotenv
|
9 |
+
import sys
|
10 |
+
from io import StringIO
|
11 |
+
import re
|
12 |
+
dotenv.load_dotenv()
|
13 |
+
|
14 |
+
from langchain.callbacks import get_openai_callback
|
15 |
+
|
16 |
+
import hydra
|
17 |
+
from omegaconf import DictConfig, open_dict, OmegaConf
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class ChatbotAgentGradio():
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
config_name
|
25 |
+
):
|
26 |
+
|
27 |
+
config = OmegaConf.load(f'./config/{config_name}.yaml')
|
28 |
+
self.chatbot = hydra.utils.instantiate(config.model, _convert_="partial")
|
29 |
+
|
30 |
+
def chat(self,
|
31 |
+
inp: str,
|
32 |
+
history: Optional[Tuple[str, str]],
|
33 |
+
):
|
34 |
+
|
35 |
+
"""Method for integration with gradio Chatbot"""
|
36 |
+
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
37 |
+
print("inp: " + inp)
|
38 |
+
history = history or []
|
39 |
+
|
40 |
+
output = self.chatbot.run(inp)
|
41 |
+
|
42 |
+
|
43 |
+
history.append((inp, output))
|
44 |
+
|
45 |
+
return history, history#, ""
|
46 |
+
|
47 |
+
def update_foo(self, widget, state):
|
48 |
+
if widget:
|
49 |
+
state = widget
|
50 |
+
return state
|
51 |
+
|
52 |
+
def launch_app(self):
|
53 |
+
|
54 |
+
block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
|
55 |
+
|
56 |
+
with block:
|
57 |
+
instance = gr.State()
|
58 |
+
show_chain_state = gr.State(False)
|
59 |
+
|
60 |
+
|
61 |
+
with gr.Row():
|
62 |
+
gr.Markdown("<h3><center>UNHCR</center></h3>")
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
chatbot = gr.Chatbot()
|
68 |
+
|
69 |
+
|
70 |
+
with gr.Row():
|
71 |
+
message = gr.Textbox(
|
72 |
+
label="What's your question?",
|
73 |
+
lines=1,
|
74 |
+
)
|
75 |
+
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
|
76 |
+
|
77 |
+
|
78 |
+
state = gr.State()
|
79 |
+
agent_state = gr.State()
|
80 |
+
|
81 |
+
submit.click(self.chat, inputs=[message, state], outputs=[chatbot, state])
|
82 |
+
message.submit(self.chat, inputs=[message, state], outputs=[chatbot, state])
|
83 |
+
|
84 |
+
|
85 |
+
block.launch(debug=True, share=False, server_port=7861)#, server_name='192.168.0.73', )
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def simple(config):
|
90 |
+
config = OmegaConf.load(f'./config/{config}.yaml')
|
91 |
+
chatbot = hydra.utils.instantiate(config.model, _convert_="partial")
|
92 |
+
|
93 |
+
while True:
|
94 |
+
inp = input("\nUser: ")
|
95 |
+
print(chatbot.run(inp))
|
96 |
+
if __name__ == '__main__':
|
97 |
+
|
98 |
+
#simple('conf_0.1')
|
99 |
+
app = ChatbotAgentGradio('conf_0.1')
|
100 |
+
app.launch_app()
|
101 |
+
|
102 |
+
#QA = QA(store, k=1)
|
103 |
+
#app = QAGradio(QA)
|
104 |
+
#app.launch_app()
|
config/conf_0.1.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# chatbot model
|
3 |
+
model:
|
4 |
+
_target_: modules.initialize_agent
|
5 |
+
|
6 |
+
agent: "conversational-react-description" # langchain template for agent
|
7 |
+
|
8 |
+
tools:
|
9 |
+
- _target_: langchain.agents.Tool
|
10 |
+
name: "Content Search"
|
11 |
+
func:
|
12 |
+
_target_: tools.SemanticSearch
|
13 |
+
threshold: 0.5
|
14 |
+
k: 5
|
15 |
+
description: ^
|
16 |
+
A content search through the UNHCR documents, it will return relevant extracts for your query.
|
17 |
+
The action input should be a full english sentence.
|
18 |
+
ALWAYS use this to answer ANY question. If the tool doesn't return anything, say that you don't know.
|
19 |
+
|
20 |
+
|
21 |
+
llm:
|
22 |
+
_target_: langchain.llms.OpenAI
|
23 |
+
temperature: 0
|
24 |
+
openai_api_key: ${oc.env:OPENAI_API_KEY} # environment variable
|
25 |
+
|
26 |
+
memory:
|
27 |
+
_target_: langchain.chains.conversation.memory.ConversationBufferWindowMemory
|
28 |
+
memory_key: "chat_history"
|
29 |
+
k: 5 # how many of the past interactions it keeps
|
30 |
+
#verbose: True
|
31 |
+
|
32 |
+
prefix: |
|
33 |
+
- You are an AI whose purpose is to help answer questions about the UNHCR documents.
|
34 |
+
- You answer in a factual manner, always basing your answer on the context provided to you
|
35 |
+
- You are free to ignore irrelevant information
|
36 |
+
- If you do not know something, you will say that you don't know.
|
37 |
+
- Give long answers, answering every question with a lot of detail.
|
38 |
+
|
39 |
+
TOOLS:
|
40 |
+
------
|
41 |
+
You have access to the following tools:
|
42 |
+
|
43 |
+
suffix: |
|
44 |
+
Begin!
|
45 |
+
Previous conversation history:
|
46 |
+
{chat_history}
|
47 |
+
New input: {input}
|
48 |
+
{agent_scratchpad}
|
49 |
+
|
50 |
+
ai_prefix: "AI"
|
51 |
+
|
52 |
+
verbose: True
|
data/store.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cccec9eb3ff0488f652e5db4ec9f263979c25deaecf773b4413c108f5493fb0e
|
3 |
+
size 6998194
|
ingest.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import faiss
|
3 |
+
import pickle
|
4 |
+
from PyPDF2 import PdfReader
|
5 |
+
from tqdm import tqdm
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
|
10 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
11 |
+
from langchain.text_splitter import CharacterTextSplitter
|
12 |
+
from langchain.vectorstores import FAISS
|
13 |
+
from langchain.document_loaders import TextLoader
|
14 |
+
|
15 |
+
|
16 |
+
import dotenv
|
17 |
+
|
18 |
+
dotenv.load_dotenv()
|
19 |
+
|
20 |
+
def get_all_pdf_filenames(paths, recursive):
|
21 |
+
extensions = ["pdf"]
|
22 |
+
filenames = []
|
23 |
+
for ext_name in extensions:
|
24 |
+
ext = f"**/*.{ext_name}" if recursive else f"*.{ext_name}"
|
25 |
+
for path in paths:
|
26 |
+
filenames.extend(glob.glob(os.path.join(path, ext), recursive=recursive))
|
27 |
+
return filenames
|
28 |
+
|
29 |
+
|
30 |
+
#all_pdf_paths = get_all_pdf_filenames(["/mnt/c/users/elio/Downloads/UNHCR Emergency Manual"], recursive=True)
|
31 |
+
#print(f"Found {len(all_pdf_paths)} PDF files")
|
32 |
+
#assert len(all_pdf_paths) > 0
|
33 |
+
#all_pdf_paths = ['/mnt/c/users/elio/Downloads/UNHCR Emergency Manual/UNHCR Emergency Manual/46a9e29a2.pdf']
|
34 |
+
|
35 |
+
class Ingester():
|
36 |
+
"""
|
37 |
+
Vectorises chunks of the data and puts source as metadata
|
38 |
+
"""
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
separator='\n',
|
42 |
+
chunk_overlap=200,
|
43 |
+
chunk_size=200,
|
44 |
+
):
|
45 |
+
|
46 |
+
self.splitter = CharacterTextSplitter(chunk_size=chunk_size, separator=separator, chunk_overlap=chunk_overlap)
|
47 |
+
|
48 |
+
def ingest(self, path):
|
49 |
+
#ps = get_all_pdf_filenames([path], recursive=True) # get paths
|
50 |
+
ps = ['/mnt/c/users/elio/Downloads/UNHCR Emergency Manual/UNHCR Emergency Manual/46a9e29a2.pdf']
|
51 |
+
data = []
|
52 |
+
sources = []
|
53 |
+
for p in tqdm(ps): # extract data from paths
|
54 |
+
reader = PdfReader(p)
|
55 |
+
page = '\n'.join([reader.pages[i].extract_text() for i in range(len(reader.pages))])
|
56 |
+
data.append(page)
|
57 |
+
sources.append(p)
|
58 |
+
|
59 |
+
docs = []
|
60 |
+
metadatas = []
|
61 |
+
for i, d in tqdm(enumerate(data)): # split text and make documents
|
62 |
+
splits = self.splitter.split_text(d)
|
63 |
+
if all(s != "" for s in splits):
|
64 |
+
docs.extend(splits)
|
65 |
+
metadatas.extend([{"source": sources[i]}] * len(splits))
|
66 |
+
|
67 |
+
assert len(docs) > 0
|
68 |
+
|
69 |
+
print("Extracting embeddings")
|
70 |
+
store = FAISS.from_texts(docs, OpenAIEmbeddings(), metadatas=metadatas)
|
71 |
+
|
72 |
+
with open(os.path.join('./data', 'store.pkl'), "wb") as f:
|
73 |
+
pickle.dump(store, f)
|
74 |
+
|
75 |
+
print(f"Saved store at {os.path.join('./data', 'store.pkl')}.")
|
76 |
+
|
77 |
+
ingester = Ingester(chunk_size=2000)
|
78 |
+
ingester.ingest("/mnt/c/users/elio/Downloads/UNHCR Emergency Manual")
|
modules.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""Load agent."""
|
3 |
+
from typing import Any, List, Optional
|
4 |
+
|
5 |
+
from langchain.agents.agent import AgentExecutor
|
6 |
+
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
7 |
+
from langchain.agents.tools import Tool
|
8 |
+
from langchain.callbacks.base import BaseCallbackManager
|
9 |
+
from langchain.llms.base import BaseLLM
|
10 |
+
|
11 |
+
|
12 |
+
def initialize_agent(
|
13 |
+
tools: List[Tool],
|
14 |
+
llm: BaseLLM,
|
15 |
+
agent: Optional[str] = None,
|
16 |
+
callback_manager: Optional[BaseCallbackManager] = None,
|
17 |
+
agent_path: Optional[str] = None,
|
18 |
+
prefix: Optional[str] = None,
|
19 |
+
suffix: Optional[str]= None,
|
20 |
+
ai_prefix: Optional[str] = None,
|
21 |
+
human_prefix: Optional[str] = None,
|
22 |
+
**kwargs: Any,
|
23 |
+
) -> AgentExecutor:
|
24 |
+
"""Load agent given tools and LLM.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
tools: List of tools this agent has access to.
|
28 |
+
llm: Language model to use as the agent.
|
29 |
+
agent: The agent to use. Valid options are:
|
30 |
+
`zero-shot-react-description`
|
31 |
+
`react-docstore`
|
32 |
+
`self-ask-with-search`
|
33 |
+
`conversational-react-description`
|
34 |
+
If None and agent_path is also None, will default to
|
35 |
+
`zero-shot-react-description`.
|
36 |
+
callback_manager: CallbackManager to use. Global callback manager is used if
|
37 |
+
not provided. Defaults to None.
|
38 |
+
agent_path: Path to serialized agent to use.
|
39 |
+
**kwargs: Additional key word arguments to pass to the agent.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
An agent.
|
43 |
+
"""
|
44 |
+
if agent is None and agent_path is None:
|
45 |
+
agent = "zero-shot-react-description"
|
46 |
+
if agent is not None and agent_path is not None:
|
47 |
+
raise ValueError(
|
48 |
+
"Both `agent` and `agent_path` are specified, "
|
49 |
+
"but at most only one should be."
|
50 |
+
)
|
51 |
+
if agent is not None:
|
52 |
+
if agent not in AGENT_TO_CLASS:
|
53 |
+
raise ValueError(
|
54 |
+
f"Got unknown agent type: {agent}. "
|
55 |
+
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
56 |
+
)
|
57 |
+
agent_cls = AGENT_TO_CLASS[agent]
|
58 |
+
agent_obj = agent_cls.from_llm_and_tools(
|
59 |
+
llm, tools, prefix=prefix, suffix=suffix, ai_prefix=ai_prefix, human_prefix=human_prefix, callback_manager=callback_manager # added prefix and suffix
|
60 |
+
)
|
61 |
+
elif agent_path is not None:
|
62 |
+
agent_obj = load_agent(
|
63 |
+
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
raise ValueError(
|
67 |
+
"Somehow both `agent` and `agent_path` are None, "
|
68 |
+
"this should never happen."
|
69 |
+
)
|
70 |
+
return AgentExecutor.from_agent_and_tools(
|
71 |
+
agent=agent_obj,
|
72 |
+
tools=tools,
|
73 |
+
callback_manager=callback_manager,
|
74 |
+
**kwargs,
|
75 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faiss-cpu
|
2 |
+
langchain
|
3 |
+
openai
|
4 |
+
numpy
|
5 |
+
gradio
|
6 |
+
PyPDF2
|
7 |
+
python-dotenv
|
tools.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
from PyPDF2 import PdfReader
|
5 |
+
import glob
|
6 |
+
from pathlib import Path
|
7 |
+
import re
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from langchain.chains import LLMChain
|
11 |
+
from langchain.llms import OpenAI
|
12 |
+
from langchain import PromptTemplate
|
13 |
+
|
14 |
+
from langchain.vectorstores import FAISS
|
15 |
+
from langchain.embeddings import OpenAIEmbeddings
|
16 |
+
|
17 |
+
import dotenv
|
18 |
+
dotenv.load_dotenv()
|
19 |
+
|
20 |
+
|
21 |
+
def call_semantic_api(query, store_path, k):
|
22 |
+
payload = {
|
23 |
+
"query": query,
|
24 |
+
"store_path": store_path,
|
25 |
+
"k": k,
|
26 |
+
}
|
27 |
+
|
28 |
+
# response = requests.post("http://localhost:3001/search", json=payload)
|
29 |
+
response = semantic_search.search(payload)
|
30 |
+
return response
|
31 |
+
|
32 |
+
|
33 |
+
class SemanticSearch():
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
threshold: float,
|
37 |
+
with_source=False,
|
38 |
+
k=5,
|
39 |
+
):
|
40 |
+
self.threshold = threshold
|
41 |
+
self.with_source = with_source
|
42 |
+
self.k = k
|
43 |
+
|
44 |
+
with open('./data/store.pkl', 'rb') as f:
|
45 |
+
self.db = pickle.load(f)
|
46 |
+
|
47 |
+
def __call__(self, query):
|
48 |
+
|
49 |
+
documents = self.db.similarity_search_with_score(query, k=self.k)
|
50 |
+
if len(documents) == 0:
|
51 |
+
return None
|
52 |
+
|
53 |
+
if not self.with_source:
|
54 |
+
output = '\n\n\n'.join([i[0].page_content for i in documents])
|
55 |
+
else:
|
56 |
+
output = '\n\n\n'.join([i[0].page_content + '\n\nSource:' + os.path.basename(
|
57 |
+
str(i[0].metadata['source']) + '\n') for i in documents])
|
58 |
+
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class ContentSearch():
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
semantic_search,
|
66 |
+
prompt_template,
|
67 |
+
):
|
68 |
+
self.semantic_search = semantic_search
|
69 |
+
self.prompt_template = prompt_template
|
70 |
+
|
71 |
+
def __call__(self, query):
|
72 |
+
content = self.semantic_search(query)
|
73 |
+
if content is None:
|
74 |
+
return "No results found"
|
75 |
+
else:
|
76 |
+
return self.prompt_template.format(content=content)
|