File size: 3,653 Bytes
31f4c0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from PyPDF2 import PdfReader
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from langchain.docstore.document import Document
from langchain_groq import ChatGroq

# Function to extract text from a PDF
def get_pdf_text(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page in reader.pages:
        content = page.extract_text()
        if content:
            text += content
    return text

# Initialize LLM (Large Language Model)
llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)

# Function to initialize chatbot with or without a PDF
def initialize_chatbot(pdf_path=None):
    embed = HuggingFaceHubEmbeddings(model="sentence-transformers/all-MiniLM-L6-v2")
    if pdf_path:
        texts = get_pdf_text(pdf_path)
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=200,
            chunk_overlap=20,
            length_function=len,
            separators=["\n"]
        )
        text_documents = splitter.create_documents([texts])
        db = FAISS.from_documents(text_documents, embed)
    else:
        # Initialize an empty FAISS index with a dummy document
        dummy_docs = [Document(page_content="dummy")]
        db = FAISS.from_documents(dummy_docs, embed)
        text_documents = []

    template = [
        ("system", "You are a helpful Assistant. If you don't get the answer from the pdf, then answer from your knowledge base."),
        ("human", "{question}\n\nContext:\n{context}")
    ]

    prompt = ChatPromptTemplate.from_messages(template)

    chatbot = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=db.as_retriever(),
        combine_docs_chain_kwargs={"prompt": prompt, "document_variable_name": "context"},
        verbose=False
    )
    return chatbot, db, text_documents

# Define the main function for the Streamlit app
def main():
    st.title("Chat with Your Assistant")

    # File upload for PDF
    uploaded_file = st.file_uploader("Upload PDF File", type=['pdf'])

    if uploaded_file is not None:
        pdf_path = "uploaded_pdf.pdf"
        with open(pdf_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        
        st.success("PDF uploaded successfully! Now ask a question.")

        # Initialize chatbot with uploaded PDF
        chatbot, db, text_documents = initialize_chatbot(pdf_path)
    else:
        st.info("Please upload a PDF file to start.")

        # Initialize chatbot without PDF
        chatbot, db, text_documents = initialize_chatbot()

    # Chat interface
    chat_history = []

    # User input for questions
    user_input = st.text_input("You:", "")

    if st.button("Send"):
        if user_input.strip() == "":
            st.warning("Please enter a question.")
        else:
            with st.spinner('Thinking...'):
                response = query(user_input, chat_history, chatbot)
                st.text_area("Assistant:", value=response, height=200, max_chars=None)

# Function to handle chatbot query
def query(question, chat_history, chatbot):
    response = chatbot({
        "question": question,
        "chat_history": chat_history
    })
    
    chat_history.append((f'Question: {question}', f'Answer: {response["answer"]}'))
    return response["answer"]

# Run the app
if __name__ == "__main__":
    main()