Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,43 +7,47 @@ from langchain_community.vectorstores import Chroma
|
|
7 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
8 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
9 |
from langchain_core.messages import AIMessage, HumanMessage
|
10 |
-
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
|
11 |
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
12 |
from langchain_community.llms import HuggingFaceHub
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
|
|
|
|
14 |
import yt_dlp
|
15 |
import json
|
16 |
import gc
|
17 |
-
import gradio as gr
|
18 |
-
from gradio_client import Client
|
19 |
import datetime
|
20 |
import os
|
|
|
21 |
|
22 |
|
|
|
23 |
|
24 |
-
|
25 |
-
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
|
26 |
whisper_jax = Client(whisper_jax_api)
|
27 |
|
28 |
-
def transcribe_audio(audio_path,
|
29 |
-
task='transcribe',
|
30 |
-
return_timestamps=True) -> str:
|
31 |
|
|
|
32 |
text, runtime = whisper_jax.predict(
|
33 |
audio_path,
|
34 |
task,
|
35 |
return_timestamps,
|
36 |
-
api_name=
|
37 |
)
|
38 |
return text
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
def format_whisper_jax_output(whisper_jax_output: str,
|
44 |
-
max_duration: int = 60) -> list[dict]:
|
45 |
-
|
46 |
-
"""Returns a list of dict with keys 'start', 'end', 'text'
|
47 |
The segments from whisper jax output are merged to form paragraphs.
|
48 |
|
49 |
`max_duration` controls how many seconds of the audio's transcripts are merged
|
@@ -54,68 +58,65 @@ def format_whisper_jax_output(whisper_jax_output: str,
|
|
54 |
|
55 |
final_output = []
|
56 |
max_duration = datetime.timedelta(seconds=max_duration)
|
57 |
-
segments = whisper_jax_output.split(
|
58 |
-
current_start = datetime.datetime.strptime(
|
59 |
-
current_text =
|
60 |
|
61 |
for i, seg in enumerate(segments):
|
62 |
-
|
63 |
-
|
64 |
|
65 |
# Sometimes whisper jax returns None for timestamp
|
66 |
try:
|
67 |
-
end = datetime.datetime.strptime(seg[14:19],
|
68 |
except ValueError:
|
69 |
end = current_start + max_duration
|
70 |
|
71 |
-
if
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
'text': current_text
|
80 |
-
})
|
81 |
-
|
82 |
-
# Update current start and text
|
83 |
-
current_start = end
|
84 |
-
current_text = ''
|
85 |
|
86 |
else:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
|
|
|
|
|
|
|
94 |
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
global audio_file_number
|
102 |
-
|
103 |
progress = gr.Progress()
|
104 |
progress(0.1)
|
105 |
|
106 |
-
with yt_dlp.YoutubeDL(
|
107 |
-
|
108 |
-
|
109 |
-
}) as video:
|
110 |
-
|
111 |
info_dict = video.extract_info(url, download=False)
|
112 |
global video_title
|
113 |
-
video_title = info_dict[
|
114 |
video.download(url)
|
115 |
|
116 |
progress(0.4)
|
117 |
-
audio_file =
|
118 |
-
audio_file_number += 1
|
119 |
|
120 |
result = transcribe_audio(audio_file, return_timestamps=True)
|
121 |
progress(0.7)
|
@@ -123,61 +124,83 @@ def yt_audio_to_text(url: str,
|
|
123 |
result = format_whisper_jax_output(result, max_duration=max_duration)
|
124 |
progress(0.9)
|
125 |
|
126 |
-
with open(
|
127 |
json.dump(result, f)
|
128 |
|
|
|
129 |
|
130 |
|
131 |
|
132 |
-
def metadata_func(record: dict, metadata: dict) -> dict:
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
return metadata
|
139 |
|
140 |
|
141 |
def load_data():
|
142 |
loader = JSONLoader(
|
143 |
-
file_path=
|
144 |
-
jq_schema=
|
145 |
-
content_key=
|
146 |
-
metadata_func=metadata_func
|
147 |
)
|
148 |
|
149 |
data = loader.load()
|
|
|
150 |
|
151 |
return data
|
152 |
|
153 |
|
154 |
|
155 |
|
156 |
-
|
157 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
158 |
-
embedding_model_kwargs = {'device': device}
|
159 |
|
160 |
-
|
161 |
-
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
"""
|
167 |
|
168 |
-
vectordb = Chroma.from_documents(
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
return vectordb, retriever
|
173 |
|
|
|
174 |
|
|
|
|
|
175 |
|
176 |
|
177 |
-
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
|
178 |
-
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000})
|
179 |
|
180 |
|
|
|
181 |
|
182 |
# Map
|
183 |
map_template = """Summarise the following text:
|
@@ -187,8 +210,6 @@ Answer:"""
|
|
187 |
map_prompt = PromptTemplate.from_template(map_template)
|
188 |
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
189 |
|
190 |
-
|
191 |
-
|
192 |
# Reduce
|
193 |
reduce_template = """The following is a set of summaries:
|
194 |
{docs}
|
@@ -214,7 +235,7 @@ reduce_documents_chain = ReduceDocumentsChain(
|
|
214 |
# If documents exceed context for `StuffDocumentsChain`
|
215 |
collapse_documents_chain=combine_documents_chain,
|
216 |
# The maximum number of tokens to group documents into.
|
217 |
-
token_max=4000
|
218 |
)
|
219 |
|
220 |
|
@@ -227,21 +248,24 @@ map_reduce_chain = MapReduceDocumentsChain(
|
|
227 |
# The variable name in the llm_chain to put the documents in
|
228 |
document_variable_name="docs",
|
229 |
# Return the results of the map steps in the output
|
230 |
-
return_intermediate_steps=False
|
231 |
)
|
232 |
|
|
|
233 |
def get_summary(documents) -> str:
|
234 |
summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
|
235 |
-
return summary[
|
236 |
|
237 |
|
238 |
|
|
|
239 |
|
240 |
contextualise_q_prompt = PromptTemplate.from_template(
|
241 |
"""Given a chat history and the latest user question \
|
242 |
-
which might reference the chat history, formulate a
|
243 |
-
that can be understood without
|
244 |
-
|
|
|
245 |
|
246 |
Chat history: {chat_history}
|
247 |
|
@@ -255,12 +279,15 @@ contextualise_q_chain = contextualise_q_prompt | llm
|
|
255 |
|
256 |
|
257 |
|
|
|
|
|
258 |
standalone_prompt = PromptTemplate.from_template(
|
259 |
"""Given a chat history and the latest user question, \
|
260 |
-
identify whether the question is a standalone question
|
261 |
-
references the chat history. Answer 'yes'
|
262 |
-
question, and 'no' if the
|
263 |
-
|
|
|
264 |
|
265 |
Chat history:
|
266 |
{chat_history}
|
@@ -272,14 +299,19 @@ standalone_prompt = PromptTemplate.from_template(
|
|
272 |
"""
|
273 |
)
|
274 |
|
|
|
275 |
def format_output(answer: str) -> str:
|
276 |
-
|
277 |
-
|
|
|
|
|
|
|
278 |
|
279 |
standalone_chain = standalone_prompt | llm | format_output
|
280 |
|
281 |
|
282 |
|
|
|
283 |
|
284 |
qa_prompt = PromptTemplate.from_template(
|
285 |
"""You are an assistant for question-answering tasks. \
|
@@ -298,226 +330,326 @@ qa_prompt = PromptTemplate.from_template(
|
|
298 |
"""
|
299 |
)
|
300 |
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
class YouTubeChatbot:
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
312 |
self.n_sources = n_sources
|
313 |
-
self.
|
314 |
self.timestamp_interval = timestamp_interval
|
315 |
self.chat_history = ConversationBufferWindowMemory(k=memory)
|
|
|
|
|
316 |
|
317 |
|
318 |
def format_docs(self, docs: list) -> str:
|
319 |
-
"""Combine documents
|
|
|
320 |
"""
|
|
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
return '\n\n'.join(doc.page_content for doc in docs)
|
325 |
-
|
326 |
|
327 |
|
328 |
def standalone_question(self, input_: dict) -> str:
|
329 |
-
"""If the question is a not a standalone question,
|
330 |
run contextualise_q_chain.
|
331 |
"""
|
332 |
-
if input_[
|
333 |
return contextualise_q_chain
|
334 |
else:
|
335 |
-
return input_[
|
336 |
|
337 |
|
338 |
def format_answer(self, answer: str) -> str:
|
339 |
-
|
340 |
-
|
|
|
341 |
return answer.strip()
|
342 |
else:
|
343 |
timestamps = self.filter_timestamps()
|
344 |
answer_with_sources = (
|
345 |
-
answer.strip()
|
346 |
-
|
347 |
-
|
348 |
-
)
|
349 |
return answer_with_sources
|
350 |
|
351 |
|
352 |
def filter_timestamps(self) -> list[str]:
|
353 |
-
"""Returns a list of timestamps with length
|
354 |
-
The timestamps are at least an
|
355 |
-
This prevents returning
|
356 |
-
close together.
|
357 |
"""
|
358 |
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
while len(filtered_timestamps) < self.n_sources:
|
363 |
-
timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1],
|
364 |
-
'%H:%M:%S')
|
365 |
|
|
|
|
|
|
|
366 |
try:
|
367 |
-
|
368 |
-
'%H:%M:%S')
|
369 |
except IndexError:
|
370 |
break
|
371 |
|
372 |
-
|
373 |
|
374 |
-
if
|
375 |
-
filtered_timestamps.append(
|
376 |
|
377 |
i += 1
|
378 |
|
|
|
|
|
|
|
|
|
|
|
379 |
return filtered_timestamps
|
380 |
|
381 |
|
382 |
-
def
|
383 |
-
"""Given a YouTube
|
|
|
384 |
"""
|
385 |
-
|
386 |
yt_audio_to_text(url)
|
|
|
387 |
|
388 |
-
|
|
|
|
|
|
|
389 |
|
390 |
-
|
|
|
|
|
|
|
391 |
|
|
|
392 |
|
393 |
-
|
|
|
|
|
394 |
RunnablePassthrough.assign(standalone=standalone_chain)
|
395 |
-
| {
|
396 |
-
|
|
|
|
|
397 |
| qa_prompt
|
398 |
-
| llm
|
|
|
399 |
|
400 |
-
return
|
401 |
|
402 |
|
|
|
|
|
|
|
403 |
|
404 |
-
def get_answer(self, question: str) -> str:
|
405 |
|
|
|
406 |
try:
|
407 |
-
ai_msg = self.qa_chain.invoke(
|
408 |
-
|
|
|
409 |
except AttributeError:
|
410 |
-
raise AttributeError(
|
411 |
-
|
412 |
-
|
|
|
|
|
413 |
|
414 |
-
|
415 |
|
416 |
-
self.
|
417 |
-
{'answer':answer})
|
418 |
|
419 |
return answer
|
420 |
-
|
421 |
|
422 |
|
423 |
|
|
|
424 |
|
425 |
class YouTubeChatbotApp(YouTubeChatbot):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
-
def __init__(self,
|
428 |
-
n_sources: int,
|
429 |
-
k: int,
|
430 |
-
timestamp_interval: datetime.timedelta,
|
431 |
-
memory: int,
|
432 |
-
default_youtube_url: str
|
433 |
-
):
|
434 |
-
super().__init__(n_sources, k, timestamp_interval, memory)
|
435 |
self.default_youtube_url = default_youtube_url
|
436 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
|
|
|
|
|
|
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
self.gradio_chat_history.append((None, summary))
|
444 |
-
greeting_message = ('You can ask me anything about the video. '
|
445 |
-
'I will do my best to answer!')
|
446 |
-
self.gradio_chat_history.append((None, greeting_message))
|
447 |
-
return self.gradio_chat_history
|
448 |
|
|
|
449 |
|
450 |
-
def question(self,
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
|
|
|
|
|
|
454 |
|
455 |
-
def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]:
|
456 |
try:
|
457 |
-
|
|
|
|
|
458 |
except AttributeError:
|
459 |
-
raise gr.Error(
|
460 |
-
|
|
|
461 |
|
|
|
462 |
|
463 |
-
self.
|
464 |
-
return self.gradio_chat_history
|
465 |
|
|
|
466 |
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
|
|
|
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
def launch(self, **kwargs):
|
474 |
-
|
475 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
-
#
|
478 |
with gr.Row():
|
479 |
-
url_input = gr.Textbox(
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
chatbot = gr.Chatbot()
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
#
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
demo.launch(**kwargs)
|
512 |
|
513 |
|
514 |
-
|
515 |
if __name__ == "__main__":
|
516 |
-
app = YouTubeChatbotApp(
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
7 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
8 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
9 |
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
10 |
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
11 |
from langchain_community.llms import HuggingFaceHub
|
12 |
+
from langchain.chains import (
|
13 |
+
LLMChain,
|
14 |
+
StuffDocumentsChain,
|
15 |
+
MapReduceDocumentsChain,
|
16 |
+
ReduceDocumentsChain,
|
17 |
+
)
|
18 |
|
19 |
+
from gradio_client import Client
|
20 |
+
import gradio as gr
|
21 |
import yt_dlp
|
22 |
import json
|
23 |
import gc
|
|
|
|
|
24 |
import datetime
|
25 |
import os
|
26 |
+
import numpy as np
|
27 |
|
28 |
|
29 |
+
"""Prepare data"""
|
30 |
|
31 |
+
whisper_jax_api = "https://sanchit-gandhi-whisper-jax.hf.space/"
|
|
|
32 |
whisper_jax = Client(whisper_jax_api)
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
+
def transcribe_audio(audio_path, task="transcribe", return_timestamps=True) -> str:
|
36 |
text, runtime = whisper_jax.predict(
|
37 |
audio_path,
|
38 |
task,
|
39 |
return_timestamps,
|
40 |
+
api_name="/predict_1",
|
41 |
)
|
42 |
return text
|
43 |
|
44 |
+
def format_whisper_jax_output(
|
45 |
+
whisper_jax_output: str, max_duration: int = 60
|
46 |
+
) -> list[dict]:
|
47 |
+
"""Whisper JAX outputs are in the format
|
48 |
+
'[00:00.000 -> 00:00.000] text\n[00:00.000 -> 00:00.000] text'.
|
49 |
|
50 |
+
Returns a list of dict with keys 'start', 'end', 'text'
|
|
|
|
|
|
|
|
|
|
|
51 |
The segments from whisper jax output are merged to form paragraphs.
|
52 |
|
53 |
`max_duration` controls how many seconds of the audio's transcripts are merged
|
|
|
58 |
|
59 |
final_output = []
|
60 |
max_duration = datetime.timedelta(seconds=max_duration)
|
61 |
+
segments = whisper_jax_output.split("\n")
|
62 |
+
current_start = datetime.datetime.strptime("00:00", "%M:%S")
|
63 |
+
current_text = ""
|
64 |
|
65 |
for i, seg in enumerate(segments):
|
66 |
+
text = seg.split("]")[-1].strip()
|
67 |
+
current_text += " " + text
|
68 |
|
69 |
# Sometimes whisper jax returns None for timestamp
|
70 |
try:
|
71 |
+
end = datetime.datetime.strptime(seg[14:19], "%M:%S")
|
72 |
except ValueError:
|
73 |
end = current_start + max_duration
|
74 |
|
75 |
+
if i == len(segments) - 1:
|
76 |
+
final_output.append(
|
77 |
+
{
|
78 |
+
"start": current_start.strftime("%H:%M:%S"),
|
79 |
+
"end": end.strftime("%H:%M:%S"),
|
80 |
+
"text": current_text.strip(),
|
81 |
+
}
|
82 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
else:
|
85 |
+
if end - current_start >= max_duration and current_text[-1] == ".":
|
86 |
+
# If we have exceeded max duration, check whether we have
|
87 |
+
# reached the end of a sentence. If not, keep merging.
|
88 |
+
final_output.append(
|
89 |
+
{
|
90 |
+
"start": current_start.strftime("%H:%M:%S"),
|
91 |
+
"end": end.strftime("%H:%M:%S"),
|
92 |
+
"text": current_text.strip(),
|
93 |
+
}
|
94 |
+
)
|
95 |
|
96 |
+
# Update current start and text
|
97 |
+
current_start = end
|
98 |
+
current_text = ""
|
99 |
|
100 |
+
return final_output
|
101 |
|
102 |
+
def yt_audio_to_text(url: str, max_duration: int = 60):
|
103 |
+
"""Given a YouTube url, download audio and transcribe it to text. Reformat
|
104 |
+
the output from Whisper JAX and save the final result in a json file.
|
105 |
+
"""
|
106 |
|
|
|
|
|
107 |
progress = gr.Progress()
|
108 |
progress(0.1)
|
109 |
|
110 |
+
with yt_dlp.YoutubeDL(
|
111 |
+
{"extract_audio": True, "format": "bestaudio", "outtmpl": "audio.mp3"}
|
112 |
+
) as video:
|
|
|
|
|
113 |
info_dict = video.extract_info(url, download=False)
|
114 |
global video_title
|
115 |
+
video_title = info_dict["title"]
|
116 |
video.download(url)
|
117 |
|
118 |
progress(0.4)
|
119 |
+
audio_file = "audio.mp3"
|
|
|
120 |
|
121 |
result = transcribe_audio(audio_file, return_timestamps=True)
|
122 |
progress(0.7)
|
|
|
124 |
result = format_whisper_jax_output(result, max_duration=max_duration)
|
125 |
progress(0.9)
|
126 |
|
127 |
+
with open("audio.json", "w") as f:
|
128 |
json.dump(result, f)
|
129 |
|
130 |
+
os.remove(audio_file)
|
131 |
|
132 |
|
133 |
|
|
|
134 |
|
135 |
+
"""Load data"""
|
136 |
+
|
137 |
+
def metadata_func(record: dict, metadata: dict) -> dict:
|
138 |
+
"""This function is used to tell the Langchain loader the keys that
|
139 |
+
contain metadata and extract them.
|
140 |
+
"""
|
141 |
+
metadata["start"] = record.get("start")
|
142 |
+
metadata["end"] = record.get("end")
|
143 |
+
metadata["source"] = metadata["start"] + " -> " + metadata["end"]
|
144 |
|
145 |
return metadata
|
146 |
|
147 |
|
148 |
def load_data():
|
149 |
loader = JSONLoader(
|
150 |
+
file_path="audio.json",
|
151 |
+
jq_schema=".[]",
|
152 |
+
content_key="text",
|
153 |
+
metadata_func=metadata_func,
|
154 |
)
|
155 |
|
156 |
data = loader.load()
|
157 |
+
os.remove("audio.json")
|
158 |
|
159 |
return data
|
160 |
|
161 |
|
162 |
|
163 |
|
164 |
+
"""Create embeddings and vector store"""
|
|
|
|
|
165 |
|
166 |
+
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
|
167 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
168 |
+
embedding_model_kwargs = {"device": device}
|
169 |
|
170 |
+
embeddings = HuggingFaceEmbeddings(
|
171 |
+
model_name=embedding_model_name, model_kwargs=embedding_model_kwargs
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
def create_vectordb(data, n_retrieved_docs: int, collection_name="YouTube"):
|
176 |
+
"""Returns a retriever which is used to fetch relevant documents from
|
177 |
+
the vector database.
|
178 |
+
|
179 |
+
`n_retrieved_docs` is the number of retrieved documents.
|
180 |
"""
|
181 |
|
182 |
+
vectordb = Chroma.from_documents(
|
183 |
+
documents=data, embedding=embeddings, collection_name=collection_name
|
184 |
+
)
|
185 |
+
n_docs = len(vectordb.get()["ids"])
|
186 |
+
retriever = vectordb.as_retriever(
|
187 |
+
search_type="mmr", search_kwargs={"k": n_retrieved_docs, "fetch_k": n_docs}
|
188 |
+
)
|
189 |
+
|
190 |
+
return retriever
|
191 |
+
|
192 |
+
|
193 |
|
|
|
194 |
|
195 |
+
"""Load LLM"""
|
196 |
|
197 |
+
repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
|
198 |
+
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"max_new_tokens": 1000})
|
199 |
|
200 |
|
|
|
|
|
201 |
|
202 |
|
203 |
+
"""Summarisation"""
|
204 |
|
205 |
# Map
|
206 |
map_template = """Summarise the following text:
|
|
|
210 |
map_prompt = PromptTemplate.from_template(map_template)
|
211 |
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
212 |
|
|
|
|
|
213 |
# Reduce
|
214 |
reduce_template = """The following is a set of summaries:
|
215 |
{docs}
|
|
|
235 |
# If documents exceed context for `StuffDocumentsChain`
|
236 |
collapse_documents_chain=combine_documents_chain,
|
237 |
# The maximum number of tokens to group documents into.
|
238 |
+
token_max=4000,
|
239 |
)
|
240 |
|
241 |
|
|
|
248 |
# The variable name in the llm_chain to put the documents in
|
249 |
document_variable_name="docs",
|
250 |
# Return the results of the map steps in the output
|
251 |
+
return_intermediate_steps=False,
|
252 |
)
|
253 |
|
254 |
+
|
255 |
def get_summary(documents) -> str:
|
256 |
summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
|
257 |
+
return summary["output_text"].strip()
|
258 |
|
259 |
|
260 |
|
261 |
+
"""Contextualising the question"""
|
262 |
|
263 |
contextualise_q_prompt = PromptTemplate.from_template(
|
264 |
"""Given a chat history and the latest user question \
|
265 |
+
which might reference the chat history, formulate a \
|
266 |
+
standalone question that can be understood without \
|
267 |
+
the chat history. Do NOT answer the question, just \
|
268 |
+
reformulate it if needed and otherwise return it as is.
|
269 |
|
270 |
Chat history: {chat_history}
|
271 |
|
|
|
279 |
|
280 |
|
281 |
|
282 |
+
"""Standalone question chain"""
|
283 |
+
|
284 |
standalone_prompt = PromptTemplate.from_template(
|
285 |
"""Given a chat history and the latest user question, \
|
286 |
+
identify whether the question is a standalone question \
|
287 |
+
or the question references the chat history. Answer 'yes' \
|
288 |
+
if the question is a standalone question, and 'no' if the \
|
289 |
+
question references the chat history. Do not answer \
|
290 |
+
anything other than 'yes' or 'no'.
|
291 |
|
292 |
Chat history:
|
293 |
{chat_history}
|
|
|
299 |
"""
|
300 |
)
|
301 |
|
302 |
+
|
303 |
def format_output(answer: str) -> str:
|
304 |
+
"""All lower case and remove all whitespace to ensure
|
305 |
+
that the answer given by the LLM is either 'yes' or 'no'.
|
306 |
+
"""
|
307 |
+
return "".join(answer.lower().split())
|
308 |
+
|
309 |
|
310 |
standalone_chain = standalone_prompt | llm | format_output
|
311 |
|
312 |
|
313 |
|
314 |
+
"""Q&A chain"""
|
315 |
|
316 |
qa_prompt = PromptTemplate.from_template(
|
317 |
"""You are an assistant for question-answering tasks. \
|
|
|
330 |
"""
|
331 |
)
|
332 |
|
|
|
|
|
|
|
333 |
class YouTubeChatbot:
|
334 |
+
instance_count = 0
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
n_sources: int = 3,
|
339 |
+
n_retrieved_docs: int = 5,
|
340 |
+
timestamp_interval: datetime.timedelta = datetime.timedelta(minutes=2),
|
341 |
+
memory: int = 5,
|
342 |
+
):
|
343 |
+
YouTubeChatbot.instance_count += 1
|
344 |
+
self.chatbot_id = YouTubeChatbot.instance_count
|
345 |
self.n_sources = n_sources
|
346 |
+
self.n_retrieved_docs = n_retrieved_docs
|
347 |
self.timestamp_interval = timestamp_interval
|
348 |
self.chat_history = ConversationBufferWindowMemory(k=memory)
|
349 |
+
self.retriever = None
|
350 |
+
self.qa_chain = None
|
351 |
|
352 |
|
353 |
def format_docs(self, docs: list) -> str:
|
354 |
+
"""Combine documents into a single string which will be included
|
355 |
+
in the prompt given to the LLM.
|
356 |
"""
|
357 |
+
self.sources = [doc.metadata["start"] for doc in docs]
|
358 |
|
359 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
|
|
|
360 |
|
361 |
|
362 |
def standalone_question(self, input_: dict) -> str:
|
363 |
+
"""If the question is a not a standalone question,
|
364 |
run contextualise_q_chain.
|
365 |
"""
|
366 |
+
if input_["standalone"] == "yes":
|
367 |
return contextualise_q_chain
|
368 |
else:
|
369 |
+
return input_["question"]
|
370 |
|
371 |
|
372 |
def format_answer(self, answer: str) -> str:
|
373 |
+
"""Add timestamps to answers.
|
374 |
+
"""
|
375 |
+
if "cannot find the answer" in answer:
|
376 |
return answer.strip()
|
377 |
else:
|
378 |
timestamps = self.filter_timestamps()
|
379 |
answer_with_sources = (
|
380 |
+
answer.strip() + " You can find more information "
|
381 |
+
"at these timestamps: {}.".format(", ".join(timestamps))
|
382 |
+
)
|
|
|
383 |
return answer_with_sources
|
384 |
|
385 |
|
386 |
def filter_timestamps(self) -> list[str]:
|
387 |
+
"""Returns a list of timestamps with length less or
|
388 |
+
equal to `n_sources`. The timestamps are at least an
|
389 |
+
`timestamp_interval` apart. This prevents returning
|
390 |
+
a list of timestamps that are too close together.
|
391 |
"""
|
392 |
|
393 |
+
filtered_timestamps = np.array(
|
394 |
+
[datetime.datetime.strptime(self.sources[0], "%H:%M:%S")]
|
395 |
+
)
|
|
|
|
|
|
|
396 |
|
397 |
+
i = 1
|
398 |
+
|
399 |
+
while len(filtered_timestamps) < self.n_sources:
|
400 |
try:
|
401 |
+
new_timestamp = datetime.datetime.strptime(self.sources[i], "%H:%M:%S")
|
|
|
402 |
except IndexError:
|
403 |
break
|
404 |
|
405 |
+
absolute_time_difference = abs(new_timestamp - filtered_timestamps)
|
406 |
|
407 |
+
if all(absolute_time_difference >= self.timestamp_interval):
|
408 |
+
filtered_timestamps = np.append(filtered_timestamps, new_timestamp)
|
409 |
|
410 |
i += 1
|
411 |
|
412 |
+
filtered_timestamps = [
|
413 |
+
timestamp.strftime("%H:%M:%S") for timestamp in filtered_timestamps
|
414 |
+
]
|
415 |
+
filtered_timestamps.sort()
|
416 |
+
|
417 |
return filtered_timestamps
|
418 |
|
419 |
|
420 |
+
def process_video(self, url: str, data=None, retriever=None):
|
421 |
+
"""Given a YouTube URL, transcribe YouTube audio to text.
|
422 |
+
Then set up the vector database.
|
423 |
"""
|
|
|
424 |
yt_audio_to_text(url)
|
425 |
+
data = load_data()
|
426 |
|
427 |
+
if retriever is not None:
|
428 |
+
# If we already have documents in the vector store, delete them.
|
429 |
+
ids = retriever.vectorstore.get()["ids"]
|
430 |
+
retriever.vectorstore.delete(ids)
|
431 |
|
432 |
+
retriever = create_vectordb(
|
433 |
+
data, self.n_retrieved_docs,
|
434 |
+
collection_name=f"Chatbot{self.chatbot_id}"
|
435 |
+
)
|
436 |
|
437 |
+
return url, data, retriever
|
438 |
|
439 |
+
|
440 |
+
def setup_qa_chain(self, retriever, qa_chain=None):
|
441 |
+
qa_chain = (
|
442 |
RunnablePassthrough.assign(standalone=standalone_chain)
|
443 |
+
| {
|
444 |
+
"question": self.standalone_question,
|
445 |
+
"context": self.standalone_question | retriever | self.format_docs,
|
446 |
+
}
|
447 |
| qa_prompt
|
448 |
+
| llm
|
449 |
+
)
|
450 |
|
451 |
+
return retriever, qa_chain
|
452 |
|
453 |
|
454 |
+
def setup_chatbot(self, url: str):
|
455 |
+
_, _, self.retriever = self.process_video(url=url, retriever=self.retriever)
|
456 |
+
_, self.qa_chain = self.setup_qa_chain(retriever=self.retriever)
|
457 |
|
|
|
458 |
|
459 |
+
def get_answer(self, question: str) -> str:
|
460 |
try:
|
461 |
+
ai_msg = self.qa_chain.invoke(
|
462 |
+
{"question": question, "chat_history": self.chat_history}
|
463 |
+
)
|
464 |
except AttributeError:
|
465 |
+
raise AttributeError(
|
466 |
+
"You haven't setup the chatbot yet. "
|
467 |
+
"Setup the chatbot by calling the "
|
468 |
+
"instance method `setup_chatbot`."
|
469 |
+
)
|
470 |
|
471 |
+
self.chat_history.save_context({"question": question}, {"answer": ai_msg})
|
472 |
|
473 |
+
answer = self.format_answer(ai_msg)
|
|
|
474 |
|
475 |
return answer
|
|
|
476 |
|
477 |
|
478 |
|
479 |
+
"""Web app"""
|
480 |
|
481 |
class YouTubeChatbotApp(YouTubeChatbot):
|
482 |
+
def __init__(
|
483 |
+
self,
|
484 |
+
n_sources: int,
|
485 |
+
n_retrieved_docs: int,
|
486 |
+
timestamp_interval: datetime.timedelta,
|
487 |
+
memory: int,
|
488 |
+
default_youtube_url: str,
|
489 |
+
):
|
490 |
+
super().__init__(n_sources, n_retrieved_docs, timestamp_interval, memory)
|
491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
self.default_youtube_url = default_youtube_url
|
493 |
+
self.memory = memory
|
494 |
+
self.chat_history = None
|
495 |
+
self.data = None
|
496 |
+
self.retriever = None
|
497 |
+
self.qa_chain = None
|
498 |
+
|
499 |
+
# Gradio components
|
500 |
+
self.url_input = None
|
501 |
+
self.url_button = None
|
502 |
+
self.app_chat_history = None
|
503 |
+
self.chatbot = None
|
504 |
+
self.user_input = None
|
505 |
+
self.clear_button = None
|
506 |
+
|
507 |
+
def greet(self, data, app_chat_history) -> dict:
|
508 |
+
"""Summarise the video and greet the user.
|
509 |
+
"""
|
510 |
+
summary_message = f'Here is a summary of the video "{video_title}":'
|
511 |
+
app_chat_history.append((None, summary_message))
|
512 |
|
513 |
+
summary = get_summary(data)
|
514 |
+
self.data = gr.State(None)
|
515 |
+
app_chat_history.append((None, summary))
|
516 |
|
517 |
+
greeting_message = (
|
518 |
+
"You can ask me anything about the video. " "I will do my best to answer!"
|
519 |
+
)
|
520 |
+
app_chat_history.append((None, greeting_message))
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
+
return {self.app_chat_history: app_chat_history, self.chatbot: app_chat_history}
|
523 |
|
524 |
+
def question(self, user_question: str, app_chat_history) -> dict:
|
525 |
+
"""Display the question asked by the user in the chat window,
|
526 |
+
and delete from the input textbox.
|
527 |
+
"""
|
528 |
+
app_chat_history.append((user_question, None))
|
529 |
+
return {
|
530 |
+
self.user_input: "",
|
531 |
+
self.app_chat_history: app_chat_history,
|
532 |
+
self.chatbot: app_chat_history,
|
533 |
+
}
|
534 |
|
535 |
+
def respond(self, qa_chain, chat_history, app_chat_history) -> dict:
|
536 |
+
"""Respond to user's latest question"""
|
537 |
+
question = app_chat_history[-1][0]
|
538 |
|
|
|
539 |
try:
|
540 |
+
ai_msg = qa_chain.invoke(
|
541 |
+
{"question": question, "chat_history": chat_history}
|
542 |
+
)
|
543 |
except AttributeError:
|
544 |
+
raise gr.Error(
|
545 |
+
"You need to process the video " "first by pressing the `Go` button."
|
546 |
+
)
|
547 |
|
548 |
+
chat_history.save_context({"question": question}, {"answer": ai_msg})
|
549 |
|
550 |
+
answer = self.format_answer(ai_msg)
|
|
|
551 |
|
552 |
+
app_chat_history.append((None, answer))
|
553 |
|
554 |
+
return {
|
555 |
+
self.qa_chain: qa_chain,
|
556 |
+
self.chat_history: chat_history,
|
557 |
+
self.app_chat_history: app_chat_history,
|
558 |
+
self.chatbot: app_chat_history,
|
559 |
+
}
|
560 |
|
561 |
+
def clear_chat_history(self, chat_history, app_chat_history):
|
562 |
+
chat_history.clear()
|
563 |
+
app_chat_history = []
|
564 |
+
return {
|
565 |
+
self.chat_history: chat_history,
|
566 |
+
self.app_chat_history: app_chat_history,
|
567 |
+
self.chatbot: app_chat_history,
|
568 |
+
}
|
569 |
|
570 |
def launch(self, **kwargs):
|
|
|
571 |
with gr.Blocks() as demo:
|
572 |
+
self.chat_history = gr.State(ConversationBufferWindowMemory(k=self.memory))
|
573 |
+
self.app_chat_history = gr.State([])
|
574 |
+
self.data = gr.State()
|
575 |
+
self.retriever = gr.State()
|
576 |
+
self.qa_chain = gr.State()
|
577 |
|
578 |
+
# App structure
|
579 |
with gr.Row():
|
580 |
+
self.url_input = gr.Textbox(
|
581 |
+
value=self.default_youtube_url, label="YouTube URL", scale=5
|
582 |
+
)
|
583 |
+
self.url_button = gr.Button(value="Go", scale=1)
|
584 |
+
|
585 |
+
self.chatbot = gr.Chatbot()
|
586 |
+
self.user_input = gr.Textbox(label="Ask a question:")
|
587 |
+
self.clear_button = gr.Button(value="Clear")
|
588 |
+
|
589 |
+
|
590 |
+
# App actions
|
591 |
+
|
592 |
+
# When a new url is given, clear past chat history and process
|
593 |
+
# the new video. Set up the Q&A chain with the new video's data.
|
594 |
+
# Provide a summary of the new video.
|
595 |
+
self.url_button.click(
|
596 |
+
self.clear_chat_history,
|
597 |
+
inputs=[self.chat_history, self.app_chat_history],
|
598 |
+
outputs=[self.chat_history, self.app_chat_history, self.chatbot],
|
599 |
+
trigger_mode="once",
|
600 |
+
).then(
|
601 |
+
self.process_video,
|
602 |
+
inputs=[self.url_input, self.data, self.retriever],
|
603 |
+
outputs=[self.url_input, self.data, self.retriever],
|
604 |
+
).then(
|
605 |
+
self.setup_qa_chain,
|
606 |
+
inputs=[self.retriever, self.qa_chain],
|
607 |
+
outputs=[self.retriever, self.qa_chain],
|
608 |
+
).then(
|
609 |
+
self.greet,
|
610 |
+
inputs=[self.data, self.app_chat_history],
|
611 |
+
outputs=[self.app_chat_history, self.chatbot],
|
612 |
+
)
|
613 |
+
|
614 |
+
# When a user asks a question, display the question in the chat
|
615 |
+
# window and remove it from the text input area. Then respond
|
616 |
+
# with the Q&A chain.
|
617 |
+
self.user_input.submit(
|
618 |
+
self.question,
|
619 |
+
inputs=[self.user_input, self.app_chat_history],
|
620 |
+
outputs=[self.user_input, self.app_chat_history, self.chatbot],
|
621 |
+
queue=False,
|
622 |
+
).then(
|
623 |
+
self.respond,
|
624 |
+
inputs=[self.qa_chain, self.chat_history, self.app_chat_history],
|
625 |
+
outputs=[
|
626 |
+
self.qa_chain,
|
627 |
+
self.chat_history,
|
628 |
+
self.app_chat_history,
|
629 |
+
self.chatbot,
|
630 |
+
],
|
631 |
+
)
|
632 |
+
|
633 |
+
# When the `Clear` button is clicked, clear the chat history from
|
634 |
+
# the chat window.
|
635 |
+
self.clear_button.click(
|
636 |
+
self.clear_chat_history,
|
637 |
+
inputs=[self.chat_history, self.app_chat_history],
|
638 |
+
outputs=[self.chat_history, self.app_chat_history, self.chatbot],
|
639 |
+
queue=False,
|
640 |
+
)
|
641 |
|
642 |
demo.launch(**kwargs)
|
643 |
|
644 |
|
|
|
645 |
if __name__ == "__main__":
|
646 |
+
app = YouTubeChatbotApp(
|
647 |
+
n_sources=3,
|
648 |
+
n_retrieved_docs=5,
|
649 |
+
timestamp_interval=datetime.timedelta(minutes=2),
|
650 |
+
memory=5,
|
651 |
+
default_youtube_url="https://www.youtube.com/watch?v=4Bdc55j80l8",
|
652 |
+
)
|
653 |
+
|
654 |
+
app.launch()
|
655 |
+
|