Spaces:
Sleeping
Sleeping
Upload app (2).py
Browse files- app (2).py +426 -0
app (2).py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setup_code import * # This imports everything from setup_code.py
|
2 |
+
|
3 |
+
class Query_Agent:
|
4 |
+
def __init__(self, pinecone_index, pinecone_index_python, openai_client) -> None:
|
5 |
+
# TODO: Initialize the Query_Agent agent
|
6 |
+
self.pinecone_index = pinecone_index
|
7 |
+
self.pinecone_index_python = pinecone_index_python
|
8 |
+
self.openai_client = openai_client
|
9 |
+
self.query_embedding = None
|
10 |
+
self.codbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
11 |
+
self.codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
|
12 |
+
|
13 |
+
def get_codebert_embedding(self, code: str):
|
14 |
+
inputs = self.codbert_tokenizer(code, return_tensors="pt", max_length=512, truncation=True)
|
15 |
+
outputs = self.codebert_model(**inputs)
|
16 |
+
cb_embedding = outputs.last_hidden_state.mean(dim=1) # A simple way to pool the embeddings
|
17 |
+
cb_embedding = cb_embedding.detach().numpy()
|
18 |
+
cb_embedding = cb_embedding.tolist()
|
19 |
+
cb_embedding = cb_embedding[0]
|
20 |
+
return cb_embedding
|
21 |
+
|
22 |
+
def get_openai_embedding(self, text, model="text-embedding-ada-002"):
|
23 |
+
text = text.replace("\n", " ")
|
24 |
+
return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding
|
25 |
+
|
26 |
+
def query_vector_store(self, query, query_topic: str, index=None, k=5) -> str:
|
27 |
+
if index == None:
|
28 |
+
index = self.pinecone_index
|
29 |
+
|
30 |
+
if query_topic == 'ml':
|
31 |
+
self.query_embedding = self.get_openai_embedding(query)
|
32 |
+
elif query_topic == 'python':
|
33 |
+
index = self.pinecone_index_python
|
34 |
+
self.query_embedding = self.get_codebert_embedding(query)
|
35 |
+
|
36 |
+
def get_namespace(index):
|
37 |
+
stat = index.describe_index_stats()
|
38 |
+
stat_dict_key = stat['namespaces'].keys()
|
39 |
+
|
40 |
+
stat_dict_key_list = list(stat_dict_key)
|
41 |
+
first_key = stat_dict_key_list[0]
|
42 |
+
|
43 |
+
return first_key
|
44 |
+
|
45 |
+
ns = get_namespace(index)
|
46 |
+
|
47 |
+
if query_topic == 'ml':
|
48 |
+
matches_text = get_top_k_text(index.query(
|
49 |
+
namespace=ns,
|
50 |
+
top_k=k,
|
51 |
+
vector=self.query_embedding,
|
52 |
+
include_values=True,
|
53 |
+
include_metadata=True
|
54 |
+
)
|
55 |
+
)
|
56 |
+
elif query_topic == 'python':
|
57 |
+
matches_text = get_top_filename(index.query(
|
58 |
+
namespace=ns,
|
59 |
+
top_k=k,
|
60 |
+
vector=self.query_embedding,
|
61 |
+
include_values=True,
|
62 |
+
include_metadata=True
|
63 |
+
)
|
64 |
+
)
|
65 |
+
|
66 |
+
return matches_text
|
67 |
+
|
68 |
+
def process_query_response(self, head_agent, user_query, query_topic):
|
69 |
+
|
70 |
+
# Retrieve the history related to the query_topic
|
71 |
+
conversation = []
|
72 |
+
index = head_agent.pinecone_index
|
73 |
+
|
74 |
+
if query_topic == "ml":
|
75 |
+
conversation = Head_Agent.get_history_about('ml')
|
76 |
+
elif query_topic == 'python':
|
77 |
+
conversation = Head_Agent.get_history_about('python')
|
78 |
+
index = head_agent.pinecone_index_python
|
79 |
+
|
80 |
+
# get matches from Query_Agent, which uses Pinecone
|
81 |
+
user_query_plus_conversation = f"The current query is: {user_query}"
|
82 |
+
if len(conversation) > 0:
|
83 |
+
conversation_text = "\n".join(conversation)
|
84 |
+
user_query_plus_conversation += f'The current conversation is: {conversation_text}'
|
85 |
+
|
86 |
+
## self.query_embedding is set here
|
87 |
+
matches_text = self.query_vector_store(user_query_plus_conversation, query_topic, index)
|
88 |
+
|
89 |
+
if head_agent.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation) or contains_py_filename(matches_text):
|
90 |
+
response = head_agent.answering_agent.generate_response(user_query, matches_text, conversation, head_agent.selected_mode)
|
91 |
+
else:
|
92 |
+
prompt_for_gpt = f"Return a response to this query: {user_query} in the context of this conversation: {conversation}. Please use language appropriate for a {head_agent.selected_mode}."
|
93 |
+
response = get_completion(head_agent.openai_client, prompt_for_gpt)
|
94 |
+
response = "[EXTERNAL] " + response
|
95 |
+
|
96 |
+
return response
|
97 |
+
|
98 |
+
class Answering_Agent:
|
99 |
+
def __init__(self, openai_client) -> None:
|
100 |
+
self.client = openai_client
|
101 |
+
|
102 |
+
def generate_response(self, query, docs, conv_history, selected_mode):
|
103 |
+
prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}."
|
104 |
+
return get_completion(self.client, prompt_for_gpt)
|
105 |
+
|
106 |
+
def generate_response_topic(self, topic_desc, topic_text, conv_history, selected_mode):
|
107 |
+
prompt_for_gpt = f"Please return a summary response on this topic: {topic_desc} using this text as best as possible {topic_text} in the context of this {conv_history}. Please use language appropriate for a {selected_mode}."
|
108 |
+
return get_completion(self.client, prompt_for_gpt)
|
109 |
+
|
110 |
+
def generate_image(self, text):
|
111 |
+
if DEBUG:
|
112 |
+
return None, ""
|
113 |
+
|
114 |
+
dall_e_prompt_from_gpt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image."
|
115 |
+
dall_e_text = get_completion(self.client, dall_e_prompt_from_gpt)
|
116 |
+
|
117 |
+
# Write open_ai text
|
118 |
+
with open("dall_e_prompts.txt", "a") as f:
|
119 |
+
f.write(f"{dall_e_text}\n\n")
|
120 |
+
|
121 |
+
# get image from dall-e
|
122 |
+
image = Head_Agent.text_to_image(self.client, dall_e_text)
|
123 |
+
|
124 |
+
# once u have get a caption from GPT
|
125 |
+
image_caption_prompt = f"This text in double square brackets is used to prompt dall-e: [[{dall_e_text}]]. Please generate a simple caption that I can use to display with the image dall-e will create. Only return that caption."
|
126 |
+
image_caption = get_completion(self.client, image_caption_prompt)
|
127 |
+
#st.write(f"image_caption_prompt): {image_caption_prompt}")
|
128 |
+
return (image, image_caption)
|
129 |
+
|
130 |
+
class Concepts_Agent:
|
131 |
+
def __init__(self):
|
132 |
+
self._df = pd.read_csv("/content/gdrive/MyDrive/LLM_Winter2024/concepts_final.csv")
|
133 |
+
#self.topic_matrix = [[0] * 5 for _ in range(12)]
|
134 |
+
|
135 |
+
def increase_cell(self, i, j):
|
136 |
+
st.session_state.topic_matrix[i][j] += + 1
|
137 |
+
|
138 |
+
def display_topic_matrix(self):
|
139 |
+
headers = [f"Topic {i}" for i in range(1, 6)]
|
140 |
+
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
|
141 |
+
|
142 |
+
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
|
143 |
+
st.table(topic_df)
|
144 |
+
|
145 |
+
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
|
146 |
+
|
147 |
+
def display_topic_matrix(self):
|
148 |
+
headers = [f"Topic {i}" for i in range(1, 6)]
|
149 |
+
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
|
150 |
+
|
151 |
+
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
|
152 |
+
st.table(topic_df)
|
153 |
+
|
154 |
+
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
|
155 |
+
|
156 |
+
def display_topic_matrix_star(self):
|
157 |
+
headers = [f"Topic {i}" for i in range(1, 6)]
|
158 |
+
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
|
159 |
+
|
160 |
+
# Replace 1 with the Unicode star symbol
|
161 |
+
topic_matrix_star = [[chr(9733) if val == 1 else val for val in row] for row in st.session_state.topic_matrix]
|
162 |
+
|
163 |
+
topic_df = pd.DataFrame(topic_matrix_star, row_indices, headers)
|
164 |
+
st.table(topic_df)
|
165 |
+
|
166 |
+
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
|
167 |
+
|
168 |
+
def display_topic_matrix_as_image(self):
|
169 |
+
headers = [f"Topic {i}" for i in range(1, 6)]
|
170 |
+
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
|
171 |
+
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
|
172 |
+
|
173 |
+
df_html = topic_df.to_html(index=False)
|
174 |
+
|
175 |
+
# Create an image of the HTML table
|
176 |
+
image = Image.new("RGB", (800, 600), color="white") # Define image size
|
177 |
+
draw = ImageDraw.Draw(image)
|
178 |
+
draw.text((10, 10), df_html, fill="black") # Position of the table in the image
|
179 |
+
|
180 |
+
# Save the image to a byte stream
|
181 |
+
image_byte_array = io.BytesIO()
|
182 |
+
image.save(image_byte_array, format="PNG")
|
183 |
+
image_byte_array.seek(0)
|
184 |
+
|
185 |
+
# Now you can use the image_byte_array in Streamlit as an image
|
186 |
+
st.image(image_byte_array, caption="DataFrame as Image")
|
187 |
+
return image_byte_array
|
188 |
+
|
189 |
+
# for each query_embedding, we will look through the df of concepts
|
190 |
+
# we'll do a cosine_similarity of that query_embedding with each of the embeddings for each concept
|
191 |
+
def find_top_concept_index(self, query_embedding):
|
192 |
+
top_sim = 0
|
193 |
+
top_concept_index = 0
|
194 |
+
|
195 |
+
for index, row in self._df.iterrows():
|
196 |
+
|
197 |
+
float_array = np.array(ast.literal_eval(row['embedding'])).reshape(1, -1)
|
198 |
+
qe_array = np.array(query_embedding).reshape(1, -1)
|
199 |
+
|
200 |
+
sim = cosine_similarity(float_array, qe_array)
|
201 |
+
|
202 |
+
if sim[0][0] > top_sim:
|
203 |
+
top_sim = sim[0][0]
|
204 |
+
top_concept_index = index
|
205 |
+
|
206 |
+
return top_concept_index
|
207 |
+
|
208 |
+
def get_top_k_text_list(self, matches, k):
|
209 |
+
text_list = []
|
210 |
+
for i in range(0, k):
|
211 |
+
text_list.append(matches.get('matches')[i]['metadata']['text'])
|
212 |
+
return text_list
|
213 |
+
|
214 |
+
def write_to_file(self, filename):
|
215 |
+
self._df.to_csv(filename, index=False) # Setting index=False to avoid writing row indices
|
216 |
+
|
217 |
+
class Head_Agent:
|
218 |
+
def __init__(self, openai_key, pinecone_key) -> None:
|
219 |
+
# TODO: Initialize the Head_Agent
|
220 |
+
self.openai_key = openai_key
|
221 |
+
self.pinecone_key = pinecone_key
|
222 |
+
self.selected_mode = ""
|
223 |
+
|
224 |
+
self.openai_client = OpenAI(api_key=self.openai_key)
|
225 |
+
self.pc = Pinecone(api_key=self.pinecone_key)
|
226 |
+
self.pinecone_index = self.pc.Index("index-600")
|
227 |
+
self.pinecone_index_python = self.pc.Index("index-python-files")
|
228 |
+
|
229 |
+
self.query_embedding_local = None
|
230 |
+
self.setup_sub_agents()
|
231 |
+
|
232 |
+
def setup_sub_agents(self):
|
233 |
+
self.classify_agent = Classify_Agent(self.openai_client)
|
234 |
+
self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client) # took away embeddings argument since not used
|
235 |
+
self.answering_agent = Answering_Agent(self.openai_client)
|
236 |
+
self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client)
|
237 |
+
self.ca = Concepts_Agent()
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def get_conversation():
|
241 |
+
# ... (code for getting conversation history)
|
242 |
+
return Head_Agent.get_history_about()
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def get_history_about(topic=None):
|
246 |
+
history = []
|
247 |
+
|
248 |
+
for message in st.session_state.messages:
|
249 |
+
role = message["role"]
|
250 |
+
content = message["content"]
|
251 |
+
|
252 |
+
if topic == None:
|
253 |
+
if role == "user":
|
254 |
+
history.append(f"{content} ")
|
255 |
+
else:
|
256 |
+
if message["topic"] == topic:
|
257 |
+
history.append(f"{content} ")
|
258 |
+
|
259 |
+
# st.write(f"user history in get_conversation is {history}")
|
260 |
+
|
261 |
+
if history != None:
|
262 |
+
history = history[-2:]
|
263 |
+
|
264 |
+
return history
|
265 |
+
|
266 |
+
@staticmethod
|
267 |
+
def text_to_image(openai_client, text):
|
268 |
+
model = "dall-e-3"
|
269 |
+
size = "512x512"
|
270 |
+
with st.spinner("Generating ..."):
|
271 |
+
response = openai_client.images.generate(
|
272 |
+
model=model,
|
273 |
+
prompt = text,
|
274 |
+
n=1,
|
275 |
+
size="1024x1024"
|
276 |
+
)
|
277 |
+
image_url = response.data[0].url
|
278 |
+
with urllib.request.urlopen(image_url) as image_url:
|
279 |
+
img = Image.open(BytesIO(image_url.read()))
|
280 |
+
|
281 |
+
return img
|
282 |
+
|
283 |
+
def get_default_value(self, variable):
|
284 |
+
if variable == "openai_model": return "gpt-3.5-turbo"
|
285 |
+
elif variable == "messages": return []
|
286 |
+
elif variable == "stage": return 0
|
287 |
+
elif variable == "query_embedding": return None
|
288 |
+
elif variable == "topic_matrix": return [[0] * 5 for _ in range(12)]
|
289 |
+
else:
|
290 |
+
st.write(f"Error: get_default_value, variable not defined: {variable}")
|
291 |
+
return None
|
292 |
+
|
293 |
+
def initialize_session_state(self):
|
294 |
+
session_state_variables = ["openai_model", "messages", "stage", "query_embedding", "topic_matrix"]
|
295 |
+
for variable in session_state_variables:
|
296 |
+
if variable not in st.session_state:
|
297 |
+
st.session_state[variable] = self.get_default_value(variable)
|
298 |
+
|
299 |
+
def display_selection_options(self):
|
300 |
+
modes = ['college student', 'middle school student', '1st grade student', 'high school student', 'grad student']
|
301 |
+
self.selected_mode = st.selectbox("Select your education level:", modes)
|
302 |
+
|
303 |
+
def display_chat_messages(self):
|
304 |
+
# Display existing chat messages
|
305 |
+
for message in st.session_state.messages:
|
306 |
+
if message["role"] == "assistant":
|
307 |
+
with st.chat_message("assistant"):
|
308 |
+
st.write(message["content"])
|
309 |
+
if message['image'] != None:
|
310 |
+
st.image(message['image'])
|
311 |
+
else:
|
312 |
+
with st.chat_message("user"):
|
313 |
+
st.write(message["content"])
|
314 |
+
|
315 |
+
def main_loop(self):
|
316 |
+
st.title("Machine Learning Text Guide Chatbot")
|
317 |
+
|
318 |
+
self.initialize_session_state()
|
319 |
+
self.display_selection_options()
|
320 |
+
self.display_chat_messages()
|
321 |
+
|
322 |
+
### Wait for user input ###
|
323 |
+
if user_query := st.chat_input("What would you like to chat about?"):
|
324 |
+
with st.chat_message("user"): st.write(user_query)
|
325 |
+
|
326 |
+
with st.chat_message("assistant"):
|
327 |
+
response = ""; topic = None; image = None; caption = ""; st.session_state.stage = 0
|
328 |
+
|
329 |
+
# Get the current conversation with new user query to check for users' intention
|
330 |
+
conversation = self.get_conversation()
|
331 |
+
user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}"
|
332 |
+
classify_query = self.classify_agent.classify_query(user_query_plus_conversation)
|
333 |
+
|
334 |
+
if classify_query == general_greeting_num:
|
335 |
+
response = "How can I assist you today?"
|
336 |
+
elif classify_query == general_question_num:
|
337 |
+
response = "Please ask a question about Machine Learning or Python Code."
|
338 |
+
elif classify_query == obnoxious_num:
|
339 |
+
response = "Please dont be obnoxious."
|
340 |
+
elif classify_query == progress_num:
|
341 |
+
self.ca.display_topic_matrix_star()
|
342 |
+
elif classify_query == default_num:
|
343 |
+
response = "I'm not sure how to respond to that."
|
344 |
+
elif classify_query == machine_learning_num:
|
345 |
+
response = self.query_agent.process_query_response(self, user_query, 'ml')
|
346 |
+
st.session_state.query_embedding = self.query_agent.get_openai_embedding(user_query)
|
347 |
+
image, caption = self.answering_agent.generate_image(response)
|
348 |
+
topic = "ml"
|
349 |
+
st.session_state.stage = 1
|
350 |
+
elif classify_query == python_code_num:
|
351 |
+
response = self.query_agent.process_query_response(self, user_query, 'python')
|
352 |
+
image, caption = self.answering_agent.generate_image(response)
|
353 |
+
topic = "python"
|
354 |
+
st.session_state.stage = 0
|
355 |
+
else:
|
356 |
+
response = "I'm not sure how to respond to that."
|
357 |
+
|
358 |
+
# ... (get AI response and display it)
|
359 |
+
st.write(response)
|
360 |
+
if image and caption != "": st.image(image, caption)
|
361 |
+
|
362 |
+
st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None})
|
363 |
+
st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})
|
364 |
+
|
365 |
+
if st.session_state.stage == 1: ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ###
|
366 |
+
|
367 |
+
# it looks like after we hit st.button, we go back to the top of the st.session_state.stage == 1 loop, and we lose the query_embedding_local
|
368 |
+
|
369 |
+
# we use st.session_state.query_embedding to get the concept index
|
370 |
+
top_concept_index = self.ca.find_top_concept_index(st.session_state.query_embedding)
|
371 |
+
concept_name = self.ca._df['concept'][top_concept_index]
|
372 |
+
|
373 |
+
st.write(f"Your question is associated to the Fundamental Concept in Machine Learning: {concept_name}.\n\n")
|
374 |
+
st.write(f"Here are some topics you can explore to help you learn about {concept_name}, pick one.")
|
375 |
+
|
376 |
+
response = ""; image = None; topic = ""
|
377 |
+
topic0_desc = self.ca._df['topic_0_desc'][top_concept_index]
|
378 |
+
topic1_desc = self.ca._df['topic_1_desc'][top_concept_index]
|
379 |
+
topic2_desc = self.ca._df['topic_2_desc'][top_concept_index]
|
380 |
+
topic3_desc = self.ca._df['topic_3_desc'][top_concept_index]
|
381 |
+
topic4_desc = self.ca._df['topic_4_desc'][top_concept_index]
|
382 |
+
|
383 |
+
matrix_row = st.session_state.topic_matrix[top_concept_index]
|
384 |
+
|
385 |
+
if (matrix_row[0] == 0 and st.session_state.stage):
|
386 |
+
if st.button(topic0_desc): process_button_click(self, 0, topic0_desc, top_concept_index)
|
387 |
+
if (matrix_row[1] == 0 and st.session_state.stage):
|
388 |
+
if st.button(topic1_desc): process_button_click(self, 1, topic1_desc, top_concept_index)
|
389 |
+
if (matrix_row[2] == 0 and st.session_state.stage):
|
390 |
+
if st.button(topic2_desc): process_button_click(self, 2, topic2_desc, top_concept_index)
|
391 |
+
if (matrix_row[3] == 0 and st.session_state.stage):
|
392 |
+
if st.button(topic3_desc): process_button_click(self, 3, topic3_desc, top_concept_index)
|
393 |
+
if (matrix_row[4] == 0 and st.session_state.stage):
|
394 |
+
if st.button(topic4_desc): process_button_click(self, 4, topic4_desc, top_concept_index)
|
395 |
+
|
396 |
+
def process_button_click(head, button_index, topic_desc, top_concept_index):
|
397 |
+
with st.chat_message("user"): st.write(topic_desc)
|
398 |
+
|
399 |
+
# we then assign to st.session_state.query_embedding the embedding for the topic_desc
|
400 |
+
st.session_state.query_embedding = head.query_agent.get_openai_embedding(topic_desc)
|
401 |
+
|
402 |
+
topic_text_index = 'topic_' + str(button_index)
|
403 |
+
topic_text = head.ca._df[topic_text_index][top_concept_index]
|
404 |
+
|
405 |
+
response = head.answering_agent.generate_response_topic(topic_desc, topic_text, head.get_conversation(), head.selected_mode)
|
406 |
+
image, caption = head.answering_agent.generate_image(topic_text)
|
407 |
+
topic = topic_desc
|
408 |
+
|
409 |
+
st.session_state.topic_matrix[top_concept_index][button_index] += 1
|
410 |
+
|
411 |
+
st.write(response)
|
412 |
+
if image and caption != "": st.image(image, caption)
|
413 |
+
|
414 |
+
# ... (add response & image to message)
|
415 |
+
st.session_state.messages.append({"role": "user", "content": topic_desc, "topic": "ml", "image": None})
|
416 |
+
st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})
|
417 |
+
|
418 |
+
st.session_state.stage = 0
|
419 |
+
|
420 |
+
|
421 |
+
if __name__ == "__main__":
|
422 |
+
head_agent = Head_Agent(OPENAI_KEY, pc_apikey)
|
423 |
+
DEBUG = False
|
424 |
+
|
425 |
+
head_agent.main_loop()
|
426 |
+
#main()
|