ww0 commited on
Commit
059dde7
1 Parent(s): 7fea2c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -225
app.py CHANGED
@@ -7,43 +7,47 @@ from langchain_community.vectorstores import Chroma
7
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
  from langchain_core.messages import AIMessage, HumanMessage
10
- from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
11
  from langchain.memory.buffer_window import ConversationBufferWindowMemory
12
  from langchain_community.llms import HuggingFaceHub
 
 
 
 
 
 
13
 
 
 
14
  import yt_dlp
15
  import json
16
  import gc
17
- import gradio as gr
18
- from gradio_client import Client
19
  import datetime
20
  import os
 
21
 
22
 
 
23
 
24
-
25
- whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
26
  whisper_jax = Client(whisper_jax_api)
27
 
28
- def transcribe_audio(audio_path,
29
- task='transcribe',
30
- return_timestamps=True) -> str:
31
 
 
32
  text, runtime = whisper_jax.predict(
33
  audio_path,
34
  task,
35
  return_timestamps,
36
- api_name='/predict_1',
37
  )
38
  return text
39
 
 
 
 
 
 
40
 
41
-
42
-
43
- def format_whisper_jax_output(whisper_jax_output: str,
44
- max_duration: int = 60) -> list[dict]:
45
-
46
- """Returns a list of dict with keys 'start', 'end', 'text'
47
  The segments from whisper jax output are merged to form paragraphs.
48
 
49
  `max_duration` controls how many seconds of the audio's transcripts are merged
@@ -54,68 +58,65 @@ def format_whisper_jax_output(whisper_jax_output: str,
54
 
55
  final_output = []
56
  max_duration = datetime.timedelta(seconds=max_duration)
57
- segments = whisper_jax_output.split('\n')
58
- current_start = datetime.datetime.strptime('00:00', '%M:%S')
59
- current_text = ''
60
 
61
  for i, seg in enumerate(segments):
62
-
63
- text = seg.split(']')[-1].strip()
64
 
65
  # Sometimes whisper jax returns None for timestamp
66
  try:
67
- end = datetime.datetime.strptime(seg[14:19], '%M:%S')
68
  except ValueError:
69
  end = current_start + max_duration
70
 
71
- if (end-current_start >= max_duration) or (i == len(segments)-1):
72
- # If we have exceeded max duration or at the last segment,
73
- # stop merging and append to final_output.
74
-
75
- current_text += text
76
- final_output.append({
77
- 'start': current_start.strftime('%H:%M:%S'),
78
- 'end': end.strftime('%H:%M:%S'),
79
- 'text': current_text
80
- })
81
-
82
- # Update current start and text
83
- current_start = end
84
- current_text = ''
85
 
86
  else:
87
- # If we have not exceeded max duration, keep merging.
88
- current_text += text
89
-
90
- return final_output
91
-
92
-
 
 
 
 
93
 
 
 
 
94
 
 
95
 
96
- audio_file_number = 1
97
- def yt_audio_to_text(url: str,
98
- max_duration: int = 60
99
- ):
100
 
101
- global audio_file_number
102
-
103
  progress = gr.Progress()
104
  progress(0.1)
105
 
106
- with yt_dlp.YoutubeDL({'extract_audio': True,
107
- 'format': 'bestaudio',
108
- 'outtmpl': f'{audio_file_number}.mp3'
109
- }) as video:
110
-
111
  info_dict = video.extract_info(url, download=False)
112
  global video_title
113
- video_title = info_dict['title']
114
  video.download(url)
115
 
116
  progress(0.4)
117
- audio_file = f'{audio_file_number}.mp3'
118
- audio_file_number += 1
119
 
120
  result = transcribe_audio(audio_file, return_timestamps=True)
121
  progress(0.7)
@@ -123,61 +124,83 @@ def yt_audio_to_text(url: str,
123
  result = format_whisper_jax_output(result, max_duration=max_duration)
124
  progress(0.9)
125
 
126
- with open('audio.json', 'w') as f:
127
  json.dump(result, f)
128
 
 
129
 
130
 
131
 
132
- def metadata_func(record: dict, metadata: dict) -> dict:
133
 
134
- metadata['start'] = record.get('start')
135
- metadata['end'] = record.get('end')
136
- metadata['source'] = metadata['start'] + ' -> ' + metadata['end']
 
 
 
 
 
 
137
 
138
  return metadata
139
 
140
 
141
  def load_data():
142
  loader = JSONLoader(
143
- file_path='audio.json',
144
- jq_schema='.[]',
145
- content_key='text',
146
- metadata_func=metadata_func
147
  )
148
 
149
  data = loader.load()
 
150
 
151
  return data
152
 
153
 
154
 
155
 
156
- embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
157
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
- embedding_model_kwargs = {'device': device}
159
 
160
- embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
161
- model_kwargs=embedding_model_kwargs)
 
162
 
163
- def create_vectordb(data, k: int):
164
- """Returns a vector database, and its retriever
165
- `k` is the number of retrieved documents
 
 
 
 
 
 
 
166
  """
167
 
168
- vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
169
- retriever = vectordb.as_retriever(search_type='similarity',
170
- search_kwargs={'k': k})
 
 
 
 
 
 
 
 
171
 
172
- return vectordb, retriever
173
 
 
174
 
 
 
175
 
176
 
177
- repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
178
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000})
179
 
180
 
 
181
 
182
  # Map
183
  map_template = """Summarise the following text:
@@ -187,8 +210,6 @@ Answer:"""
187
  map_prompt = PromptTemplate.from_template(map_template)
188
  map_chain = LLMChain(llm=llm, prompt=map_prompt)
189
 
190
-
191
-
192
  # Reduce
193
  reduce_template = """The following is a set of summaries:
194
  {docs}
@@ -214,7 +235,7 @@ reduce_documents_chain = ReduceDocumentsChain(
214
  # If documents exceed context for `StuffDocumentsChain`
215
  collapse_documents_chain=combine_documents_chain,
216
  # The maximum number of tokens to group documents into.
217
- token_max=4000
218
  )
219
 
220
 
@@ -227,21 +248,24 @@ map_reduce_chain = MapReduceDocumentsChain(
227
  # The variable name in the llm_chain to put the documents in
228
  document_variable_name="docs",
229
  # Return the results of the map steps in the output
230
- return_intermediate_steps=False
231
  )
232
 
 
233
  def get_summary(documents) -> str:
234
  summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
235
- return summary['output_text'].strip()
236
 
237
 
238
 
 
239
 
240
  contextualise_q_prompt = PromptTemplate.from_template(
241
  """Given a chat history and the latest user question \
242
- which might reference the chat history, formulate a standalone question \
243
- that can be understood without the chat history. Do NOT answer the question, \
244
- just reformulate it if needed and otherwise return it as is.
 
245
 
246
  Chat history: {chat_history}
247
 
@@ -255,12 +279,15 @@ contextualise_q_chain = contextualise_q_prompt | llm
255
 
256
 
257
 
 
 
258
  standalone_prompt = PromptTemplate.from_template(
259
  """Given a chat history and the latest user question, \
260
- identify whether the question is a standalone question or the question \
261
- references the chat history. Answer 'yes' if the question is a standalone \
262
- question, and 'no' if the question references the chat history. Do not \
263
- answer anything other than 'yes' or 'no'.
 
264
 
265
  Chat history:
266
  {chat_history}
@@ -272,14 +299,19 @@ standalone_prompt = PromptTemplate.from_template(
272
  """
273
  )
274
 
 
275
  def format_output(answer: str) -> str:
276
- # All lower case and remove all whitespace
277
- return ''.join(answer.lower().split())
 
 
 
278
 
279
  standalone_chain = standalone_prompt | llm | format_output
280
 
281
 
282
 
 
283
 
284
  qa_prompt = PromptTemplate.from_template(
285
  """You are an assistant for question-answering tasks. \
@@ -298,226 +330,326 @@ qa_prompt = PromptTemplate.from_template(
298
  """
299
  )
300
 
301
-
302
-
303
-
304
  class YouTubeChatbot:
305
-
306
- def __init__(self,
307
- n_sources: int,
308
- k: int,
309
- timestamp_interval: datetime.timedelta,
310
- memory: int,
311
- ):
 
 
 
 
312
  self.n_sources = n_sources
313
- self.k = k
314
  self.timestamp_interval = timestamp_interval
315
  self.chat_history = ConversationBufferWindowMemory(k=memory)
 
 
316
 
317
 
318
  def format_docs(self, docs: list) -> str:
319
- """Combine documents
 
320
  """
 
321
 
322
- self.sources = [doc.metadata['start'] for doc in docs]
323
-
324
- return '\n\n'.join(doc.page_content for doc in docs)
325
-
326
 
327
 
328
  def standalone_question(self, input_: dict) -> str:
329
- """If the question is a not a standalone question,
330
  run contextualise_q_chain.
331
  """
332
- if input_['standalone']=='yes':
333
  return contextualise_q_chain
334
  else:
335
- return input_['question']
336
 
337
 
338
  def format_answer(self, answer: str) -> str:
339
-
340
- if 'cannot find the answer' in answer:
 
341
  return answer.strip()
342
  else:
343
  timestamps = self.filter_timestamps()
344
  answer_with_sources = (
345
- answer.strip()
346
- + ' You can find more information '\
347
- 'at these timestamps: {}.'.format(', '.join(timestamps))
348
- )
349
  return answer_with_sources
350
 
351
 
352
  def filter_timestamps(self) -> list[str]:
353
- """Returns a list of timestamps with length `n_sources`.
354
- The timestamps are at least an `timestamp_interval` apart.
355
- This prevents returning a list of timestamps that are too
356
- close together.
357
  """
358
 
359
- sorted_timestamps = sorted(self.sources)
360
- filtered_timestamps = [sorted_timestamps[0]]
361
- i=1
362
- while len(filtered_timestamps) < self.n_sources:
363
- timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1],
364
- '%H:%M:%S')
365
 
 
 
 
366
  try:
367
- timestamp2 = datetime.datetime.strptime(sorted_timestamps[i],
368
- '%H:%M:%S')
369
  except IndexError:
370
  break
371
 
372
- time_diff = timestamp2 - timestamp1
373
 
374
- if time_diff>=self.timestamp_interval:
375
- filtered_timestamps.append(str(timestamp2.time()))
376
 
377
  i += 1
378
 
 
 
 
 
 
379
  return filtered_timestamps
380
 
381
 
382
- def setup_chatbot(self, url: str) -> str:
383
- """Given a YouTube url, set up the chatbot.
 
384
  """
385
-
386
  yt_audio_to_text(url)
 
387
 
388
- self.data = load_data()
 
 
 
389
 
390
- _, self.retriever = create_vectordb(self.data, self.k)
 
 
 
391
 
 
392
 
393
- self.qa_chain = (
 
 
394
  RunnablePassthrough.assign(standalone=standalone_chain)
395
- | {'question':self.standalone_question,
396
- 'context':self.standalone_question|self.retriever|self.format_docs}
 
 
397
  | qa_prompt
398
- | llm)
 
399
 
400
- return url
401
 
402
 
 
 
 
403
 
404
- def get_answer(self, question: str) -> str:
405
 
 
406
  try:
407
- ai_msg = self.qa_chain.invoke({'question': question,
408
- 'chat_history': self.chat_history})
 
409
  except AttributeError:
410
- raise AttributeError("You haven't setup the chatbot yet. "
411
- "Setup the chatbot by calling the "
412
- "instance method `setup_chatbot`.")
 
 
413
 
414
- answer = self.format_answer(ai_msg)
415
 
416
- self.chat_history.save_context({'question':question},
417
- {'answer':answer})
418
 
419
  return answer
420
-
421
 
422
 
423
 
 
424
 
425
  class YouTubeChatbotApp(YouTubeChatbot):
 
 
 
 
 
 
 
 
 
426
 
427
- def __init__(self,
428
- n_sources: int,
429
- k: int,
430
- timestamp_interval: datetime.timedelta,
431
- memory: int,
432
- default_youtube_url: str
433
- ):
434
- super().__init__(n_sources, k, timestamp_interval, memory)
435
  self.default_youtube_url = default_youtube_url
436
- self.gradio_chat_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
 
 
 
438
 
439
- def greet(self) -> list[tuple[str|None, str|None]]:
440
- summary = get_summary(self.data)
441
- summary_message = f'Here is a summary of the video "{video_title}":'
442
- self.gradio_chat_history.append((None, summary_message))
443
- self.gradio_chat_history.append((None, summary))
444
- greeting_message = ('You can ask me anything about the video. '
445
- 'I will do my best to answer!')
446
- self.gradio_chat_history.append((None, greeting_message))
447
- return self.gradio_chat_history
448
 
 
449
 
450
- def question(self, user_message: str) -> list[tuple[str|None, str|None]]:
451
- self.gradio_chat_history.append((user_message, None))
452
- return '', self.gradio_chat_history
 
 
 
 
 
 
 
453
 
 
 
 
454
 
455
- def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]:
456
  try:
457
- ai_message = self.get_answer(self.gradio_chat_history[-1][0])
 
 
458
  except AttributeError:
459
- raise gr.Error('You need to process the video '
460
- 'first by pressing the `Go` button.')
 
461
 
 
462
 
463
- self.gradio_chat_history.append((None, ai_message))
464
- return self.gradio_chat_history
465
 
 
466
 
467
- def clear_chat_history(self) -> list:
468
- self.chat_history.clear()
469
- self.gradio_chat_history = []
470
- return self.gradio_chat_history
 
 
471
 
 
 
 
 
 
 
 
 
472
 
473
  def launch(self, **kwargs):
474
-
475
  with gr.Blocks() as demo:
 
 
 
 
 
476
 
477
- # Structure
478
  with gr.Row():
479
- url_input = gr.Textbox(value=self.default_youtube_url,
480
- label='YouTube URL',
481
- scale=5)
482
- button = gr.Button(value='Go', scale=1)
483
-
484
- chatbot = gr.Chatbot()
485
- user_message = gr.Textbox(label='Ask a question:')
486
- clear = gr.ClearButton([user_message, chatbot])
487
-
488
-
489
- # Actions
490
- button.click(self.clear_chat_history,
491
- inputs=[],
492
- outputs=[chatbot],
493
- trigger_mode='once'
494
- ).then(self.setup_chatbot,
495
- inputs=[url_input],
496
- outputs=[url_input]
497
- ).then(self.greet,
498
- inputs=[],
499
- outputs=[chatbot])
500
-
501
- user_message.submit(self.question,
502
- inputs=[user_message],
503
- outputs=[user_message, chatbot]
504
- ).then(self.respond,
505
- inputs=[],
506
- outputs=[chatbot])
507
-
508
- clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot])
509
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
511
  demo.launch(**kwargs)
512
 
513
 
514
-
515
  if __name__ == "__main__":
516
- app = YouTubeChatbotApp(n_sources=3,
517
- k=5,
518
- timestamp_interval=datetime.timedelta(minutes=2),
519
- memory=5,
520
- default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8'
521
- )
522
-
523
- app.launch()
 
 
 
7
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
  from langchain_core.messages import AIMessage, HumanMessage
 
10
  from langchain.memory.buffer_window import ConversationBufferWindowMemory
11
  from langchain_community.llms import HuggingFaceHub
12
+ from langchain.chains import (
13
+ LLMChain,
14
+ StuffDocumentsChain,
15
+ MapReduceDocumentsChain,
16
+ ReduceDocumentsChain,
17
+ )
18
 
19
+ from gradio_client import Client
20
+ import gradio as gr
21
  import yt_dlp
22
  import json
23
  import gc
 
 
24
  import datetime
25
  import os
26
+ import numpy as np
27
 
28
 
29
+ """Prepare data"""
30
 
31
+ whisper_jax_api = "https://sanchit-gandhi-whisper-jax.hf.space/"
 
32
  whisper_jax = Client(whisper_jax_api)
33
 
 
 
 
34
 
35
+ def transcribe_audio(audio_path, task="transcribe", return_timestamps=True) -> str:
36
  text, runtime = whisper_jax.predict(
37
  audio_path,
38
  task,
39
  return_timestamps,
40
+ api_name="/predict_1",
41
  )
42
  return text
43
 
44
+ def format_whisper_jax_output(
45
+ whisper_jax_output: str, max_duration: int = 60
46
+ ) -> list[dict]:
47
+ """Whisper JAX outputs are in the format
48
+ '[00:00.000 -> 00:00.000] text\n[00:00.000 -> 00:00.000] text'.
49
 
50
+ Returns a list of dict with keys 'start', 'end', 'text'
 
 
 
 
 
51
  The segments from whisper jax output are merged to form paragraphs.
52
 
53
  `max_duration` controls how many seconds of the audio's transcripts are merged
 
58
 
59
  final_output = []
60
  max_duration = datetime.timedelta(seconds=max_duration)
61
+ segments = whisper_jax_output.split("\n")
62
+ current_start = datetime.datetime.strptime("00:00", "%M:%S")
63
+ current_text = ""
64
 
65
  for i, seg in enumerate(segments):
66
+ text = seg.split("]")[-1].strip()
67
+ current_text += " " + text
68
 
69
  # Sometimes whisper jax returns None for timestamp
70
  try:
71
+ end = datetime.datetime.strptime(seg[14:19], "%M:%S")
72
  except ValueError:
73
  end = current_start + max_duration
74
 
75
+ if i == len(segments) - 1:
76
+ final_output.append(
77
+ {
78
+ "start": current_start.strftime("%H:%M:%S"),
79
+ "end": end.strftime("%H:%M:%S"),
80
+ "text": current_text.strip(),
81
+ }
82
+ )
 
 
 
 
 
 
83
 
84
  else:
85
+ if end - current_start >= max_duration and current_text[-1] == ".":
86
+ # If we have exceeded max duration, check whether we have
87
+ # reached the end of a sentence. If not, keep merging.
88
+ final_output.append(
89
+ {
90
+ "start": current_start.strftime("%H:%M:%S"),
91
+ "end": end.strftime("%H:%M:%S"),
92
+ "text": current_text.strip(),
93
+ }
94
+ )
95
 
96
+ # Update current start and text
97
+ current_start = end
98
+ current_text = ""
99
 
100
+ return final_output
101
 
102
+ def yt_audio_to_text(url: str, max_duration: int = 60):
103
+ """Given a YouTube url, download audio and transcribe it to text. Reformat
104
+ the output from Whisper JAX and save the final result in a json file.
105
+ """
106
 
 
 
107
  progress = gr.Progress()
108
  progress(0.1)
109
 
110
+ with yt_dlp.YoutubeDL(
111
+ {"extract_audio": True, "format": "bestaudio", "outtmpl": "audio.mp3"}
112
+ ) as video:
 
 
113
  info_dict = video.extract_info(url, download=False)
114
  global video_title
115
+ video_title = info_dict["title"]
116
  video.download(url)
117
 
118
  progress(0.4)
119
+ audio_file = "audio.mp3"
 
120
 
121
  result = transcribe_audio(audio_file, return_timestamps=True)
122
  progress(0.7)
 
124
  result = format_whisper_jax_output(result, max_duration=max_duration)
125
  progress(0.9)
126
 
127
+ with open("audio.json", "w") as f:
128
  json.dump(result, f)
129
 
130
+ os.remove(audio_file)
131
 
132
 
133
 
 
134
 
135
+ """Load data"""
136
+
137
+ def metadata_func(record: dict, metadata: dict) -> dict:
138
+ """This function is used to tell the Langchain loader the keys that
139
+ contain metadata and extract them.
140
+ """
141
+ metadata["start"] = record.get("start")
142
+ metadata["end"] = record.get("end")
143
+ metadata["source"] = metadata["start"] + " -> " + metadata["end"]
144
 
145
  return metadata
146
 
147
 
148
  def load_data():
149
  loader = JSONLoader(
150
+ file_path="audio.json",
151
+ jq_schema=".[]",
152
+ content_key="text",
153
+ metadata_func=metadata_func,
154
  )
155
 
156
  data = loader.load()
157
+ os.remove("audio.json")
158
 
159
  return data
160
 
161
 
162
 
163
 
164
+ """Create embeddings and vector store"""
 
 
165
 
166
+ embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
167
+ device = "cuda" if torch.cuda.is_available() else "cpu"
168
+ embedding_model_kwargs = {"device": device}
169
 
170
+ embeddings = HuggingFaceEmbeddings(
171
+ model_name=embedding_model_name, model_kwargs=embedding_model_kwargs
172
+ )
173
+
174
+
175
+ def create_vectordb(data, n_retrieved_docs: int, collection_name="YouTube"):
176
+ """Returns a retriever which is used to fetch relevant documents from
177
+ the vector database.
178
+
179
+ `n_retrieved_docs` is the number of retrieved documents.
180
  """
181
 
182
+ vectordb = Chroma.from_documents(
183
+ documents=data, embedding=embeddings, collection_name=collection_name
184
+ )
185
+ n_docs = len(vectordb.get()["ids"])
186
+ retriever = vectordb.as_retriever(
187
+ search_type="mmr", search_kwargs={"k": n_retrieved_docs, "fetch_k": n_docs}
188
+ )
189
+
190
+ return retriever
191
+
192
+
193
 
 
194
 
195
+ """Load LLM"""
196
 
197
+ repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
198
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"max_new_tokens": 1000})
199
 
200
 
 
 
201
 
202
 
203
+ """Summarisation"""
204
 
205
  # Map
206
  map_template = """Summarise the following text:
 
210
  map_prompt = PromptTemplate.from_template(map_template)
211
  map_chain = LLMChain(llm=llm, prompt=map_prompt)
212
 
 
 
213
  # Reduce
214
  reduce_template = """The following is a set of summaries:
215
  {docs}
 
235
  # If documents exceed context for `StuffDocumentsChain`
236
  collapse_documents_chain=combine_documents_chain,
237
  # The maximum number of tokens to group documents into.
238
+ token_max=4000,
239
  )
240
 
241
 
 
248
  # The variable name in the llm_chain to put the documents in
249
  document_variable_name="docs",
250
  # Return the results of the map steps in the output
251
+ return_intermediate_steps=False,
252
  )
253
 
254
+
255
  def get_summary(documents) -> str:
256
  summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
257
+ return summary["output_text"].strip()
258
 
259
 
260
 
261
+ """Contextualising the question"""
262
 
263
  contextualise_q_prompt = PromptTemplate.from_template(
264
  """Given a chat history and the latest user question \
265
+ which might reference the chat history, formulate a \
266
+ standalone question that can be understood without \
267
+ the chat history. Do NOT answer the question, just \
268
+ reformulate it if needed and otherwise return it as is.
269
 
270
  Chat history: {chat_history}
271
 
 
279
 
280
 
281
 
282
+ """Standalone question chain"""
283
+
284
  standalone_prompt = PromptTemplate.from_template(
285
  """Given a chat history and the latest user question, \
286
+ identify whether the question is a standalone question \
287
+ or the question references the chat history. Answer 'yes' \
288
+ if the question is a standalone question, and 'no' if the \
289
+ question references the chat history. Do not answer \
290
+ anything other than 'yes' or 'no'.
291
 
292
  Chat history:
293
  {chat_history}
 
299
  """
300
  )
301
 
302
+
303
  def format_output(answer: str) -> str:
304
+ """All lower case and remove all whitespace to ensure
305
+ that the answer given by the LLM is either 'yes' or 'no'.
306
+ """
307
+ return "".join(answer.lower().split())
308
+
309
 
310
  standalone_chain = standalone_prompt | llm | format_output
311
 
312
 
313
 
314
+ """Q&A chain"""
315
 
316
  qa_prompt = PromptTemplate.from_template(
317
  """You are an assistant for question-answering tasks. \
 
330
  """
331
  )
332
 
 
 
 
333
  class YouTubeChatbot:
334
+ instance_count = 0
335
+
336
+ def __init__(
337
+ self,
338
+ n_sources: int = 3,
339
+ n_retrieved_docs: int = 5,
340
+ timestamp_interval: datetime.timedelta = datetime.timedelta(minutes=2),
341
+ memory: int = 5,
342
+ ):
343
+ YouTubeChatbot.instance_count += 1
344
+ self.chatbot_id = YouTubeChatbot.instance_count
345
  self.n_sources = n_sources
346
+ self.n_retrieved_docs = n_retrieved_docs
347
  self.timestamp_interval = timestamp_interval
348
  self.chat_history = ConversationBufferWindowMemory(k=memory)
349
+ self.retriever = None
350
+ self.qa_chain = None
351
 
352
 
353
  def format_docs(self, docs: list) -> str:
354
+ """Combine documents into a single string which will be included
355
+ in the prompt given to the LLM.
356
  """
357
+ self.sources = [doc.metadata["start"] for doc in docs]
358
 
359
+ return "\n\n".join(doc.page_content for doc in docs)
 
 
 
360
 
361
 
362
  def standalone_question(self, input_: dict) -> str:
363
+ """If the question is a not a standalone question,
364
  run contextualise_q_chain.
365
  """
366
+ if input_["standalone"] == "yes":
367
  return contextualise_q_chain
368
  else:
369
+ return input_["question"]
370
 
371
 
372
  def format_answer(self, answer: str) -> str:
373
+ """Add timestamps to answers.
374
+ """
375
+ if "cannot find the answer" in answer:
376
  return answer.strip()
377
  else:
378
  timestamps = self.filter_timestamps()
379
  answer_with_sources = (
380
+ answer.strip() + " You can find more information "
381
+ "at these timestamps: {}.".format(", ".join(timestamps))
382
+ )
 
383
  return answer_with_sources
384
 
385
 
386
  def filter_timestamps(self) -> list[str]:
387
+ """Returns a list of timestamps with length less or
388
+ equal to `n_sources`. The timestamps are at least an
389
+ `timestamp_interval` apart. This prevents returning
390
+ a list of timestamps that are too close together.
391
  """
392
 
393
+ filtered_timestamps = np.array(
394
+ [datetime.datetime.strptime(self.sources[0], "%H:%M:%S")]
395
+ )
 
 
 
396
 
397
+ i = 1
398
+
399
+ while len(filtered_timestamps) < self.n_sources:
400
  try:
401
+ new_timestamp = datetime.datetime.strptime(self.sources[i], "%H:%M:%S")
 
402
  except IndexError:
403
  break
404
 
405
+ absolute_time_difference = abs(new_timestamp - filtered_timestamps)
406
 
407
+ if all(absolute_time_difference >= self.timestamp_interval):
408
+ filtered_timestamps = np.append(filtered_timestamps, new_timestamp)
409
 
410
  i += 1
411
 
412
+ filtered_timestamps = [
413
+ timestamp.strftime("%H:%M:%S") for timestamp in filtered_timestamps
414
+ ]
415
+ filtered_timestamps.sort()
416
+
417
  return filtered_timestamps
418
 
419
 
420
+ def process_video(self, url: str, data=None, retriever=None):
421
+ """Given a YouTube URL, transcribe YouTube audio to text.
422
+ Then set up the vector database.
423
  """
 
424
  yt_audio_to_text(url)
425
+ data = load_data()
426
 
427
+ if retriever is not None:
428
+ # If we already have documents in the vector store, delete them.
429
+ ids = retriever.vectorstore.get()["ids"]
430
+ retriever.vectorstore.delete(ids)
431
 
432
+ retriever = create_vectordb(
433
+ data, self.n_retrieved_docs,
434
+ collection_name=f"Chatbot{self.chatbot_id}"
435
+ )
436
 
437
+ return url, data, retriever
438
 
439
+
440
+ def setup_qa_chain(self, retriever, qa_chain=None):
441
+ qa_chain = (
442
  RunnablePassthrough.assign(standalone=standalone_chain)
443
+ | {
444
+ "question": self.standalone_question,
445
+ "context": self.standalone_question | retriever | self.format_docs,
446
+ }
447
  | qa_prompt
448
+ | llm
449
+ )
450
 
451
+ return retriever, qa_chain
452
 
453
 
454
+ def setup_chatbot(self, url: str):
455
+ _, _, self.retriever = self.process_video(url=url, retriever=self.retriever)
456
+ _, self.qa_chain = self.setup_qa_chain(retriever=self.retriever)
457
 
 
458
 
459
+ def get_answer(self, question: str) -> str:
460
  try:
461
+ ai_msg = self.qa_chain.invoke(
462
+ {"question": question, "chat_history": self.chat_history}
463
+ )
464
  except AttributeError:
465
+ raise AttributeError(
466
+ "You haven't setup the chatbot yet. "
467
+ "Setup the chatbot by calling the "
468
+ "instance method `setup_chatbot`."
469
+ )
470
 
471
+ self.chat_history.save_context({"question": question}, {"answer": ai_msg})
472
 
473
+ answer = self.format_answer(ai_msg)
 
474
 
475
  return answer
 
476
 
477
 
478
 
479
+ """Web app"""
480
 
481
  class YouTubeChatbotApp(YouTubeChatbot):
482
+ def __init__(
483
+ self,
484
+ n_sources: int,
485
+ n_retrieved_docs: int,
486
+ timestamp_interval: datetime.timedelta,
487
+ memory: int,
488
+ default_youtube_url: str,
489
+ ):
490
+ super().__init__(n_sources, n_retrieved_docs, timestamp_interval, memory)
491
 
 
 
 
 
 
 
 
 
492
  self.default_youtube_url = default_youtube_url
493
+ self.memory = memory
494
+ self.chat_history = None
495
+ self.data = None
496
+ self.retriever = None
497
+ self.qa_chain = None
498
+
499
+ # Gradio components
500
+ self.url_input = None
501
+ self.url_button = None
502
+ self.app_chat_history = None
503
+ self.chatbot = None
504
+ self.user_input = None
505
+ self.clear_button = None
506
+
507
+ def greet(self, data, app_chat_history) -> dict:
508
+ """Summarise the video and greet the user.
509
+ """
510
+ summary_message = f'Here is a summary of the video "{video_title}":'
511
+ app_chat_history.append((None, summary_message))
512
 
513
+ summary = get_summary(data)
514
+ self.data = gr.State(None)
515
+ app_chat_history.append((None, summary))
516
 
517
+ greeting_message = (
518
+ "You can ask me anything about the video. " "I will do my best to answer!"
519
+ )
520
+ app_chat_history.append((None, greeting_message))
 
 
 
 
 
521
 
522
+ return {self.app_chat_history: app_chat_history, self.chatbot: app_chat_history}
523
 
524
+ def question(self, user_question: str, app_chat_history) -> dict:
525
+ """Display the question asked by the user in the chat window,
526
+ and delete from the input textbox.
527
+ """
528
+ app_chat_history.append((user_question, None))
529
+ return {
530
+ self.user_input: "",
531
+ self.app_chat_history: app_chat_history,
532
+ self.chatbot: app_chat_history,
533
+ }
534
 
535
+ def respond(self, qa_chain, chat_history, app_chat_history) -> dict:
536
+ """Respond to user's latest question"""
537
+ question = app_chat_history[-1][0]
538
 
 
539
  try:
540
+ ai_msg = qa_chain.invoke(
541
+ {"question": question, "chat_history": chat_history}
542
+ )
543
  except AttributeError:
544
+ raise gr.Error(
545
+ "You need to process the video " "first by pressing the `Go` button."
546
+ )
547
 
548
+ chat_history.save_context({"question": question}, {"answer": ai_msg})
549
 
550
+ answer = self.format_answer(ai_msg)
 
551
 
552
+ app_chat_history.append((None, answer))
553
 
554
+ return {
555
+ self.qa_chain: qa_chain,
556
+ self.chat_history: chat_history,
557
+ self.app_chat_history: app_chat_history,
558
+ self.chatbot: app_chat_history,
559
+ }
560
 
561
+ def clear_chat_history(self, chat_history, app_chat_history):
562
+ chat_history.clear()
563
+ app_chat_history = []
564
+ return {
565
+ self.chat_history: chat_history,
566
+ self.app_chat_history: app_chat_history,
567
+ self.chatbot: app_chat_history,
568
+ }
569
 
570
  def launch(self, **kwargs):
 
571
  with gr.Blocks() as demo:
572
+ self.chat_history = gr.State(ConversationBufferWindowMemory(k=self.memory))
573
+ self.app_chat_history = gr.State([])
574
+ self.data = gr.State()
575
+ self.retriever = gr.State()
576
+ self.qa_chain = gr.State()
577
 
578
+ # App structure
579
  with gr.Row():
580
+ self.url_input = gr.Textbox(
581
+ value=self.default_youtube_url, label="YouTube URL", scale=5
582
+ )
583
+ self.url_button = gr.Button(value="Go", scale=1)
584
+
585
+ self.chatbot = gr.Chatbot()
586
+ self.user_input = gr.Textbox(label="Ask a question:")
587
+ self.clear_button = gr.Button(value="Clear")
588
+
589
+
590
+ # App actions
591
+
592
+ # When a new url is given, clear past chat history and process
593
+ # the new video. Set up the Q&A chain with the new video's data.
594
+ # Provide a summary of the new video.
595
+ self.url_button.click(
596
+ self.clear_chat_history,
597
+ inputs=[self.chat_history, self.app_chat_history],
598
+ outputs=[self.chat_history, self.app_chat_history, self.chatbot],
599
+ trigger_mode="once",
600
+ ).then(
601
+ self.process_video,
602
+ inputs=[self.url_input, self.data, self.retriever],
603
+ outputs=[self.url_input, self.data, self.retriever],
604
+ ).then(
605
+ self.setup_qa_chain,
606
+ inputs=[self.retriever, self.qa_chain],
607
+ outputs=[self.retriever, self.qa_chain],
608
+ ).then(
609
+ self.greet,
610
+ inputs=[self.data, self.app_chat_history],
611
+ outputs=[self.app_chat_history, self.chatbot],
612
+ )
613
+
614
+ # When a user asks a question, display the question in the chat
615
+ # window and remove it from the text input area. Then respond
616
+ # with the Q&A chain.
617
+ self.user_input.submit(
618
+ self.question,
619
+ inputs=[self.user_input, self.app_chat_history],
620
+ outputs=[self.user_input, self.app_chat_history, self.chatbot],
621
+ queue=False,
622
+ ).then(
623
+ self.respond,
624
+ inputs=[self.qa_chain, self.chat_history, self.app_chat_history],
625
+ outputs=[
626
+ self.qa_chain,
627
+ self.chat_history,
628
+ self.app_chat_history,
629
+ self.chatbot,
630
+ ],
631
+ )
632
+
633
+ # When the `Clear` button is clicked, clear the chat history from
634
+ # the chat window.
635
+ self.clear_button.click(
636
+ self.clear_chat_history,
637
+ inputs=[self.chat_history, self.app_chat_history],
638
+ outputs=[self.chat_history, self.app_chat_history, self.chatbot],
639
+ queue=False,
640
+ )
641
 
642
  demo.launch(**kwargs)
643
 
644
 
 
645
  if __name__ == "__main__":
646
+ app = YouTubeChatbotApp(
647
+ n_sources=3,
648
+ n_retrieved_docs=5,
649
+ timestamp_interval=datetime.timedelta(minutes=2),
650
+ memory=5,
651
+ default_youtube_url="https://www.youtube.com/watch?v=4Bdc55j80l8",
652
+ )
653
+
654
+ app.launch()
655
+