File size: 3,649 Bytes
033ca0b
 
 
 
 
 
5057f9b
033ca0b
 
 
 
bb8e417
033ca0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8f4754
 
 
033ca0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b9777
033ca0b
7f62156
033ca0b
 
 
 
 
c8f4754
033ca0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82d4
033ca0b
d17e699
 
3adc1d4
7f62156
 
752b8f7
d17e699
3adc1d4
1e0e5ba
033ca0b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from ast import List
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import dotenv
from langchain.prompts import PromptTemplate
import gradio as gr
from langchain import PromptTemplate, LLMChain
import requests
from fastembed.embedding import FlagEmbedding as Embedding
import numpy as np
import os
from langchain.schema.messages import HumanMessage


dotenv.load_dotenv()

api_token = os.environ.get("API_TOKEN")
API_URL = "https://vpb8x4glbmizmiya.eu-west-1.aws.endpoints.huggingface.cloud"
headers = {
    "Authorization": f"Bearer {api_token}",
    "Content-Type": "application/json",
}


def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


def get_top_k(query_embedding, embeddings, documents, k=3):
    # use numpy to calculate the cosine similarity between the query and the documents
    scores = np.dot(embeddings, query_embedding)
    # sort the scores in descending order
    sorted_scores = np.argsort(scores)[::-1]
    # print the top 5
    result = []
    for i in range(k):
        print(f"Rank {i+1}: {documents[sorted_scores[i]]}", "\n")
        result.append(documents[sorted_scores[i]])

    return result


prompt_template = """
You are the helpful assistant representing the company Philip Morris.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use the following pieces of context to answer the question at the end.
Think step by step in your answer.
Only answer the given question.

Context:
{context}

Question: {question}
Answer:
"""


PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

loader = DirectoryLoader("./documents", glob="**/*.txt", show_progress=True)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
texts = text_splitter.split_documents(docs)

embedding_model = Embedding(model_name="BAAI/bge-base-en", max_length=512)
embeddings = list(embedding_model.embed([text.page_content for text in texts]))


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(height=800)
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])

    def respond(message, chat_history):
        message_embedding = list(embedding_model.embed([message]))[0]
        result_docs = get_top_k(message_embedding, embeddings, texts, k=2)

        human_message = HumanMessage(
            content=PROMPT.format(context=result_docs, question=message)
        )

        print("Question: ", human_message)
        output = query(
            {
                "inputs": human_message.content,
                "parameters": {
                    "temperature": 0.9,
                    "top_p": 0.95,
                    "repetition_penalty": 1.2,
                    "top_k": 50,
                    "truncate": 1000,
                    "max_new_tokens": 1024,
                },
            }
        )
        print("Response: ", output, "\n")
        bot_message = ""
        
        if output[0]["generated_text"]:
            bot_message = output[0]["generated_text"]
            bot_message += "\n \n"
            bot_message += "Document sources"
            bot_message += "\n \n"

        
            for i, doc in enumerate(result_docs):
                bot_message += f"⚫️ Source {i+1}: {doc.page_content}\n Document link: N/A Page: N/A \n"
 
        chat_history.append((message, bot_message))
        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])


if __name__ == "__main__":
    demo.launch()