File size: 4,628 Bytes
1869fec
2f49f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings 
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
from pathlib import Path
import chromadb
from unidecode import unidecode
from transformers import AutoTokenizer
import transformers
import torch
import tqdm 
import accelerate
import re

# Function to load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, 
        chunk_overlap=chunk_overlap
    )
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# Function to create vector database
def create_db(splits, collection_name):
    embedding = HuggingFaceEmbeddings()
    new_client = chromadb.EphemeralClient()
    vectordb = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        client=new_client,
        collection_name=collection_name,
    )
    return vectordb

# Initialize Langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        llm = HuggingFaceEndpoint(
            repo_id=llm_model, 
            temperature=temperature,
            max_new_tokens=max_tokens,
            top_k=top_k,
            load_in_8bit=True,
        )
    # Add other LLM models initialization conditions here...
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    retriever = vector_db.as_retriever()
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    return qa_chain

# Function to process uploaded PDFs and initialize the database
def process_documents(list_file_obj, chunk_size, chunk_overlap):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    collection_name = create_collection_name(list_file_path[0])
    doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
    vector_db = create_db(doc_splits, collection_name)
    return vector_db

# Streamlit app
def main():
    st.title("PDF-based Chatbot")
    st.write("Ask any questions about your PDF documents")

    # Step 1: Upload PDF documents
    uploaded_files = st.file_uploader("Upload your PDF documents (single or multiple)", type=["pdf"], accept_multiple_files=True)

    # Step 2: Process documents and initialize vector database
    if uploaded_files:
        chunk_size = st.slider("Chunk size", min_value=100, max_value=1000, value=600, step=20)
        chunk_overlap = st.slider("Chunk overlap", min_value=10, max_value=200, value=40, step=10)
        if st.button("Generate Vector Database"):
            vector_db = process_documents(uploaded_files, chunk_size, chunk_overlap)
            st.success("Vector database generated successfully!")

            # Step 3: Initialize QA chain with selected LLM model
            st.header("Initialize Question Answering (QA) Chain")
            llm_model = st.selectbox("Choose LLM Model", list_llm_simple)
            temperature = st.slider("Temperature", min_value=0.01, max_value=1.0, value=0.7, step=0.1)
            max_tokens = st.slider("Max Tokens", min_value=224, max_value=4096, value=1024, step=32)
            top_k = st.slider("Top-k Samples", min_value=1, max_value=10, value=3, step=1)
            if st.button("Initialize QA Chain"):
                qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db)
                st.success("QA Chain initialized successfully!")

                # Step 4: Chatbot interaction
                st.header("Chatbot")
                message = st.text_input("Type your message here")
                if st.button("Submit"):
                    response = qa_chain(message)
                    st.write(f"Chatbot Response: {response['answer']}")

if __name__ == "__main__":
    main()