m7mdal7aj commited on
Commit
d72aea6
1 Parent(s): 677e938

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +4 -1
my_model/KBVQA.py CHANGED
@@ -34,6 +34,7 @@ class KBVQA:
34
  kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
35
  bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
36
  access_token (str): Access token for Hugging Face API.
 
37
 
38
  Methods:
39
  create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
@@ -66,6 +67,7 @@ class KBVQA:
66
  self.kbvqa_model: Optional[AutoModelForCausalLM] = None
67
  self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
68
  self.access_token: str = config.HUGGINGFACE_TOKEN
 
69
 
70
 
71
  def create_bnb_config(self) -> BitsAndBytesConfig:
@@ -227,8 +229,9 @@ class KBVQA:
227
 
228
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
229
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
 
230
  if num_tokens > self.max_context_window:
231
- st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
232
  return
233
 
234
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
 
34
  kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
35
  bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
36
  access_token (str): Access token for Hugging Face API.
37
+ current_prompt_length (int): Prompt length.
38
 
39
  Methods:
40
  create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
 
67
  self.kbvqa_model: Optional[AutoModelForCausalLM] = None
68
  self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
69
  self.access_token: str = config.HUGGINGFACE_TOKEN
70
+ self.current_prompt_length = None
71
 
72
 
73
  def create_bnb_config(self) -> BitsAndBytesConfig:
 
229
 
230
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
231
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
232
+ self.current_prompt_length = num_tokens
233
  if num_tokens > self.max_context_window:
234
+ st.warning(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
235
  return
236
 
237
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')