m7mdal7aj commited on
Commit
3e07f22
1 Parent(s): 0cdf3dd

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +11 -3
my_model/KBVQA.py CHANGED
@@ -176,12 +176,16 @@ class KBVQA:
176
  free_gpu_resources()
177
  if self.kbvqa_model is not None:
178
  del self.kbvqa_model
 
179
  if self.captioner is not None:
180
  del self.captioner
 
181
  if self.detector is not None:
182
  del self.detector
183
-
 
184
  free_gpu_resources()
 
185
 
186
 
187
  def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
@@ -253,7 +257,7 @@ class KBVQA:
253
 
254
  return output_text.capitalize()
255
 
256
- def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
257
  """
258
  Prepares the KBVQA model for use, including loading necessary sub-models.
259
 
@@ -269,7 +273,11 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
269
  kbvqa.detection_model = st.session_state.detection_model
270
  # Progress bar for model loading
271
  with kbvqa.col1:
272
- with st.spinner('Loading model.. this should take no more than a few minutes!'):
 
 
 
 
273
  if not only_reload_detection_model:
274
  progress_bar = st.progress(0)
275
  kbvqa.load_detector(kbvqa.detection_model)
 
176
  free_gpu_resources()
177
  if self.kbvqa_model is not None:
178
  del self.kbvqa_model
179
+ free_gpu_resources()
180
  if self.captioner is not None:
181
  del self.captioner
182
+ free_gpu_resources()
183
  if self.detector is not None:
184
  del self.detector
185
+ free_gpu_resources()
186
+
187
  free_gpu_resources()
188
+ prepare_kbvqa_model(only_reload_detection_model=False, force_reload=True)
189
 
190
 
191
  def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
 
257
 
258
  return output_text.capitalize()
259
 
260
+ def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = True) -> KBVQA:
261
  """
262
  Prepares the KBVQA model for use, including loading necessary sub-models.
263
 
 
273
  kbvqa.detection_model = st.session_state.detection_model
274
  # Progress bar for model loading
275
  with kbvqa.col1:
276
+ if force_reload:
277
+ loading_message = 'Force Reloading model.. this should take no more than a few minutes!'
278
+ else: loading_message = 'Looading model.. this should take no more than a few minutes!'
279
+
280
+ with st.spinner(loading_message):
281
  if not only_reload_detection_model:
282
  progress_bar = st.progress(0)
283
  kbvqa.load_detector(kbvqa.detection_model)