Update app.py
Browse filesAdd routing agent to help in determining when vector DB is needed for query and when it should be avoided.
Will likely need to add another route option to help us construct a new query to the vector DB based on the chat history context (rather than just the individual pure prompt)
app.py
CHANGED
@@ -7,6 +7,8 @@ import os
|
|
7 |
import json
|
8 |
import getpass
|
9 |
import openai
|
|
|
|
|
10 |
|
11 |
from langchain.vectorstores import Pinecone
|
12 |
from langchain.embeddings import OpenAIEmbeddings
|
@@ -35,10 +37,6 @@ index = pinecone.Index(index_name)
|
|
35 |
|
36 |
k = 5
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
st.title("USC GPT - Find the perfect class")
|
43 |
|
44 |
class_time = st.slider(
|
@@ -57,6 +55,23 @@ units = st.slider(
|
|
57 |
assistant = st.chat_message("assistant")
|
58 |
initial_message = "How can I help you today?"
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if "messages" not in st.session_state:
|
61 |
st.session_state.messages = []
|
62 |
with st.chat_message("assistant"):
|
@@ -72,59 +87,33 @@ if prompt := st.chat_input("What kind of class are you looking for?"):
|
|
72 |
with st.chat_message("assistant"):
|
73 |
message_placeholder = st.empty()
|
74 |
full_response = ""
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
for chunk in assistant_response.split():
|
87 |
full_response += chunk + " "
|
88 |
time.sleep(0.05)
|
89 |
message_placeholder.markdown(full_response + "▌")
|
90 |
message_placeholder.markdown(full_response)
|
91 |
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
# if prompt := st.chat_input("What kind of class are you looking for?"):
|
98 |
-
# # Display user message in chat message container
|
99 |
-
# with st.chat_message("user"):
|
100 |
-
# st.markdown(prompt)
|
101 |
-
# # Add user message to chat history
|
102 |
-
# st.session_state.messages.append({"role": "user", "content": prompt})
|
103 |
-
|
104 |
-
# response = filter_agent(prompt, OPENAI_API)
|
105 |
-
# query = response
|
106 |
-
|
107 |
-
# response = index.query(
|
108 |
-
# vector= embeddings.embed_query(query),
|
109 |
-
# # filter= build_filter(json),
|
110 |
-
# top_k=5,
|
111 |
-
# include_metadata=True
|
112 |
-
# )
|
113 |
-
# response = reranker(query, response)
|
114 |
-
# result_query = 'Original Query:' + query + 'Query Results:' + str(response)
|
115 |
-
# assistant_response = results_agent(result_query, OPENAI_API)
|
116 |
-
|
117 |
-
# if assistant_response:
|
118 |
-
# with st.chat_message("assistant"):
|
119 |
-
# message_placeholder = st.empty()
|
120 |
-
# full_response = ""
|
121 |
-
# # Simulate stream of response with milliseconds delay
|
122 |
-
# for chunk in assistant_response.split():
|
123 |
-
# full_response += chunk + " "
|
124 |
-
# time.sleep(0.05)
|
125 |
-
# # Add a blinking cursor to simulate typing
|
126 |
-
# message_placeholder.markdown(full_response + "▌")
|
127 |
-
# message_placeholder.markdown(full_response)
|
128 |
-
# # Add assistant response to chat history
|
129 |
-
# st.session_state.messages.append({"role": "assistant", "content": full_response})
|
130 |
|
|
|
7 |
import json
|
8 |
import getpass
|
9 |
import openai
|
10 |
+
|
11 |
+
from openai import OpenAi
|
12 |
|
13 |
from langchain.vectorstores import Pinecone
|
14 |
from langchain.embeddings import OpenAIEmbeddings
|
|
|
37 |
|
38 |
k = 5
|
39 |
|
|
|
|
|
|
|
|
|
40 |
st.title("USC GPT - Find the perfect class")
|
41 |
|
42 |
class_time = st.slider(
|
|
|
55 |
assistant = st.chat_message("assistant")
|
56 |
initial_message = "How can I help you today?"
|
57 |
|
58 |
+
def get_rag_results(prompt):
|
59 |
+
'''
|
60 |
+
1. Remove filters from the prompt to optimize success of the RAG-based step.
|
61 |
+
2. Query the Pinecone DB and return the top 25 results based on cosine similarity
|
62 |
+
3. Rerank the results from vector DB using a BERT-based cross encoder
|
63 |
+
'''
|
64 |
+
query = prompt
|
65 |
+
response = filter_agent(prompt, OPENAI_API)
|
66 |
+
response = index.query(
|
67 |
+
vector = embeddings.embed_query(query),
|
68 |
+
top_k = 25,
|
69 |
+
include_metadata = True
|
70 |
+
)
|
71 |
+
response = reranker(query, response) # BERT cross encoder for ranking
|
72 |
+
|
73 |
+
return response
|
74 |
+
|
75 |
if "messages" not in st.session_state:
|
76 |
st.session_state.messages = []
|
77 |
with st.chat_message("assistant"):
|
|
|
87 |
with st.chat_message("assistant"):
|
88 |
message_placeholder = st.empty()
|
89 |
full_response = ""
|
90 |
+
|
91 |
+
messages = [{"role": m["role"], "content": m["content"]}
|
92 |
+
for m in st.session_state.messages]
|
93 |
+
message_history = " ".join([message["content"] for message in messages])
|
94 |
+
|
95 |
+
route = routing_agent(prompt, OPENAI_API, message_history)
|
96 |
+
|
97 |
+
if route == "1":
|
98 |
+
## Option for accessing Vector DB
|
99 |
+
rag_response = get_rag_results(prompt)
|
100 |
+
result_query = 'Original Query:' + prompt + 'Query Results:' + str(rag_response)
|
101 |
+
assistant_response = results_agent(result_query, OPENAI_API)
|
102 |
+
else:
|
103 |
+
## Option if not accessing Database
|
104 |
+
assistant_response = openai.chatCompletion.create(
|
105 |
+
model = "gpt-4",
|
106 |
+
messages = [
|
107 |
+
{"role": m["role"], "content": m["content"]}
|
108 |
+
for m in st.session_state.messages
|
109 |
+
]
|
110 |
+
)["choices"][0]["message"]["content"]
|
111 |
+
|
112 |
+
## Display response regardless of route
|
113 |
for chunk in assistant_response.split():
|
114 |
full_response += chunk + " "
|
115 |
time.sleep(0.05)
|
116 |
message_placeholder.markdown(full_response + "▌")
|
117 |
message_placeholder.markdown(full_response)
|
118 |
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|