|
from langchain_core.prompts import PromptTemplate |
|
import os |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.llms.ctransformers import CTransformers |
|
|
|
from langchain.chains.retrieval_qa.base import RetrievalQA |
|
import streamlit as st |
|
|
|
DB_FAISS_PATH = 'vectorstores/' |
|
|
|
custom_prompt_template = '''use the following pieces of information to answer the user's questions. |
|
If you don't know the answer, please just say that don't know the answer, don't try to make uo an answer. |
|
Context : {context} |
|
Question : {question} |
|
only return the helpful answer below and nothing else. |
|
''' |
|
|
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for QA retrieval for vector stores |
|
""" |
|
prompt = PromptTemplate(template = custom_prompt_template, |
|
input_variables = ['context','question']) |
|
|
|
return prompt |
|
|
|
|
|
def load_llm(): |
|
llm = CTransformers( |
|
|
|
|
|
model = 'MaziyarPanahi/BioMistral-7B-GGUF' |
|
model_type = 'llama', |
|
max_new_token = 512, |
|
temperature = 0.5 |
|
) |
|
return llm |
|
|
|
def retrieval_qa_chain(llm,prompt,db): |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm = llm, |
|
chain_type = 'stuff', |
|
retriever = db.as_retriever(search_kwargs= {'k': 2}), |
|
return_source_documents = True, |
|
chain_type_kwargs = {'prompt': prompt} |
|
) |
|
|
|
return qa_chain |
|
|
|
def qa_bot(): |
|
embeddings = HuggingFaceBgeEmbeddings(model_name = 'NeuML/pubmedbert-base-embeddings', |
|
model_kwargs = {'device':'cpu'}) |
|
|
|
|
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings,allow_dangerous_deserialization=True) |
|
llm = load_llm() |
|
qa_prompt = set_custom_prompt() |
|
qa = retrieval_qa_chain(llm,qa_prompt, db) |
|
|
|
return qa |
|
|
|
def final_result(query): |
|
qa_result = qa_bot() |
|
response = qa_result({'query' : query}) |
|
|
|
return response |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
bot = qa_bot() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title('Medical Chatbot') |
|
|
|
|
|
user_query = st.text_input("Please enter your question:") |
|
|
|
|
|
if st.button('Get Answer'): |
|
if user_query: |
|
|
|
response = final_result(user_query) |
|
if response: |
|
|
|
st.write("### Answer") |
|
st.write(response['result']) |
|
|
|
|
|
if 'source_documents' in response: |
|
st.write("### Source Document Information") |
|
for doc in response['source_documents']: |
|
|
|
formatted_content = doc.page_content.replace("\\n", "\n") |
|
st.write("#### Document Content") |
|
st.text_area(label="Page Content", value=formatted_content, height=300) |
|
|
|
|
|
source = doc.metadata['source'] |
|
page = doc.metadata['page'] |
|
st.write(f"Source: {source}") |
|
st.write(f"Page Number: {page}") |
|
|
|
else: |
|
st.write("Sorry, I couldn't find an answer to your question.") |
|
else: |
|
st.write("Please enter a question to get an answer.") |