File size: 3,595 Bytes
aa774c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#import os
#os.system("bash setup.sh")

import streamlit as st
# import fitz  # PyMuPDF for extracting text from PDFs
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
import torch
import re
import transformers
from torch import bfloat16
from langchain_community.document_loaders import DirectoryLoader

# Initialize embeddings and ChromaDB
model_name = "sentence-transformers/all-mpnet-base-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {"device": device}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)

# loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)

books_db_client = books_db.as_retriever()

# Initialize the model and tokenizer
model_name = "stabilityai/stablelm-zephyr-3b"

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    device_map=device,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

query_pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    torch_dtype=torch.float16,
    device_map=device,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    max_new_tokens=256
)

llm = HuggingFacePipeline(pipeline=query_pipeline)

books_db_client_retriever = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=books_db_client,
    verbose=True
)

st.title("RAG System with ChromaDB")

if 'messages' not in st.session_state:
    st.session_state.messages = [{'role': 'assistant', "content": 'Hello! Upload PDF files and ask me anything about their content.'}]

# Function to retrieve answer using the RAG system
def test_rag(qa, query):
    return qa.run(query)

user_prompt = st.chat_input("Ask me anything about the content of the PDF(s):")
if user_prompt:
    st.session_state.messages.append({'role': 'user', "content": user_prompt})
    books_retriever = test_rag(books_db_client_retriever, user_prompt)
    # Extracting the relevant answer using regex
    corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)

    if corrected_text_match:
        corrected_text_books = corrected_text_match.group(1).strip()
    else:
        corrected_text_books = "No helpful answer found."
    st.session_state.messages.append({'role': 'assistant', "content": corrected_text_books})

for message in st.session_state.messages:
    with st.chat_message(message['role']):
        st.write(message['content'])