m7mdal7aj commited on
Commit
61347b1
1 Parent(s): fbd7ace

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +29 -29
my_model/KBVQA.py CHANGED
@@ -50,7 +50,7 @@ class KBVQA:
50
 
51
  def __init__(self):
52
 
53
- self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
54
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME
55
  self.quantization: str = config.QUANTIZATION
56
  self.max_context_window: int = config.MAX_CONTEXT_WINDOW
@@ -245,7 +245,7 @@ class KBVQA:
245
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
246
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
247
  if num_tokens > self.max_context_window:
248
- self.col2.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
249
  return
250
 
251
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
@@ -272,33 +272,33 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload:
272
  kbvqa = KBVQA()
273
  kbvqa.detection_model = st.session_state.detection_model
274
  # Progress bar for model loading
275
- with kbvqa.col1:
276
- if force_reload:
277
- self.delete_model()
278
- loading_message = 'Force Reloading model.. this should take no more than a few minutes!'
279
- else: loading_message = 'Looading model.. this should take no more than a few minutes!'
280
-
281
- with st.spinner(loading_message):
282
- if not only_reload_detection_model:
283
- progress_bar = st.progress(0)
284
- kbvqa.load_detector(kbvqa.detection_model)
285
- progress_bar.progress(33)
286
- kbvqa.load_caption_model()
287
- free_gpu_resources()
288
- progress_bar.progress(75)
289
- st.text('Almost there :)')
290
- kbvqa.load_fine_tuned_model()
291
- free_gpu_resources()
292
- progress_bar.progress(100)
293
- else:
294
- progress_bar = st.progress(0)
295
- kbvqa.load_detector(kbvqa.detection_model)
296
- progress_bar.progress(100)
297
-
298
- if kbvqa.all_models_loaded:
299
- st.success('Model loaded successfully and ready for inferecne!')
300
- kbvqa.kbvqa_model.eval()
301
  free_gpu_resources()
302
- return kbvqa
 
 
 
 
 
 
 
 
 
 
303
 
304
 
 
50
 
51
  def __init__(self):
52
 
53
+ # self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
54
  self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME
55
  self.quantization: str = config.QUANTIZATION
56
  self.max_context_window: int = config.MAX_CONTEXT_WINDOW
 
245
  prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
246
  num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
247
  if num_tokens > self.max_context_window:
248
+ st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
249
  return
250
 
251
  model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
 
272
  kbvqa = KBVQA()
273
  kbvqa.detection_model = st.session_state.detection_model
274
  # Progress bar for model loading
275
+
276
+ if force_reload:
277
+ self.delete_model()
278
+ loading_message = 'Force Reloading model.. this should take no more than a few minutes!'
279
+ else: loading_message = 'Looading model.. this should take no more than a few minutes!'
280
+
281
+ with st.spinner(loading_message):
282
+ if not only_reload_detection_model:
283
+ progress_bar = st.progress(0)
284
+ kbvqa.load_detector(kbvqa.detection_model)
285
+ progress_bar.progress(33)
286
+ kbvqa.load_caption_model()
287
+ free_gpu_resources()
288
+ progress_bar.progress(75)
289
+ st.text('Almost there :)')
290
+ kbvqa.load_fine_tuned_model()
 
 
 
 
 
 
 
 
 
 
291
  free_gpu_resources()
292
+ progress_bar.progress(100)
293
+ else:
294
+ progress_bar = st.progress(0)
295
+ kbvqa.load_detector(kbvqa.detection_model)
296
+ progress_bar.progress(100)
297
+
298
+ if kbvqa.all_models_loaded:
299
+ st.success('Model loaded successfully and ready for inferecne!')
300
+ kbvqa.kbvqa_model.eval()
301
+ free_gpu_resources()
302
+ return kbvqa
303
 
304