|
import streamlit as st |
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
from langchain import PromptTemplate |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.llms import CTransformers |
|
from langchain.chains import RetrievalQA |
|
import chainlit as cl |
|
|
|
DB_FAISS_PATH = 'vectorstore/db_faiss' |
|
|
|
custom_prompt_template = """Use the following pieces of information to answer the user's question. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
Context: {context} |
|
Question: {question} |
|
Only return the helpful answer below and nothing else. |
|
Helpful answer: |
|
""" |
|
|
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for QA retrieval for each vectorstore |
|
""" |
|
prompt = PromptTemplate(template=custom_prompt_template, |
|
input_variables=['context', 'question']) |
|
return prompt |
|
|
|
|
|
def retrieval_qa_chain(llm, prompt, db): |
|
qa_chain = RetrievalQA.from_chain_type(llm=llm, |
|
chain_type='stuff', |
|
retriever=db.as_retriever(search_kwargs={'k': 2}), |
|
return_source_documents=True, |
|
chain_type_kwargs={'prompt': prompt} |
|
) |
|
return qa_chain |
|
|
|
|
|
def load_llm(max_new_tokens, temperature): |
|
|
|
llm = CTransformers( |
|
model="llama-2-7b-chat.ggmlv3.q8_0.bin", |
|
model_type="llama", |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature |
|
) |
|
return llm |
|
|
|
|
|
def qa_bot(max_new_tokens, temperature): |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
model_kwargs={'device': 'cpu'}) |
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings) |
|
llm = load_llm(max_new_tokens, temperature) |
|
qa_prompt = set_custom_prompt() |
|
qa = retrieval_qa_chain(llm, qa_prompt, db) |
|
|
|
return qa |
|
|
|
def main(): |
|
st.title("Pinaki's LLM") |
|
|
|
max_new_tokens = st.slider("Max New Tokens", min_value=1, max_value=1000, value=512) |
|
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, step=0.1, value=0.5) |
|
|
|
qa_result = qa_bot(max_new_tokens, temperature) |
|
|
|
user_input = st.text_input("Enter your question:") |
|
|
|
if st.button("Ask"): |
|
response = qa_result({'query': user_input}) |
|
answer = response["result"] |
|
sources = response["source_documents"] |
|
|
|
st.write("Answer:", answer) |
|
if sources: |
|
st.write("Sources:", sources) |
|
else: |
|
st.write("No sources found") |
|
|
|
if st.button("Clear"): |
|
st.text_input("Enter your question:", value="") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|