m7mdal7aj commited on
Commit
33da84c
1 Parent(s): d148f27

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +7 -3
my_model/KBVQA.py CHANGED
@@ -224,7 +224,7 @@ class KBVQA:
224
  return p
225
 
226
  @staticmethod
227
- def trim_objects(self, detected_objects_str):
228
  """
229
  Trim the last object from the detected objects string.
230
 
@@ -257,7 +257,9 @@ class KBVQA:
257
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
258
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
259
  self.current_prompt_length = num_tokens
260
-
 
 
261
  while self.current_prompt_length > self.max_context_window:
262
  detected_objects_str = self.trim_objects(detected_objects_str)
263
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
@@ -265,7 +267,9 @@ class KBVQA:
265
 
266
  if detected_objects_str == "":
267
  break # Break if no objects are left
268
-
 
 
269
 
270
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
271
  free_gpu_resources()
 
224
  return p
225
 
226
  @staticmethod
227
+ def trim_objects(detected_objects_str):
228
  """
229
  Trim the last object from the detected objects string.
230
 
 
257
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
258
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
259
  self.current_prompt_length = num_tokens
260
+ if self.current_prompt_length > self.max_context_window:
261
+ trim = True
262
+ st.warning(f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2, objects detected with low confidence will be removed one at a time until the prompt length is within the maximum context window ...")
263
  while self.current_prompt_length > self.max_context_window:
264
  detected_objects_str = self.trim_objects(detected_objects_str)
265
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
 
267
 
268
  if detected_objects_str == "":
269
  break # Break if no objects are left
270
+ if trim:
271
+ st.warning(f"New prompt length is: {self.current_prompt_length}")
272
+ trim = False
273
 
274
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
275
  free_gpu_resources()