Custom-QandA / app.py
Manoj21k's picture
Update app.py
bed496d
raw
history blame
2.82 kB
import streamlit as st
from tempfile import NamedTemporaryFile
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain import PromptTemplate, LLMChain
from langchain.llms import HuggingFaceHub
import os
# Function to save the uploaded PDF to a temporary file
def save_uploaded_file(uploaded_file):
with NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
temp_file.write(uploaded_file.read())
return temp_file.name
# Initialize the model and other resources outside the main function
@st.cache(allow_output_mutation=True)
def initialize_model():
# Initialize the HuggingFaceHub with the appropriate task
llm = HuggingFaceHub(
repo_id="Manoj21k/GPT4ALL",
model_kwargs={"temperature": 1e-10}
)
return llm
# Streamlit UI
st.title("PDF Question Answering App")
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
if uploaded_file is not None:
# Save the uploaded file to a temporary location
temp_file_path = save_uploaded_file(uploaded_file)
# Load the PDF document using PyPDFLoader
loader = PyPDFLoader(temp_file_path)
pages = loader.load_and_split()
# Initialize embeddings and Chroma
embed = HuggingFaceEmbeddings()
db = Chroma.from_documents(pages, embed)
# Load the model using the cached function
llm = initialize_model()
# Define a function to get answers
def get_answer(question):
doc = db.similarity_search(question, k=4)
context = doc[0].page_content + doc[1].page_content + doc[2].page_content + doc[3].page_content
max_seq_length = 512 # You may define this based on your model
context = context[:max_seq_length]
# Prompt template
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context} and
Question: {question}
Answer: """
prompt = PromptTemplate(template=template, input_variables=["context", "question"]).partial(context=context)
llm_chain = LLMChain(prompt=prompt, llm=llm)
output = llm_chain.run(context=context, question=question, max_length=512)
answer_index = output.find("Answer:")
next_line_index = output.find("\n", answer_index)
answer_content = output[answer_index + len("Answer:")-1:next_line_index]
return answer_content
question = st.text_input("Enter your question:")
if st.button("Get Answer"):
answer = get_answer(question)
st.write("Answer:")
st.write(answer)
# Cleanup: Delete the temporary file
os.remove(temp_file_path)