m7mdal7aj commited on
Commit
7ea3839
1 Parent(s): 6d914c0

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +17 -11
my_model/KBVQA.py CHANGED
@@ -159,23 +159,29 @@ class KBVQA():
159
 
160
  return output_text.capitalize()
161
 
162
- def prepare_kbvqa_model(detection_model):
163
  free_gpu_resources()
164
  kbvqa = KBVQA()
165
  kbvqa.detection_model = detection_model
166
  # Progress bar for model loading
167
  with st.spinner('Loading model...'):
168
-
169
- progress_bar = st.progress(0)
170
 
171
- kbvqa.load_detector(kbvqa.detection_model)
172
- progress_bar.progress(33)
173
- kbvqa.load_caption_model()
174
- free_gpu_resources()
175
- progress_bar.progress(66)
176
- kbvqa.load_fine_tuned_model()
177
- free_gpu_resources()
178
- progress_bar.progress(100)
 
 
 
 
 
 
 
 
179
 
180
  if kbvqa.all_models_loaded:
181
  st.success('Model loaded successfully!')
 
159
 
160
  return output_text.capitalize()
161
 
162
+ def prepare_kbvqa_model(detection_model, only_reload_detection_model=False):
163
  free_gpu_resources()
164
  kbvqa = KBVQA()
165
  kbvqa.detection_model = detection_model
166
  # Progress bar for model loading
167
  with st.spinner('Loading model...'):
 
 
168
 
169
+ if not only_reload_detection_model:
170
+ progress_bar = st.progress(0)
171
+
172
+ kbvqa.load_detector(kbvqa.detection_model)
173
+ progress_bar.progress(33)
174
+ kbvqa.load_caption_model()
175
+ free_gpu_resources()
176
+ progress_bar.progress(66)
177
+ kbvqa.load_fine_tuned_model()
178
+ free_gpu_resources()
179
+ progress_bar.progress(100)
180
+
181
+ else:
182
+ progress_bar = st.progress(0)
183
+ kbvqa.load_detector(kbvqa.detection_model)
184
+ progress_bar.progress(100)
185
 
186
  if kbvqa.all_models_loaded:
187
  st.success('Model loaded successfully!')