Spaces:
Runtime error
Runtime error
import streamlit as st | |
from langchain.document_loaders import PyPDFLoader, DirectoryLoader | |
from langchain import PromptTemplate | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.llms import CTransformers | |
from langchain.chains import RetrievalQA | |
import chainlit as cl | |
DB_FAISS_PATH = 'vectorstore/db_faiss' | |
custom_prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
Only return the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
def set_custom_prompt(): | |
""" | |
Prompt template for QA retrieval for each vectorstore | |
""" | |
prompt = PromptTemplate(template=custom_prompt_template, | |
input_variables=['context', 'question']) | |
return prompt | |
# Retrieval QA Chain | |
def retrieval_qa_chain(llm, prompt, db): | |
qa_chain = RetrievalQA.from_chain_type(llm=llm, | |
chain_type='stuff', | |
retriever=db.as_retriever(search_kwargs={'k': 2}), | |
return_source_documents=True, | |
chain_type_kwargs={'prompt': prompt} | |
) | |
return qa_chain | |
# Loading the model | |
def load_llm(): | |
# Load the locally downloaded model here | |
llm = CTransformers( | |
model="llama-2-7b-chat.ggmlv3.q8_0.bin", | |
model_type="llama", | |
max_new_tokens=512, | |
temperature=0.5 | |
) | |
return llm | |
# QA Model Function | |
def qa_bot(): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}) | |
db = FAISS.load_local(DB_FAISS_PATH, embeddings) | |
llm = load_llm() | |
qa_prompt = set_custom_prompt() | |
qa = retrieval_qa_chain(llm, qa_prompt, db) | |
return qa | |
def main(): | |
st.title("AI ChatBot LLM") | |
qa_result = qa_bot() | |
user_input = st.text_input("Enter your question:") | |
if st.button("Ask"): | |
response = qa_result({'query': user_input}) | |
answer = response["result"] | |
sources = response["source_documents"] | |
st.write("Answer:", answer) | |
if sources: | |
st.write("Sources:", sources) | |
else: | |
st.write("No sources found") | |
if __name__ == "__main__": | |
main() |