ww0 commited on
Commit
50be202
1 Parent(s): 871d1d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -175
app.py CHANGED
@@ -1,14 +1,15 @@
1
  import torch
2
 
3
- from langchain import PromptTemplate
4
- from langchain.document_loaders import JSONLoader
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.vectorstores import Chroma
7
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
  from langchain_core.messages import AIMessage, HumanMessage
10
  from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
11
- from langchain.llms import HuggingFaceHub
 
12
 
13
  import yt_dlp
14
  import json
@@ -16,12 +17,18 @@ import gc
16
  import gradio as gr
17
  from gradio_client import Client
18
  import datetime
 
 
 
19
 
20
 
21
  whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
22
  whisper_jax = Client(whisper_jax_api)
23
 
24
- def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
 
 
 
25
  text, runtime = whisper_jax.predict(
26
  audio_path,
27
  task,
@@ -32,17 +39,18 @@ def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
32
 
33
 
34
 
35
- def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list:
36
 
37
- '''
38
- Returns a list of dict with keys 'start', 'end', 'text'
 
 
39
  The segments from whisper jax output are merged to form paragraphs.
40
 
41
  `max_duration` controls how many seconds of the audio's transcripts are merged
42
 
43
  For example, if `max_duration`=60, in the final output, each segment is roughly
44
  60 seconds.
45
- '''
46
 
47
  final_output = []
48
  max_duration = datetime.timedelta(seconds=max_duration)
@@ -53,25 +61,30 @@ def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) ->
53
  for i, seg in enumerate(segments):
54
 
55
  text = seg.split(']')[-1].strip()
56
- end = datetime.datetime.strptime(seg[14:19], '%M:%S')
57
 
58
- if (end - current_start > max_duration) or (i == len(segments)-1):
59
- # If we have exceeded max duration or
60
- # at the last segment, stop merging
61
- # and append to final_output
 
 
 
 
 
 
62
  current_text += text
63
- final_output.append({'start': current_start.strftime('%H:%M:%S'),
64
- 'end': end.strftime('%H:%M:%S'),
65
- 'text': current_text
66
- })
 
67
 
68
  # Update current start and text
69
  current_start = end
70
  current_text = ''
71
 
72
  else:
73
- # If we have not exceeded max duration,
74
- # keep merging.
75
  current_text += text
76
 
77
  return final_output
@@ -79,6 +92,7 @@ def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) ->
79
 
80
 
81
 
 
82
  audio_file_number = 1
83
  def yt_audio_to_text(url: str,
84
  max_duration: int = 60
@@ -91,7 +105,8 @@ def yt_audio_to_text(url: str,
91
 
92
  with yt_dlp.YoutubeDL({'extract_audio': True,
93
  'format': 'bestaudio',
94
- 'outtmpl': f'{audio_file_number}.mp3'}) as video:
 
95
 
96
  info_dict = video.extract_info(url, download=False)
97
  global video_title
@@ -113,11 +128,12 @@ def yt_audio_to_text(url: str,
113
 
114
 
115
 
 
116
  def metadata_func(record: dict, metadata: dict) -> dict:
117
 
118
  metadata['start'] = record.get('start')
119
  metadata['end'] = record.get('end')
120
- metadata['source'] = metadata['start'] + '->' + metadata['end']
121
 
122
  return metadata
123
 
@@ -136,6 +152,7 @@ def load_data():
136
 
137
 
138
 
 
139
  embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
140
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
141
  embedding_model_kwargs = {'device': device}
@@ -144,9 +161,9 @@ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
144
  model_kwargs=embedding_model_kwargs)
145
 
146
  def create_vectordb(data, k: int):
147
- '''
148
  `k` is the number of retrieved documents
149
- '''
150
 
151
  vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
152
  retriever = vectordb.as_retriever(search_type='similarity',
@@ -155,8 +172,11 @@ def create_vectordb(data, k: int):
155
  return vectordb, retriever
156
 
157
 
 
 
158
  repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
159
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_length': 1024})
 
160
 
161
 
162
  # Map
@@ -168,11 +188,14 @@ map_prompt = PromptTemplate.from_template(map_template)
168
  map_chain = LLMChain(llm=llm, prompt=map_prompt)
169
 
170
 
 
171
  # Reduce
172
  reduce_template = """The following is a set of summaries:
173
  {docs}
174
 
175
- Take these and distill it into a final, consolidated summary of the main themes.
 
 
176
  Answer:"""
177
 
178
  reduce_prompt = PromptTemplate.from_template(reduce_template)
@@ -191,7 +214,7 @@ reduce_documents_chain = ReduceDocumentsChain(
191
  # If documents exceed context for `StuffDocumentsChain`
192
  collapse_documents_chain=combine_documents_chain,
193
  # The maximum number of tokens to group documents into.
194
- token_max=4000,
195
  )
196
 
197
 
@@ -204,18 +227,20 @@ map_reduce_chain = MapReduceDocumentsChain(
204
  # The variable name in the llm_chain to put the documents in
205
  document_variable_name="docs",
206
  # Return the results of the map steps in the output
207
- return_intermediate_steps=False,
208
  )
209
 
210
- def get_summary():
211
- summary = map_reduce_chain.run(data)
212
- return summary
 
 
213
 
214
 
215
  contextualise_q_prompt = PromptTemplate.from_template(
216
- '''Given a chat history and the latest user question \
217
  which might reference the chat history, formulate a standalone question \
218
- which can be understood without the chat history. Do NOT answer the question, \
219
  just reformulate it if needed and otherwise return it as is.
220
 
221
  Chat history: {chat_history}
@@ -223,13 +248,15 @@ contextualise_q_prompt = PromptTemplate.from_template(
223
  Question: {question}
224
 
225
  Answer:
226
- '''
227
  )
228
 
229
  contextualise_q_chain = contextualise_q_prompt | llm
230
 
 
 
231
  standalone_prompt = PromptTemplate.from_template(
232
- '''Given a chat history and the latest user question, \
233
  identify whether the question is a standalone question or the question \
234
  references the chat history. Answer 'yes' if the question is a standalone \
235
  question, and 'no' if the question references the chat history. Do not \
@@ -242,7 +269,7 @@ standalone_prompt = PromptTemplate.from_template(
242
  {question}
243
 
244
  Answer:
245
- '''
246
  )
247
 
248
  def format_output(answer: str) -> str:
@@ -252,8 +279,10 @@ def format_output(answer: str) -> str:
252
  standalone_chain = standalone_prompt | llm | format_output
253
 
254
 
 
 
255
  qa_prompt = PromptTemplate.from_template(
256
- '''You are an assistant for question-answering tasks. \
257
  ONLY use the following context to answer the question. \
258
  Do NOT answer with information that is not contained in \
259
  the context. If you don't know the answer, just say:\
@@ -266,181 +295,230 @@ qa_prompt = PromptTemplate.from_template(
266
  {question}
267
 
268
  Answer:
269
- '''
270
  )
271
 
272
 
273
- def format_docs(docs: list) -> str:
274
- '''
275
- Combine documents
276
- '''
277
- global sources
278
- sources = [doc.metadata['start'] for doc in docs]
279
-
280
- return '\n\n'.join(doc.page_content for doc in docs)
281
-
282
-
283
- def standalone_question(input_: dict) -> str:
284
- '''
285
- If the question is a not a standalone question, run contextualise_q_chain
286
- '''
287
- if input_['standalone']=='yes':
288
- return contextualise_q_chain
289
- else:
290
- return input_['question']
291
-
292
-
293
- def format_answer(answer: str,
294
- n_sources: int=1,
295
- timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:
296
-
297
- if 'cannot find the answer' in answer:
298
- return answer.strip()
299
- else:
300
- timestamps = filter_timestamps(n_sources, timestamp_interval)
301
- answer_with_sources = (answer.strip()
302
- + ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
303
- )
304
- return answer_with_sources
305
-
306
-
307
- def filter_timestamps(n_sources: int,
308
- timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
309
- '''Returns a list of timestamps with length `n_sources`.
310
- The timestamps are at least an `timestamp_interval` apart.
311
- This prevents returning a list of timestamps that are too
312
- close together.
313
- '''
314
- sorted_timestamps = sorted(sources)
315
- output = [sorted_timestamps[0]]
316
- i=1
317
- while len(output)<n_sources:
318
- timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')
319
 
320
- try:
321
- timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
322
- except IndexError:
323
- break
324
 
325
- time_diff = timestamp2 - timestamp1
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- if time_diff>timestamp_interval:
328
- output.append(str(timestamp2.time()))
329
 
330
- i += 1
 
 
331
 
332
- return output
333
 
 
334
 
335
- def setup_rag(url):
336
- '''Given a YouTube url, set up the vector database and the RAG chain.
337
- '''
338
 
339
- yt_audio_to_text(url)
340
 
341
- global data
342
- data = load_data()
 
 
 
 
 
 
343
 
344
- global retriever
345
- _, retriever = create_vectordb(data, k)
346
 
347
- global rag_chain
348
- rag_chain = (
349
- RunnablePassthrough.assign(standalone=standalone_chain)
350
- | {'question':standalone_question,
351
- 'context':standalone_question|retriever|format_docs
352
- }
353
- | qa_prompt
354
- | llm
355
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- return url
 
 
 
 
358
 
 
359
 
 
 
360
 
361
- def get_answer(question: str) -> str:
362
 
363
- global chat_history
364
 
365
- ai_msg = rag_chain.invoke({'question': question,
366
- 'chat_history': chat_history
367
- })
368
 
369
- answer = format_answer(ai_msg, n_sources, timestamp_interval)
 
 
370
 
371
- chat_history.extend([HumanMessage(content=question),
372
- AIMessage(content=answer)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
- return answer
375
 
 
 
376
 
377
 
378
- # Chatbot settings
379
- n_sources = 3 # Number of sources provided in the answer
380
- k = 5 # Number of documents returned by the retriever
381
- timestamp_interval = datetime.timedelta(minutes=2)
382
- default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'
383
 
384
 
385
- def greet():
386
- summary = get_summary()
387
- global gradio_chat_history
388
- summary_message = f'Here is a summary of the video "{video_title}":'
389
- gradio_chat_history.append((None, summary_message))
390
- gradio_chat_history.append((None, summary))
391
- greeting_message = f'You can ask me anything about the video. I will do my best to answer!'
392
- gradio_chat_history.append((None, greeting_message))
393
- return gradio_chat_history
394
 
395
- def question(user_message):
396
- global gradio_chat_history
397
- gradio_chat_history.append((user_message, None))
398
- return gradio_chat_history
399
 
400
- def respond():
401
- global gradio_chat_history
402
- ai_message = get_answer(gradio_chat_history[-1][0])
403
- gradio_chat_history.append((None, ai_message))
404
- return '', gradio_chat_history
 
405
 
406
- def clear_chat_history():
407
- global chat_history
408
- global gradio_chat_history
409
- chat_history = []
410
- gradio_chat_history = []
411
 
412
 
413
- chat_history = []
414
- gradio_chat_history = []
 
 
 
 
 
 
 
 
 
415
 
416
- with gr.Blocks() as demo:
 
 
 
 
 
417
 
418
- # Structure
419
- with gr.Row():
420
- url_input = gr.Textbox(value=default_youtube_url,
421
- label='YouTube URL',
422
- scale=5)
423
- button = gr.Button(value='Go', scale=1)
424
 
425
- chatbot = gr.Chatbot()
426
- user_message = gr.Textbox(label='Ask a question:')
427
- clear = gr.ClearButton([user_message, chatbot])
428
 
 
429
 
430
- # Actions
431
- button.click(setup_rag,
432
- inputs=[url_input],
433
- outputs=[url_input],
434
- trigger_mode='once').then(greet,
435
- inputs=[],
436
- outputs=[chatbot])
437
 
438
- user_message.submit(question,
439
- inputs=[user_message],
440
- outputs=[chatbot]).then(respond,
441
- inputs=[],
442
- outputs=[user_message, chatbot])
443
 
444
- clear.click(clear_chat_history)
 
 
 
 
 
445
 
446
- demo.launch()
 
1
  import torch
2
 
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_community.document_loaders import JSONLoader
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
  from langchain_core.messages import AIMessage, HumanMessage
10
  from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
11
+ from langchain.memory.buffer_window import ConversationBufferWindowMemory
12
+ from langchain_community.llms import HuggingFaceHub
13
 
14
  import yt_dlp
15
  import json
 
17
  import gradio as gr
18
  from gradio_client import Client
19
  import datetime
20
+ import os
21
+
22
+
23
 
24
 
25
  whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
26
  whisper_jax = Client(whisper_jax_api)
27
 
28
+ def transcribe_audio(audio_path,
29
+ task='transcribe',
30
+ return_timestamps=True) -> str:
31
+
32
  text, runtime = whisper_jax.predict(
33
  audio_path,
34
  task,
 
39
 
40
 
41
 
 
42
 
43
+ def format_whisper_jax_output(whisper_jax_output: str,
44
+ max_duration: int = 60) -> list[dict]:
45
+
46
+ """Returns a list of dict with keys 'start', 'end', 'text'
47
  The segments from whisper jax output are merged to form paragraphs.
48
 
49
  `max_duration` controls how many seconds of the audio's transcripts are merged
50
 
51
  For example, if `max_duration`=60, in the final output, each segment is roughly
52
  60 seconds.
53
+ """
54
 
55
  final_output = []
56
  max_duration = datetime.timedelta(seconds=max_duration)
 
61
  for i, seg in enumerate(segments):
62
 
63
  text = seg.split(']')[-1].strip()
 
64
 
65
+ # Sometimes whisper jax returns None for timestamp
66
+ try:
67
+ end = datetime.datetime.strptime(seg[14:19], '%M:%S')
68
+ except ValueError:
69
+ end = current_start + max_duration
70
+
71
+ if (end-current_start >= max_duration) or (i == len(segments)-1):
72
+ # If we have exceeded max duration or at the last segment,
73
+ # stop merging and append to final_output.
74
+
75
  current_text += text
76
+ final_output.append({
77
+ 'start': current_start.strftime('%H:%M:%S'),
78
+ 'end': end.strftime('%H:%M:%S'),
79
+ 'text': current_text
80
+ })
81
 
82
  # Update current start and text
83
  current_start = end
84
  current_text = ''
85
 
86
  else:
87
+ # If we have not exceeded max duration, keep merging.
 
88
  current_text += text
89
 
90
  return final_output
 
92
 
93
 
94
 
95
+
96
  audio_file_number = 1
97
  def yt_audio_to_text(url: str,
98
  max_duration: int = 60
 
105
 
106
  with yt_dlp.YoutubeDL({'extract_audio': True,
107
  'format': 'bestaudio',
108
+ 'outtmpl': f'{audio_file_number}.mp3'
109
+ }) as video:
110
 
111
  info_dict = video.extract_info(url, download=False)
112
  global video_title
 
128
 
129
 
130
 
131
+
132
  def metadata_func(record: dict, metadata: dict) -> dict:
133
 
134
  metadata['start'] = record.get('start')
135
  metadata['end'] = record.get('end')
136
+ metadata['source'] = metadata['start'] + ' -> ' + metadata['end']
137
 
138
  return metadata
139
 
 
152
 
153
 
154
 
155
+
156
  embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
157
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
  embedding_model_kwargs = {'device': device}
 
161
  model_kwargs=embedding_model_kwargs)
162
 
163
  def create_vectordb(data, k: int):
164
+ """Returns a vector database, and its retriever
165
  `k` is the number of retrieved documents
166
+ """
167
 
168
  vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
169
  retriever = vectordb.as_retriever(search_type='similarity',
 
172
  return vectordb, retriever
173
 
174
 
175
+
176
+
177
  repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
178
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000})
179
+
180
 
181
 
182
  # Map
 
188
  map_chain = LLMChain(llm=llm, prompt=map_prompt)
189
 
190
 
191
+
192
  # Reduce
193
  reduce_template = """The following is a set of summaries:
194
  {docs}
195
 
196
+ Take these and distill it into a final, consolidated summary of the main themes \
197
+ in 150 words or less.
198
+
199
  Answer:"""
200
 
201
  reduce_prompt = PromptTemplate.from_template(reduce_template)
 
214
  # If documents exceed context for `StuffDocumentsChain`
215
  collapse_documents_chain=combine_documents_chain,
216
  # The maximum number of tokens to group documents into.
217
+ token_max=4000
218
  )
219
 
220
 
 
227
  # The variable name in the llm_chain to put the documents in
228
  document_variable_name="docs",
229
  # Return the results of the map steps in the output
230
+ return_intermediate_steps=False
231
  )
232
 
233
+ def get_summary(documents) -> str:
234
+ summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
235
+ return summary['output_text'].strip()
236
+
237
+
238
 
239
 
240
  contextualise_q_prompt = PromptTemplate.from_template(
241
+ """Given a chat history and the latest user question \
242
  which might reference the chat history, formulate a standalone question \
243
+ that can be understood without the chat history. Do NOT answer the question, \
244
  just reformulate it if needed and otherwise return it as is.
245
 
246
  Chat history: {chat_history}
 
248
  Question: {question}
249
 
250
  Answer:
251
+ """
252
  )
253
 
254
  contextualise_q_chain = contextualise_q_prompt | llm
255
 
256
+
257
+
258
  standalone_prompt = PromptTemplate.from_template(
259
+ """Given a chat history and the latest user question, \
260
  identify whether the question is a standalone question or the question \
261
  references the chat history. Answer 'yes' if the question is a standalone \
262
  question, and 'no' if the question references the chat history. Do not \
 
269
  {question}
270
 
271
  Answer:
272
+ """
273
  )
274
 
275
  def format_output(answer: str) -> str:
 
279
  standalone_chain = standalone_prompt | llm | format_output
280
 
281
 
282
+
283
+
284
  qa_prompt = PromptTemplate.from_template(
285
+ """You are an assistant for question-answering tasks. \
286
  ONLY use the following context to answer the question. \
287
  Do NOT answer with information that is not contained in \
288
  the context. If you don't know the answer, just say:\
 
295
  {question}
296
 
297
  Answer:
298
+ """
299
  )
300
 
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
 
 
 
 
303
 
304
+ class YouTubeChatbot:
305
+
306
+ def __init__(self,
307
+ n_sources: int,
308
+ k: int,
309
+ timestamp_interval: datetime.timedelta,
310
+ memory: int,
311
+ ):
312
+ self.n_sources = n_sources
313
+ self.k = k
314
+ self.timestamp_interval = timestamp_interval
315
+ self.chat_history = ConversationBufferWindowMemory(k=memory)
316
 
 
 
317
 
318
+ def format_docs(self, docs: list) -> str:
319
+ """Combine documents
320
+ """
321
 
322
+ self.sources = [doc.metadata['start'] for doc in docs]
323
 
324
+ return '\n\n'.join(doc.page_content for doc in docs)
325
 
 
 
 
326
 
 
327
 
328
+ def standalone_question(self, input_: dict) -> str:
329
+ """If the question is a not a standalone question,
330
+ run contextualise_q_chain.
331
+ """
332
+ if input_['standalone']=='yes':
333
+ return contextualise_q_chain
334
+ else:
335
+ return input_['question']
336
 
 
 
337
 
338
+ def format_answer(self, answer: str) -> str:
339
+
340
+ if 'cannot find the answer' in answer:
341
+ return answer.strip()
342
+ else:
343
+ timestamps = self.filter_timestamps()
344
+ answer_with_sources = (
345
+ answer.strip()
346
+ + ' You can find more information '\
347
+ 'at these timestamps: {}.'.format(', '.join(timestamps))
348
+ )
349
+ return answer_with_sources
350
+
351
+
352
+ def filter_timestamps(self) -> list[str]:
353
+ """Returns a list of timestamps with length `n_sources`.
354
+ The timestamps are at least an `timestamp_interval` apart.
355
+ This prevents returning a list of timestamps that are too
356
+ close together.
357
+ """
358
+
359
+ sorted_timestamps = sorted(self.sources)
360
+ filtered_timestamps = [sorted_timestamps[0]]
361
+ i=1
362
+ while len(filtered_timestamps) < self.n_sources:
363
+ timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1],
364
+ '%H:%M:%S')
365
 
366
+ try:
367
+ timestamp2 = datetime.datetime.strptime(sorted_timestamps[i],
368
+ '%H:%M:%S')
369
+ except IndexError:
370
+ break
371
 
372
+ time_diff = timestamp2 - timestamp1
373
 
374
+ if time_diff>=self.timestamp_interval:
375
+ filtered_timestamps.append(str(timestamp2.time()))
376
 
377
+ i += 1
378
 
379
+ return filtered_timestamps
380
 
 
 
 
381
 
382
+ def setup_chatbot(self, url: str) -> str:
383
+ """Given a YouTube url, set up the chatbot.
384
+ """
385
 
386
+ yt_audio_to_text(url)
387
+
388
+ self.data = load_data()
389
+
390
+ _, self.retriever = create_vectordb(self.data, self.k)
391
+
392
+
393
+ self.qa_chain = (
394
+ RunnablePassthrough.assign(standalone=standalone_chain)
395
+ | {'question':self.standalone_question,
396
+ 'context':self.standalone_question|self.retriever|self.format_docs}
397
+ | qa_prompt
398
+ | llm)
399
+
400
+ return url
401
+
402
+
403
+
404
+ def get_answer(self, question: str) -> str:
405
+
406
+ try:
407
+ ai_msg = self.qa_chain.invoke({'question': question,
408
+ 'chat_history': self.chat_history})
409
+ except AttributeError:
410
+ raise AttributeError("You haven't setup the chatbot yet. "
411
+ "Setup the chatbot by calling the "
412
+ "instance method `setup_chatbot`.")
413
+
414
+ answer = self.format_answer(ai_msg)
415
+
416
+ self.chat_history.save_context({'question':question},
417
+ {'answer':answer})
418
+
419
+ return answer
420
+
421
+
422
+
423
+
424
+
425
+
426
+
427
+ class YouTubeChatbotApp(YouTubeChatbot):
428
+
429
+ def __init__(self,
430
+ n_sources: int,
431
+ k: int,
432
+ timestamp_interval: datetime.timedelta,
433
+ memory: int,
434
+ default_youtube_url: str
435
+ ):
436
+ super().__init__(n_sources, k, timestamp_interval, memory)
437
+ self.default_youtube_url = default_youtube_url
438
+ self.gradio_chat_history = []
439
+
440
+
441
+ def greet(self) -> list[tuple[str|None, str|None]]:
442
+ summary = get_summary(self.data)
443
+ summary_message = f'Here is a summary of the video "{video_title}":'
444
+ self.gradio_chat_history.append((None, summary_message))
445
+ self.gradio_chat_history.append((None, summary))
446
+ greeting_message = ('You can ask me anything about the video.'
447
+ 'I will do my best to answer!')
448
+ self.gradio_chat_history.append((None, greeting_message))
449
+ return self.gradio_chat_history
450
+
451
+
452
+ def question(self, user_message: str) -> list[tuple[str|None, str|None]]:
453
+ self.gradio_chat_history.append((user_message, None))
454
+ return '', self.gradio_chat_history
455
+
456
+
457
+ def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]:
458
+ try:
459
+ ai_message = self.get_answer(self.gradio_chat_history[-1][0])
460
+ except AttributeError:
461
+ raise gr.Error('You need to process the video '
462
+ 'first by pressing the `Go` button.')
463
 
 
464
 
465
+ self.gradio_chat_history.append((None, ai_message))
466
+ return self.gradio_chat_history
467
 
468
 
469
+ def clear_chat_history(self) -> list:
470
+ self.chat_history.clear()
471
+ self.gradio_chat_history = []
472
+ return self.gradio_chat_history
 
473
 
474
 
475
+ def launch(self, **kwargs):
 
 
 
 
 
 
 
 
476
 
477
+ with gr.Blocks() as demo:
 
 
 
478
 
479
+ # Structure
480
+ with gr.Row():
481
+ url_input = gr.Textbox(value=self.default_youtube_url,
482
+ label='YouTube URL',
483
+ scale=5)
484
+ button = gr.Button(value='Go', scale=1)
485
 
486
+ chatbot = gr.Chatbot()
487
+ user_message = gr.Textbox(label='Ask a question:')
488
+ clear = gr.ClearButton([user_message, chatbot])
 
 
489
 
490
 
491
+ # Actions
492
+ button.click(self.clear_chat_history,
493
+ inputs=[],
494
+ outputs=[chatbot],
495
+ trigger_mode='once'
496
+ ).then(self.setup_chatbot,
497
+ inputs=[url_input],
498
+ outputs=[url_input]
499
+ ).then(self.greet,
500
+ inputs=[],
501
+ outputs=[chatbot])
502
 
503
+ user_message.submit(self.question,
504
+ inputs=[user_message],
505
+ outputs=[user_message, chatbot]
506
+ ).then(self.respond,
507
+ inputs=[],
508
+ outputs=[chatbot])
509
 
510
+ clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot])
 
 
 
 
 
511
 
 
 
 
512
 
513
+ demo.launch(**kwargs)
514
 
 
 
 
 
 
 
 
515
 
 
 
 
 
 
516
 
517
+ app = YouTubeChatbotApp(n_sources=3,
518
+ k=5,
519
+ timestamp_interval=datetime.timedelta(minutes=2),
520
+ memory=5,
521
+ default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8'
522
+ )
523
 
524
+ app.launch(debug=True)