jerpint commited on
Commit
f97aa81
1 Parent(s): 2c4fa53

development branch (#7)

Browse files

* fix relative import

* add embeddings requirement

* update openai embeddings requirements...

* format responses appropriately

* add markdown response

* Fix newline formatting

* add threshold and top_k

* update response

* fix merge conflict

Files changed (1) hide show
  1. buster/chatbot.py +41 -6
buster/chatbot.py CHANGED
@@ -12,13 +12,16 @@ logging.basicConfig(level=logging.INFO)
12
 
13
 
14
  # search through the reviews for a specific product
15
- def rank_documents(df: pd.DataFrame, query: str, top_k: int = 3) -> pd.DataFrame:
16
  product_embedding = get_embedding(
17
  query,
18
  engine=EMBEDDING_MODEL,
19
  )
20
  df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
21
 
 
 
 
22
  if top_k == -1:
23
  # return all results
24
  n = len(df)
@@ -28,13 +31,43 @@ def rank_documents(df: pd.DataFrame, query: str, top_k: int = 3) -> pd.DataFrame
28
 
29
 
30
  def engineer_prompt(question: str, documents: list[str]) -> str:
31
- return " ".join(documents) + "\nNow answer the following question:\n" + question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
33
 
34
- def answer_question(question: str, df) -> str:
 
35
  # rank the documents, get the highest scoring doc and generate the prompt
36
- candidates = rank_documents(df, query=question, top_k=1)
 
 
 
 
 
 
37
  documents = candidates.text.to_list()
 
38
  prompt = engineer_prompt(question, documents)
39
 
40
  logger.info(f"querying GPT...")
@@ -58,12 +91,14 @@ def answer_question(question: str, df) -> str:
58
  GPT Response:\n{response_text}
59
  """
60
  )
61
- return response_text
 
62
  except Exception as e:
63
  import traceback
64
 
65
  logging.error(traceback.format_exc())
66
- return "Oops, something went wrong. Try again later!"
 
67
 
68
 
69
  def load_embeddings(path: str) -> pd.DataFrame:
 
12
 
13
 
14
  # search through the reviews for a specific product
15
+ def rank_documents(df: pd.DataFrame, query: str, top_k: int = 1, thresh: float = None) -> pd.DataFrame:
16
  product_embedding = get_embedding(
17
  query,
18
  engine=EMBEDDING_MODEL,
19
  )
20
  df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
21
 
22
+ if thresh:
23
+ df = df[df.similarity > thresh]
24
+
25
  if top_k == -1:
26
  # return all results
27
  n = len(df)
 
31
 
32
 
33
  def engineer_prompt(question: str, documents: list[str]) -> str:
34
+ documents_str = " ".join(documents)
35
+ if len(documents_str) > 3000:
36
+ logger.info("truncating documents to fit...")
37
+ documents_str = documents_str[0:3000]
38
+ return documents_str + "\nNow answer the following question:\n" + question
39
+
40
+
41
+ def format_response(response_text, sources_url=None):
42
+
43
+ response = f"{response_text}\n"
44
+
45
+ if sources_url:
46
+ response += f"<br><br>Here are the sources I used to answer your question:\n"
47
+ for url in sources_url:
48
+ response += f"<br>[{url}]({url})\n"
49
 
50
+ response += "<br><br>"
51
+ response += """
52
+ ```
53
+ I'm a bot 🤖 and not always perfect.
54
+ For more info, view the full documentation here (https://docs.mila.quebec/) or contact support@mila.quebec
55
+ ```
56
+ """
57
+ return response
58
 
59
+
60
+ def answer_question(question: str, df, top_k: int = 1, thresh: float = None) -> str:
61
  # rank the documents, get the highest scoring doc and generate the prompt
62
+ candidates = rank_documents(df, query=question, top_k=top_k, thresh=thresh)
63
+
64
+ logger.info(f"candidate responses: {candidates}")
65
+
66
+ if len(candidates) == 0:
67
+ return format_response("I did not find any relevant documentation related to your question.")
68
+
69
  documents = candidates.text.to_list()
70
+ sources_url = candidates.url.to_list()
71
  prompt = engineer_prompt(question, documents)
72
 
73
  logger.info(f"querying GPT...")
 
91
  GPT Response:\n{response_text}
92
  """
93
  )
94
+ return format_response(response_text, sources_url)
95
+
96
  except Exception as e:
97
  import traceback
98
 
99
  logging.error(traceback.format_exc())
100
+ response = "Oops, something went wrong. Try again later!"
101
+ return format_response(response)
102
 
103
 
104
  def load_embeddings(path: str) -> pd.DataFrame: