File size: 3,113 Bytes
4e3ae23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f1620e
4e3ae23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.chat_models.gigachat import GigaChat
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import os
import telebot


def get_yt_links(contexts):
    html = '''
    <iframe width="100%" height="200" src="{}?start={}" \
    title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; \
    encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" \
    allowfullscreen></iframe>
    '''
    yt_htmls = []
    for context in contexts:
        link = context.metadata['link']
        start = context.metadata['time']
        yt_htmls.append(html.format(link, start))
    return yt_htmls


def process_input(text):
    response = retrieval_chain.invoke({"input": text})
    bot.send_message(int(user_id), str(response))
    youtube_links = get_yt_links(response['context'])
    return response['answer'], youtube_links[0], youtube_links[1], youtube_links[2]

giga = os.getenv('GIGA')
token = os.getenv('BOT')
user_id = os.getenv('CREATOR')
bot = telebot.TeleBot(token)
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_kwargs = {'device': 'gpu'}
encode_kwargs = {'normalize_embeddings': False}
embedding = HuggingFaceEmbeddings(model_name=model_name,
                                  model_kwargs=model_kwargs,
                                  encode_kwargs=encode_kwargs)

vector_db = FAISS.load_local('faiss_index',
                            embeddings=embedding,
                            allow_dangerous_deserialization=True)
llm = GigaChat(credentials=giga, verify_ssl_certs=False, profanity_check=False)

prompt = ChatPromptTemplate.from_template('''Ответь на вопрос пользователя. \
Используй при этом только информацию из контекста. Если в контексте нет \
информации для ответа, сообщи об этом пользователю.
Контекст: {context}
Вопрос: {input}
Ответ:'''
)

embedding_retriever = vector_store.as_retriever(search_kwargs={"k": 3})

document_chain = create_stuff_documents_chain(
    llm=llm,
    prompt=prompt
    )

retrieval_chain = create_retrieval_chain(embedding_retriever, document_chain)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Введите запрос")
            submit_btn = gr.Button("Отправить запрос")
            text_output = gr.Textbox(label="Ответ", interactive=False)
        
        with gr.Column():
            youtube_video1 = gr.HTML()
            youtube_video2 = gr.HTML()
            youtube_video3 = gr.HTML()
    
    submit_btn.click(process_input, text_input, [text_output, youtube_video1, youtube_video2, youtube_video3])


    demo.launch()