Spaces:
Runtime error
Runtime error
compartmentalize buster config
Browse files- buster/busterbot.py +48 -54
- buster/completers/__init__.py +2 -2
- buster/completers/base.py +1 -1
- buster/examples/cfg.py +13 -10
- buster/formatters/prompts.py +8 -0
- tests/test_chatbot.py +3 -1
buster/busterbot.py
CHANGED
@@ -6,9 +6,10 @@ import numpy as np
|
|
6 |
import pandas as pd
|
7 |
from openai.embeddings_utils import cosine_similarity, get_embedding
|
8 |
|
9 |
-
from buster.completers import
|
10 |
from buster.completers.base import Completion
|
11 |
-
from buster.formatters.prompts import SystemPromptFormatter
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
logging.basicConfig(level=logging.INFO)
|
@@ -23,36 +24,30 @@ class Response:
|
|
23 |
|
24 |
@dataclass
|
25 |
class BusterConfig:
|
26 |
-
"""Configuration object for a chatbot.
|
27 |
-
|
28 |
-
documents_csv: Path to the csv file containing the documents and their embeddings.
|
29 |
-
embedding_model: OpenAI model to use to get embeddings.
|
30 |
-
top_k: Max number of documents to retrieve, ordered by cosine similarity
|
31 |
-
thresh: threshold for cosine similarity to be considered
|
32 |
-
max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
|
33 |
-
completion_kwargs: kwargs for the OpenAI.Completion() method
|
34 |
-
separator: the separator to use, can be either "\n" or <p> depending on rendering.
|
35 |
-
response_format: the type of format to render links with, e.g. slack or markdown
|
36 |
-
unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
|
37 |
-
text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
|
38 |
-
reponse_footnote: Generic response to add the the chatbot's reply.
|
39 |
-
source: the source of the document to consider
|
40 |
-
"""
|
41 |
-
|
42 |
-
documents_file: str = ""
|
43 |
embedding_model: str = "text-embedding-ada-002"
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
completer_cfg: dict = field(
|
49 |
-
# TODO: Put all this in its own config with sane defaults?
|
50 |
default_factory=lambda: {
|
51 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
"text_before_documents": "You are a chatbot answering questions.\n",
|
53 |
"text_before_prompt": "Answer the following question:\n",
|
|
|
|
|
|
|
|
|
|
|
54 |
"completion_kwargs": {
|
55 |
-
"engine": "
|
56 |
"max_tokens": 200,
|
57 |
"temperature": None,
|
58 |
"top_p": None,
|
@@ -61,18 +56,11 @@ class BusterConfig:
|
|
61 |
},
|
62 |
}
|
63 |
)
|
64 |
-
unknown_prompt: str = "I Don't know how to answer your question."
|
65 |
-
response_format: str = "slack"
|
66 |
-
source: str = ""
|
67 |
-
|
68 |
-
|
69 |
-
from buster.retriever import Retriever
|
70 |
|
71 |
|
72 |
class Buster:
|
73 |
def __init__(self, cfg: BusterConfig, retriever: Retriever):
|
74 |
self._unk_embedding = None
|
75 |
-
self.cfg = cfg
|
76 |
self.update_cfg(cfg)
|
77 |
|
78 |
self.retriever = retriever
|
@@ -89,16 +77,23 @@ class Buster:
|
|
89 |
|
90 |
def update_cfg(self, cfg: BusterConfig):
|
91 |
"""Every time we set a new config, we update the things that need to be updated."""
|
92 |
-
logger.info(f"Updating config to {cfg.
|
93 |
-
self.
|
94 |
-
self.
|
95 |
-
self.
|
96 |
-
|
97 |
-
self.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
logger.info(f"Config Updated.")
|
104 |
|
@@ -129,9 +124,8 @@ class Buster:
|
|
129 |
logger.info(f"matched documents before thresh: {matched_documents}")
|
130 |
|
131 |
# filter out matched_documents using a threshold
|
132 |
-
|
133 |
-
|
134 |
-
logger.info(f"matched documents after thresh: {matched_documents}")
|
135 |
|
136 |
return matched_documents
|
137 |
|
@@ -168,10 +162,10 @@ class Buster:
|
|
168 |
|
169 |
matched_documents = self.rank_documents(
|
170 |
query=user_input,
|
171 |
-
top_k=self.
|
172 |
-
thresh=self.
|
173 |
-
engine=self.
|
174 |
-
source=self.
|
175 |
)
|
176 |
|
177 |
if len(matched_documents) == 0:
|
@@ -189,15 +183,15 @@ class Buster:
|
|
189 |
# check for relevance
|
190 |
is_relevant = self.check_response_relevance(
|
191 |
completion_text=completion.text,
|
192 |
-
engine=self.
|
193 |
unk_embedding=self.unk_embedding,
|
194 |
-
unk_threshold=self.
|
195 |
)
|
196 |
if not is_relevant:
|
197 |
matched_documents = pd.DataFrame(columns=matched_documents.columns)
|
198 |
# answer generated was the chatbot saying it doesn't know how to answer
|
199 |
# uncomment override completion with unknown prompt
|
200 |
-
# completion = Completion(text=self.
|
201 |
|
202 |
response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
|
203 |
return response
|
|
|
6 |
import pandas as pd
|
7 |
from openai.embeddings_utils import cosine_similarity, get_embedding
|
8 |
|
9 |
+
from buster.completers import completer_factory
|
10 |
from buster.completers.base import Completion
|
11 |
+
from buster.formatters.prompts import SystemPromptFormatter, prompt_formatter_factory
|
12 |
+
from buster.retriever import Retriever
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
logging.basicConfig(level=logging.INFO)
|
|
|
24 |
|
25 |
@dataclass
|
26 |
class BusterConfig:
|
27 |
+
"""Configuration object for a chatbot."""
|
28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
embedding_model: str = "text-embedding-ada-002"
|
30 |
+
unknown_threshold: float = 0.9
|
31 |
+
unknown_prompt: str = "I Don't know how to answer your question."
|
32 |
+
document_source: str = ""
|
33 |
+
retriever_cfg: dict = field(
|
|
|
|
|
34 |
default_factory=lambda: {
|
35 |
+
"top_k": 3,
|
36 |
+
"thresh": 0.7,
|
37 |
+
}
|
38 |
+
)
|
39 |
+
prompt_cfg: dict = field(
|
40 |
+
default_factory=lambda: {
|
41 |
+
"max_words": 3000,
|
42 |
"text_before_documents": "You are a chatbot answering questions.\n",
|
43 |
"text_before_prompt": "Answer the following question:\n",
|
44 |
+
}
|
45 |
+
)
|
46 |
+
completion_cfg: dict = field(
|
47 |
+
default_factory=lambda: {
|
48 |
+
"name": "ChatGPT",
|
49 |
"completion_kwargs": {
|
50 |
+
"engine": "gpt-3.5-turbo",
|
51 |
"max_tokens": 200,
|
52 |
"temperature": None,
|
53 |
"top_p": None,
|
|
|
56 |
},
|
57 |
}
|
58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
class Buster:
|
62 |
def __init__(self, cfg: BusterConfig, retriever: Retriever):
|
63 |
self._unk_embedding = None
|
|
|
64 |
self.update_cfg(cfg)
|
65 |
|
66 |
self.retriever = retriever
|
|
|
77 |
|
78 |
def update_cfg(self, cfg: BusterConfig):
|
79 |
"""Every time we set a new config, we update the things that need to be updated."""
|
80 |
+
logger.info(f"Updating config to {cfg.document_source}:\n{cfg}")
|
81 |
+
self._cfg = cfg
|
82 |
+
self.embedding_model = cfg.embedding_model
|
83 |
+
self.unknown_threshold = cfg.unknown_threshold
|
84 |
+
self.unknown_prompt = cfg.unknown_prompt
|
85 |
+
self.document_source = cfg.document_source
|
86 |
+
|
87 |
+
self.retriever_cfg = cfg.retriever_cfg
|
88 |
+
self.completion_cfg = cfg.completion_cfg
|
89 |
+
self.prompt_cfg = cfg.prompt_cfg
|
90 |
+
|
91 |
+
# set the unk. embedding
|
92 |
+
self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model)
|
93 |
+
|
94 |
+
# update completer and formatter cfg
|
95 |
+
self.completer = completer_factory(self.completion_cfg)
|
96 |
+
self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg)
|
97 |
|
98 |
logger.info(f"Config Updated.")
|
99 |
|
|
|
124 |
logger.info(f"matched documents before thresh: {matched_documents}")
|
125 |
|
126 |
# filter out matched_documents using a threshold
|
127 |
+
matched_documents = matched_documents[matched_documents.similarity > thresh]
|
128 |
+
logger.info(f"matched documents after thresh: {matched_documents}")
|
|
|
129 |
|
130 |
return matched_documents
|
131 |
|
|
|
162 |
|
163 |
matched_documents = self.rank_documents(
|
164 |
query=user_input,
|
165 |
+
top_k=self.retriever_cfg["top_k"],
|
166 |
+
thresh=self.retriever_cfg["thresh"],
|
167 |
+
engine=self.embedding_model,
|
168 |
+
source=self.document_source,
|
169 |
)
|
170 |
|
171 |
if len(matched_documents) == 0:
|
|
|
183 |
# check for relevance
|
184 |
is_relevant = self.check_response_relevance(
|
185 |
completion_text=completion.text,
|
186 |
+
engine=self.embedding_model,
|
187 |
unk_embedding=self.unk_embedding,
|
188 |
+
unk_threshold=self.unknown_threshold,
|
189 |
)
|
190 |
if not is_relevant:
|
191 |
matched_documents = pd.DataFrame(columns=matched_documents.columns)
|
192 |
# answer generated was the chatbot saying it doesn't know how to answer
|
193 |
# uncomment override completion with unknown prompt
|
194 |
+
# completion = Completion(text=self.unknown_prompt)
|
195 |
|
196 |
response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
|
197 |
return response
|
buster/completers/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from .base import ChatGPTCompleter, GPT3Completer,
|
2 |
|
3 |
__all__ = [
|
4 |
-
|
5 |
GPT3Completer,
|
6 |
ChatGPTCompleter,
|
7 |
]
|
|
|
1 |
+
from .base import ChatGPTCompleter, GPT3Completer, completer_factory
|
2 |
|
3 |
__all__ = [
|
4 |
+
completer_factory,
|
5 |
GPT3Completer,
|
6 |
ChatGPTCompleter,
|
7 |
]
|
buster/completers/base.py
CHANGED
@@ -91,7 +91,7 @@ class ChatGPTCompleter(Completer):
|
|
91 |
return response["choices"][0]["message"]["content"]
|
92 |
|
93 |
|
94 |
-
def
|
95 |
name = completer_cfg["name"]
|
96 |
completers = {
|
97 |
"GPT3": GPT3Completer,
|
|
|
91 |
return response["choices"][0]["message"]["content"]
|
92 |
|
93 |
|
94 |
+
def completer_factory(completer_cfg):
|
95 |
name = completer_cfg["name"]
|
96 |
completers = {
|
97 |
"GPT3": GPT3Completer,
|
buster/examples/cfg.py
CHANGED
@@ -2,13 +2,20 @@ from buster.busterbot import BusterConfig
|
|
2 |
|
3 |
documents_filepath = "./documents.db"
|
4 |
buster_cfg = BusterConfig(
|
5 |
-
unknown_prompt="I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
|
6 |
embedding_model="text-embedding-ada-002",
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
"name": "ChatGPT",
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"text_before_documents": (
|
13 |
"You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
|
14 |
"You can only respond to a question if the content necessary to answer the question is contained in the following provided documentation. "
|
@@ -34,10 +41,6 @@ buster_cfg = BusterConfig(
|
|
34 |
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?"
|
35 |
"Now answer the following question:\n"
|
36 |
),
|
37 |
-
"completion_kwargs": {
|
38 |
-
"model": "gpt-3.5-turbo",
|
39 |
-
},
|
40 |
},
|
41 |
-
|
42 |
-
source="stackoverflow",
|
43 |
)
|
|
|
2 |
|
3 |
documents_filepath = "./documents.db"
|
4 |
buster_cfg = BusterConfig(
|
|
|
5 |
embedding_model="text-embedding-ada-002",
|
6 |
+
unknown_prompt="I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
|
7 |
+
retriever_cfg={
|
8 |
+
"top_k": 3,
|
9 |
+
"thresh": 0.7,
|
10 |
+
},
|
11 |
+
completion_cfg={
|
12 |
"name": "ChatGPT",
|
13 |
+
"completion_kwargs": {
|
14 |
+
"model": "gpt-3.5-turbo",
|
15 |
+
},
|
16 |
+
},
|
17 |
+
prompt_cfg={
|
18 |
+
"max_words": 3000,
|
19 |
"text_before_documents": (
|
20 |
"You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
|
21 |
"You can only respond to a question if the content necessary to answer the question is contained in the following provided documentation. "
|
|
|
41 |
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?"
|
42 |
"Now answer the following question:\n"
|
43 |
),
|
|
|
|
|
|
|
44 |
},
|
45 |
+
document_source="stackoverflow",
|
|
|
46 |
)
|
buster/formatters/prompts.py
CHANGED
@@ -40,3 +40,11 @@ class SystemPromptFormatter:
|
|
40 |
documents = self.format_documents(matched_documents, max_words=self.max_words)
|
41 |
system_prompt = self.text_before_docs + documents + self.text_after_docs
|
42 |
return system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
documents = self.format_documents(matched_documents, max_words=self.max_words)
|
41 |
system_prompt = self.text_before_docs + documents + self.text_after_docs
|
42 |
return system_prompt
|
43 |
+
|
44 |
+
|
45 |
+
def prompt_formatter_factory(prompt_cfg):
|
46 |
+
return SystemPromptFormatter(
|
47 |
+
text_before_docs=prompt_cfg["text_before_documents"],
|
48 |
+
text_after_docs=prompt_cfg["text_before_prompt"],
|
49 |
+
max_words=prompt_cfg["max_words"],
|
50 |
+
)
|
tests/test_chatbot.py
CHANGED
@@ -60,7 +60,9 @@ logging.basicConfig(level=logging.INFO)
|
|
60 |
def test_chatbot_mock_data(tmp_path, monkeypatch):
|
61 |
gpt_expected_answer = "this is GPT answer"
|
62 |
monkeypatch.setattr(Buster, "get_embedding", lambda self, prompt, engine: get_fake_embedding())
|
63 |
-
monkeypatch.setattr(
|
|
|
|
|
64 |
|
65 |
hf_transformers_cfg = BusterConfig(
|
66 |
unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
|
|
|
60 |
def test_chatbot_mock_data(tmp_path, monkeypatch):
|
61 |
gpt_expected_answer = "this is GPT answer"
|
62 |
monkeypatch.setattr(Buster, "get_embedding", lambda self, prompt, engine: get_fake_embedding())
|
63 |
+
monkeypatch.setattr(
|
64 |
+
"buster.busterbot.completer_factory", lambda x: MockCompleter(expected_answer=gpt_expected_answer)
|
65 |
+
)
|
66 |
|
67 |
hf_transformers_cfg = BusterConfig(
|
68 |
unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
|