Ellbendls commited on
Commit
fad0e0e
·
1 Parent(s): dd91ec8

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +165 -0
  2. constants.py +9 -0
  3. ingest.py +27 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import base64
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from transformers import pipeline
7
+ import torch
8
+ import textwrap
9
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader, PDFMinerLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.embeddings import SentenceTransformerEmbeddings
12
+ from langchain.vectorstores import Chroma
13
+ from langchain.llms import HuggingFacePipeline
14
+ from langchain.chains import RetrievalQA
15
+ from constants import CHROMA_SETTINGS
16
+ from streamlit_chat import message
17
+
18
+ st.set_page_config(layout="wide")
19
+
20
+ checkpoint = "LaMini-T5-738M"
21
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
22
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
23
+ checkpoint,
24
+ device_map="auto",
25
+ torch_dtype = torch.float32
26
+ )
27
+
28
+ persist_directory = "db"
29
+
30
+ @st.cache_resource
31
+ def data_ingestion():
32
+ for root, dirs, files in os.walk("docs"):
33
+ for file in files:
34
+ if file.endswith(".pdf"):
35
+ print(file)
36
+ loader = PDFMinerLoader(os.path.join(root, file))
37
+ documents = loader.load()
38
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
39
+ texts = text_splitter.split_documents(documents)
40
+ #create embeddings here
41
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
42
+ #create vector store here
43
+ db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
44
+ db.persist()
45
+ db=None
46
+
47
+ @st.cache_resource
48
+ def llm_pipeline():
49
+ pipe = pipeline(
50
+ 'text2text-generation',
51
+ model = base_model,
52
+ tokenizer = tokenizer,
53
+ max_length = 256,
54
+ do_sample = True,
55
+ temperature = 0.3,
56
+ top_p= 0.95
57
+ )
58
+ local_llm = HuggingFacePipeline(pipeline=pipe)
59
+ return local_llm
60
+
61
+ @st.cache_resource
62
+ def qa_llm():
63
+ llm = llm_pipeline()
64
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
65
+ db = Chroma(persist_directory="db", embedding_function = embeddings, client_settings=CHROMA_SETTINGS)
66
+ retriever = db.as_retriever()
67
+ qa = RetrievalQA.from_chain_type(
68
+ llm = llm,
69
+ chain_type = "stuff",
70
+ retriever = retriever,
71
+ return_source_documents=True
72
+ )
73
+ return qa
74
+
75
+ def process_answer(instruction):
76
+ response = ''
77
+ instruction = instruction
78
+ qa = qa_llm()
79
+ generated_text = qa(instruction)
80
+ answer = generated_text['result']
81
+ return answer
82
+
83
+ def get_file_size(file):
84
+ file.seek(0, os.SEEK_END)
85
+ file_size = file.tell()
86
+ file.seek(0)
87
+ return file_size
88
+
89
+ @st.cache_data
90
+ #function to display the PDF of a given file
91
+ def displayPDF(file):
92
+ # Opening file from file path
93
+ with open(file, "rb") as f:
94
+ base64_pdf = base64.b64encode(f.read()).decode('utf-8')
95
+
96
+ # Embedding PDF in HTML
97
+ pdf_display = F'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
98
+
99
+ # Displaying File
100
+ st.markdown(pdf_display, unsafe_allow_html=True)
101
+
102
+ # Display conversation history using Streamlit messages
103
+ def display_conversation(history):
104
+ for i in range(len(history["generated"])):
105
+ message(history["past"][i], is_user=True, key=str(i) + "_user")
106
+ message(history["generated"][i],key=str(i))
107
+
108
+ def main():
109
+ st.markdown("<h1 style='text-align: center; color: blue;'>ChatPDFv2</h1>", unsafe_allow_html=True)
110
+
111
+ st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF</h2>", unsafe_allow_html=True)
112
+
113
+ uploaded_file = st.file_uploader("", type=["pdf"])
114
+
115
+ if uploaded_file is not None:
116
+ file_details = {
117
+ "Filename": uploaded_file.name,
118
+ "File size": get_file_size(uploaded_file)
119
+ }
120
+ filepath = "docs/"+uploaded_file.name
121
+ with open(filepath, "wb") as temp_file:
122
+ temp_file.write(uploaded_file.read())
123
+
124
+ col1, col2= st.columns([1,2])
125
+ with col1:
126
+ st.markdown("<h4 style color:black;'>File details</h4>", unsafe_allow_html=True)
127
+ st.json(file_details)
128
+ st.markdown("<h4 style color:black;'>File preview</h4>", unsafe_allow_html=True)
129
+ pdf_view = displayPDF(filepath)
130
+
131
+ with col2:
132
+ with st.spinner('Embeddings are in process...'):
133
+ ingested_data = data_ingestion()
134
+ st.success('Embeddings are created successfully!')
135
+ st.markdown("<h4 style color:black;'>Chat Here</h4>", unsafe_allow_html=True)
136
+
137
+
138
+ user_input = st.text_input("", key="input")
139
+
140
+ # Initialize session state for generated responses and past messages
141
+ if "generated" not in st.session_state:
142
+ st.session_state["generated"] = ["I am ready to help you"]
143
+ if "past" not in st.session_state:
144
+ st.session_state["past"] = ["Hey there!"]
145
+
146
+ # Search the database for a response based on user input and update session state
147
+ if user_input:
148
+ answer = process_answer({'query': user_input})
149
+ st.session_state["past"].append(user_input)
150
+ response = answer
151
+ st.session_state["generated"].append(response)
152
+
153
+ # Display conversation history using Streamlit messages
154
+ if st.session_state["generated"]:
155
+ display_conversation(st.session_state)
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
165
+
constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from chromadb.config import Settings
3
+
4
+ #Define the chroma settings
5
+ CHROMA_SETTINGS = Settings(
6
+ chroma_db_impl = 'duckdb+parquet',
7
+ persist_directory = "db",
8
+ anonymized_telemetry = False
9
+ )
ingest.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader, PDFMinerLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.embeddings import SentenceTransformerEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+ import os
6
+ from constants import CHROMA_SETTINGS
7
+
8
+ persist_directory = "db"
9
+
10
+ def main():
11
+ for root, dirs, files in os.walk("docs"):
12
+ for file in files:
13
+ if file.endswith(".pdf"):
14
+ print(file)
15
+ loader = PDFMinerLoader(os.path.join(root, file))
16
+ documents = loader.load()
17
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
18
+ texts = text_splitter.split_documents(documents)
19
+ #create embeddings here
20
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
21
+ #create vector store here
22
+ db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
23
+ db.persist()
24
+ db=None
25
+
26
+ if __name__ == "__main__":
27
+ main()