AlanFeder commited on
Commit
9ee33af
1 Parent(s): a974aa1

Upload 2 files

Browse files
Files changed (2) hide show
  1. b1_all_rag_fns.py +426 -0
  2. gradio_served1.py +57 -0
b1_all_rag_fns.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+
4
+ import numpy as np
5
+ import requests
6
+
7
+
8
+ def import_talk_info() -> list[dict]:
9
+ """
10
+ Import talk info from file.
11
+
12
+ Returns:
13
+ list[dict]: A list of talk info.
14
+ """
15
+
16
+ target_file_url = "https://raw.githubusercontent.com/AlanFeder/rgov-2024/main/data/rgov_talks.json"
17
+
18
+ response = requests.get(target_file_url)
19
+ response.raise_for_status() # Ensure we notice if the download fails
20
+ return response.json()
21
+
22
+
23
+ def import_embeds() -> np.ndarray:
24
+ """
25
+ Import embeddings from file.
26
+
27
+ Returns:
28
+ np.ndarray: The embeddings.
29
+ """
30
+
31
+ target_file_url = (
32
+ "https://raw.githubusercontent.com/AlanFeder/rgov-2024/main/data/embeds.csv"
33
+ )
34
+
35
+ response = requests.get(target_file_url)
36
+ response.raise_for_status()
37
+
38
+ # Use numpy.genfromtxt to read the CSV data from the response text
39
+ data = np.genfromtxt(
40
+ io.StringIO(response.text), delimiter=","
41
+ ) # skip header if needed
42
+
43
+ return data
44
+
45
+
46
+ def import_data() -> tuple[list[dict], np.ndarray]:
47
+ # """
48
+ # Import data from files.
49
+
50
+ # Returns:
51
+ # tuple[list[dict], dict]: A tuple containing the talk info and embeddings.
52
+ # """
53
+
54
+ talk_info = import_talk_info()
55
+ embeds = import_embeds()
56
+
57
+ return talk_info, embeds
58
+
59
+
60
+ def do_1_embed(lt: str, oai_api_key: str) -> np.ndarray:
61
+ """
62
+ Generate embeddings using the OpenAI API for a single text.
63
+
64
+ Args:
65
+ lt (str): A text to generate embeddings for.
66
+ emb_client (OpenAI): The embedding API client (OpenAI).
67
+
68
+ Returns:
69
+ np.ndarray: The generated embeddings.
70
+ """
71
+ # OpenAI API endpoint for embeddings
72
+ url = "https://api.openai.com/v1/embeddings"
73
+
74
+ # Headers for the API request
75
+ headers = {
76
+ "Content-Type": "application/json",
77
+ "Authorization": f"Bearer {oai_api_key}",
78
+ }
79
+
80
+ # Request payload
81
+ payload = {"input": lt, "model": "text-embedding-3-small"}
82
+
83
+ # Make the API request
84
+ response = requests.post(url, headers=headers, data=json.dumps(payload))
85
+
86
+ # Check if the request was successful
87
+ if response.status_code == 200:
88
+ # Parse the JSON response
89
+ embed_response = response.json()
90
+
91
+ # Extract the embedding
92
+ here_embed = np.array(embed_response["data"][0]["embedding"])
93
+
94
+ return here_embed
95
+ else:
96
+ print(f"Error: {response.status_code}")
97
+ print(response.text)
98
+
99
+
100
+ def do_sort(
101
+ embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str]
102
+ ) -> list[dict[str, str | float]]:
103
+ """
104
+ Sort documents based on their cosine similarity to the query embedding.
105
+
106
+ Args:
107
+ embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings.
108
+ arr_q (np.ndarray): Query embedding.
109
+
110
+ Returns:
111
+ pd.DataFrame: Sorted dataframe containing document IDs and similarity scores.
112
+ """
113
+
114
+ # Calculate cosine similarities between query embedding and document embeddings
115
+ cos_sims = np.dot(embed_talks, embed_q)
116
+
117
+ # Get the indices of the best matching video IDs
118
+ best_match_video_ids = np.argsort(-cos_sims)
119
+
120
+ # Get the sorted video IDs based on the best match indices
121
+ sorted_vids = [
122
+ {"id0": list_talk_ids[i], "score": -cs}
123
+ for i, cs in zip(best_match_video_ids, np.sort(-cos_sims))
124
+ ]
125
+
126
+ return sorted_vids
127
+
128
+
129
+ def limit_docs(
130
+ sorted_vids: list[dict],
131
+ talk_info: dict,
132
+ n_results: int,
133
+ ) -> list[dict]:
134
+ """
135
+ Limit the retrieved documents based on a score threshold and return the top documents.
136
+
137
+ Args:
138
+ df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores.
139
+ df_talks (pd.DataFrame): Dataframe containing talk information.
140
+ n_results (int): Number of top documents to retrieve.
141
+ transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID.
142
+
143
+ Returns:
144
+ dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text.
145
+ """
146
+
147
+ # Get the top n_results documents
148
+ top_vids = sorted_vids[:n_results]
149
+
150
+ # Get the top score and calculate the score threshold
151
+ top_score = top_vids[0]["score"]
152
+ score_thresh = max(min(0.6, top_score - 0.2), 0.2)
153
+
154
+ # Filter the top documents based on the score threshold
155
+ keep_texts = []
156
+ for my_vid in top_vids:
157
+ if my_vid["score"] >= score_thresh:
158
+ vid_data = talk_info[my_vid["id0"]]
159
+ vid_data = {**vid_data, **my_vid}
160
+ keep_texts.append(vid_data)
161
+
162
+ return keep_texts
163
+
164
+
165
+ def do_retrieval(
166
+ query0: str,
167
+ n_results: int,
168
+ oai_api_key: str,
169
+ embeds: np.ndarray,
170
+ talk_info: dict[str, str | int],
171
+ ) -> list[dict]:
172
+ """
173
+ Retrieve relevant documents based on the user's query.
174
+
175
+ Args:
176
+ query0 (str): The user's query.
177
+ n_results (int): The number of documents to retrieve.
178
+ api_client (OpenAI): The API client (OpenAI) for generating embeddings.
179
+
180
+ Returns:
181
+ dict[str, dict]: The retrieved documents.
182
+ """
183
+ try:
184
+ # Generate embeddings for the query
185
+ arr_q = do_1_embed(query0, oai_api_key=oai_api_key)
186
+
187
+ # reformat to be like old version
188
+ talk_ids = [ti["id0"] for ti in talk_info]
189
+ talk_info = {ti["id0"]: ti for ti in talk_info}
190
+
191
+ # Sort documents based on their cosine similarity to the query embedding
192
+ sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids)
193
+
194
+ # Limit the retrieved documents based on a score threshold
195
+ keep_texts = limit_docs(
196
+ sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results
197
+ )
198
+
199
+ return keep_texts
200
+ except Exception as e:
201
+ raise e
202
+
203
+
204
+ SYSTEM_PROMPT = """
205
+ You are an AI assistant that helps answer questions by searching through video transcripts.
206
+ I have retrieved the transcripts most likely to answer the user's question.
207
+ Carefully read through the transcripts to find information that helps answer the question.
208
+ Be brief - your response should not be more than two paragraphs.
209
+ Only use information directly stated in the provided transcripts to answer the question.
210
+ Do not add any information or make any claims that are not explicitly supported by the transcripts.
211
+ If the transcripts do not contain enough information to answer the question, state that you do not have enough information to provide a complete answer.
212
+ Format the response clearly. If only one of the transcripts answers the question, don't reference the other and don't explain why its content is irrelevant.
213
+ Do not speak in the first person. DO NOT write a letter, make an introduction, or salutation.
214
+ Reference the speaker's name when you say what they said.
215
+ """
216
+
217
+
218
+ def set_messages(system_prompt: str, user_prompt: str) -> list[dict[str, str]]:
219
+ """
220
+ Set the messages for the chat completion.
221
+
222
+ Args:
223
+ system_prompt (str): The system prompt.
224
+ user_prompt (str): The user prompt.
225
+
226
+ Returns:
227
+ tuple[list[dict[str, str]], int]: A tuple containing the messages and the total number of input tokens.
228
+ """
229
+ messages1 = [
230
+ {"role": "system", "content": system_prompt},
231
+ {"role": "user", "content": user_prompt},
232
+ ]
233
+
234
+ return messages1
235
+
236
+
237
+ def make_user_prompt(question: str, keep_texts: list[dict]) -> str:
238
+ """
239
+ Create the user prompt based on the question and the retrieved transcripts.
240
+
241
+ Args:
242
+ question (str): The user's question.
243
+ keep_texts (dict[str, dict[str, str]]): The retrieved transcripts.
244
+
245
+ Returns:
246
+ str: The user prompt.
247
+ """
248
+ user_prompt = f"""
249
+ Question: {question}
250
+ ==============================
251
+ """
252
+ if len(keep_texts) > 0:
253
+ list_strs = []
254
+ for i, tx_val in enumerate(keep_texts):
255
+ text0 = tx_val["transcript"]
256
+ speaker_name = tx_val["Speaker"]
257
+ list_strs.append(
258
+ f"Video Transcript {i+1}\nSpeaker: {speaker_name}\n{text0}"
259
+ )
260
+ user_prompt += "\n-------\n".join(list_strs)
261
+ user_prompt += """
262
+ ==============================
263
+ After analyzing the above video transcripts, please provide a helpful answer to my question. Remember to stay within two paragraphs
264
+ Address the response to me directly. Do not use any information not explicitly supported by the transcripts. Remember to reference the speaker's name."""
265
+ else:
266
+ # If no relevant transcripts are found, generate a default response
267
+ user_prompt += "No relevant video transcripts were found. Please just return a result that says something like 'I'm sorry, but the answer to {Question} was not found in the transcripts from the R/Gov Conference'"
268
+ # logger.info(f'User prompt: {user_prompt}')
269
+ return user_prompt
270
+
271
+
272
+ def parse_1_query_stream(response):
273
+ # Check if the request was successful
274
+ if response.status_code == 200:
275
+ for line in response.iter_lines():
276
+ if line:
277
+ line = line.decode("utf-8")
278
+ if line.startswith("data: "):
279
+ data = line[6:] # Remove 'data: ' prefix
280
+ if data != "[DONE]":
281
+ try:
282
+ chunk = json.loads(data)
283
+ content = chunk["choices"][0]["delta"].get("content", "")
284
+ if content:
285
+ yield content
286
+ except json.JSONDecodeError:
287
+ yield f"Error decoding JSON: {data}"
288
+ else:
289
+ yield f"Error: {response.status_code}\n{response.text}"
290
+
291
+
292
+ def parse_1_query_no_stream(response):
293
+ if response.status_code == 200:
294
+ try:
295
+ response1 = response.json()
296
+ completion = response1["choices"][0]["message"]["content"]
297
+ return completion
298
+ except json.JSONDecodeError:
299
+ return f"Error decoding JSON: {response.text}"
300
+ else:
301
+ return f"Error: {response.status_code}\n{response.text}"
302
+
303
+
304
+ def do_1_query(
305
+ messages1: list[dict[str, str]], oai_api_key: str, stream: bool, model_name: str
306
+ ):
307
+ """
308
+ Generate a response using the specified chat completion model.
309
+
310
+ Args:
311
+ messages1 (list[dict[str, str]]): The messages for the chat completion.
312
+ gen_client (OpenAI): The generation client (OpenAI).
313
+ """
314
+
315
+ # OpenAI API endpoint for chat completions
316
+ url = "https://api.openai.com/v1/chat/completions"
317
+
318
+ # Your OpenAI API key
319
+ # Headers for the API request
320
+ headers = {
321
+ "Content-Type": "application/json",
322
+ "Authorization": f"Bearer {oai_api_key}",
323
+ }
324
+ if stream:
325
+ headers["Accept"] = "text/event-stream" # Required for streaming
326
+
327
+ # Model to use
328
+ model1 = model_name
329
+
330
+ # Request payload
331
+ payload = {
332
+ "model": model1,
333
+ "messages": messages1,
334
+ "seed": 18,
335
+ "temperature": 0,
336
+ "stream": stream,
337
+ }
338
+
339
+ # Make the API request
340
+ response = requests.post(
341
+ url, headers=headers, data=json.dumps(payload), stream=stream
342
+ )
343
+
344
+ if stream:
345
+ response1 = parse_1_query_stream(response)
346
+ else:
347
+ # Check if the request was successful
348
+ response1 = parse_1_query_no_stream(response)
349
+
350
+ return response1
351
+
352
+
353
+ def do_generation(
354
+ query1: str, keep_texts: list[dict], oai_api_key: str, stream: bool, model_name: str
355
+ ):
356
+ """
357
+ Generate the chatbot response using the specified generation client.
358
+
359
+ Args:
360
+ query1 (str): The user's query.
361
+ keep_texts (dict[str, dict[str, str]]): The retrieved relevant texts.
362
+ gen_client (OpenAI): The generation client (OpenAI).
363
+
364
+ Returns:
365
+ tuple[Stream, int]: A tuple containing the generated response stream and the number of prompt tokens.
366
+ """
367
+ user_prompt = make_user_prompt(query1, keep_texts=keep_texts)
368
+ messages1 = set_messages(SYSTEM_PROMPT, user_prompt)
369
+ response = do_1_query(
370
+ messages1, oai_api_key=oai_api_key, stream=stream, model_name=model_name
371
+ )
372
+
373
+ return response
374
+
375
+
376
+ def calc_cost(
377
+ prompt_tokens: int, completion_tokens: int, embedding_tokens: int
378
+ ) -> float:
379
+ """
380
+ Calculate the cost in cents based on the number of prompt, completion, and embedding tokens.
381
+
382
+ Args:
383
+ prompt_tokens (int): The number of tokens in the prompt.
384
+ completion_tokens (int): The number of tokens in the completion.
385
+ embedding_tokens (int): The number of tokens in the embedding.
386
+
387
+ Returns:
388
+ float: The cost in cents.
389
+ """
390
+ prompt_cost = prompt_tokens / 2000
391
+ completion_cost = 3 * completion_tokens / 2000
392
+ embedding_cost = embedding_tokens / 500000
393
+
394
+ cost_cents = prompt_cost + completion_cost + embedding_cost
395
+
396
+ return cost_cents
397
+
398
+
399
+ def do_rag(
400
+ user_input: str,
401
+ oai_api_key: str,
402
+ model_name: str,
403
+ stream: bool = False,
404
+ n_results: int = 3,
405
+ ):
406
+ # Load the data
407
+ talk_info, embeds = import_data()
408
+ # Load the model
409
+
410
+ retrieved_docs = do_retrieval(
411
+ query0=user_input,
412
+ n_results=n_results,
413
+ oai_api_key=oai_api_key,
414
+ embeds=embeds,
415
+ talk_info=talk_info,
416
+ )
417
+
418
+ response = do_generation(
419
+ query1=user_input,
420
+ keep_texts=retrieved_docs,
421
+ model_name=model_name,
422
+ oai_api_key=oai_api_key,
423
+ stream=stream,
424
+ )
425
+
426
+ return response, retrieved_docs
gradio_served1.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import gradio as gr
5
+
6
+ # Get the directory of the current script
7
+ current_dir = os.path.dirname(__file__)
8
+
9
+ # Move up to the parent directory and then to the cousin folder
10
+ cousin_folder = os.path.join(current_dir, "..", "b1_rag_fns")
11
+
12
+ # Add cousin folder to sys.path so it can be imported
13
+ sys.path.append(os.path.abspath(cousin_folder))
14
+
15
+ from b1_all_rag_fns import do_rag
16
+ from dotenv import load_dotenv
17
+
18
+
19
+ def gr_ch_if(user_input: str, history):
20
+ oai_api_key = os.getenv("OPENAI_API_KEY")
21
+ response, _ = do_rag(
22
+ user_input,
23
+ stream=False,
24
+ n_results=3,
25
+ model_name="gpt-4o-mini",
26
+ oai_api_key=oai_api_key,
27
+ )
28
+ return response
29
+
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.ChatInterface(
33
+ fn=gr_ch_if,
34
+ # type="messages",
35
+ title="Use Gradio to Run RAG on the previous R/Gov Talks - Chat Interface 1",
36
+ )
37
+
38
+ # Add the static markdown at the bottom
39
+ gr.Markdown(
40
+ """
41
+ This Gradio app was created for Alan Feder's [talk at the 2024 R/Gov Conference](https://rstats.ai/gov.html). \n\n The Github repository that houses all the code is [here](https://github.com/AlanFeder/rgov-2024) -- feel free to fork it and use it on your own!
42
+ """
43
+ )
44
+ gr.Divider()
45
+ gr.Subheader("Contact me!")
46
+ gr.Image("AJF_Headshot.jpg", width=60)
47
+ gr.Markdown(
48
+ """
49
+ [Email](mailto:AlanFeder@gmail.com) | [Website](https://www.alanfeder.com/) | [LinkedIn](https://www.linkedin.com/in/alanfeder/) | [GitHub](https://github.com/AlanFeder)
50
+ """
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch(
55
+ share=True,
56
+ favicon_path="https://raw.githubusercontent.com/AlanFeder/rgov-2024/refs/heads/main/favicon_io/favicon.ico",
57
+ )