File size: 2,646 Bytes
2371911 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import ctransformers
from langchain.chains import RetrievalQA
import chainlit as cl
DB_FAISS_PATH = "vectorstores/db_faiss"
input_variables = ["Context", "question"]
custom_prompt_template = """Use the following pieces of information to answer the user's question.
Context = {Context}
Question = {question}
only returns the helpful answer below and nothing else
Helpful answer
"""
def set_custon_prompt():
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=input_variables,
validate_variable_names=False)
return prompt
def load_llm():
llm = ctransformers.CTransformers(
# old model = llama-2-7b-chat.ggmlv3.q8_0.bin
model="llama-2-7b-chat.ggmlv3.q4_0.bin",
model_type='llama',
max_new_tokens=512,
temperature=0.5
)
return llm
def retrival_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt, 'document_variable_name': 'Context'}
)
return qa_chain
def qa_bot():
embeddings = HuggingFaceBgeEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custon_prompt()
qa = retrival_qa_chain(llm, qa_prompt, db)
return qa
def final_result(query):
qa_result = qa_bot()
response = qa_result({'query': query})
return response
### Chain LIT ###
@cl.on_chat_start
async def start():
chain = qa_bot()
cl.user_session.set("chain", chain)
msg = cl.Message(content="Starting the bot.....")
await msg.send()
msg.content = "Hi, What is your query?"
await msg.update()
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb])
answer = res['result']
sources = res['source_documents']
if sources:
answer += f"\nSources: " + str(sources)
else:
answer += f"\nSources: No sources found"
await cl.Message(content=answer).send()
|