m7mdal7aj commited on
Commit
d148f27
1 Parent(s): bbdd166

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +27 -4
my_model/KBVQA.py CHANGED
@@ -222,7 +222,22 @@ class KBVQA:
222
  p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
223
 
224
  return p
225
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
228
  """
@@ -236,13 +251,21 @@ class KBVQA:
236
  Returns:
237
  str: The generated answer to the question.
238
  """
 
 
239
  free_gpu_resources()
240
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
241
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
242
  self.current_prompt_length = num_tokens
243
- if num_tokens > self.max_context_window:
244
- st.warning(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
245
- return
 
 
 
 
 
 
246
 
247
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
248
  free_gpu_resources()
 
222
  p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
223
 
224
  return p
225
+
226
+ @staticmethod
227
+ def trim_objects(self, detected_objects_str):
228
+ """
229
+ Trim the last object from the detected objects string.
230
+
231
+ Args:
232
+ - detected_objects_str (str): String containing detected objects.
233
+
234
+ Returns:
235
+ - (str): The string with the last object removed.
236
+ """
237
+ objects = detected_objects_str.strip().split("\n")
238
+ if len(objects) >= 1:
239
+ return "\n".join(objects[:-1])
240
+ return ""
241
 
242
  def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
243
  """
 
251
  Returns:
252
  str: The generated answer to the question.
253
  """
254
+
255
+
256
  free_gpu_resources()
257
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
258
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
259
  self.current_prompt_length = num_tokens
260
+
261
+ while self.current_prompt_length > self.max_context_window:
262
+ detected_objects_str = self.trim_objects(detected_objects_str)
263
+ prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
264
+ self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
265
+
266
+ if detected_objects_str == "":
267
+ break # Break if no objects are left
268
+
269
 
270
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
271
  free_gpu_resources()