ww0 commited on
Commit
a2db357
1 Parent(s): b15fe6b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +445 -0
app.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from langchain import PromptTemplate
4
+ from langchain.document_loaders import JSONLoader
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
+ from langchain_core.messages import AIMessage, HumanMessage
10
+ from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
11
+
12
+ import yt_dlp
13
+ import json
14
+ import gc
15
+ import gradio as gr
16
+ from gradio_client import Client
17
+ import datetime
18
+
19
+
20
+ whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
21
+ whisper_jax = Client(whisper_jax_api)
22
+
23
+ def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
24
+ text, runtime = whisper_jax.predict(
25
+ audio_path,
26
+ task,
27
+ return_timestamps,
28
+ api_name='/predict_1',
29
+ )
30
+ return text
31
+
32
+
33
+
34
+ def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list:
35
+
36
+ '''
37
+ Returns a list of dict with keys 'start', 'end', 'text'
38
+ The segments from whisper jax output are merged to form paragraphs.
39
+
40
+ `max_duration` controls how many seconds of the audio's transcripts are merged
41
+
42
+ For example, if `max_duration`=60, in the final output, each segment is roughly
43
+ 60 seconds.
44
+ '''
45
+
46
+ final_output = []
47
+ max_duration = datetime.timedelta(seconds=max_duration)
48
+ segments = whisper_jax_output.split('\n')
49
+ current_start = datetime.datetime.strptime('00:00', '%M:%S')
50
+ current_text = ''
51
+
52
+ for i, seg in enumerate(segments):
53
+
54
+ text = seg.split(']')[-1].strip()
55
+ end = datetime.datetime.strptime(seg[14:19], '%M:%S')
56
+
57
+ if (end - current_start > max_duration) or (i == len(segments)-1):
58
+ # If we have exceeded max duration or
59
+ # at the last segment, stop merging
60
+ # and append to final_output
61
+ current_text += text
62
+ final_output.append({'start': current_start.strftime('%H:%M:%S'),
63
+ 'end': end.strftime('%H:%M:%S'),
64
+ 'text': current_text
65
+ })
66
+
67
+ # Update current start and text
68
+ current_start = end
69
+ current_text = ''
70
+
71
+ else:
72
+ # If we have not exceeded max duration,
73
+ # keep merging.
74
+ current_text += text
75
+
76
+ return final_output
77
+
78
+
79
+
80
+
81
+ audio_file_number = 1
82
+ def yt_audio_to_text(url: str,
83
+ max_duration: int = 60
84
+ ):
85
+
86
+ global audio_file_number
87
+ global progress
88
+ progress = gr.Progress()
89
+ progress(0.1)
90
+
91
+ with yt_dlp.YoutubeDL({'extract_audio': True,
92
+ 'format': 'bestaudio',
93
+ 'outtmpl': f'{audio_file_number}.mp3'}) as video:
94
+
95
+ info_dict = video.extract_info(url, download=False)
96
+ global video_title
97
+ video_title = info_dict['title']
98
+ video.download(url)
99
+
100
+ progress(0.4)
101
+ audio_file = f'{audio_file_number}.mp3'
102
+ audio_file_number += 1
103
+
104
+ result = transcribe_audio(audio_file, return_timestamps=True)
105
+ progress(0.7)
106
+
107
+ result = format_whisper_jax_output(result, max_duration=max_duration)
108
+ progress(0.9)
109
+
110
+ with open('audio.json', 'w') as f:
111
+ json.dump(result, f)
112
+
113
+
114
+
115
+ def metadata_func(record: dict, metadata: dict) -> dict:
116
+
117
+ metadata['start'] = record.get('start')
118
+ metadata['end'] = record.get('end')
119
+ metadata['source'] = metadata['start'] + '->' + metadata['end']
120
+
121
+ return metadata
122
+
123
+
124
+ def load_data():
125
+ loader = JSONLoader(
126
+ file_path='audio.json',
127
+ jq_schema='.[]',
128
+ content_key='text',
129
+ metadata_func=metadata_func
130
+ )
131
+
132
+ data = loader.load()
133
+
134
+ return data
135
+
136
+
137
+
138
+ embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
139
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
140
+ embedding_model_kwargs = {'device': device}
141
+
142
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
143
+ model_kwargs=embedding_model_kwargs)
144
+
145
+ def create_vectordb(data, k: int):
146
+ '''
147
+ `k` is the number of retrieved documents
148
+ '''
149
+
150
+ vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
151
+ retriever = vectordb.as_retriever(search_type='similarity',
152
+ search_kwargs={'k': k})
153
+
154
+ return vectordb, retriever
155
+
156
+
157
+ repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
158
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024})
159
+
160
+
161
+ # Map
162
+ map_template = """Summarise the following text:
163
+ {docs}
164
+
165
+ Answer:"""
166
+ map_prompt = PromptTemplate.from_template(map_template)
167
+ map_chain = LLMChain(llm=llm, prompt=map_prompt)
168
+
169
+
170
+ # Reduce
171
+ reduce_template = """The following is a set of summaries:
172
+ {docs}
173
+
174
+ Take these and distill it into a final, consolidated summary of the main themes.
175
+ Answer:"""
176
+
177
+ reduce_prompt = PromptTemplate.from_template(reduce_template)
178
+ reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
179
+
180
+ # Takes a list of documents, combines them into a single string, and passes this to llm
181
+ combine_documents_chain = StuffDocumentsChain(
182
+ llm_chain=reduce_chain, document_variable_name="docs"
183
+ )
184
+
185
+
186
+ # Combines and iteravely reduces the mapped documents
187
+ reduce_documents_chain = ReduceDocumentsChain(
188
+ # This is final chain that is called.
189
+ combine_documents_chain=combine_documents_chain,
190
+ # If documents exceed context for `StuffDocumentsChain`
191
+ collapse_documents_chain=combine_documents_chain,
192
+ # The maximum number of tokens to group documents into.
193
+ token_max=4000,
194
+ )
195
+
196
+
197
+ # Combining documents by mapping a chain over them, then combining results
198
+ map_reduce_chain = MapReduceDocumentsChain(
199
+ # Map chain
200
+ llm_chain=map_chain,
201
+ # Reduce chain
202
+ reduce_documents_chain=reduce_documents_chain,
203
+ # The variable name in the llm_chain to put the documents in
204
+ document_variable_name="docs",
205
+ # Return the results of the map steps in the output
206
+ return_intermediate_steps=False,
207
+ )
208
+
209
+ def get_summary():
210
+ summary = map_reduce_chain.run(data)
211
+ return summary
212
+
213
+
214
+ contextualise_q_prompt = PromptTemplate.from_template(
215
+ '''Given a chat history and the latest user question \
216
+ which might reference the chat history, formulate a standalone question \
217
+ which can be understood without the chat history. Do NOT answer the question, \
218
+ just reformulate it if needed and otherwise return it as is.
219
+
220
+ Chat history: {chat_history}
221
+
222
+ Question: {question}
223
+
224
+ Answer:
225
+ '''
226
+ )
227
+
228
+ contextualise_q_chain = contextualise_q_prompt | llm
229
+
230
+ standalone_prompt = PromptTemplate.from_template(
231
+ '''Given a chat history and the latest user question, \
232
+ identify whether the question is a standalone question or the question \
233
+ references the chat history. Answer 'yes' if the question is a standalone \
234
+ question, and 'no' if the question references the chat history. Do not \
235
+ answer anything other than 'yes' or 'no'.
236
+
237
+ Chat history:
238
+ {chat_history}
239
+
240
+ Question:
241
+ {question}
242
+
243
+ Answer:
244
+ '''
245
+ )
246
+
247
+ def format_output(answer: str) -> str:
248
+ # All lower case and remove all whitespace
249
+ return ''.join(answer.lower().split())
250
+
251
+ standalone_chain = standalone_prompt | llm | format_output
252
+
253
+
254
+ qa_prompt = PromptTemplate.from_template(
255
+ '''You are an assistant for question-answering tasks. \
256
+ ONLY use the following context to answer the question. \
257
+ Do NOT answer with information that is not contained in \
258
+ the context. If you don't know the answer, just say:\
259
+ "Sorry, I cannot find the answer to that question in the video."
260
+
261
+ Context:
262
+ {context}
263
+
264
+ Question:
265
+ {question}
266
+
267
+ Answer:
268
+ '''
269
+ )
270
+
271
+
272
+ def format_docs(docs: list) -> str:
273
+ '''
274
+ Combine documents
275
+ '''
276
+ global sources
277
+ sources = [doc.metadata['start'] for doc in docs]
278
+
279
+ return '\n\n'.join(doc.page_content for doc in docs)
280
+
281
+
282
+ def standalone_question(input_: dict) -> str:
283
+ '''
284
+ If the question is a not a standalone question, run contextualise_q_chain
285
+ '''
286
+ if input_['standalone']=='yes':
287
+ return contextualise_q_chain
288
+ else:
289
+ return input_['question']
290
+
291
+
292
+ def format_answer(answer: str,
293
+ n_sources: int=1,
294
+ timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:
295
+
296
+ if 'cannot find the answer' in answer:
297
+ return answer.strip()
298
+ else:
299
+ timestamps = filter_timestamps(n_sources, timestamp_interval)
300
+ answer_with_sources = (answer.strip()
301
+ + ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
302
+ )
303
+ return answer_with_sources
304
+
305
+
306
+ def filter_timestamps(n_sources: int,
307
+ timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
308
+ '''Returns a list of timestamps with length `n_sources`.
309
+ The timestamps are at least an `timestamp_interval` apart.
310
+ This prevents returning a list of timestamps that are too
311
+ close together.
312
+ '''
313
+ sorted_timestamps = sorted(sources)
314
+ output = [sorted_timestamps[0]]
315
+ i=1
316
+ while len(output)<n_sources:
317
+ timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')
318
+
319
+ try:
320
+ timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
321
+ except IndexError:
322
+ break
323
+
324
+ time_diff = timestamp2 - timestamp1
325
+
326
+ if time_diff>timestamp_interval:
327
+ output.append(str(timestamp2.time()))
328
+
329
+ i += 1
330
+
331
+ return output
332
+
333
+
334
+ def setup_rag(url):
335
+ '''Given a YouTube url, set up the vector database and the RAG chain.
336
+ '''
337
+
338
+ yt_audio_to_text(url)
339
+
340
+ global data
341
+ data = load_data()
342
+
343
+ global retriever
344
+ _, retriever = create_vectordb(data, k)
345
+
346
+ global rag_chain
347
+ rag_chain = (
348
+ RunnablePassthrough.assign(standalone=standalone_chain)
349
+ | {'question':standalone_question,
350
+ 'context':standalone_question|retriever|format_docs
351
+ }
352
+ | qa_prompt
353
+ | llm
354
+ )
355
+
356
+ return url
357
+
358
+
359
+
360
+ def get_answer(question: str) -> str:
361
+
362
+ global chat_history
363
+
364
+ ai_msg = rag_chain.invoke({'question': question,
365
+ 'chat_history': chat_history
366
+ })
367
+
368
+ answer = format_answer(ai_msg, n_sources, timestamp_interval)
369
+
370
+ chat_history.extend([HumanMessage(content=question),
371
+ AIMessage(content=answer)])
372
+
373
+ return answer
374
+
375
+
376
+
377
+ # Chatbot settings
378
+ n_sources = 3 # Number of sources provided in the answer
379
+ k = 5 # Number of documents returned by the retriever
380
+ timestamp_interval = datetime.timedelta(minutes=2)
381
+ default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'
382
+
383
+
384
+ def greet():
385
+ summary = get_summary()
386
+ global gradio_chat_history
387
+ summary_message = f'Here is a summary of the video "{video_title}":'
388
+ gradio_chat_history.append((None, summary_message))
389
+ gradio_chat_history.append((None, summary))
390
+ greeting_message = f'You can ask me anything about the video. I will do my best to answer!'
391
+ gradio_chat_history.append((None, greeting_message))
392
+ return gradio_chat_history
393
+
394
+ def question(user_message):
395
+ global gradio_chat_history
396
+ gradio_chat_history.append((user_message, None))
397
+ return gradio_chat_history
398
+
399
+ def respond():
400
+ global gradio_chat_history
401
+ ai_message = get_answer(gradio_chat_history[-1][0])
402
+ gradio_chat_history.append((None, ai_message))
403
+ return '', gradio_chat_history
404
+
405
+ def clear_chat_history():
406
+ global chat_history
407
+ global gradio_chat_history
408
+ chat_history = []
409
+ gradio_chat_history = []
410
+
411
+
412
+ chat_history = []
413
+ gradio_chat_history = []
414
+
415
+ with gr.Blocks() as demo:
416
+
417
+ # Structure
418
+ with gr.Row():
419
+ url_input = gr.Textbox(value=default_youtube_url,
420
+ label='YouTube URL',
421
+ scale=5)
422
+ button = gr.Button(value='Go', scale=1)
423
+
424
+ chatbot = gr.Chatbot()
425
+ user_message = gr.Textbox(label='Ask a question:')
426
+ clear = gr.ClearButton([user_message, chatbot])
427
+
428
+
429
+ # Actions
430
+ button.click(setup_rag,
431
+ inputs=[url_input],
432
+ outputs=[url_input],
433
+ trigger_mode='once').then(greet,
434
+ inputs=[],
435
+ outputs=[chatbot])
436
+
437
+ user_message.submit(question,
438
+ inputs=[user_message],
439
+ outputs=[chatbot]).then(respond,
440
+ inputs=[],
441
+ outputs=[user_message, chatbot])
442
+
443
+ clear.click(clear_chat_history)
444
+
445
+ demo.launch()