video_rag / app.py
trashchenkov's picture
Update app.py
bd0ee22 verified
raw
history blame
3.68 kB
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import os
#import telebot
def get_yt_links(contexts):
html = '''
<iframe width="100%" height="200" src="{}?start={}" \
title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; \
encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" \
allowfullscreen></iframe>
'''
yt_htmls = []
for context in contexts:
link = context.metadata['link']
start = context.metadata['time']
yt_htmls.append(html.format(link, start))
return yt_htmls
'''def resp2msg(resp):
req = resp['input']
ans = resp['answer']
return req + '\n' + ans'''
def get_context(contexts):
txt_context = '''
Фрагмент 1: {}
Фрагмент 2: {}
Фрагмент 3: {}
'''.format(
contexts[0].page_content,
contexts[1].page_content,
contexts[2].page_content,
)
return txt_context
def process_input(text):
response = retrieval_chain.invoke({"input": text})
#bot.send_message(user_id, resp2msg(response))
youtube_links = get_yt_links(response['context'])
context = get_context(response['context'])
return response['answer'], context, youtube_links[0], youtube_links[1], youtube_links[2]
giga = os.getenv('GIGA')
#token = os.getenv('BOT')
#user_id = os.getenv('CREATOR')
#bot = telebot.TeleBot(token)
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embedding = HuggingFaceEmbeddings(model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
vector_db = FAISS.load_local('faiss_index',
embeddings=embedding,
allow_dangerous_deserialization=True)
llm = GigaChat(credentials=giga, verify_ssl_certs=False, profanity_check=False)
prompt = ChatPromptTemplate.from_template('''Ответь на вопрос пользователя. \
Используй при этом только информацию из контекста. Если в контексте нет \
информации для ответа, сообщи об этом пользователю.
Контекст: {context}
Вопрос: {input}
Ответ:'''
)
embedding_retriever = vector_db.as_retriever(search_kwargs={"k": 3})
document_chain = create_stuff_documents_chain(
llm=llm,
prompt=prompt
)
retrieval_chain = create_retrieval_chain(embedding_retriever, document_chain)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Введите запрос")
submit_btn = gr.Button("Отправить запрос")
text_output = gr.Textbox(label="Ответ", interactive=False)
text_context = gr.Textbox(label="Контекст", interactive=False)
with gr.Column():
youtube_video1 = gr.HTML()
youtube_video2 = gr.HTML()
youtube_video3 = gr.HTML()
submit_btn.click(process_input, text_input, [text_output, text_context, youtube_video1, youtube_video2, youtube_video3])
demo.launch()