Spaces:
Sleeping
Sleeping
import streamlit as st | |
from tempfile import NamedTemporaryFile | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain import PromptTemplate, LLMChain | |
from langchain.llms import HuggingFaceHub | |
import os | |
# Function to save the uploaded PDF to a temporary file | |
def save_uploaded_file(uploaded_file): | |
with NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
temp_file.write(uploaded_file.read()) | |
return temp_file.name | |
# Initialize the model and other resources outside the main function | |
def initialize_model(): | |
# Initialize the HuggingFaceHub with the appropriate task | |
llm = HuggingFaceHub( | |
repo_id="Manoj21k/GPT4ALL", | |
model_kwargs={"temperature": 1e-10} | |
) | |
return llm | |
# Streamlit UI | |
st.title("PDF Question Answering App") | |
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
if uploaded_file is not None: | |
# Save the uploaded file to a temporary location | |
temp_file_path = save_uploaded_file(uploaded_file) | |
# Load the PDF document using PyPDFLoader | |
loader = PyPDFLoader(temp_file_path) | |
pages = loader.load_and_split() | |
# Initialize embeddings and Chroma | |
embed = HuggingFaceEmbeddings() | |
db = Chroma.from_documents(pages, embed) | |
# Load the model using the cached function | |
llm = initialize_model() | |
# Define a function to get answers | |
def get_answer(question): | |
doc = db.similarity_search(question, k=4) | |
context = doc[0].page_content + doc[1].page_content + doc[2].page_content + doc[3].page_content | |
max_seq_length = 512 # You may define this based on your model | |
context = context[:max_seq_length] | |
# Prompt template | |
template = """Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} and | |
Question: {question} | |
Answer: """ | |
prompt = PromptTemplate(template=template, input_variables=["context", "question"]).partial(context=context) | |
llm_chain = LLMChain(prompt=prompt, llm=llm) | |
output = llm_chain.run(context=context, question=question, max_length=512) | |
answer_index = output.find("Answer:") | |
next_line_index = output.find("\n", answer_index) | |
answer_content = output[answer_index + len("Answer:")-1:next_line_index] | |
return answer_content | |
question = st.text_input("Enter your question:") | |
if st.button("Get Answer"): | |
answer = get_answer(question) | |
st.write("Answer:") | |
st.write(answer) | |
# Cleanup: Delete the temporary file | |
os.remove(temp_file_path) | |