m7mdal7aj commited on
Commit
61e10b7
·
verified ·
1 Parent(s): 1a0e500

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +133 -28
my_model/KBVQA.py CHANGED
@@ -3,30 +3,67 @@ import torch
3
  import copy
4
  import os
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
- from typing import Optional
7
  from my_model.gen_utilities import free_gpu_resources
8
  from my_model.captioner.image_captioning import ImageCaptioningModel
9
  from my_model.object_detection import ObjectDetector
 
10
 
11
 
12
  class KBVQA():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def __init__(self):
15
- self.kbvqa_model_name = "m7mdal7aj/fine_tunned_llama_2_merged"
16
- self.quantization='4bit'
17
- self.bnb_config = self.create_bnb_config()
18
- self.max_context_window = 4000
19
- self.add_eos_token = False
20
- self.trust_remote = False
21
- self.use_fast = True
 
22
  self.kbvqa_tokenizer = None
23
  self.captioner = None
24
  self.detector = None
25
  self.detection_model = None
26
  self.detection_confidence = None
27
  self.kbvqa_model = None
28
- self.access_token = os.getenv("HUGGINGFACE_TOKEN")
29
- # self.kbvqa_model_loaded = self.all_models_loaded()
 
30
 
31
 
32
  def create_bnb_config(self) -> BitsAndBytesConfig:
@@ -51,27 +88,59 @@ class KBVQA():
51
  )
52
 
53
 
54
- def load_caption_model(self):
 
 
 
 
55
  self.captioner = ImageCaptioningModel()
56
  self.captioner.load_model()
57
 
58
- def get_caption(self, img):
 
 
 
 
 
 
 
 
 
59
 
60
  return self.captioner.generate_caption(img)
61
 
62
- def load_detector(self, model):
 
 
 
 
 
 
63
 
64
  self.detector = ObjectDetector()
65
  self.detector.load_model(model)
66
 
67
- def detect_objects(self, img):
 
 
 
 
 
 
 
 
 
 
68
  image = self.detector.process_image(img)
69
  detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence)
70
  image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
71
  return image_with_boxes, detected_objects_string
72
 
73
- def load_fine_tuned_model(self):
74
-
 
 
 
75
  self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
76
  device_map="auto",
77
  low_cpu_mem_usage=True,
@@ -88,9 +157,20 @@ class KBVQA():
88
 
89
  @property
90
  def all_models_loaded(self):
 
 
 
 
 
 
 
91
  return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
92
 
93
  def force_reload_model(self):
 
 
 
 
94
  free_gpu_resources()
95
  if self.kbvqa_model is not None:
96
  del self.kbvqa_model
@@ -101,17 +181,24 @@ class KBVQA():
101
 
102
  free_gpu_resources()
103
 
104
-
105
-
106
-
107
-
108
 
 
 
 
109
 
 
 
 
 
 
 
110
 
111
- def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
 
 
112
 
113
  if sys_prompt is None:
114
- sys_prompt = "You are a helpful, respectful and honest assistant for visual question answering. you are provided with a caption of an image and a list of objects detected in the image along with their bounding boxes and level of certainty, you will output an answer to the given questions in no more than one sentence. Use logical reasoning to reach to the answer, but do not output your reasoning process unless asked for it. If provided, you will use the [CAP] and [/CAP] tags to indicate the begining and end of the caption respectively. If provided you will use the [OBJ] and [/OBJ] tags to indicate the begining and end of the list of detected objects in the image along with their bounding boxes respectively.if provided, you will use [QES] and [/QES] tags to indicate the begining and end of the question respectively."
115
 
116
  B_SENT = '<s>'
117
  E_SENT = '</s>'
@@ -126,7 +213,6 @@ class KBVQA():
126
  B_OBJ = '[OBJ]'
127
  E_OBJ = '[/OBJ]'
128
 
129
-
130
  current_query = current_query.strip()
131
  sys_prompt = sys_prompt.strip()
132
 
@@ -138,11 +224,21 @@ class KBVQA():
138
  else:
139
  p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
140
 
141
-
142
  return p
143
 
144
 
145
- def generate_answer(self, question, caption, detected_objects_str,):
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
148
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
@@ -159,7 +255,17 @@ class KBVQA():
159
 
160
  return output_text.capitalize()
161
 
162
- def prepare_kbvqa_model(only_reload_detection_model=False):
 
 
 
 
 
 
 
 
 
 
163
  free_gpu_resources()
164
  kbvqa = KBVQA()
165
  kbvqa.detection_model = st.session_state.detection_model
@@ -177,12 +283,11 @@ def prepare_kbvqa_model(only_reload_detection_model=False):
177
  kbvqa.load_fine_tuned_model()
178
  free_gpu_resources()
179
  progress_bar.progress(100)
180
-
181
  else:
182
  progress_bar = st.progress(0)
183
  kbvqa.load_detector(kbvqa.detection_model)
184
  progress_bar.progress(100)
185
-
186
  if kbvqa.all_models_loaded:
187
  st.success('Model loaded successfully and ready for inferecne!')
188
  kbvqa.kbvqa_model.eval()
 
3
  import copy
4
  import os
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
+ from typing import Tuple, Optional
7
  from my_model.gen_utilities import free_gpu_resources
8
  from my_model.captioner.image_captioning import ImageCaptioningModel
9
  from my_model.object_detection import ObjectDetector
10
+ import my_model.config.kbvqa_config as config
11
 
12
 
13
  class KBVQA():
14
+ """
15
+ The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
16
+ It integrates various components such as an image captioning model, object detection model, and a fine-tuned
17
+ language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.
18
+
19
+ Attributes:
20
+ kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA.
21
+ quantization (str): The quantization setting for the model (e.g., '4bit', '8bit').
22
+ max_context_window (int): The maximum number of tokens allowed in the model's context window.
23
+ add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer.
24
+ trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer.
25
+ use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer.
26
+ low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage.
27
+ kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model.
28
+ captioner (Optional[ImageCaptioningModel]): The model used for generating image captions.
29
+ detector (Optional[ObjectDetector]): The object detection model.
30
+ detection_model (Optional[str]): The name of the object detection model.
31
+ detection_confidence (Optional[float]): The confidence threshold for object detection.
32
+ kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
33
+ bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
34
+ access_token (str): Access token for Hugging Face API.
35
+
36
+ Methods:
37
+ create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
38
+ load_caption_model: Loads the image captioning model.
39
+ get_caption: Generates a caption for a given image.
40
+ load_detector: Loads the object detection model.
41
+ detect_objects: Detects objects in a given image.
42
+ load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer.
43
+ all_models_loaded: Checks if all the required models are loaded.
44
+ force_reload_model: Forces a reload of all models, freeing up GPU resources.
45
+ format_prompt: Formats the prompt for the KBVQA model.
46
+ generate_answer: Generates an answer to a given question using the KBVQA model.
47
+ """
48
 
49
  def __init__(self):
50
+
51
+ self.kbvqa_model_name = config.KBVQA_MODEL_NAME
52
+ self.quantization = config.QUANTIZATION
53
+ self.max_context_window = config.MAX_CONTEXT_WINDOW
54
+ self.add_eos_token = config.ADD_EOS_TOKEN
55
+ self.trust_remote = config.TRUST_REMOTE
56
+ self.use_fast = config.USE_FAST
57
+ self.low_cpu_mem_usage=config.LOW_CPU_MEM_USAGE
58
  self.kbvqa_tokenizer = None
59
  self.captioner = None
60
  self.detector = None
61
  self.detection_model = None
62
  self.detection_confidence = None
63
  self.kbvqa_model = None
64
+ self.bnb_config = self.create_bnb_config()
65
+ self.access_token = config.HUGGINGFACE_TOKEN
66
+
67
 
68
 
69
  def create_bnb_config(self) -> BitsAndBytesConfig:
 
88
  )
89
 
90
 
91
+ def load_caption_model(self) -> None:
92
+ """
93
+ Loads the image captioning model into the KBVQA instance.
94
+ """
95
+
96
  self.captioner = ImageCaptioningModel()
97
  self.captioner.load_model()
98
 
99
+ def get_caption(self, img: Image.Image) -> str:
100
+ """
101
+ Generates a caption for a given image using the image captioning model.
102
+
103
+ Args:
104
+ img (PIL.Image.Image): The image for which to generate a caption.
105
+
106
+ Returns:
107
+ str: The generated caption for the image.
108
+ """
109
 
110
  return self.captioner.generate_caption(img)
111
 
112
+ def load_detector(self, model: str) -> None:
113
+ """
114
+ Loads the object detection model.
115
+
116
+ Args:
117
+ model (str): The name of the object detection model to load.
118
+ """
119
 
120
  self.detector = ObjectDetector()
121
  self.detector.load_model(model)
122
 
123
+ def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
124
+ """
125
+ Detects objects in a given image using the loaded object detection model.
126
+
127
+ Args:
128
+ img (PIL.Image.Image): The image in which to detect objects.
129
+
130
+ Returns:
131
+ tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
132
+ """
133
+
134
  image = self.detector.process_image(img)
135
  detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence)
136
  image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
137
  return image_with_boxes, detected_objects_string
138
 
139
+ def load_fine_tuned_model(self) -> None:
140
+ """
141
+ Loads the fine-tuned KBVQA model along with its tokenizer.
142
+ """
143
+
144
  self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
145
  device_map="auto",
146
  low_cpu_mem_usage=True,
 
157
 
158
  @property
159
  def all_models_loaded(self):
160
+ """
161
+ Checks if all the required models (KBVQA, captioner, detector) are loaded.
162
+
163
+ Returns:
164
+ bool: True if all models are loaded, False otherwise.
165
+ """
166
+
167
  return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
168
 
169
  def force_reload_model(self):
170
+ """
171
+ Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
172
+ """
173
+
174
  free_gpu_resources()
175
  if self.kbvqa_model is not None:
176
  del self.kbvqa_model
 
181
 
182
  free_gpu_resources()
183
 
 
 
 
 
184
 
185
+ def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
186
+ """
187
+ Formats the prompt for the KBVQA model based on the provided parameters.
188
 
189
+ Args:
190
+ current_query (str): The current question to be answered.
191
+ history (str, optional): The history of previous interactions.
192
+ sys_prompt (str, optional): The system prompt or instructions for the model.
193
+ caption (str, optional): The caption of the image.
194
+ objects (str, optional): The detected objects in the image.
195
 
196
+ Returns:
197
+ str: The formatted prompt for the KBVQA model.
198
+ """
199
 
200
  if sys_prompt is None:
201
+ sys_prompt = config.SYSTEM_PROMPT
202
 
203
  B_SENT = '<s>'
204
  E_SENT = '</s>'
 
213
  B_OBJ = '[OBJ]'
214
  E_OBJ = '[/OBJ]'
215
 
 
216
  current_query = current_query.strip()
217
  sys_prompt = sys_prompt.strip()
218
 
 
224
  else:
225
  p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
226
 
 
227
  return p
228
 
229
 
230
+ def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
231
+ """
232
+ Generates an answer to a given question using the KBVQA model.
233
+
234
+ Args:
235
+ question (str): The question to be answered.
236
+ caption (str): The caption of the image related to the question.
237
+ detected_objects_str (str): The string representation of detected objects in the image.
238
+
239
+ Returns:
240
+ str: The generated answer to the question.
241
+ """
242
 
243
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
244
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
 
255
 
256
  return output_text.capitalize()
257
 
258
+ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
259
+ """
260
+ Prepares the KBVQA model for use, including loading necessary sub-models.
261
+
262
+ Args:
263
+ only_reload_detection_model (bool): If True, only the object detection model is reloaded.
264
+
265
+ Returns:
266
+ KBVQA: An instance of the KBVQA model ready for inference.
267
+ """
268
+
269
  free_gpu_resources()
270
  kbvqa = KBVQA()
271
  kbvqa.detection_model = st.session_state.detection_model
 
283
  kbvqa.load_fine_tuned_model()
284
  free_gpu_resources()
285
  progress_bar.progress(100)
 
286
  else:
287
  progress_bar = st.progress(0)
288
  kbvqa.load_detector(kbvqa.detection_model)
289
  progress_bar.progress(100)
290
+
291
  if kbvqa.all_models_loaded:
292
  st.success('Model loaded successfully and ready for inferecne!')
293
  kbvqa.kbvqa_model.eval()