m7mdal7aj commited on
Commit
a97003a
1 Parent(s): d8f32a4

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +240 -240
my_model/KBVQA.py CHANGED
@@ -112,248 +112,248 @@ class KBVQA:
112
  self.current_prompt_length = None
113
 
114
 
115
- def create_bnb_config(self) -> BitsAndBytesConfig:
116
- """
117
- Creates a BitsAndBytes configuration based on the quantization setting.
118
- Returns:
119
- BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
120
- """
121
-
122
- if self.quantization == '4bit':
123
- return BitsAndBytesConfig(
124
- load_in_4bit=True,
125
- bnb_4bit_use_double_quant=True,
126
- bnb_4bit_quant_type="nf4",
127
- bnb_4bit_compute_dtype=torch.bfloat16
128
- )
129
- elif self.quantization == '8bit':
130
- return BitsAndBytesConfig(
131
- load_in_8bit=True,
132
- bnb_8bit_use_double_quant=True,
133
- bnb_8bit_quant_type="nf4",
134
- bnb_8bit_compute_dtype=torch.bfloat16
135
- )
136
-
137
-
138
- def load_caption_model(self) -> None:
139
- """
140
- Loads the image captioning model into the KBVQA instance.
141
-
142
- Returns:
143
- None
144
- """
145
-
146
- self.captioner = ImageCaptioningModel()
147
- self.captioner.load_model()
148
- free_gpu_resources()
149
-
150
-
151
- def get_caption(self, img: Image.Image) -> str:
152
- """
153
- Generates a caption for a given image using the image captioning model.
154
-
155
- Args:
156
- img (PIL.Image.Image): The image for which to generate a caption.
157
-
158
- Returns:
159
- str: The generated caption for the image.
160
- """
161
- caption = self.captioner.generate_caption(img)
162
- free_gpu_resources()
163
- return caption
164
-
165
-
166
- def load_detector(self, model: str) -> None:
167
- """
168
- Loads the object detection model.
169
-
170
- Args:
171
- model (str): The name of the object detection model to load.
172
-
173
- Returns:
174
- None
175
- """
176
-
177
- self.detector = ObjectDetector()
178
- self.detector.load_model(model)
179
- free_gpu_resources()
180
-
181
-
182
- def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
183
- """
184
- Detects objects in a given image using the loaded object detection model.
185
-
186
- Args:
187
- img (PIL.Image.Image): The image in which to detect objects.
188
-
189
- Returns:
190
- tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
191
- """
192
-
193
- image = self.detector.process_image(img)
194
- free_gpu_resources()
195
- detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
196
- 'confidence_level'])
197
- free_gpu_resources()
198
- image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
199
- free_gpu_resources()
200
- return image_with_boxes, detected_objects_string
201
-
202
-
203
- def load_fine_tuned_model(self) -> None:
204
- """
205
- Loads the fine-tuned KBVQA model along with its tokenizer.
206
-
207
- Returns:
208
- None
209
- """
210
-
211
- self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
212
- device_map="auto",
213
- low_cpu_mem_usage=True,
214
- quantization_config=self.bnb_config,
215
- token=self.access_token)
216
-
217
- free_gpu_resources()
218
-
219
- self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
220
- use_fast=self.use_fast,
221
- low_cpu_mem_usage=True,
222
- trust_remote_code=self.trust_remote,
223
- add_eos_token=self.add_eos_token,
224
- token=self.access_token)
225
- free_gpu_resources()
226
-
227
-
228
- @property
229
- def all_models_loaded(self) -> bool:
230
- """
231
- Checks if all the required models (KBVQA, captioner, detector) are loaded.
232
-
233
- Returns:
234
- bool: True if all models are loaded, False otherwise.
235
- """
236
-
237
- return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
238
-
239
-
240
- def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
241
- caption: str = None, objects: Optional[str] = None) -> str:
242
- """
243
- Formats the prompt for the KBVQA model based on the provided parameters.
244
-
245
- This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
246
-
247
- Args:
248
- current_query (str): The current question to be answered.
249
- history (str, optional): The history of previous interactions.
250
- sys_prompt (str, optional): The system prompt or instructions for the model.
251
- caption (str, optional): The caption of the image.
252
- objects (str, optional): The detected objects in the image.
253
-
254
- Returns:
255
- str: The formatted prompt for the KBVQA model.
256
- """
257
-
258
- # These are the special tokens designed for the model to be fine-tuned on.
259
- B_CAP = '[CAP]'
260
- E_CAP = '[/CAP]'
261
- B_QES = '[QES]'
262
- E_QES = '[/QES]'
263
- B_OBJ = '[OBJ]'
264
- E_OBJ = '[/OBJ]'
265
-
266
- # These are the default special tokens of LLaMA-2 Chat Model.
267
- B_SENT = '<s>'
268
- E_SENT = '</s>'
269
- B_INST = '[INST]'
270
- E_INST = '[/INST]'
271
- B_SYS = '<<SYS>>\n'
272
- E_SYS = '\n<</SYS>>\n\n'
273
-
274
- current_query = current_query.strip()
275
- if sys_prompt is None:
276
- sys_prompt = config.SYSTEM_PROMPT.strip()
277
-
278
- # History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
279
- if history is None:
280
- if objects is None:
281
- p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
 
 
282
  else:
283
- p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
284
- else:
285
- p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
286
-
287
- return p
288
-
289
-
290
- @staticmethod
291
- def trim_objects(detected_objects_str: str) -> str:
292
- """
293
- Trim the last object from the detected objects string.
294
- This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
295
-
296
- Args:
297
- detected_objects_str (str): String containing detected objects.
298
-
299
- Returns:
300
- str: The string with the last object removed.
301
- """
302
-
303
- objects = detected_objects_str.strip().split("\n")
304
- if len(objects) >= 1:
305
- return "\n".join(objects[:-1])
306
- return ""
307
-
308
-
309
- def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
310
- """
311
- Generates an answer to a given question using the KBVQA model.
312
-
313
- Args:
314
- question (str): The question to be answered.
315
- caption (str): The caption of the image related to the question.
316
- detected_objects_str (str): The string representation of detected objects in the image.
317
-
318
- Returns:
319
- str: The generated answer to the question.
320
- """
321
-
322
- free_gpu_resources()
323
- prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
324
- num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
325
- self.current_prompt_length = num_tokens
326
- trim = False # flag used to check if prompt trim is required or no.
327
- # max_context_window is set to 4,000 tokens, refer to the config file.
328
- if self.current_prompt_length > self.max_context_window:
329
- trim = True
330
- st.warning(
331
- f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
332
- f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
333
- f" maximum context window ...")
334
- # an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
335
- while self.current_prompt_length > self.max_context_window:
336
- detected_objects_str = self.trim_objects(detected_objects_str)
337
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
338
- self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
339
-
340
- if detected_objects_str == "":
341
- break # Break if no objects are left
342
- if trim:
343
- st.warning(f"New prompt length is: {self.current_prompt_length}")
344
- trim = False
345
-
346
- model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
347
- free_gpu_resources()
348
- input_ids = model_inputs["input_ids"]
349
- output_ids = self.kbvqa_model.generate(input_ids)
350
- free_gpu_resources()
351
- index = input_ids.shape[1] # needed to avoid printing the input prompt
352
- history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
353
- output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
354
-
355
- return output_text.capitalize()
356
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
359
  """
 
112
  self.current_prompt_length = None
113
 
114
 
115
+ def create_bnb_config(self) -> BitsAndBytesConfig:
116
+ """
117
+ Creates a BitsAndBytes configuration based on the quantization setting.
118
+ Returns:
119
+ BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
120
+ """
121
+
122
+ if self.quantization == '4bit':
123
+ return BitsAndBytesConfig(
124
+ load_in_4bit=True,
125
+ bnb_4bit_use_double_quant=True,
126
+ bnb_4bit_quant_type="nf4",
127
+ bnb_4bit_compute_dtype=torch.bfloat16
128
+ )
129
+ elif self.quantization == '8bit':
130
+ return BitsAndBytesConfig(
131
+ load_in_8bit=True,
132
+ bnb_8bit_use_double_quant=True,
133
+ bnb_8bit_quant_type="nf4",
134
+ bnb_8bit_compute_dtype=torch.bfloat16
135
+ )
136
+
137
+
138
+ def load_caption_model(self) -> None:
139
+ """
140
+ Loads the image captioning model into the KBVQA instance.
141
+
142
+ Returns:
143
+ None
144
+ """
145
+
146
+ self.captioner = ImageCaptioningModel()
147
+ self.captioner.load_model()
148
+ free_gpu_resources()
149
+
150
+
151
+ def get_caption(self, img: Image.Image) -> str:
152
+ """
153
+ Generates a caption for a given image using the image captioning model.
154
+
155
+ Args:
156
+ img (PIL.Image.Image): The image for which to generate a caption.
157
+
158
+ Returns:
159
+ str: The generated caption for the image.
160
+ """
161
+ caption = self.captioner.generate_caption(img)
162
+ free_gpu_resources()
163
+ return caption
164
+
165
+
166
+ def load_detector(self, model: str) -> None:
167
+ """
168
+ Loads the object detection model.
169
+
170
+ Args:
171
+ model (str): The name of the object detection model to load.
172
+
173
+ Returns:
174
+ None
175
+ """
176
+
177
+ self.detector = ObjectDetector()
178
+ self.detector.load_model(model)
179
+ free_gpu_resources()
180
+
181
+
182
+ def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
183
+ """
184
+ Detects objects in a given image using the loaded object detection model.
185
+
186
+ Args:
187
+ img (PIL.Image.Image): The image in which to detect objects.
188
+
189
+ Returns:
190
+ tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
191
+ """
192
+
193
+ image = self.detector.process_image(img)
194
+ free_gpu_resources()
195
+ detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
196
+ 'confidence_level'])
197
+ free_gpu_resources()
198
+ image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
199
+ free_gpu_resources()
200
+ return image_with_boxes, detected_objects_string
201
+
202
+
203
+ def load_fine_tuned_model(self) -> None:
204
+ """
205
+ Loads the fine-tuned KBVQA model along with its tokenizer.
206
+
207
+ Returns:
208
+ None
209
+ """
210
+
211
+ self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
212
+ device_map="auto",
213
+ low_cpu_mem_usage=True,
214
+ quantization_config=self.bnb_config,
215
+ token=self.access_token)
216
+
217
+ free_gpu_resources()
218
+
219
+ self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
220
+ use_fast=self.use_fast,
221
+ low_cpu_mem_usage=True,
222
+ trust_remote_code=self.trust_remote,
223
+ add_eos_token=self.add_eos_token,
224
+ token=self.access_token)
225
+ free_gpu_resources()
226
+
227
+
228
+ @property
229
+ def all_models_loaded(self) -> bool:
230
+ """
231
+ Checks if all the required models (KBVQA, captioner, detector) are loaded.
232
+
233
+ Returns:
234
+ bool: True if all models are loaded, False otherwise.
235
+ """
236
+
237
+ return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
238
+
239
+
240
+ def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
241
+ caption: str = None, objects: Optional[str] = None) -> str:
242
+ """
243
+ Formats the prompt for the KBVQA model based on the provided parameters.
244
+
245
+ This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
246
+
247
+ Args:
248
+ current_query (str): The current question to be answered.
249
+ history (str, optional): The history of previous interactions.
250
+ sys_prompt (str, optional): The system prompt or instructions for the model.
251
+ caption (str, optional): The caption of the image.
252
+ objects (str, optional): The detected objects in the image.
253
+
254
+ Returns:
255
+ str: The formatted prompt for the KBVQA model.
256
+ """
257
+
258
+ # These are the special tokens designed for the model to be fine-tuned on.
259
+ B_CAP = '[CAP]'
260
+ E_CAP = '[/CAP]'
261
+ B_QES = '[QES]'
262
+ E_QES = '[/QES]'
263
+ B_OBJ = '[OBJ]'
264
+ E_OBJ = '[/OBJ]'
265
+
266
+ # These are the default special tokens of LLaMA-2 Chat Model.
267
+ B_SENT = '<s>'
268
+ E_SENT = '</s>'
269
+ B_INST = '[INST]'
270
+ E_INST = '[/INST]'
271
+ B_SYS = '<<SYS>>\n'
272
+ E_SYS = '\n<</SYS>>\n\n'
273
+
274
+ current_query = current_query.strip()
275
+ if sys_prompt is None:
276
+ sys_prompt = config.SYSTEM_PROMPT.strip()
277
+
278
+ # History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
279
+ if history is None:
280
+ if objects is None:
281
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
282
+ else:
283
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
284
  else:
285
+ p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
286
+
287
+ return p
288
+
289
+
290
+ @staticmethod
291
+ def trim_objects(detected_objects_str: str) -> str:
292
+ """
293
+ Trim the last object from the detected objects string.
294
+ This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
295
+
296
+ Args:
297
+ detected_objects_str (str): String containing detected objects.
298
+
299
+ Returns:
300
+ str: The string with the last object removed.
301
+ """
302
+
303
+ objects = detected_objects_str.strip().split("\n")
304
+ if len(objects) >= 1:
305
+ return "\n".join(objects[:-1])
306
+ return ""
307
+
308
+
309
+ def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
310
+ """
311
+ Generates an answer to a given question using the KBVQA model.
312
+
313
+ Args:
314
+ question (str): The question to be answered.
315
+ caption (str): The caption of the image related to the question.
316
+ detected_objects_str (str): The string representation of detected objects in the image.
317
+
318
+ Returns:
319
+ str: The generated answer to the question.
320
+ """
321
+
322
+ free_gpu_resources()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
324
+ num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
325
+ self.current_prompt_length = num_tokens
326
+ trim = False # flag used to check if prompt trim is required or no.
327
+ # max_context_window is set to 4,000 tokens, refer to the config file.
328
+ if self.current_prompt_length > self.max_context_window:
329
+ trim = True
330
+ st.warning(
331
+ f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
332
+ f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
333
+ f" maximum context window ...")
334
+ # an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
335
+ while self.current_prompt_length > self.max_context_window:
336
+ detected_objects_str = self.trim_objects(detected_objects_str)
337
+ prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
338
+ self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
339
+
340
+ if detected_objects_str == "":
341
+ break # Break if no objects are left
342
+ if trim:
343
+ st.warning(f"New prompt length is: {self.current_prompt_length}")
344
+ trim = False
345
+
346
+ model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
347
+ free_gpu_resources()
348
+ input_ids = model_inputs["input_ids"]
349
+ output_ids = self.kbvqa_model.generate(input_ids)
350
+ free_gpu_resources()
351
+ index = input_ids.shape[1] # needed to avoid printing the input prompt
352
+ history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
353
+ output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
354
+
355
+ return output_text.capitalize()
356
+
357
 
358
  def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
359
  """