jerpint commited on
Commit
c5f5dc3
·
unverified ·
1 Parent(s): 7d4662a

fix bug when reading csv (#19)

Browse files

* fix bug when reading csv

* add pytorch bot

* Log the actual prompt

* update pytorch prompt

Files changed (3) hide show
  1. app.py +40 -3
  2. buster/chatbot.py +7 -11
  3. buster/docparser.py +4 -1
app.py CHANGED
@@ -6,10 +6,11 @@ from buster.chatbot import Chatbot, ChatbotConfig
6
 
7
  MILA_CLUSTER_CHANNEL = "C04LR4H9KQA"
8
  ORION_CHANNEL = "C04LYHGUYB0"
 
9
 
10
  buster_cfg = ChatbotConfig(
11
- documents_csv="buster/data/document_embeddings.csv",
12
- unknown_prompt="This doesn't seem to be related to cluster usage. I am not sure how to answer.",
13
  embedding_model="text-embedding-ada-002",
14
  top_k=3,
15
  thresh=0.7,
@@ -44,7 +45,7 @@ buster_cfg = ChatbotConfig(
44
  buster_chatbot = Chatbot(buster_cfg)
45
 
46
  orion_cfg = ChatbotConfig(
47
- documents_csv="buster/data/document_embeddings_orion.csv",
48
  unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
49
  embedding_model="text-embedding-ada-002",
50
  top_k=3,
@@ -76,6 +77,39 @@ orion_cfg = ChatbotConfig(
76
  )
77
  orion_chatbot = Chatbot(orion_cfg)
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  app = App(token=os.environ.get("SLACK_BOT_TOKEN"), signing_secret=os.environ.get("SLACK_SIGNING_SECRET"))
80
 
81
 
@@ -93,6 +127,9 @@ def respond_to_question(event, say):
93
  elif channel == ORION_CHANNEL:
94
  print("*******using ORION********")
95
  answer = orion_chatbot.process_input(text)
 
 
 
96
 
97
  # responds to the message in the thread
98
  thread_ts = event["event_ts"]
 
6
 
7
  MILA_CLUSTER_CHANNEL = "C04LR4H9KQA"
8
  ORION_CHANNEL = "C04LYHGUYB0"
9
+ PYTORCH_CHANNEL = "C04MEK6N882"
10
 
11
  buster_cfg = ChatbotConfig(
12
+ documents_file="buster/data/document_embeddings.csv",
13
+ unknown_prompt="This doesn't seem to be related to cluster usage.",
14
  embedding_model="text-embedding-ada-002",
15
  top_k=3,
16
  thresh=0.7,
 
45
  buster_chatbot = Chatbot(buster_cfg)
46
 
47
  orion_cfg = ChatbotConfig(
48
+ documents_file="buster/data/document_embeddings_orion.csv",
49
  unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
50
  embedding_model="text-embedding-ada-002",
51
  top_k=3,
 
77
  )
78
  orion_chatbot = Chatbot(orion_cfg)
79
 
80
+ pytorch_cfg = ChatbotConfig(
81
+ documents_file="buster/data/document_embeddings_pytorch.tar.gz",
82
+ unknown_prompt="This doesn't seem to be related to the pytorch library. I am not sure how to answer.",
83
+ embedding_model="text-embedding-ada-002",
84
+ top_k=3,
85
+ thresh=0.7,
86
+ max_chars=3000,
87
+ completion_kwargs={
88
+ "engine": "text-davinci-003",
89
+ "max_tokens": 500,
90
+ },
91
+ separator="\n",
92
+ link_format="slack",
93
+ text_after_response="I'm a bot 🤖 and not always perfect.",
94
+ text_before_prompt="""You are a slack chatbot assistant answering technical questions about pytorch, a library to train neural networks written in python.
95
+ Make sure to format your answers in Markdown format, including code block and snippets.
96
+ Do not include any links to urls or hyperlinks in your answers.
97
+
98
+ If you do not know the answer to a question, or if it is completely irrelevant to the library usage, simply reply with:
99
+
100
+ 'This doesn't seem to be related to the pytorch library.'
101
+
102
+ For example:
103
+
104
+ What is the meaning of life for pytorch?
105
+
106
+ This doesn't seem to be related to cluster usage.
107
+
108
+ Now answer the following question:
109
+ """,
110
+ )
111
+ pytorch_chatbot = Chatbot(pytorch_cfg)
112
+
113
  app = App(token=os.environ.get("SLACK_BOT_TOKEN"), signing_secret=os.environ.get("SLACK_SIGNING_SECRET"))
114
 
115
 
 
127
  elif channel == ORION_CHANNEL:
128
  print("*******using ORION********")
129
  answer = orion_chatbot.process_input(text)
130
+ elif channel == PYTORCH_CHANNEL:
131
+ print("*******using PYTORCH********")
132
+ answer = pytorch_chatbot.process_input(text)
133
 
134
  # responds to the message in the thread
135
  thread_ts = event["event_ts"]
buster/chatbot.py CHANGED
@@ -7,20 +7,12 @@ import pandas as pd
7
  from omegaconf import OmegaConf
8
  from openai.embeddings_utils import cosine_similarity, get_embedding
9
 
10
- from buster.docparser import EMBEDDING_MODEL
11
 
12
  logger = logging.getLogger(__name__)
13
  logging.basicConfig(level=logging.INFO)
14
 
15
 
16
- def load_documents(path: str) -> pd.DataFrame:
17
- logger.info(f"loading embeddings from {path}...")
18
- df = pd.read_csv(path)
19
- df["embedding"] = df.embedding.apply(eval).apply(np.array)
20
- logger.info(f"embeddings loaded.")
21
- return df
22
-
23
-
24
  class Chatbot:
25
  def __init__(self, cfg: OmegaConf):
26
  # TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
@@ -29,7 +21,10 @@ class Chatbot:
29
  self._init_unk_embedding()
30
 
31
  def _init_documents(self):
32
- self.documents = load_documents(self.cfg.documents_csv)
 
 
 
33
 
34
  def _init_unk_embedding(self):
35
  logger.info("Generating UNK token...")
@@ -101,6 +96,7 @@ class Chatbot:
101
  return response_text
102
 
103
  logger.info(f"querying GPT...")
 
104
  # Call the API to generate a response
105
  try:
106
  completion_kwargs = self.cfg.completion_kwargs
@@ -198,7 +194,7 @@ class ChatbotConfig:
198
  text_after_response: Generic response to add the the chatbot's reply.
199
  """
200
 
201
- documents_csv: str = "buster/data/document_embeddings.csv"
202
  embedding_model: str = "text-embedding-ada-002"
203
  top_k: int = 3
204
  thresh: float = 0.7
 
7
  from omegaconf import OmegaConf
8
  from openai.embeddings_utils import cosine_similarity, get_embedding
9
 
10
+ from buster.docparser import EMBEDDING_MODEL, read_documents
11
 
12
  logger = logging.getLogger(__name__)
13
  logging.basicConfig(level=logging.INFO)
14
 
15
 
 
 
 
 
 
 
 
 
16
  class Chatbot:
17
  def __init__(self, cfg: OmegaConf):
18
  # TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
 
21
  self._init_unk_embedding()
22
 
23
  def _init_documents(self):
24
+ filepath = self.cfg.documents_file
25
+ logger.info(f"loading embeddings from {filepath}...")
26
+ self.documents = read_documents(filepath)
27
+ logger.info(f"embeddings loaded.")
28
 
29
  def _init_unk_embedding(self):
30
  logger.info("Generating UNK token...")
 
96
  return response_text
97
 
98
  logger.info(f"querying GPT...")
99
+ logger.info(f"Prompt: {prompt}")
100
  # Call the API to generate a response
101
  try:
102
  completion_kwargs = self.cfg.completion_kwargs
 
194
  text_after_response: Generic response to add the the chatbot's reply.
195
  """
196
 
197
+ documents_file: str = "buster/data/document_embeddings.csv"
198
  embedding_model: str = "text-embedding-ada-002"
199
  top_k: int = 3
200
  thresh: float = 0.7
buster/docparser.py CHANGED
@@ -3,6 +3,7 @@ import math
3
  import os
4
 
5
  import bs4
 
6
  import pandas as pd
7
  import tiktoken
8
  from bs4 import BeautifulSoup
@@ -126,7 +127,9 @@ def read_documents(filepath: str) -> pd.DataFrame:
126
  ext = get_file_extension(filepath)
127
 
128
  if ext == ".csv":
129
- return pd.read_csv(filepath)
 
 
130
  elif ext in PICKLE_EXTENSIONS:
131
  return pd.read_pickle(filepath)
132
  else:
 
3
  import os
4
 
5
  import bs4
6
+ import numpy as np
7
  import pandas as pd
8
  import tiktoken
9
  from bs4 import BeautifulSoup
 
127
  ext = get_file_extension(filepath)
128
 
129
  if ext == ".csv":
130
+ df = pd.read_csv(filepath)
131
+ df["embedding"] = df.embedding.apply(eval).apply(np.array)
132
+ return df
133
  elif ext in PICKLE_EXTENSIONS:
134
  return pd.read_pickle(filepath)
135
  else: