mintaeng commited on
Commit
e6db427
1 Parent(s): e99b871

Create pdfchatbot.py

Browse files
Files changed (1) hide show
  1. pdfchatbot.py +193 -0
pdfchatbot.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import fitz
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.llms import HuggingFacePipeline
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain.document_loaders import PyPDFLoader
11
+ from langchain.prompts import PromptTemplate
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
+
14
+ class PDFChatBot:
15
+ def __init__(self, config_path="../config.yaml"):
16
+ """
17
+ Initialize the PDFChatBot instance.
18
+
19
+ Parameters:
20
+ config_path (str): Path to the configuration file (default is "../config.yaml").
21
+ """
22
+ self.processed = False
23
+ self.page = 0
24
+ self.chat_history = []
25
+ self.config = self.load_config(config_path)
26
+ # Initialize other attributes to None
27
+ self.prompt = None
28
+ self.documents = None
29
+ self.embeddings = None
30
+ self.vectordb = None
31
+ self.tokenizer = None
32
+ self.model = None
33
+ self.pipeline = None
34
+ self.chain = None
35
+
36
+ def load_config(self, file_path):
37
+ """
38
+ Load configuration from a YAML file.
39
+
40
+ Parameters:
41
+ file_path (str): Path to the YAML configuration file.
42
+
43
+ Returns:
44
+ dict: Configuration as a dictionary.
45
+ """
46
+ with open(file_path, 'r') as stream:
47
+ try:
48
+ config = yaml.safe_load(stream)
49
+ return config
50
+ except yaml.YAMLError as exc:
51
+ print(f"Error loading configuration: {exc}")
52
+ return None
53
+
54
+ def add_text(self, history, text):
55
+ """
56
+ Add user-entered text to the chat history.
57
+
58
+ Parameters:
59
+ history (list): List of chat history tuples.
60
+ text (str): User-entered text.
61
+
62
+ Returns:
63
+ list: Updated chat history.
64
+ """
65
+ if not text:
66
+ raise gr.Error('Enter text')
67
+ history.append((text, ''))
68
+ return history
69
+
70
+ def create_prompt_template(self):
71
+ """
72
+ Create a prompt template for the chatbot.
73
+ """
74
+ template = (
75
+ f"The assistant should provide detailed explanations."
76
+ "Combine the chat history and follow up question into "
77
+ "Follow up question: What is this"
78
+ )
79
+ self.prompt = PromptTemplate.from_template(template)
80
+
81
+ def load_embeddings(self):
82
+ """
83
+ Load embeddings from Hugging Face and set in the config file.
84
+ """
85
+ self.embeddings = HuggingFaceEmbeddings(model_name=self.config.get("modelEmbeddings"))
86
+
87
+ def load_vectordb(self):
88
+ """
89
+ Load the vector database from the documents and embeddings.
90
+ """
91
+ self.vectordb = Chroma.from_documents(self.documents, self.embeddings)
92
+
93
+ def load_tokenizer(self):
94
+ """
95
+ Load the tokenizer from Hugging Face and set in the config file.
96
+ """
97
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.get("autoTokenizer"))
98
+
99
+ def load_model(self):
100
+ """
101
+ Load the causal language model from Hugging Face and set in the config file.
102
+ """
103
+ self.model = AutoModelForCausalLM.from_pretrained(
104
+ self.config.get("autoModelForCausalLM"),
105
+ device_map='auto',
106
+ torch_dtype=torch.float32,
107
+ token=True,
108
+ load_in_4bit=True
109
+ )
110
+
111
+ def create_pipeline(self):
112
+ """
113
+ Create a pipeline for text generation using the loaded model and tokenizer.
114
+ """
115
+ pipe = pipeline(
116
+ model=self.model,
117
+ task='text-generation',
118
+ tokenizer=self.tokenizer,
119
+ max_new_tokens=200
120
+ )
121
+ self.pipeline = HuggingFacePipeline(pipeline=pipe)
122
+
123
+ def create_chain(self):
124
+ """
125
+ Create a Conversational Retrieval Chain
126
+ """
127
+ self.chain = ConversationalRetrievalChain.from_llm(
128
+ self.pipeline,
129
+ chain_type="stuff",
130
+ retriever=self.vectordb.as_retriever(search_kwargs={"k": 1}),
131
+ condense_question_prompt=self.prompt,
132
+ return_source_documents=True
133
+ )
134
+
135
+ def process_file(self, file):
136
+ """
137
+ Process the uploaded PDF file and initialize necessary components: Tokenizer, VectorDB and LLM.
138
+
139
+ Parameters:
140
+ file (FileStorage): The uploaded PDF file.
141
+ """
142
+ self.create_prompt_template()
143
+ self.documents = PyPDFLoader(file.name).load()
144
+ self.load_embeddings()
145
+ self.load_vectordb()
146
+ self.load_tokenizer()
147
+ self.load_model()
148
+ self.create_pipeline()
149
+ self.create_chain()
150
+
151
+ def generate_response(self, history, query, file):
152
+ """
153
+ Generate a response based on user query and chat history.
154
+
155
+ Parameters:
156
+ history (list): List of chat history tuples.
157
+ query (str): User's query.
158
+ file (FileStorage): The uploaded PDF file.
159
+
160
+ Returns:
161
+ tuple: Updated chat history and a space.
162
+ """
163
+ if not query:
164
+ raise gr.Error(message='Submit a question')
165
+ if not file:
166
+ raise gr.Error(message='Upload a PDF')
167
+ if not self.processed:
168
+ self.process_file(file)
169
+ self.processed = True
170
+
171
+ result = self.chain({"question": query, 'chat_history': self.chat_history}, return_only_outputs=True)
172
+ self.chat_history.append((query, result["answer"]))
173
+ self.page = list(result['source_documents'][0])[1][1]['page']
174
+
175
+ for char in result['answer']:
176
+ history[-1][-1] += char
177
+ return history, " "
178
+
179
+ def render_file(self, file):
180
+ """
181
+ Renders a specific page of a PDF file as an image.
182
+
183
+ Parameters:
184
+ file (FileStorage): The PDF file.
185
+
186
+ Returns:
187
+ PIL.Image.Image: The rendered page as an image.
188
+ """
189
+ doc = fitz.open(file.name)
190
+ page = doc[self.page]
191
+ pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
192
+ image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
193
+ return image