Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
from langchain import PromptTemplate
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
8 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
9 |
from langchain_core.messages import AIMessage, HumanMessage
|
10 |
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
|
11 |
-
from langchain.
|
|
|
12 |
|
13 |
import yt_dlp
|
14 |
import json
|
@@ -16,12 +17,18 @@ import gc
|
|
16 |
import gradio as gr
|
17 |
from gradio_client import Client
|
18 |
import datetime
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
|
22 |
whisper_jax = Client(whisper_jax_api)
|
23 |
|
24 |
-
def transcribe_audio(audio_path,
|
|
|
|
|
|
|
25 |
text, runtime = whisper_jax.predict(
|
26 |
audio_path,
|
27 |
task,
|
@@ -32,17 +39,18 @@ def transcribe_audio(audio_path, task='transcribe', return_timestamps=True):
|
|
32 |
|
33 |
|
34 |
|
35 |
-
def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) -> list:
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
The segments from whisper jax output are merged to form paragraphs.
|
40 |
|
41 |
`max_duration` controls how many seconds of the audio's transcripts are merged
|
42 |
|
43 |
For example, if `max_duration`=60, in the final output, each segment is roughly
|
44 |
60 seconds.
|
45 |
-
|
46 |
|
47 |
final_output = []
|
48 |
max_duration = datetime.timedelta(seconds=max_duration)
|
@@ -53,25 +61,30 @@ def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) ->
|
|
53 |
for i, seg in enumerate(segments):
|
54 |
|
55 |
text = seg.split(']')[-1].strip()
|
56 |
-
end = datetime.datetime.strptime(seg[14:19], '%M:%S')
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
current_text += text
|
63 |
-
final_output.append({
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
# Update current start and text
|
69 |
current_start = end
|
70 |
current_text = ''
|
71 |
|
72 |
else:
|
73 |
-
# If we have not exceeded max duration,
|
74 |
-
# keep merging.
|
75 |
current_text += text
|
76 |
|
77 |
return final_output
|
@@ -79,6 +92,7 @@ def format_whisper_jax_output(whisper_jax_output: str, max_duration: int=60) ->
|
|
79 |
|
80 |
|
81 |
|
|
|
82 |
audio_file_number = 1
|
83 |
def yt_audio_to_text(url: str,
|
84 |
max_duration: int = 60
|
@@ -91,7 +105,8 @@ def yt_audio_to_text(url: str,
|
|
91 |
|
92 |
with yt_dlp.YoutubeDL({'extract_audio': True,
|
93 |
'format': 'bestaudio',
|
94 |
-
'outtmpl': f'{audio_file_number}.mp3'
|
|
|
95 |
|
96 |
info_dict = video.extract_info(url, download=False)
|
97 |
global video_title
|
@@ -113,11 +128,12 @@ def yt_audio_to_text(url: str,
|
|
113 |
|
114 |
|
115 |
|
|
|
116 |
def metadata_func(record: dict, metadata: dict) -> dict:
|
117 |
|
118 |
metadata['start'] = record.get('start')
|
119 |
metadata['end'] = record.get('end')
|
120 |
-
metadata['source'] = metadata['start'] + '->' + metadata['end']
|
121 |
|
122 |
return metadata
|
123 |
|
@@ -136,6 +152,7 @@ def load_data():
|
|
136 |
|
137 |
|
138 |
|
|
|
139 |
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
|
140 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
141 |
embedding_model_kwargs = {'device': device}
|
@@ -144,9 +161,9 @@ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name,
|
|
144 |
model_kwargs=embedding_model_kwargs)
|
145 |
|
146 |
def create_vectordb(data, k: int):
|
147 |
-
|
148 |
`k` is the number of retrieved documents
|
149 |
-
|
150 |
|
151 |
vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
|
152 |
retriever = vectordb.as_retriever(search_type='similarity',
|
@@ -155,8 +172,11 @@ def create_vectordb(data, k: int):
|
|
155 |
return vectordb, retriever
|
156 |
|
157 |
|
|
|
|
|
158 |
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
|
159 |
-
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'
|
|
|
160 |
|
161 |
|
162 |
# Map
|
@@ -168,11 +188,14 @@ map_prompt = PromptTemplate.from_template(map_template)
|
|
168 |
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
169 |
|
170 |
|
|
|
171 |
# Reduce
|
172 |
reduce_template = """The following is a set of summaries:
|
173 |
{docs}
|
174 |
|
175 |
-
Take these and distill it into a final, consolidated summary of the main themes
|
|
|
|
|
176 |
Answer:"""
|
177 |
|
178 |
reduce_prompt = PromptTemplate.from_template(reduce_template)
|
@@ -191,7 +214,7 @@ reduce_documents_chain = ReduceDocumentsChain(
|
|
191 |
# If documents exceed context for `StuffDocumentsChain`
|
192 |
collapse_documents_chain=combine_documents_chain,
|
193 |
# The maximum number of tokens to group documents into.
|
194 |
-
token_max=4000
|
195 |
)
|
196 |
|
197 |
|
@@ -204,18 +227,20 @@ map_reduce_chain = MapReduceDocumentsChain(
|
|
204 |
# The variable name in the llm_chain to put the documents in
|
205 |
document_variable_name="docs",
|
206 |
# Return the results of the map steps in the output
|
207 |
-
return_intermediate_steps=False
|
208 |
)
|
209 |
|
210 |
-
def get_summary():
|
211 |
-
summary = map_reduce_chain.
|
212 |
-
return summary
|
|
|
|
|
213 |
|
214 |
|
215 |
contextualise_q_prompt = PromptTemplate.from_template(
|
216 |
-
|
217 |
which might reference the chat history, formulate a standalone question \
|
218 |
-
|
219 |
just reformulate it if needed and otherwise return it as is.
|
220 |
|
221 |
Chat history: {chat_history}
|
@@ -223,13 +248,15 @@ contextualise_q_prompt = PromptTemplate.from_template(
|
|
223 |
Question: {question}
|
224 |
|
225 |
Answer:
|
226 |
-
|
227 |
)
|
228 |
|
229 |
contextualise_q_chain = contextualise_q_prompt | llm
|
230 |
|
|
|
|
|
231 |
standalone_prompt = PromptTemplate.from_template(
|
232 |
-
|
233 |
identify whether the question is a standalone question or the question \
|
234 |
references the chat history. Answer 'yes' if the question is a standalone \
|
235 |
question, and 'no' if the question references the chat history. Do not \
|
@@ -242,7 +269,7 @@ standalone_prompt = PromptTemplate.from_template(
|
|
242 |
{question}
|
243 |
|
244 |
Answer:
|
245 |
-
|
246 |
)
|
247 |
|
248 |
def format_output(answer: str) -> str:
|
@@ -252,8 +279,10 @@ def format_output(answer: str) -> str:
|
|
252 |
standalone_chain = standalone_prompt | llm | format_output
|
253 |
|
254 |
|
|
|
|
|
255 |
qa_prompt = PromptTemplate.from_template(
|
256 |
-
|
257 |
ONLY use the following context to answer the question. \
|
258 |
Do NOT answer with information that is not contained in \
|
259 |
the context. If you don't know the answer, just say:\
|
@@ -266,181 +295,230 @@ qa_prompt = PromptTemplate.from_template(
|
|
266 |
{question}
|
267 |
|
268 |
Answer:
|
269 |
-
|
270 |
)
|
271 |
|
272 |
|
273 |
-
def format_docs(docs: list) -> str:
|
274 |
-
'''
|
275 |
-
Combine documents
|
276 |
-
'''
|
277 |
-
global sources
|
278 |
-
sources = [doc.metadata['start'] for doc in docs]
|
279 |
-
|
280 |
-
return '\n\n'.join(doc.page_content for doc in docs)
|
281 |
-
|
282 |
-
|
283 |
-
def standalone_question(input_: dict) -> str:
|
284 |
-
'''
|
285 |
-
If the question is a not a standalone question, run contextualise_q_chain
|
286 |
-
'''
|
287 |
-
if input_['standalone']=='yes':
|
288 |
-
return contextualise_q_chain
|
289 |
-
else:
|
290 |
-
return input_['question']
|
291 |
-
|
292 |
-
|
293 |
-
def format_answer(answer: str,
|
294 |
-
n_sources: int=1,
|
295 |
-
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> str:
|
296 |
-
|
297 |
-
if 'cannot find the answer' in answer:
|
298 |
-
return answer.strip()
|
299 |
-
else:
|
300 |
-
timestamps = filter_timestamps(n_sources, timestamp_interval)
|
301 |
-
answer_with_sources = (answer.strip()
|
302 |
-
+ ' You can find more information at these timestamps: {}.'.format(', '.join(timestamps))
|
303 |
-
)
|
304 |
-
return answer_with_sources
|
305 |
-
|
306 |
-
|
307 |
-
def filter_timestamps(n_sources: int,
|
308 |
-
timestamp_interval: datetime.timedelta=datetime.timedelta(minutes=5)) -> list:
|
309 |
-
'''Returns a list of timestamps with length `n_sources`.
|
310 |
-
The timestamps are at least an `timestamp_interval` apart.
|
311 |
-
This prevents returning a list of timestamps that are too
|
312 |
-
close together.
|
313 |
-
'''
|
314 |
-
sorted_timestamps = sorted(sources)
|
315 |
-
output = [sorted_timestamps[0]]
|
316 |
-
i=1
|
317 |
-
while len(output)<n_sources:
|
318 |
-
timestamp1 = datetime.datetime.strptime(output[-1], '%H:%M:%S')
|
319 |
|
320 |
-
try:
|
321 |
-
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i], '%H:%M:%S')
|
322 |
-
except IndexError:
|
323 |
-
break
|
324 |
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
-
if time_diff>timestamp_interval:
|
328 |
-
output.append(str(timestamp2.time()))
|
329 |
|
330 |
-
|
|
|
|
|
331 |
|
332 |
-
|
333 |
|
|
|
334 |
|
335 |
-
def setup_rag(url):
|
336 |
-
'''Given a YouTube url, set up the vector database and the RAG chain.
|
337 |
-
'''
|
338 |
|
339 |
-
yt_audio_to_text(url)
|
340 |
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
-
global retriever
|
345 |
-
_, retriever = create_vectordb(data, k)
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
-
|
|
|
|
|
|
|
|
|
358 |
|
|
|
359 |
|
|
|
|
|
360 |
|
361 |
-
|
362 |
|
363 |
-
|
364 |
|
365 |
-
ai_msg = rag_chain.invoke({'question': question,
|
366 |
-
'chat_history': chat_history
|
367 |
-
})
|
368 |
|
369 |
-
|
|
|
|
|
370 |
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
-
return answer
|
375 |
|
|
|
|
|
376 |
|
377 |
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
default_youtube_url = 'https://www.youtube.com/watch?v=4Bdc55j80l8'
|
383 |
|
384 |
|
385 |
-
def
|
386 |
-
summary = get_summary()
|
387 |
-
global gradio_chat_history
|
388 |
-
summary_message = f'Here is a summary of the video "{video_title}":'
|
389 |
-
gradio_chat_history.append((None, summary_message))
|
390 |
-
gradio_chat_history.append((None, summary))
|
391 |
-
greeting_message = f'You can ask me anything about the video. I will do my best to answer!'
|
392 |
-
gradio_chat_history.append((None, greeting_message))
|
393 |
-
return gradio_chat_history
|
394 |
|
395 |
-
|
396 |
-
global gradio_chat_history
|
397 |
-
gradio_chat_history.append((user_message, None))
|
398 |
-
return gradio_chat_history
|
399 |
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
chat_history = []
|
410 |
-
gradio_chat_history = []
|
411 |
|
412 |
|
413 |
-
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
-
|
419 |
-
with gr.Row():
|
420 |
-
url_input = gr.Textbox(value=default_youtube_url,
|
421 |
-
label='YouTube URL',
|
422 |
-
scale=5)
|
423 |
-
button = gr.Button(value='Go', scale=1)
|
424 |
|
425 |
-
chatbot = gr.Chatbot()
|
426 |
-
user_message = gr.Textbox(label='Ask a question:')
|
427 |
-
clear = gr.ClearButton([user_message, chatbot])
|
428 |
|
|
|
429 |
|
430 |
-
# Actions
|
431 |
-
button.click(setup_rag,
|
432 |
-
inputs=[url_input],
|
433 |
-
outputs=[url_input],
|
434 |
-
trigger_mode='once').then(greet,
|
435 |
-
inputs=[],
|
436 |
-
outputs=[chatbot])
|
437 |
|
438 |
-
user_message.submit(question,
|
439 |
-
inputs=[user_message],
|
440 |
-
outputs=[chatbot]).then(respond,
|
441 |
-
inputs=[],
|
442 |
-
outputs=[user_message, chatbot])
|
443 |
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
-
|
|
|
1 |
import torch
|
2 |
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain_community.document_loaders import JSONLoader
|
5 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
6 |
+
from langchain_community.vectorstores import Chroma
|
7 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
8 |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
9 |
from langchain_core.messages import AIMessage, HumanMessage
|
10 |
from langchain.chains import LLMChain, StuffDocumentsChain, MapReduceDocumentsChain, ReduceDocumentsChain
|
11 |
+
from langchain.memory.buffer_window import ConversationBufferWindowMemory
|
12 |
+
from langchain_community.llms import HuggingFaceHub
|
13 |
|
14 |
import yt_dlp
|
15 |
import json
|
|
|
17 |
import gradio as gr
|
18 |
from gradio_client import Client
|
19 |
import datetime
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
|
24 |
|
25 |
whisper_jax_api = 'https://sanchit-gandhi-whisper-jax.hf.space/'
|
26 |
whisper_jax = Client(whisper_jax_api)
|
27 |
|
28 |
+
def transcribe_audio(audio_path,
|
29 |
+
task='transcribe',
|
30 |
+
return_timestamps=True) -> str:
|
31 |
+
|
32 |
text, runtime = whisper_jax.predict(
|
33 |
audio_path,
|
34 |
task,
|
|
|
39 |
|
40 |
|
41 |
|
|
|
42 |
|
43 |
+
def format_whisper_jax_output(whisper_jax_output: str,
|
44 |
+
max_duration: int = 60) -> list[dict]:
|
45 |
+
|
46 |
+
"""Returns a list of dict with keys 'start', 'end', 'text'
|
47 |
The segments from whisper jax output are merged to form paragraphs.
|
48 |
|
49 |
`max_duration` controls how many seconds of the audio's transcripts are merged
|
50 |
|
51 |
For example, if `max_duration`=60, in the final output, each segment is roughly
|
52 |
60 seconds.
|
53 |
+
"""
|
54 |
|
55 |
final_output = []
|
56 |
max_duration = datetime.timedelta(seconds=max_duration)
|
|
|
61 |
for i, seg in enumerate(segments):
|
62 |
|
63 |
text = seg.split(']')[-1].strip()
|
|
|
64 |
|
65 |
+
# Sometimes whisper jax returns None for timestamp
|
66 |
+
try:
|
67 |
+
end = datetime.datetime.strptime(seg[14:19], '%M:%S')
|
68 |
+
except ValueError:
|
69 |
+
end = current_start + max_duration
|
70 |
+
|
71 |
+
if (end-current_start >= max_duration) or (i == len(segments)-1):
|
72 |
+
# If we have exceeded max duration or at the last segment,
|
73 |
+
# stop merging and append to final_output.
|
74 |
+
|
75 |
current_text += text
|
76 |
+
final_output.append({
|
77 |
+
'start': current_start.strftime('%H:%M:%S'),
|
78 |
+
'end': end.strftime('%H:%M:%S'),
|
79 |
+
'text': current_text
|
80 |
+
})
|
81 |
|
82 |
# Update current start and text
|
83 |
current_start = end
|
84 |
current_text = ''
|
85 |
|
86 |
else:
|
87 |
+
# If we have not exceeded max duration, keep merging.
|
|
|
88 |
current_text += text
|
89 |
|
90 |
return final_output
|
|
|
92 |
|
93 |
|
94 |
|
95 |
+
|
96 |
audio_file_number = 1
|
97 |
def yt_audio_to_text(url: str,
|
98 |
max_duration: int = 60
|
|
|
105 |
|
106 |
with yt_dlp.YoutubeDL({'extract_audio': True,
|
107 |
'format': 'bestaudio',
|
108 |
+
'outtmpl': f'{audio_file_number}.mp3'
|
109 |
+
}) as video:
|
110 |
|
111 |
info_dict = video.extract_info(url, download=False)
|
112 |
global video_title
|
|
|
128 |
|
129 |
|
130 |
|
131 |
+
|
132 |
def metadata_func(record: dict, metadata: dict) -> dict:
|
133 |
|
134 |
metadata['start'] = record.get('start')
|
135 |
metadata['end'] = record.get('end')
|
136 |
+
metadata['source'] = metadata['start'] + ' -> ' + metadata['end']
|
137 |
|
138 |
return metadata
|
139 |
|
|
|
152 |
|
153 |
|
154 |
|
155 |
+
|
156 |
embedding_model_name = 'sentence-transformers/all-mpnet-base-v2'
|
157 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
158 |
embedding_model_kwargs = {'device': device}
|
|
|
161 |
model_kwargs=embedding_model_kwargs)
|
162 |
|
163 |
def create_vectordb(data, k: int):
|
164 |
+
"""Returns a vector database, and its retriever
|
165 |
`k` is the number of retrieved documents
|
166 |
+
"""
|
167 |
|
168 |
vectordb = Chroma.from_documents(documents=data, embedding=embeddings)
|
169 |
retriever = vectordb.as_retriever(search_type='similarity',
|
|
|
172 |
return vectordb, retriever
|
173 |
|
174 |
|
175 |
+
|
176 |
+
|
177 |
repo_id = 'mistralai/Mistral-7B-Instruct-v0.1'
|
178 |
+
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={'max_new_tokens': 1000})
|
179 |
+
|
180 |
|
181 |
|
182 |
# Map
|
|
|
188 |
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
189 |
|
190 |
|
191 |
+
|
192 |
# Reduce
|
193 |
reduce_template = """The following is a set of summaries:
|
194 |
{docs}
|
195 |
|
196 |
+
Take these and distill it into a final, consolidated summary of the main themes \
|
197 |
+
in 150 words or less.
|
198 |
+
|
199 |
Answer:"""
|
200 |
|
201 |
reduce_prompt = PromptTemplate.from_template(reduce_template)
|
|
|
214 |
# If documents exceed context for `StuffDocumentsChain`
|
215 |
collapse_documents_chain=combine_documents_chain,
|
216 |
# The maximum number of tokens to group documents into.
|
217 |
+
token_max=4000
|
218 |
)
|
219 |
|
220 |
|
|
|
227 |
# The variable name in the llm_chain to put the documents in
|
228 |
document_variable_name="docs",
|
229 |
# Return the results of the map steps in the output
|
230 |
+
return_intermediate_steps=False
|
231 |
)
|
232 |
|
233 |
+
def get_summary(documents) -> str:
|
234 |
+
summary = map_reduce_chain.invoke(documents, return_only_outputs=True)
|
235 |
+
return summary['output_text'].strip()
|
236 |
+
|
237 |
+
|
238 |
|
239 |
|
240 |
contextualise_q_prompt = PromptTemplate.from_template(
|
241 |
+
"""Given a chat history and the latest user question \
|
242 |
which might reference the chat history, formulate a standalone question \
|
243 |
+
that can be understood without the chat history. Do NOT answer the question, \
|
244 |
just reformulate it if needed and otherwise return it as is.
|
245 |
|
246 |
Chat history: {chat_history}
|
|
|
248 |
Question: {question}
|
249 |
|
250 |
Answer:
|
251 |
+
"""
|
252 |
)
|
253 |
|
254 |
contextualise_q_chain = contextualise_q_prompt | llm
|
255 |
|
256 |
+
|
257 |
+
|
258 |
standalone_prompt = PromptTemplate.from_template(
|
259 |
+
"""Given a chat history and the latest user question, \
|
260 |
identify whether the question is a standalone question or the question \
|
261 |
references the chat history. Answer 'yes' if the question is a standalone \
|
262 |
question, and 'no' if the question references the chat history. Do not \
|
|
|
269 |
{question}
|
270 |
|
271 |
Answer:
|
272 |
+
"""
|
273 |
)
|
274 |
|
275 |
def format_output(answer: str) -> str:
|
|
|
279 |
standalone_chain = standalone_prompt | llm | format_output
|
280 |
|
281 |
|
282 |
+
|
283 |
+
|
284 |
qa_prompt = PromptTemplate.from_template(
|
285 |
+
"""You are an assistant for question-answering tasks. \
|
286 |
ONLY use the following context to answer the question. \
|
287 |
Do NOT answer with information that is not contained in \
|
288 |
the context. If you don't know the answer, just say:\
|
|
|
295 |
{question}
|
296 |
|
297 |
Answer:
|
298 |
+
"""
|
299 |
)
|
300 |
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
class YouTubeChatbot:
|
305 |
+
|
306 |
+
def __init__(self,
|
307 |
+
n_sources: int,
|
308 |
+
k: int,
|
309 |
+
timestamp_interval: datetime.timedelta,
|
310 |
+
memory: int,
|
311 |
+
):
|
312 |
+
self.n_sources = n_sources
|
313 |
+
self.k = k
|
314 |
+
self.timestamp_interval = timestamp_interval
|
315 |
+
self.chat_history = ConversationBufferWindowMemory(k=memory)
|
316 |
|
|
|
|
|
317 |
|
318 |
+
def format_docs(self, docs: list) -> str:
|
319 |
+
"""Combine documents
|
320 |
+
"""
|
321 |
|
322 |
+
self.sources = [doc.metadata['start'] for doc in docs]
|
323 |
|
324 |
+
return '\n\n'.join(doc.page_content for doc in docs)
|
325 |
|
|
|
|
|
|
|
326 |
|
|
|
327 |
|
328 |
+
def standalone_question(self, input_: dict) -> str:
|
329 |
+
"""If the question is a not a standalone question,
|
330 |
+
run contextualise_q_chain.
|
331 |
+
"""
|
332 |
+
if input_['standalone']=='yes':
|
333 |
+
return contextualise_q_chain
|
334 |
+
else:
|
335 |
+
return input_['question']
|
336 |
|
|
|
|
|
337 |
|
338 |
+
def format_answer(self, answer: str) -> str:
|
339 |
+
|
340 |
+
if 'cannot find the answer' in answer:
|
341 |
+
return answer.strip()
|
342 |
+
else:
|
343 |
+
timestamps = self.filter_timestamps()
|
344 |
+
answer_with_sources = (
|
345 |
+
answer.strip()
|
346 |
+
+ ' You can find more information '\
|
347 |
+
'at these timestamps: {}.'.format(', '.join(timestamps))
|
348 |
+
)
|
349 |
+
return answer_with_sources
|
350 |
+
|
351 |
+
|
352 |
+
def filter_timestamps(self) -> list[str]:
|
353 |
+
"""Returns a list of timestamps with length `n_sources`.
|
354 |
+
The timestamps are at least an `timestamp_interval` apart.
|
355 |
+
This prevents returning a list of timestamps that are too
|
356 |
+
close together.
|
357 |
+
"""
|
358 |
+
|
359 |
+
sorted_timestamps = sorted(self.sources)
|
360 |
+
filtered_timestamps = [sorted_timestamps[0]]
|
361 |
+
i=1
|
362 |
+
while len(filtered_timestamps) < self.n_sources:
|
363 |
+
timestamp1 = datetime.datetime.strptime(filtered_timestamps[-1],
|
364 |
+
'%H:%M:%S')
|
365 |
|
366 |
+
try:
|
367 |
+
timestamp2 = datetime.datetime.strptime(sorted_timestamps[i],
|
368 |
+
'%H:%M:%S')
|
369 |
+
except IndexError:
|
370 |
+
break
|
371 |
|
372 |
+
time_diff = timestamp2 - timestamp1
|
373 |
|
374 |
+
if time_diff>=self.timestamp_interval:
|
375 |
+
filtered_timestamps.append(str(timestamp2.time()))
|
376 |
|
377 |
+
i += 1
|
378 |
|
379 |
+
return filtered_timestamps
|
380 |
|
|
|
|
|
|
|
381 |
|
382 |
+
def setup_chatbot(self, url: str) -> str:
|
383 |
+
"""Given a YouTube url, set up the chatbot.
|
384 |
+
"""
|
385 |
|
386 |
+
yt_audio_to_text(url)
|
387 |
+
|
388 |
+
self.data = load_data()
|
389 |
+
|
390 |
+
_, self.retriever = create_vectordb(self.data, self.k)
|
391 |
+
|
392 |
+
|
393 |
+
self.qa_chain = (
|
394 |
+
RunnablePassthrough.assign(standalone=standalone_chain)
|
395 |
+
| {'question':self.standalone_question,
|
396 |
+
'context':self.standalone_question|self.retriever|self.format_docs}
|
397 |
+
| qa_prompt
|
398 |
+
| llm)
|
399 |
+
|
400 |
+
return url
|
401 |
+
|
402 |
+
|
403 |
+
|
404 |
+
def get_answer(self, question: str) -> str:
|
405 |
+
|
406 |
+
try:
|
407 |
+
ai_msg = self.qa_chain.invoke({'question': question,
|
408 |
+
'chat_history': self.chat_history})
|
409 |
+
except AttributeError:
|
410 |
+
raise AttributeError("You haven't setup the chatbot yet. "
|
411 |
+
"Setup the chatbot by calling the "
|
412 |
+
"instance method `setup_chatbot`.")
|
413 |
+
|
414 |
+
answer = self.format_answer(ai_msg)
|
415 |
+
|
416 |
+
self.chat_history.save_context({'question':question},
|
417 |
+
{'answer':answer})
|
418 |
+
|
419 |
+
return answer
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
|
425 |
+
|
426 |
+
|
427 |
+
class YouTubeChatbotApp(YouTubeChatbot):
|
428 |
+
|
429 |
+
def __init__(self,
|
430 |
+
n_sources: int,
|
431 |
+
k: int,
|
432 |
+
timestamp_interval: datetime.timedelta,
|
433 |
+
memory: int,
|
434 |
+
default_youtube_url: str
|
435 |
+
):
|
436 |
+
super().__init__(n_sources, k, timestamp_interval, memory)
|
437 |
+
self.default_youtube_url = default_youtube_url
|
438 |
+
self.gradio_chat_history = []
|
439 |
+
|
440 |
+
|
441 |
+
def greet(self) -> list[tuple[str|None, str|None]]:
|
442 |
+
summary = get_summary(self.data)
|
443 |
+
summary_message = f'Here is a summary of the video "{video_title}":'
|
444 |
+
self.gradio_chat_history.append((None, summary_message))
|
445 |
+
self.gradio_chat_history.append((None, summary))
|
446 |
+
greeting_message = ('You can ask me anything about the video.'
|
447 |
+
'I will do my best to answer!')
|
448 |
+
self.gradio_chat_history.append((None, greeting_message))
|
449 |
+
return self.gradio_chat_history
|
450 |
+
|
451 |
+
|
452 |
+
def question(self, user_message: str) -> list[tuple[str|None, str|None]]:
|
453 |
+
self.gradio_chat_history.append((user_message, None))
|
454 |
+
return '', self.gradio_chat_history
|
455 |
+
|
456 |
+
|
457 |
+
def respond(self) -> tuple[str, list[tuple[str|None, str|None]]]:
|
458 |
+
try:
|
459 |
+
ai_message = self.get_answer(self.gradio_chat_history[-1][0])
|
460 |
+
except AttributeError:
|
461 |
+
raise gr.Error('You need to process the video '
|
462 |
+
'first by pressing the `Go` button.')
|
463 |
|
|
|
464 |
|
465 |
+
self.gradio_chat_history.append((None, ai_message))
|
466 |
+
return self.gradio_chat_history
|
467 |
|
468 |
|
469 |
+
def clear_chat_history(self) -> list:
|
470 |
+
self.chat_history.clear()
|
471 |
+
self.gradio_chat_history = []
|
472 |
+
return self.gradio_chat_history
|
|
|
473 |
|
474 |
|
475 |
+
def launch(self, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
478 |
|
479 |
+
# Structure
|
480 |
+
with gr.Row():
|
481 |
+
url_input = gr.Textbox(value=self.default_youtube_url,
|
482 |
+
label='YouTube URL',
|
483 |
+
scale=5)
|
484 |
+
button = gr.Button(value='Go', scale=1)
|
485 |
|
486 |
+
chatbot = gr.Chatbot()
|
487 |
+
user_message = gr.Textbox(label='Ask a question:')
|
488 |
+
clear = gr.ClearButton([user_message, chatbot])
|
|
|
|
|
489 |
|
490 |
|
491 |
+
# Actions
|
492 |
+
button.click(self.clear_chat_history,
|
493 |
+
inputs=[],
|
494 |
+
outputs=[chatbot],
|
495 |
+
trigger_mode='once'
|
496 |
+
).then(self.setup_chatbot,
|
497 |
+
inputs=[url_input],
|
498 |
+
outputs=[url_input]
|
499 |
+
).then(self.greet,
|
500 |
+
inputs=[],
|
501 |
+
outputs=[chatbot])
|
502 |
|
503 |
+
user_message.submit(self.question,
|
504 |
+
inputs=[user_message],
|
505 |
+
outputs=[user_message, chatbot]
|
506 |
+
).then(self.respond,
|
507 |
+
inputs=[],
|
508 |
+
outputs=[chatbot])
|
509 |
|
510 |
+
clear.click(self.clear_chat_history, inputs=[], outputs=[chatbot])
|
|
|
|
|
|
|
|
|
|
|
511 |
|
|
|
|
|
|
|
512 |
|
513 |
+
demo.launch(**kwargs)
|
514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
+
app = YouTubeChatbotApp(n_sources=3,
|
518 |
+
k=5,
|
519 |
+
timestamp_interval=datetime.timedelta(minutes=2),
|
520 |
+
memory=5,
|
521 |
+
default_youtube_url='https://www.youtube.com/watch?v=4Bdc55j80l8'
|
522 |
+
)
|
523 |
|
524 |
+
app.launch(debug=True)
|