File size: 3,816 Bytes
710db2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367d27b
f58bc94
e511c40
710db2e
 
 
 
e511c40
 
 
 
 
 
 
 
8681a91
 
367d27b
710db2e
e511c40
710db2e
 
 
 
13a85d2
 
710db2e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from huggingface_hub import InferenceClient
import gradio as gr

HF_token = os.getenv("apiToken")

embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("faiss_index", embeddings)
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":2})

chat_client = InferenceClient(
    model="HuggingFaceH4/zephyr-7b-alpha",
    token=HF_token
)

reform_client = InferenceClient(
    model="mistralai/Mistral-7B-Instruct-v0.1",
    token=HF_token
)

def format_prompt(message, history):
    docs = retriever.get_relevant_documents(message)
    context = "\n".join([doc.page_content for doc in docs])

    system = "<|system|>\nYou are a helpful virtual assistant for People's Insurance PLC that answer user's questions using website content.</s>\n"
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += f"<|user|>\n{user_prompt}</s>\n"
        prompt += f"<|assistant|>\n{bot_response}</s>\n"
    prompt += f"{system}\n website content:{context}\n<|user|>\n {message}</s>\n<|assistant|>\n"
    return prompt

def query_reform(message, history):
    previous_user_input = ""
    previous_ai_response = ""

    for i in range(len(history) - 1, -1, -1):
        if previous_user_input == "" and history[i][0] != "":
            previous_user_input = history[i][0]
        if previous_ai_response == "" and history[i][1] != "":
            previous_ai_response = history[i][1]

        if previous_user_input != "" and previous_ai_response != "":
            break

    print("Original Question:", message)
    print("Last interaction:\nuser:", previous_user_input, "\nAI:", previous_ai_response)
    new_prompt = f"Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. Chat History:\nUser: {previous_user_input}\nAI: {previous_ai_response}\nFollow-up User message: {message}\nRewritten User message:"

    return new_prompt

def reformulate_query(question, history):    
    reformulated_query = reform_client.text_generation(
        query_reform(question, history),
        temperature=0.1,
        max_new_tokens=50,
        top_p=0.9,
        repetition_penalty=1.0
    )

    print("Reformulated Query:", reformulated_query)

    return reformulated_query

def generate(
    prompt, history, temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    reformed_query = reformulate_query(prompt, history)

    docs = retriever.get_relevant_documents(reformed_query)

    sources = "\nSources:\n"
    for doc in docs:
        if 'source' in doc.metadata:
            sources += doc.metadata['source'] + '\n'

    formatted_prompt = format_prompt(reformed_query, history)

    output = ''

    stream = chat_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)

    for response in stream:
        output += response.token.text
        yield output

    output += sources
    yield output 
    return output

demo = gr.ChatInterface(generate, 
                        title="People's Insurance PLC chatbot", 
                        theme="Monochrome",
                        examples=[["What is the contact number of Peoples insurance"], ["Whata re the availbale plans"]],)

demo.queue().launch(debug=True)