m7mdal7aj commited on
Commit
fc498e0
·
verified ·
1 Parent(s): 453b185

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +300 -200
my_model/KBVQA.py CHANGED
@@ -1,7 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
- import copy
4
- import os
5
  from PIL import Image
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
  from typing import Tuple, Optional
@@ -11,11 +49,10 @@ from my_model.detector.object_detection import ObjectDetector
11
  import my_model.config.kbvqa_config as config
12
 
13
 
14
-
15
  class KBVQA:
16
  """
17
- The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
18
- It integrates various components such as an image captioning model, object detection model, and a fine-tuned
19
  language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.
20
 
21
  Attributes:
@@ -49,14 +86,17 @@ class KBVQA:
49
  generate_answer: Generates an answer to a given question using the KBVQA model.
50
  """
51
 
52
- def __init__(self):
 
 
 
53
 
54
  if st.session_state["method"] == "7b-Fine-Tuned Model":
55
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b
56
  elif st.session_state["method"] == "13b-Fine-Tuned Model":
57
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b
58
  self.quantization: str = config.QUANTIZATION
59
- self.max_context_window: int = config.MAX_CONTEXT_WINDOW
60
  self.add_eos_token: bool = config.ADD_EOS_TOKEN
61
  self.trust_remote: bool = config.TRUST_REMOTE
62
  self.use_fast: bool = config.USE_FAST
@@ -70,234 +110,270 @@ class KBVQA:
70
  self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
71
  self.access_token: str = config.HUGGINGFACE_TOKEN
72
  self.current_prompt_length = None
73
-
74
-
75
- def create_bnb_config(self) -> BitsAndBytesConfig:
76
- """
77
- Creates a BitsAndBytes configuration based on the quantization setting.
78
- Returns:
79
- BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
80
- """
81
- if self.quantization == '4bit':
82
- return BitsAndBytesConfig(
83
- load_in_4bit=True,
84
- bnb_4bit_use_double_quant=True,
85
- bnb_4bit_quant_type="nf4",
86
- bnb_4bit_compute_dtype=torch.bfloat16
87
- )
88
- elif self.quantization == '8bit':
89
- return BitsAndBytesConfig(
90
- load_in_8bit=True,
91
- bnb_8bit_use_double_quant=True,
92
- bnb_8bit_quant_type="nf4",
93
- bnb_8bit_compute_dtype=torch.bfloat16
94
- )
95
-
96
-
97
- def load_caption_model(self) -> None:
98
- """
99
- Loads the image captioning model into the KBVQA instance.
100
- """
101
-
102
- self.captioner = ImageCaptioningModel()
103
- self.captioner.load_model()
104
- free_gpu_resources()
105
 
106
- def get_caption(self, img: Image.Image) -> str:
107
- """
108
- Generates a caption for a given image using the image captioning model.
109
 
110
- Args:
111
- img (PIL.Image.Image): The image for which to generate a caption.
 
 
 
 
112
 
113
- Returns:
114
- str: The generated caption for the image.
115
- """
116
- caption = self.captioner.generate_caption(img)
117
- free_gpu_resources()
118
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- def load_detector(self, model: str) -> None:
121
- """
122
- Loads the object detection model.
123
 
124
- Args:
125
- model (str): The name of the object detection model to load.
126
- """
127
 
128
- self.detector = ObjectDetector()
129
- self.detector.load_model(model)
130
- free_gpu_resources()
131
 
132
- def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
133
- """
134
- Detects objects in a given image using the loaded object detection model.
135
 
136
- Args:
137
- img (PIL.Image.Image): The image in which to detect objects.
138
 
139
- Returns:
140
- tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
141
- """
142
-
143
- image = self.detector.process_image(img)
144
- free_gpu_resources()
145
- detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state['confidence_level'])
146
- free_gpu_resources()
147
- image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
148
- free_gpu_resources()
149
- return image_with_boxes, detected_objects_string
150
 
151
- def load_fine_tuned_model(self) -> None:
152
- """
153
- Loads the fine-tuned KBVQA model along with its tokenizer.
154
- """
155
-
156
- self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
157
- device_map="auto",
158
- low_cpu_mem_usage=True,
159
- quantization_config=self.bnb_config,
160
- token=self.access_token)
161
 
162
- free_gpu_resources()
163
-
164
- self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
165
- use_fast=self.use_fast,
166
- low_cpu_mem_usage=True,
167
- trust_remote_code=self.trust_remote,
168
- add_eos_token=self.add_eos_token,
169
- token=self.access_token)
170
- free_gpu_resources()
171
 
172
- @property
173
- def all_models_loaded(self):
174
- """
175
- Checks if all the required models (KBVQA, captioner, detector) are loaded.
176
 
177
- Returns:
178
- bool: True if all models are loaded, False otherwise.
179
- """
180
-
181
- return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
182
-
183
 
 
 
 
184
 
185
- def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
186
- """
187
- Formats the prompt for the KBVQA model based on the provided parameters.
188
 
189
- Args:
190
- current_query (str): The current question to be answered.
191
- history (str, optional): The history of previous interactions.
192
- sys_prompt (str, optional): The system prompt or instructions for the model.
193
- caption (str, optional): The caption of the image.
194
- objects (str, optional): The detected objects in the image.
195
 
196
- Returns:
197
- str: The formatted prompt for the KBVQA model.
198
- """
199
-
200
- B_SENT = '<s>'
201
- E_SENT = '</s>'
202
- B_INST = '[INST]'
203
- E_INST = '[/INST]'
204
- B_SYS = '<<SYS>>\n'
205
- E_SYS = '\n<</SYS>>\n\n'
206
- B_CAP = '[CAP]'
207
- E_CAP = '[/CAP]'
208
- B_QES = '[QES]'
209
- E_QES = '[/QES]'
210
- B_OBJ = '[OBJ]'
211
- E_OBJ = '[/OBJ]'
212
- current_query = current_query.strip()
213
- if sys_prompt is None:
214
- sys_prompt = config.SYSTEM_PROMPT.strip()
215
- if history is None:
216
- if objects is None:
217
- p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
218
- else:
219
- p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  else:
221
- p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
222
-
223
- return p
224
 
225
- @staticmethod
226
- def trim_objects(detected_objects_str):
227
- """
228
- Trim the last object from the detected objects string.
229
-
230
- Args:
231
- - detected_objects_str (str): String containing detected objects.
232
-
233
- Returns:
234
- - (str): The string with the last object removed.
235
- """
236
- objects = detected_objects_str.strip().split("\n")
237
- if len(objects) >= 1:
238
- return "\n".join(objects[:-1])
239
- return ""
240
 
241
- def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
242
- """
243
- Generates an answer to a given question using the KBVQA model.
244
 
245
- Args:
246
- question (str): The question to be answered.
247
- caption (str): The caption of the image related to the question.
248
- detected_objects_str (str): The string representation of detected objects in the image.
 
249
 
250
- Returns:
251
- str: The generated answer to the question.
252
- """
253
-
254
-
255
- free_gpu_resources()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
257
- num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
258
- self.current_prompt_length = num_tokens
 
 
 
 
259
  trim = False
260
- if self.current_prompt_length > self.max_context_window:
261
- trim = True
262
- st.warning(f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2, objects detected with low confidence will be removed one at a time until the prompt length is within the maximum context window ...")
263
- while self.current_prompt_length > self.max_context_window:
264
- detected_objects_str = self.trim_objects(detected_objects_str)
265
- prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
266
- self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
267
-
268
- if detected_objects_str == "":
269
- break # Break if no objects are left
270
- if trim:
271
- st.warning(f"New prompt length is: {self.current_prompt_length}")
272
- trim = False
273
-
274
- model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
275
- free_gpu_resources()
276
- input_ids = model_inputs["input_ids"]
277
- output_ids = self.kbvqa_model.generate(input_ids)
278
- free_gpu_resources()
279
- index = input_ids.shape[1] # needed to avoid printing the input prompt
280
- history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
281
- output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
282
 
283
- return output_text.capitalize()
 
 
 
 
 
 
 
 
 
 
284
 
285
  def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
286
  """
287
  Prepares the KBVQA model for use, including loading necessary sub-models.
288
 
 
 
289
  Args:
290
  only_reload_detection_model (bool): If True, only the object detection model is reloaded.
 
291
 
292
  Returns:
293
  KBVQA: An instance of the KBVQA model ready for inference.
294
  """
295
-
296
  if force_reload:
297
  free_gpu_resources()
298
  loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!'
299
  try:
300
- del kbvqa
301
  free_gpu_resources()
302
  free_gpu_resources()
303
  except:
@@ -305,14 +381,15 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload:
305
  free_gpu_resources()
306
  pass
307
  free_gpu_resources()
308
-
309
- else: loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!'
 
310
 
311
  free_gpu_resources()
312
  kbvqa = KBVQA()
313
  kbvqa.detection_model = st.session_state.detection_model
314
  # Progress bar for model loading
315
-
316
  with st.spinner(loading_message):
317
  if not only_reload_detection_model:
318
  progress_bar = st.progress(0)
@@ -330,11 +407,34 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload:
330
  progress_bar = st.progress(0)
331
  kbvqa.load_detector(kbvqa.detection_model)
332
  progress_bar.progress(100)
333
-
334
  if kbvqa.all_models_loaded:
335
  st.success('Model loaded successfully and ready for inferecne!')
336
  kbvqa.kbvqa_model.eval()
337
  free_gpu_resources()
338
  return kbvqa
339
 
340
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Main script for KBVQA: Knowledge-Based Visual Question Answering Module
2
+
3
+ # This module is the central component for implementing the designed model architecture for the Knowledge-Based Visual
4
+ # Question Answering (KB-VQA) project. It integrates various sub-modules, including image captioning, object detection,
5
+ # and a fine-tuned language model, to provide a comprehensive solution for answering questions based on visual input.
6
+
7
+ # --- Description ---
8
+ # **KBVQA class**:
9
+ # The KBVQA class encapsulates the functionality needed to perform visual question answering using a combination of
10
+ # multimodal models.
11
+ # The class handles the following tasks:
12
+ # - Loading and managing a fine-tuned language model (LLaMA-2) for question answering.
13
+ # - Integrating an image captioning model to generate descriptive captions for input images.
14
+ # - Utilizing an object detection model to identify and describe objects within the images.
15
+ # - Formatting and generating prompts for the language model based on the image captions and detected objects.
16
+ # - Providing methods to analyze images and generate answers to user-provided questions.
17
+
18
+ # **prepare_kbvqa_model function**:
19
+ # - The prepare_kbvqa_model function orchestrates the loading and initialization of the KBVQA class, ensuring it is
20
+ # ready for inference.
21
+
22
+ # ---Instructions---
23
+ # **Model Preparation**:
24
+ # Use the prepare_kbvqa_model function to prepare and initialize the KBVQA system, ensuring all required models are
25
+ # loaded and ready for use.
26
+
27
+ # **Image Processing and Question Answering**:
28
+ # Use the get_caption method to generate captions for input images.
29
+ # Use the detect_objects method to identify and describe objects in the images.
30
+ # Use the generate_answer method to answer questions based on the image captions and detected objects.
31
+
32
+ # This module forms the backbone of the KB-VQA project, integrating advanced models to provide an end-to-end solution
33
+ # for visual question answering tasks.
34
+ # Ensure all dependencies are installed and the required configuration file is in place before running this script.
35
+ # The configurations for the KBVQA class are defined in the 'my_model/config/kbvqa_config.py' file.
36
+
37
+ # ---------- Please run this module to utilize the full KB-VQA functionality ----------#
38
+ # ---------- Please ensure this is run on a GPU ----------#
39
+
40
+
41
  import streamlit as st
42
  import torch
 
 
43
  from PIL import Image
44
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
45
  from typing import Tuple, Optional
 
49
  import my_model.config.kbvqa_config as config
50
 
51
 
 
52
  class KBVQA:
53
  """
54
+ The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
55
+ It integrates various components such as an image captioning model, object detection model, and a fine-tuned
56
  language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.
57
 
58
  Attributes:
 
86
  generate_answer: Generates an answer to a given question using the KBVQA model.
87
  """
88
 
89
+ def __init__(self) -> None:
90
+ """
91
+ Initializes the KBVQA instance with configuration parameters.
92
+ """
93
 
94
  if st.session_state["method"] == "7b-Fine-Tuned Model":
95
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b
96
  elif st.session_state["method"] == "13b-Fine-Tuned Model":
97
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b
98
  self.quantization: str = config.QUANTIZATION
99
+ self.max_context_window: int = config.MAX_CONTEXT_WINDOW # set to 4,000 tokens
100
  self.add_eos_token: bool = config.ADD_EOS_TOKEN
101
  self.trust_remote: bool = config.TRUST_REMOTE
102
  self.use_fast: bool = config.USE_FAST
 
110
  self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
111
  self.access_token: str = config.HUGGINGFACE_TOKEN
112
  self.current_prompt_length = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
114
 
115
+ def create_bnb_config(self) -> BitsAndBytesConfig:
116
+ """
117
+ Creates a BitsAndBytes configuration based on the quantization setting.
118
+ Returns:
119
+ BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
120
+ """
121
 
122
+ if self.quantization == '4bit':
123
+ return BitsAndBytesConfig(
124
+ load_in_4bit=True,
125
+ bnb_4bit_use_double_quant=True,
126
+ bnb_4bit_quant_type="nf4",
127
+ bnb_4bit_compute_dtype=torch.bfloat16
128
+ )
129
+ elif self.quantization == '8bit':
130
+ return BitsAndBytesConfig(
131
+ load_in_8bit=True,
132
+ bnb_8bit_use_double_quant=True,
133
+ bnb_8bit_quant_type="nf4",
134
+ bnb_8bit_compute_dtype=torch.bfloat16
135
+ )
136
+
137
+
138
+ def load_caption_model(self) -> None:
139
+ """
140
+ Loads the image captioning model into the KBVQA instance.
141
 
142
+ Returns:
143
+ None
144
+ """
145
 
146
+ self.captioner = ImageCaptioningModel()
147
+ self.captioner.load_model()
148
+ free_gpu_resources()
149
 
 
 
 
150
 
151
+ def get_caption(self, img: Image.Image) -> str:
152
+ """
153
+ Generates a caption for a given image using the image captioning model.
154
 
155
+ Args:
156
+ img (PIL.Image.Image): The image for which to generate a caption.
157
 
158
+ Returns:
159
+ str: The generated caption for the image.
160
+ """
161
+ caption = self.captioner.generate_caption(img)
162
+ free_gpu_resources()
163
+ return caption
 
 
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def load_detector(self, model: str) -> None:
167
+ """
168
+ Loads the object detection model.
 
 
 
 
 
 
169
 
170
+ Args:
171
+ model (str): The name of the object detection model to load.
 
 
172
 
173
+ Returns:
174
+ None
175
+ """
 
 
 
176
 
177
+ self.detector = ObjectDetector()
178
+ self.detector.load_model(model)
179
+ free_gpu_resources()
180
 
 
 
 
181
 
182
+ def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
183
+ """
184
+ Detects objects in a given image using the loaded object detection model.
 
 
 
185
 
186
+ Args:
187
+ img (PIL.Image.Image): The image in which to detect objects.
188
+
189
+ Returns:
190
+ tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
191
+ """
192
+
193
+ image = self.detector.process_image(img)
194
+ free_gpu_resources()
195
+ detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
196
+ 'confidence_level'])
197
+ free_gpu_resources()
198
+ image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
199
+ free_gpu_resources()
200
+ return image_with_boxes, detected_objects_string
201
+
202
+
203
+ def load_fine_tuned_model(self) -> None:
204
+ """
205
+ Loads the fine-tuned KBVQA model along with its tokenizer.
206
+
207
+ Returns:
208
+ None
209
+ """
210
+
211
+ self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
212
+ device_map="auto",
213
+ low_cpu_mem_usage=True,
214
+ quantization_config=self.bnb_config,
215
+ token=self.access_token)
216
+
217
+ free_gpu_resources()
218
+
219
+ self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
220
+ use_fast=self.use_fast,
221
+ low_cpu_mem_usage=True,
222
+ trust_remote_code=self.trust_remote,
223
+ add_eos_token=self.add_eos_token,
224
+ token=self.access_token)
225
+ free_gpu_resources()
226
+
227
+
228
+ @property
229
+ def all_models_loaded(self) -> bool:
230
+ """
231
+ Checks if all the required models (KBVQA, captioner, detector) are loaded.
232
+
233
+ Returns:
234
+ bool: True if all models are loaded, False otherwise.
235
+ """
236
+
237
+ return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
238
+
239
+
240
+ def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
241
+ caption: str = None, objects: Optional[str] = None) -> str:
242
+ """
243
+ Formats the prompt for the KBVQA model based on the provided parameters.
244
+
245
+ This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
246
+
247
+ Args:
248
+ current_query (str): The current question to be answered.
249
+ history (str, optional): The history of previous interactions.
250
+ sys_prompt (str, optional): The system prompt or instructions for the model.
251
+ caption (str, optional): The caption of the image.
252
+ objects (str, optional): The detected objects in the image.
253
+
254
+ Returns:
255
+ str: The formatted prompt for the KBVQA model.
256
+ """
257
+
258
+ # These are the special tokens designed for the model to be fine-tuned on.
259
+ B_CAP = '[CAP]'
260
+ E_CAP = '[/CAP]'
261
+ B_QES = '[QES]'
262
+ E_QES = '[/QES]'
263
+ B_OBJ = '[OBJ]'
264
+ E_OBJ = '[/OBJ]'
265
+
266
+ # These are the default special tokens of LLaMA-2 Chat Model.
267
+ B_SENT = '<s>'
268
+ E_SENT = '</s>'
269
+ B_INST = '[INST]'
270
+ E_INST = '[/INST]'
271
+ B_SYS = '<<SYS>>\n'
272
+ E_SYS = '\n<</SYS>>\n\n'
273
+
274
+ current_query = current_query.strip()
275
+ if sys_prompt is None:
276
+ sys_prompt = config.SYSTEM_PROMPT.strip()
277
+
278
+ # History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
279
+ if history is None:
280
+ if objects is None:
281
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
282
  else:
283
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
284
+ else:
285
+ p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
286
 
287
+ return p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
 
 
 
289
 
290
+ @staticmethod
291
+ def trim_objects(detected_objects_str: str) -> str:
292
+ """
293
+ Trim the last object from the detected objects string.
294
+ This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
295
 
296
+ Args:
297
+ detected_objects_str (str): String containing detected objects.
298
+
299
+ Returns:
300
+ str: The string with the last object removed.
301
+ """
302
+
303
+ objects = detected_objects_str.strip().split("\n")
304
+ if len(objects) >= 1:
305
+ return "\n".join(objects[:-1])
306
+ return ""
307
+
308
+
309
+ def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
310
+ """
311
+ Generates an answer to a given question using the KBVQA model.
312
+
313
+ Args:
314
+ question (str): The question to be answered.
315
+ caption (str): The caption of the image related to the question.
316
+ detected_objects_str (str): The string representation of detected objects in the image.
317
+
318
+ Returns:
319
+ str: The generated answer to the question.
320
+ """
321
+
322
+ free_gpu_resources()
323
+ prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
324
+ num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
325
+ self.current_prompt_length = num_tokens
326
+ trim = False # flag used to check if prompt trim is required or no.
327
+ # max_context_window is set to 4,000 tokens, refer to the config file.
328
+ if self.current_prompt_length > self.max_context_window:
329
+ trim = True
330
+ st.warning(
331
+ f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
332
+ f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
333
+ f" maximum context window ...")
334
+ # an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
335
+ while self.current_prompt_length > self.max_context_window:
336
+ detected_objects_str = self.trim_objects(detected_objects_str)
337
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
338
+ self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
339
+
340
+ if detected_objects_str == "":
341
+ break # Break if no objects are left
342
+ if trim:
343
+ st.warning(f"New prompt length is: {self.current_prompt_length}")
344
  trim = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
347
+ free_gpu_resources()
348
+ input_ids = model_inputs["input_ids"]
349
+ output_ids = self.kbvqa_model.generate(input_ids)
350
+ free_gpu_resources()
351
+ index = input_ids.shape[1] # needed to avoid printing the input prompt
352
+ history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
353
+ output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
354
+
355
+ return output_text.capitalize()
356
+
357
 
358
  def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
359
  """
360
  Prepares the KBVQA model for use, including loading necessary sub-models.
361
 
362
+ This serves as the main function for loading and reloading the KB-VQA model.
363
+
364
  Args:
365
  only_reload_detection_model (bool): If True, only the object detection model is reloaded.
366
+ force_reload (bool): If True, forces the reload of all models.
367
 
368
  Returns:
369
  KBVQA: An instance of the KBVQA model ready for inference.
370
  """
371
+
372
  if force_reload:
373
  free_gpu_resources()
374
  loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!'
375
  try:
376
+ del st.session_state['kbvqa']
377
  free_gpu_resources()
378
  free_gpu_resources()
379
  except:
 
381
  free_gpu_resources()
382
  pass
383
  free_gpu_resources()
384
+
385
+ else:
386
+ loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!'
387
 
388
  free_gpu_resources()
389
  kbvqa = KBVQA()
390
  kbvqa.detection_model = st.session_state.detection_model
391
  # Progress bar for model loading
392
+
393
  with st.spinner(loading_message):
394
  if not only_reload_detection_model:
395
  progress_bar = st.progress(0)
 
407
  progress_bar = st.progress(0)
408
  kbvqa.load_detector(kbvqa.detection_model)
409
  progress_bar.progress(100)
410
+
411
  if kbvqa.all_models_loaded:
412
  st.success('Model loaded successfully and ready for inferecne!')
413
  kbvqa.kbvqa_model.eval()
414
  free_gpu_resources()
415
  return kbvqa
416
 
417
+
418
+ if __name__ == "__main__":
419
+ pass
420
+
421
+ #### Example on how to use the module ####
422
+
423
+ # Prepare the KBVQA model
424
+ # kbvqa = prepare_kbvqa_model()
425
+
426
+ # Load an image
427
+ # image = Image.open('path_to_image.jpg')
428
+
429
+ # Generate a caption for the image
430
+ # caption = kbvqa.get_caption(image)
431
+
432
+ # Detect objects in the image
433
+ # image_with_boxes, detected_objects_str = kbvqa.detect_objects(image)
434
+
435
+ # Generate an answer to a question about the image
436
+ # question = "What is the object in the image?"
437
+ # answer = kbvqa.generate_answer(question, caption, detected_objects_str)
438
+
439
+ # print(f"Answer: {answer}")
440
+