m7mdal7aj commited on
Commit
f958c4b
1 Parent(s): 810a2b0

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +5 -5
my_model/KBVQA.py CHANGED
@@ -9,7 +9,7 @@ from my_model.utilities.gen_utilities import free_gpu_resources
9
  from my_model.captioner.image_captioning import ImageCaptioningModel
10
  from my_model.object_detection import ObjectDetector
11
  import my_model.config.kbvqa_config as config
12
- from my_model.state_manager import StateManager
13
 
14
 
15
  class KBVQA(StateManager):
@@ -50,7 +50,7 @@ class KBVQA(StateManager):
50
 
51
  def __init__(self):
52
 
53
- super().__init__()
54
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME
55
  self.quantization: str = config.QUANTIZATION
56
  self.max_context_window: int = config.MAX_CONTEXT_WINDOW
@@ -241,7 +241,7 @@ class KBVQA(StateManager):
241
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
242
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
243
  if num_tokens > self.max_context_window:
244
- st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
245
  return
246
 
247
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
@@ -268,7 +268,7 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
268
  kbvqa = KBVQA()
269
  kbvqa.detection_model = st.session_state.detection_model
270
  # Progress bar for model loading
271
- with st.spinner('Loading model...'):
272
 
273
  if not only_reload_detection_model:
274
  self.col1.text('this should take no more than a few minutes!')
@@ -283,7 +283,7 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
283
  free_gpu_resources()
284
  progress_bar.progress(100)
285
  else:
286
- progress_bar = st.progress(0)
287
  kbvqa.load_detector(kbvqa.detection_model)
288
  progress_bar.progress(100)
289
 
 
9
  from my_model.captioner.image_captioning import ImageCaptioningModel
10
  from my_model.object_detection import ObjectDetector
11
  import my_model.config.kbvqa_config as config
12
+
13
 
14
 
15
  class KBVQA(StateManager):
 
50
 
51
  def __init__(self):
52
 
53
+ self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
54
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME
55
  self.quantization: str = config.QUANTIZATION
56
  self.max_context_window: int = config.MAX_CONTEXT_WINDOW
 
241
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
242
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
243
  if num_tokens > self.max_context_window:
244
+ self.col2.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
245
  return
246
 
247
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
 
268
  kbvqa = KBVQA()
269
  kbvqa.detection_model = st.session_state.detection_model
270
  # Progress bar for model loading
271
+ with self.col1.spinner('Loading model...'):
272
 
273
  if not only_reload_detection_model:
274
  self.col1.text('this should take no more than a few minutes!')
 
283
  free_gpu_resources()
284
  progress_bar.progress(100)
285
  else:
286
+ progress_bar = self.col1.progress(0)
287
  kbvqa.load_detector(kbvqa.detection_model)
288
  progress_bar.progress(100)
289