long1104 commited on
Commit
46b3bd4
1 Parent(s): 9587bef

Upload app (2).py

Browse files
Files changed (1) hide show
  1. app (2).py +426 -0
app (2).py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setup_code import * # This imports everything from setup_code.py
2
+
3
+ class Query_Agent:
4
+ def __init__(self, pinecone_index, pinecone_index_python, openai_client) -> None:
5
+ # TODO: Initialize the Query_Agent agent
6
+ self.pinecone_index = pinecone_index
7
+ self.pinecone_index_python = pinecone_index_python
8
+ self.openai_client = openai_client
9
+ self.query_embedding = None
10
+ self.codbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
11
+ self.codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
12
+
13
+ def get_codebert_embedding(self, code: str):
14
+ inputs = self.codbert_tokenizer(code, return_tensors="pt", max_length=512, truncation=True)
15
+ outputs = self.codebert_model(**inputs)
16
+ cb_embedding = outputs.last_hidden_state.mean(dim=1) # A simple way to pool the embeddings
17
+ cb_embedding = cb_embedding.detach().numpy()
18
+ cb_embedding = cb_embedding.tolist()
19
+ cb_embedding = cb_embedding[0]
20
+ return cb_embedding
21
+
22
+ def get_openai_embedding(self, text, model="text-embedding-ada-002"):
23
+ text = text.replace("\n", " ")
24
+ return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding
25
+
26
+ def query_vector_store(self, query, query_topic: str, index=None, k=5) -> str:
27
+ if index == None:
28
+ index = self.pinecone_index
29
+
30
+ if query_topic == 'ml':
31
+ self.query_embedding = self.get_openai_embedding(query)
32
+ elif query_topic == 'python':
33
+ index = self.pinecone_index_python
34
+ self.query_embedding = self.get_codebert_embedding(query)
35
+
36
+ def get_namespace(index):
37
+ stat = index.describe_index_stats()
38
+ stat_dict_key = stat['namespaces'].keys()
39
+
40
+ stat_dict_key_list = list(stat_dict_key)
41
+ first_key = stat_dict_key_list[0]
42
+
43
+ return first_key
44
+
45
+ ns = get_namespace(index)
46
+
47
+ if query_topic == 'ml':
48
+ matches_text = get_top_k_text(index.query(
49
+ namespace=ns,
50
+ top_k=k,
51
+ vector=self.query_embedding,
52
+ include_values=True,
53
+ include_metadata=True
54
+ )
55
+ )
56
+ elif query_topic == 'python':
57
+ matches_text = get_top_filename(index.query(
58
+ namespace=ns,
59
+ top_k=k,
60
+ vector=self.query_embedding,
61
+ include_values=True,
62
+ include_metadata=True
63
+ )
64
+ )
65
+
66
+ return matches_text
67
+
68
+ def process_query_response(self, head_agent, user_query, query_topic):
69
+
70
+ # Retrieve the history related to the query_topic
71
+ conversation = []
72
+ index = head_agent.pinecone_index
73
+
74
+ if query_topic == "ml":
75
+ conversation = Head_Agent.get_history_about('ml')
76
+ elif query_topic == 'python':
77
+ conversation = Head_Agent.get_history_about('python')
78
+ index = head_agent.pinecone_index_python
79
+
80
+ # get matches from Query_Agent, which uses Pinecone
81
+ user_query_plus_conversation = f"The current query is: {user_query}"
82
+ if len(conversation) > 0:
83
+ conversation_text = "\n".join(conversation)
84
+ user_query_plus_conversation += f'The current conversation is: {conversation_text}'
85
+
86
+ ## self.query_embedding is set here
87
+ matches_text = self.query_vector_store(user_query_plus_conversation, query_topic, index)
88
+
89
+ if head_agent.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation) or contains_py_filename(matches_text):
90
+ response = head_agent.answering_agent.generate_response(user_query, matches_text, conversation, head_agent.selected_mode)
91
+ else:
92
+ prompt_for_gpt = f"Return a response to this query: {user_query} in the context of this conversation: {conversation}. Please use language appropriate for a {head_agent.selected_mode}."
93
+ response = get_completion(head_agent.openai_client, prompt_for_gpt)
94
+ response = "[EXTERNAL] " + response
95
+
96
+ return response
97
+
98
+ class Answering_Agent:
99
+ def __init__(self, openai_client) -> None:
100
+ self.client = openai_client
101
+
102
+ def generate_response(self, query, docs, conv_history, selected_mode):
103
+ prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}."
104
+ return get_completion(self.client, prompt_for_gpt)
105
+
106
+ def generate_response_topic(self, topic_desc, topic_text, conv_history, selected_mode):
107
+ prompt_for_gpt = f"Please return a summary response on this topic: {topic_desc} using this text as best as possible {topic_text} in the context of this {conv_history}. Please use language appropriate for a {selected_mode}."
108
+ return get_completion(self.client, prompt_for_gpt)
109
+
110
+ def generate_image(self, text):
111
+ if DEBUG:
112
+ return None, ""
113
+
114
+ dall_e_prompt_from_gpt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image."
115
+ dall_e_text = get_completion(self.client, dall_e_prompt_from_gpt)
116
+
117
+ # Write open_ai text
118
+ with open("dall_e_prompts.txt", "a") as f:
119
+ f.write(f"{dall_e_text}\n\n")
120
+
121
+ # get image from dall-e
122
+ image = Head_Agent.text_to_image(self.client, dall_e_text)
123
+
124
+ # once u have get a caption from GPT
125
+ image_caption_prompt = f"This text in double square brackets is used to prompt dall-e: [[{dall_e_text}]]. Please generate a simple caption that I can use to display with the image dall-e will create. Only return that caption."
126
+ image_caption = get_completion(self.client, image_caption_prompt)
127
+ #st.write(f"image_caption_prompt): {image_caption_prompt}")
128
+ return (image, image_caption)
129
+
130
+ class Concepts_Agent:
131
+ def __init__(self):
132
+ self._df = pd.read_csv("/content/gdrive/MyDrive/LLM_Winter2024/concepts_final.csv")
133
+ #self.topic_matrix = [[0] * 5 for _ in range(12)]
134
+
135
+ def increase_cell(self, i, j):
136
+ st.session_state.topic_matrix[i][j] += + 1
137
+
138
+ def display_topic_matrix(self):
139
+ headers = [f"Topic {i}" for i in range(1, 6)]
140
+ row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
141
+
142
+ topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
143
+ st.table(topic_df)
144
+
145
+ st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
146
+
147
+ def display_topic_matrix(self):
148
+ headers = [f"Topic {i}" for i in range(1, 6)]
149
+ row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
150
+
151
+ topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
152
+ st.table(topic_df)
153
+
154
+ st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
155
+
156
+ def display_topic_matrix_star(self):
157
+ headers = [f"Topic {i}" for i in range(1, 6)]
158
+ row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
159
+
160
+ # Replace 1 with the Unicode star symbol
161
+ topic_matrix_star = [[chr(9733) if val == 1 else val for val in row] for row in st.session_state.topic_matrix]
162
+
163
+ topic_df = pd.DataFrame(topic_matrix_star, row_indices, headers)
164
+ st.table(topic_df)
165
+
166
+ st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}")
167
+
168
+ def display_topic_matrix_as_image(self):
169
+ headers = [f"Topic {i}" for i in range(1, 6)]
170
+ row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)]
171
+ topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers)
172
+
173
+ df_html = topic_df.to_html(index=False)
174
+
175
+ # Create an image of the HTML table
176
+ image = Image.new("RGB", (800, 600), color="white") # Define image size
177
+ draw = ImageDraw.Draw(image)
178
+ draw.text((10, 10), df_html, fill="black") # Position of the table in the image
179
+
180
+ # Save the image to a byte stream
181
+ image_byte_array = io.BytesIO()
182
+ image.save(image_byte_array, format="PNG")
183
+ image_byte_array.seek(0)
184
+
185
+ # Now you can use the image_byte_array in Streamlit as an image
186
+ st.image(image_byte_array, caption="DataFrame as Image")
187
+ return image_byte_array
188
+
189
+ # for each query_embedding, we will look through the df of concepts
190
+ # we'll do a cosine_similarity of that query_embedding with each of the embeddings for each concept
191
+ def find_top_concept_index(self, query_embedding):
192
+ top_sim = 0
193
+ top_concept_index = 0
194
+
195
+ for index, row in self._df.iterrows():
196
+
197
+ float_array = np.array(ast.literal_eval(row['embedding'])).reshape(1, -1)
198
+ qe_array = np.array(query_embedding).reshape(1, -1)
199
+
200
+ sim = cosine_similarity(float_array, qe_array)
201
+
202
+ if sim[0][0] > top_sim:
203
+ top_sim = sim[0][0]
204
+ top_concept_index = index
205
+
206
+ return top_concept_index
207
+
208
+ def get_top_k_text_list(self, matches, k):
209
+ text_list = []
210
+ for i in range(0, k):
211
+ text_list.append(matches.get('matches')[i]['metadata']['text'])
212
+ return text_list
213
+
214
+ def write_to_file(self, filename):
215
+ self._df.to_csv(filename, index=False) # Setting index=False to avoid writing row indices
216
+
217
+ class Head_Agent:
218
+ def __init__(self, openai_key, pinecone_key) -> None:
219
+ # TODO: Initialize the Head_Agent
220
+ self.openai_key = openai_key
221
+ self.pinecone_key = pinecone_key
222
+ self.selected_mode = ""
223
+
224
+ self.openai_client = OpenAI(api_key=self.openai_key)
225
+ self.pc = Pinecone(api_key=self.pinecone_key)
226
+ self.pinecone_index = self.pc.Index("index-600")
227
+ self.pinecone_index_python = self.pc.Index("index-python-files")
228
+
229
+ self.query_embedding_local = None
230
+ self.setup_sub_agents()
231
+
232
+ def setup_sub_agents(self):
233
+ self.classify_agent = Classify_Agent(self.openai_client)
234
+ self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client) # took away embeddings argument since not used
235
+ self.answering_agent = Answering_Agent(self.openai_client)
236
+ self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client)
237
+ self.ca = Concepts_Agent()
238
+
239
+ @staticmethod
240
+ def get_conversation():
241
+ # ... (code for getting conversation history)
242
+ return Head_Agent.get_history_about()
243
+
244
+ @staticmethod
245
+ def get_history_about(topic=None):
246
+ history = []
247
+
248
+ for message in st.session_state.messages:
249
+ role = message["role"]
250
+ content = message["content"]
251
+
252
+ if topic == None:
253
+ if role == "user":
254
+ history.append(f"{content} ")
255
+ else:
256
+ if message["topic"] == topic:
257
+ history.append(f"{content} ")
258
+
259
+ # st.write(f"user history in get_conversation is {history}")
260
+
261
+ if history != None:
262
+ history = history[-2:]
263
+
264
+ return history
265
+
266
+ @staticmethod
267
+ def text_to_image(openai_client, text):
268
+ model = "dall-e-3"
269
+ size = "512x512"
270
+ with st.spinner("Generating ..."):
271
+ response = openai_client.images.generate(
272
+ model=model,
273
+ prompt = text,
274
+ n=1,
275
+ size="1024x1024"
276
+ )
277
+ image_url = response.data[0].url
278
+ with urllib.request.urlopen(image_url) as image_url:
279
+ img = Image.open(BytesIO(image_url.read()))
280
+
281
+ return img
282
+
283
+ def get_default_value(self, variable):
284
+ if variable == "openai_model": return "gpt-3.5-turbo"
285
+ elif variable == "messages": return []
286
+ elif variable == "stage": return 0
287
+ elif variable == "query_embedding": return None
288
+ elif variable == "topic_matrix": return [[0] * 5 for _ in range(12)]
289
+ else:
290
+ st.write(f"Error: get_default_value, variable not defined: {variable}")
291
+ return None
292
+
293
+ def initialize_session_state(self):
294
+ session_state_variables = ["openai_model", "messages", "stage", "query_embedding", "topic_matrix"]
295
+ for variable in session_state_variables:
296
+ if variable not in st.session_state:
297
+ st.session_state[variable] = self.get_default_value(variable)
298
+
299
+ def display_selection_options(self):
300
+ modes = ['college student', 'middle school student', '1st grade student', 'high school student', 'grad student']
301
+ self.selected_mode = st.selectbox("Select your education level:", modes)
302
+
303
+ def display_chat_messages(self):
304
+ # Display existing chat messages
305
+ for message in st.session_state.messages:
306
+ if message["role"] == "assistant":
307
+ with st.chat_message("assistant"):
308
+ st.write(message["content"])
309
+ if message['image'] != None:
310
+ st.image(message['image'])
311
+ else:
312
+ with st.chat_message("user"):
313
+ st.write(message["content"])
314
+
315
+ def main_loop(self):
316
+ st.title("Machine Learning Text Guide Chatbot")
317
+
318
+ self.initialize_session_state()
319
+ self.display_selection_options()
320
+ self.display_chat_messages()
321
+
322
+ ### Wait for user input ###
323
+ if user_query := st.chat_input("What would you like to chat about?"):
324
+ with st.chat_message("user"): st.write(user_query)
325
+
326
+ with st.chat_message("assistant"):
327
+ response = ""; topic = None; image = None; caption = ""; st.session_state.stage = 0
328
+
329
+ # Get the current conversation with new user query to check for users' intention
330
+ conversation = self.get_conversation()
331
+ user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}"
332
+ classify_query = self.classify_agent.classify_query(user_query_plus_conversation)
333
+
334
+ if classify_query == general_greeting_num:
335
+ response = "How can I assist you today?"
336
+ elif classify_query == general_question_num:
337
+ response = "Please ask a question about Machine Learning or Python Code."
338
+ elif classify_query == obnoxious_num:
339
+ response = "Please dont be obnoxious."
340
+ elif classify_query == progress_num:
341
+ self.ca.display_topic_matrix_star()
342
+ elif classify_query == default_num:
343
+ response = "I'm not sure how to respond to that."
344
+ elif classify_query == machine_learning_num:
345
+ response = self.query_agent.process_query_response(self, user_query, 'ml')
346
+ st.session_state.query_embedding = self.query_agent.get_openai_embedding(user_query)
347
+ image, caption = self.answering_agent.generate_image(response)
348
+ topic = "ml"
349
+ st.session_state.stage = 1
350
+ elif classify_query == python_code_num:
351
+ response = self.query_agent.process_query_response(self, user_query, 'python')
352
+ image, caption = self.answering_agent.generate_image(response)
353
+ topic = "python"
354
+ st.session_state.stage = 0
355
+ else:
356
+ response = "I'm not sure how to respond to that."
357
+
358
+ # ... (get AI response and display it)
359
+ st.write(response)
360
+ if image and caption != "": st.image(image, caption)
361
+
362
+ st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None})
363
+ st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})
364
+
365
+ if st.session_state.stage == 1: ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ###
366
+
367
+ # it looks like after we hit st.button, we go back to the top of the st.session_state.stage == 1 loop, and we lose the query_embedding_local
368
+
369
+ # we use st.session_state.query_embedding to get the concept index
370
+ top_concept_index = self.ca.find_top_concept_index(st.session_state.query_embedding)
371
+ concept_name = self.ca._df['concept'][top_concept_index]
372
+
373
+ st.write(f"Your question is associated to the Fundamental Concept in Machine Learning: {concept_name}.\n\n")
374
+ st.write(f"Here are some topics you can explore to help you learn about {concept_name}, pick one.")
375
+
376
+ response = ""; image = None; topic = ""
377
+ topic0_desc = self.ca._df['topic_0_desc'][top_concept_index]
378
+ topic1_desc = self.ca._df['topic_1_desc'][top_concept_index]
379
+ topic2_desc = self.ca._df['topic_2_desc'][top_concept_index]
380
+ topic3_desc = self.ca._df['topic_3_desc'][top_concept_index]
381
+ topic4_desc = self.ca._df['topic_4_desc'][top_concept_index]
382
+
383
+ matrix_row = st.session_state.topic_matrix[top_concept_index]
384
+
385
+ if (matrix_row[0] == 0 and st.session_state.stage):
386
+ if st.button(topic0_desc): process_button_click(self, 0, topic0_desc, top_concept_index)
387
+ if (matrix_row[1] == 0 and st.session_state.stage):
388
+ if st.button(topic1_desc): process_button_click(self, 1, topic1_desc, top_concept_index)
389
+ if (matrix_row[2] == 0 and st.session_state.stage):
390
+ if st.button(topic2_desc): process_button_click(self, 2, topic2_desc, top_concept_index)
391
+ if (matrix_row[3] == 0 and st.session_state.stage):
392
+ if st.button(topic3_desc): process_button_click(self, 3, topic3_desc, top_concept_index)
393
+ if (matrix_row[4] == 0 and st.session_state.stage):
394
+ if st.button(topic4_desc): process_button_click(self, 4, topic4_desc, top_concept_index)
395
+
396
+ def process_button_click(head, button_index, topic_desc, top_concept_index):
397
+ with st.chat_message("user"): st.write(topic_desc)
398
+
399
+ # we then assign to st.session_state.query_embedding the embedding for the topic_desc
400
+ st.session_state.query_embedding = head.query_agent.get_openai_embedding(topic_desc)
401
+
402
+ topic_text_index = 'topic_' + str(button_index)
403
+ topic_text = head.ca._df[topic_text_index][top_concept_index]
404
+
405
+ response = head.answering_agent.generate_response_topic(topic_desc, topic_text, head.get_conversation(), head.selected_mode)
406
+ image, caption = head.answering_agent.generate_image(topic_text)
407
+ topic = topic_desc
408
+
409
+ st.session_state.topic_matrix[top_concept_index][button_index] += 1
410
+
411
+ st.write(response)
412
+ if image and caption != "": st.image(image, caption)
413
+
414
+ # ... (add response & image to message)
415
+ st.session_state.messages.append({"role": "user", "content": topic_desc, "topic": "ml", "image": None})
416
+ st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image})
417
+
418
+ st.session_state.stage = 0
419
+
420
+
421
+ if __name__ == "__main__":
422
+ head_agent = Head_Agent(OPENAI_KEY, pc_apikey)
423
+ DEBUG = False
424
+
425
+ head_agent.main_loop()
426
+ #main()