jerpint commited on
Commit
d16a006
1 Parent(s): 25a0d11

compartmentalize buster config

Browse files
buster/busterbot.py CHANGED
@@ -6,9 +6,10 @@ import numpy as np
6
  import pandas as pd
7
  from openai.embeddings_utils import cosine_similarity, get_embedding
8
 
9
- from buster.completers import get_completer
10
  from buster.completers.base import Completion
11
- from buster.formatters.prompts import SystemPromptFormatter
 
12
 
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(level=logging.INFO)
@@ -23,36 +24,30 @@ class Response:
23
 
24
  @dataclass
25
  class BusterConfig:
26
- """Configuration object for a chatbot.
27
-
28
- documents_csv: Path to the csv file containing the documents and their embeddings.
29
- embedding_model: OpenAI model to use to get embeddings.
30
- top_k: Max number of documents to retrieve, ordered by cosine similarity
31
- thresh: threshold for cosine similarity to be considered
32
- max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
33
- completion_kwargs: kwargs for the OpenAI.Completion() method
34
- separator: the separator to use, can be either "\n" or <p> depending on rendering.
35
- response_format: the type of format to render links with, e.g. slack or markdown
36
- unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
37
- text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
38
- reponse_footnote: Generic response to add the the chatbot's reply.
39
- source: the source of the document to consider
40
- """
41
-
42
- documents_file: str = ""
43
  embedding_model: str = "text-embedding-ada-002"
44
- top_k: int = 3
45
- thresh: float = 0.7
46
- max_words: int = 3000
47
- unknown_threshold: float = 0.9 # set to 0 to deactivate
48
- completer_cfg: dict = field(
49
- # TODO: Put all this in its own config with sane defaults?
50
  default_factory=lambda: {
51
- "name": "GPT3",
 
 
 
 
 
 
52
  "text_before_documents": "You are a chatbot answering questions.\n",
53
  "text_before_prompt": "Answer the following question:\n",
 
 
 
 
 
54
  "completion_kwargs": {
55
- "engine": "text-davinci-003",
56
  "max_tokens": 200,
57
  "temperature": None,
58
  "top_p": None,
@@ -61,18 +56,11 @@ class BusterConfig:
61
  },
62
  }
63
  )
64
- unknown_prompt: str = "I Don't know how to answer your question."
65
- response_format: str = "slack"
66
- source: str = ""
67
-
68
-
69
- from buster.retriever import Retriever
70
 
71
 
72
  class Buster:
73
  def __init__(self, cfg: BusterConfig, retriever: Retriever):
74
  self._unk_embedding = None
75
- self.cfg = cfg
76
  self.update_cfg(cfg)
77
 
78
  self.retriever = retriever
@@ -89,16 +77,23 @@ class Buster:
89
 
90
  def update_cfg(self, cfg: BusterConfig):
91
  """Every time we set a new config, we update the things that need to be updated."""
92
- logger.info(f"Updating config to {cfg.source}:\n{cfg}")
93
- self.cfg = cfg
94
- self.completer = get_completer(cfg.completer_cfg)
95
- self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model)
96
-
97
- self.prompt_formatter = SystemPromptFormatter(
98
- text_before_docs=self.cfg.completer_cfg["text_before_documents"],
99
- text_after_docs=self.cfg.completer_cfg["text_before_prompt"],
100
- max_words=self.cfg.max_words,
101
- )
 
 
 
 
 
 
 
102
 
103
  logger.info(f"Config Updated.")
104
 
@@ -129,9 +124,8 @@ class Buster:
129
  logger.info(f"matched documents before thresh: {matched_documents}")
130
 
131
  # filter out matched_documents using a threshold
132
- if thresh:
133
- matched_documents = matched_documents[matched_documents.similarity > thresh]
134
- logger.info(f"matched documents after thresh: {matched_documents}")
135
 
136
  return matched_documents
137
 
@@ -168,10 +162,10 @@ class Buster:
168
 
169
  matched_documents = self.rank_documents(
170
  query=user_input,
171
- top_k=self.cfg.top_k,
172
- thresh=self.cfg.thresh,
173
- engine=self.cfg.embedding_model,
174
- source=self.cfg.source,
175
  )
176
 
177
  if len(matched_documents) == 0:
@@ -189,15 +183,15 @@ class Buster:
189
  # check for relevance
190
  is_relevant = self.check_response_relevance(
191
  completion_text=completion.text,
192
- engine=self.cfg.embedding_model,
193
  unk_embedding=self.unk_embedding,
194
- unk_threshold=self.cfg.unknown_threshold,
195
  )
196
  if not is_relevant:
197
  matched_documents = pd.DataFrame(columns=matched_documents.columns)
198
  # answer generated was the chatbot saying it doesn't know how to answer
199
  # uncomment override completion with unknown prompt
200
- # completion = Completion(text=self.cfg.unknown_prompt)
201
 
202
  response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
203
  return response
 
6
  import pandas as pd
7
  from openai.embeddings_utils import cosine_similarity, get_embedding
8
 
9
+ from buster.completers import completer_factory
10
  from buster.completers.base import Completion
11
+ from buster.formatters.prompts import SystemPromptFormatter, prompt_formatter_factory
12
+ from buster.retriever import Retriever
13
 
14
  logger = logging.getLogger(__name__)
15
  logging.basicConfig(level=logging.INFO)
 
24
 
25
  @dataclass
26
  class BusterConfig:
27
+ """Configuration object for a chatbot."""
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  embedding_model: str = "text-embedding-ada-002"
30
+ unknown_threshold: float = 0.9
31
+ unknown_prompt: str = "I Don't know how to answer your question."
32
+ document_source: str = ""
33
+ retriever_cfg: dict = field(
 
 
34
  default_factory=lambda: {
35
+ "top_k": 3,
36
+ "thresh": 0.7,
37
+ }
38
+ )
39
+ prompt_cfg: dict = field(
40
+ default_factory=lambda: {
41
+ "max_words": 3000,
42
  "text_before_documents": "You are a chatbot answering questions.\n",
43
  "text_before_prompt": "Answer the following question:\n",
44
+ }
45
+ )
46
+ completion_cfg: dict = field(
47
+ default_factory=lambda: {
48
+ "name": "ChatGPT",
49
  "completion_kwargs": {
50
+ "engine": "gpt-3.5-turbo",
51
  "max_tokens": 200,
52
  "temperature": None,
53
  "top_p": None,
 
56
  },
57
  }
58
  )
 
 
 
 
 
 
59
 
60
 
61
  class Buster:
62
  def __init__(self, cfg: BusterConfig, retriever: Retriever):
63
  self._unk_embedding = None
 
64
  self.update_cfg(cfg)
65
 
66
  self.retriever = retriever
 
77
 
78
  def update_cfg(self, cfg: BusterConfig):
79
  """Every time we set a new config, we update the things that need to be updated."""
80
+ logger.info(f"Updating config to {cfg.document_source}:\n{cfg}")
81
+ self._cfg = cfg
82
+ self.embedding_model = cfg.embedding_model
83
+ self.unknown_threshold = cfg.unknown_threshold
84
+ self.unknown_prompt = cfg.unknown_prompt
85
+ self.document_source = cfg.document_source
86
+
87
+ self.retriever_cfg = cfg.retriever_cfg
88
+ self.completion_cfg = cfg.completion_cfg
89
+ self.prompt_cfg = cfg.prompt_cfg
90
+
91
+ # set the unk. embedding
92
+ self.unk_embedding = self.get_embedding(self.unknown_prompt, engine=self.embedding_model)
93
+
94
+ # update completer and formatter cfg
95
+ self.completer = completer_factory(self.completion_cfg)
96
+ self.prompt_formatter = prompt_formatter_factory(self.prompt_cfg)
97
 
98
  logger.info(f"Config Updated.")
99
 
 
124
  logger.info(f"matched documents before thresh: {matched_documents}")
125
 
126
  # filter out matched_documents using a threshold
127
+ matched_documents = matched_documents[matched_documents.similarity > thresh]
128
+ logger.info(f"matched documents after thresh: {matched_documents}")
 
129
 
130
  return matched_documents
131
 
 
162
 
163
  matched_documents = self.rank_documents(
164
  query=user_input,
165
+ top_k=self.retriever_cfg["top_k"],
166
+ thresh=self.retriever_cfg["thresh"],
167
+ engine=self.embedding_model,
168
+ source=self.document_source,
169
  )
170
 
171
  if len(matched_documents) == 0:
 
183
  # check for relevance
184
  is_relevant = self.check_response_relevance(
185
  completion_text=completion.text,
186
+ engine=self.embedding_model,
187
  unk_embedding=self.unk_embedding,
188
+ unk_threshold=self.unknown_threshold,
189
  )
190
  if not is_relevant:
191
  matched_documents = pd.DataFrame(columns=matched_documents.columns)
192
  # answer generated was the chatbot saying it doesn't know how to answer
193
  # uncomment override completion with unknown prompt
194
+ # completion = Completion(text=self.unknown_prompt)
195
 
196
  response = Response(completion=completion, matched_documents=matched_documents, is_relevant=is_relevant)
197
  return response
buster/completers/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- from .base import ChatGPTCompleter, GPT3Completer, get_completer
2
 
3
  __all__ = [
4
- get_completer,
5
  GPT3Completer,
6
  ChatGPTCompleter,
7
  ]
 
1
+ from .base import ChatGPTCompleter, GPT3Completer, completer_factory
2
 
3
  __all__ = [
4
+ completer_factory,
5
  GPT3Completer,
6
  ChatGPTCompleter,
7
  ]
buster/completers/base.py CHANGED
@@ -91,7 +91,7 @@ class ChatGPTCompleter(Completer):
91
  return response["choices"][0]["message"]["content"]
92
 
93
 
94
- def get_completer(completer_cfg):
95
  name = completer_cfg["name"]
96
  completers = {
97
  "GPT3": GPT3Completer,
 
91
  return response["choices"][0]["message"]["content"]
92
 
93
 
94
+ def completer_factory(completer_cfg):
95
  name = completer_cfg["name"]
96
  completers = {
97
  "GPT3": GPT3Completer,
buster/examples/cfg.py CHANGED
@@ -2,13 +2,20 @@ from buster.busterbot import BusterConfig
2
 
3
  documents_filepath = "./documents.db"
4
  buster_cfg = BusterConfig(
5
- unknown_prompt="I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
6
  embedding_model="text-embedding-ada-002",
7
- top_k=3,
8
- thresh=0.7,
9
- max_words=3000,
10
- completer_cfg={
 
 
11
  "name": "ChatGPT",
 
 
 
 
 
 
12
  "text_before_documents": (
13
  "You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
14
  "You can only respond to a question if the content necessary to answer the question is contained in the following provided documentation. "
@@ -34,10 +41,6 @@ buster_cfg = BusterConfig(
34
  "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?"
35
  "Now answer the following question:\n"
36
  ),
37
- "completion_kwargs": {
38
- "model": "gpt-3.5-turbo",
39
- },
40
  },
41
- response_format="gradio",
42
- source="stackoverflow",
43
  )
 
2
 
3
  documents_filepath = "./documents.db"
4
  buster_cfg = BusterConfig(
 
5
  embedding_model="text-embedding-ada-002",
6
+ unknown_prompt="I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
7
+ retriever_cfg={
8
+ "top_k": 3,
9
+ "thresh": 0.7,
10
+ },
11
+ completion_cfg={
12
  "name": "ChatGPT",
13
+ "completion_kwargs": {
14
+ "model": "gpt-3.5-turbo",
15
+ },
16
+ },
17
+ prompt_cfg={
18
+ "max_words": 3000,
19
  "text_before_documents": (
20
  "You are a chatbot assistant answering technical questions about artificial intelligence (AI)."
21
  "You can only respond to a question if the content necessary to answer the question is contained in the following provided documentation. "
 
41
  "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?"
42
  "Now answer the following question:\n"
43
  ),
 
 
 
44
  },
45
+ document_source="stackoverflow",
 
46
  )
buster/formatters/prompts.py CHANGED
@@ -40,3 +40,11 @@ class SystemPromptFormatter:
40
  documents = self.format_documents(matched_documents, max_words=self.max_words)
41
  system_prompt = self.text_before_docs + documents + self.text_after_docs
42
  return system_prompt
 
 
 
 
 
 
 
 
 
40
  documents = self.format_documents(matched_documents, max_words=self.max_words)
41
  system_prompt = self.text_before_docs + documents + self.text_after_docs
42
  return system_prompt
43
+
44
+
45
+ def prompt_formatter_factory(prompt_cfg):
46
+ return SystemPromptFormatter(
47
+ text_before_docs=prompt_cfg["text_before_documents"],
48
+ text_after_docs=prompt_cfg["text_before_prompt"],
49
+ max_words=prompt_cfg["max_words"],
50
+ )
tests/test_chatbot.py CHANGED
@@ -60,7 +60,9 @@ logging.basicConfig(level=logging.INFO)
60
  def test_chatbot_mock_data(tmp_path, monkeypatch):
61
  gpt_expected_answer = "this is GPT answer"
62
  monkeypatch.setattr(Buster, "get_embedding", lambda self, prompt, engine: get_fake_embedding())
63
- monkeypatch.setattr("buster.busterbot.get_completer", lambda x: MockCompleter(expected_answer=gpt_expected_answer))
 
 
64
 
65
  hf_transformers_cfg = BusterConfig(
66
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
 
60
  def test_chatbot_mock_data(tmp_path, monkeypatch):
61
  gpt_expected_answer = "this is GPT answer"
62
  monkeypatch.setattr(Buster, "get_embedding", lambda self, prompt, engine: get_fake_embedding())
63
+ monkeypatch.setattr(
64
+ "buster.busterbot.completer_factory", lambda x: MockCompleter(expected_answer=gpt_expected_answer)
65
+ )
66
 
67
  hf_transformers_cfg = BusterConfig(
68
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",