bhulston commited on
Commit
4e0f9dd
1 Parent(s): d8083af

Update app.py

Browse files

Add routing agent to help in determining when vector DB is needed for query and when it should be avoided.

Will likely need to add another route option to help us construct a new query to the vector DB based on the chat history context (rather than just the individual pure prompt)

Files changed (1) hide show
  1. app.py +42 -53
app.py CHANGED
@@ -7,6 +7,8 @@ import os
7
  import json
8
  import getpass
9
  import openai
 
 
10
 
11
  from langchain.vectorstores import Pinecone
12
  from langchain.embeddings import OpenAIEmbeddings
@@ -35,10 +37,6 @@ index = pinecone.Index(index_name)
35
 
36
  k = 5
37
 
38
-
39
-
40
-
41
-
42
  st.title("USC GPT - Find the perfect class")
43
 
44
  class_time = st.slider(
@@ -57,6 +55,23 @@ units = st.slider(
57
  assistant = st.chat_message("assistant")
58
  initial_message = "How can I help you today?"
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if "messages" not in st.session_state:
61
  st.session_state.messages = []
62
  with st.chat_message("assistant"):
@@ -72,59 +87,33 @@ if prompt := st.chat_input("What kind of class are you looking for?"):
72
  with st.chat_message("assistant"):
73
  message_placeholder = st.empty()
74
  full_response = ""
75
- response = filter_agent(prompt, OPENAI_API)
76
- query = response
77
- response = index.query(
78
- vector = embeddings.embed_query(query),
79
- top_k = 25,
80
- include_metadata = True
81
- )
82
- response = reranker(query, response)
83
- result_query = 'Original Query:' + query + 'Query Results:' + str(response)
84
- assistant_response = results_agent(result_query, OPENAI_API)
85
-
 
 
 
 
 
 
 
 
 
 
 
 
86
  for chunk in assistant_response.split():
87
  full_response += chunk + " "
88
  time.sleep(0.05)
89
  message_placeholder.markdown(full_response + "▌")
90
  message_placeholder.markdown(full_response)
91
  st.session_state.messages.append({"role": "assistant", "content": full_response})
92
-
93
-
94
-
95
-
96
-
97
- # if prompt := st.chat_input("What kind of class are you looking for?"):
98
- # # Display user message in chat message container
99
- # with st.chat_message("user"):
100
- # st.markdown(prompt)
101
- # # Add user message to chat history
102
- # st.session_state.messages.append({"role": "user", "content": prompt})
103
-
104
- # response = filter_agent(prompt, OPENAI_API)
105
- # query = response
106
-
107
- # response = index.query(
108
- # vector= embeddings.embed_query(query),
109
- # # filter= build_filter(json),
110
- # top_k=5,
111
- # include_metadata=True
112
- # )
113
- # response = reranker(query, response)
114
- # result_query = 'Original Query:' + query + 'Query Results:' + str(response)
115
- # assistant_response = results_agent(result_query, OPENAI_API)
116
-
117
- # if assistant_response:
118
- # with st.chat_message("assistant"):
119
- # message_placeholder = st.empty()
120
- # full_response = ""
121
- # # Simulate stream of response with milliseconds delay
122
- # for chunk in assistant_response.split():
123
- # full_response += chunk + " "
124
- # time.sleep(0.05)
125
- # # Add a blinking cursor to simulate typing
126
- # message_placeholder.markdown(full_response + "▌")
127
- # message_placeholder.markdown(full_response)
128
- # # Add assistant response to chat history
129
- # st.session_state.messages.append({"role": "assistant", "content": full_response})
130
 
 
7
  import json
8
  import getpass
9
  import openai
10
+
11
+ from openai import OpenAi
12
 
13
  from langchain.vectorstores import Pinecone
14
  from langchain.embeddings import OpenAIEmbeddings
 
37
 
38
  k = 5
39
 
 
 
 
 
40
  st.title("USC GPT - Find the perfect class")
41
 
42
  class_time = st.slider(
 
55
  assistant = st.chat_message("assistant")
56
  initial_message = "How can I help you today?"
57
 
58
+ def get_rag_results(prompt):
59
+ '''
60
+ 1. Remove filters from the prompt to optimize success of the RAG-based step.
61
+ 2. Query the Pinecone DB and return the top 25 results based on cosine similarity
62
+ 3. Rerank the results from vector DB using a BERT-based cross encoder
63
+ '''
64
+ query = prompt
65
+ response = filter_agent(prompt, OPENAI_API)
66
+ response = index.query(
67
+ vector = embeddings.embed_query(query),
68
+ top_k = 25,
69
+ include_metadata = True
70
+ )
71
+ response = reranker(query, response) # BERT cross encoder for ranking
72
+
73
+ return response
74
+
75
  if "messages" not in st.session_state:
76
  st.session_state.messages = []
77
  with st.chat_message("assistant"):
 
87
  with st.chat_message("assistant"):
88
  message_placeholder = st.empty()
89
  full_response = ""
90
+
91
+ messages = [{"role": m["role"], "content": m["content"]}
92
+ for m in st.session_state.messages]
93
+ message_history = " ".join([message["content"] for message in messages])
94
+
95
+ route = routing_agent(prompt, OPENAI_API, message_history)
96
+
97
+ if route == "1":
98
+ ## Option for accessing Vector DB
99
+ rag_response = get_rag_results(prompt)
100
+ result_query = 'Original Query:' + prompt + 'Query Results:' + str(rag_response)
101
+ assistant_response = results_agent(result_query, OPENAI_API)
102
+ else:
103
+ ## Option if not accessing Database
104
+ assistant_response = openai.chatCompletion.create(
105
+ model = "gpt-4",
106
+ messages = [
107
+ {"role": m["role"], "content": m["content"]}
108
+ for m in st.session_state.messages
109
+ ]
110
+ )["choices"][0]["message"]["content"]
111
+
112
+ ## Display response regardless of route
113
  for chunk in assistant_response.split():
114
  full_response += chunk + " "
115
  time.sleep(0.05)
116
  message_placeholder.markdown(full_response + "▌")
117
  message_placeholder.markdown(full_response)
118
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119