m7mdal7aj commited on
Commit
3eeb25f
1 Parent(s): 406614f

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +9 -10
my_model/tabs/run_inference.py CHANGED
@@ -17,7 +17,6 @@ from my_model.state_manager import StateManager
17
  from my_model.config import inference_config as config
18
 
19
 
20
-
21
  class InferenceRunner(StateManager):
22
  """
23
  Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual
@@ -244,16 +243,15 @@ class InferenceRunner(StateManager):
244
  reload_kbvqa = False
245
  reload_detection_model = False
246
  force_reload_full_model = False
247
-
248
 
249
  if self.is_model_loaded and self.settings_changed:
250
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
251
- # self.update_prev_state()
252
  st.session_state.button_label = (
253
  "Reload Model" if (self.is_model_loaded and
254
  st.session_state.kbvqa.detection_model != st.session_state['detection_model']) or
255
- (st.session_state['previous_state']['method'] is not None and
256
- st.session_state['method'] != st.session_state['previous_state']['method'])
257
  else "Load Model"
258
  )
259
 
@@ -269,10 +267,11 @@ class InferenceRunner(StateManager):
269
  fine_tuned_model_already_loaded = True
270
  else:
271
  load_fine_tuned_model = True
272
- elif st.session_state.button_label == "Reload Model"
273
- and st.session_state['method'] != st.session_state['previous_state']['method']: # check if the model size have changed
274
  force_reload_full_model = True
275
- elif (self.is_model_loaded and st.session_state.kbvqa.detection_model != st.session_state['detection_model']):
 
276
  reload_detection_model = True
277
  if nested_col12.button("Force Reload", on_click=self.disable_widgets,
278
  disabled=self.is_widget_disabled):
@@ -298,7 +297,7 @@ class InferenceRunner(StateManager):
298
  st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
299
  st.session_state['loading_in_progress'] = False
300
  st.session_state['model_loaded'] = True
301
-
302
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
303
  self.col1.warning(
304
  f'Model using {st.session_state.method} is desgined but requires large scale data and multiple '
@@ -308,7 +307,7 @@ class InferenceRunner(StateManager):
308
  st.write(st.session_state['previous_state']['method'])
309
  if st.session_state['kbvqa'] is not None:
310
  st.write(st.session_state['kbvqa'].kbvqa_model_name)
311
-
312
  if self.is_model_loaded:
313
  free_gpu_resources()
314
  st.session_state['loading_in_progress'] = False
 
17
  from my_model.config import inference_config as config
18
 
19
 
 
20
  class InferenceRunner(StateManager):
21
  """
22
  Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual
 
243
  reload_kbvqa = False
244
  reload_detection_model = False
245
  force_reload_full_model = False
 
246
 
247
  if self.is_model_loaded and self.settings_changed:
248
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
249
+ # self.update_prev_state()
250
  st.session_state.button_label = (
251
  "Reload Model" if (self.is_model_loaded and
252
  st.session_state.kbvqa.detection_model != st.session_state['detection_model']) or
253
+ (st.session_state['previous_state']['method'] is not None and
254
+ st.session_state['method'] != st.session_state['previous_state']['method'])
255
  else "Load Model"
256
  )
257
 
 
267
  fine_tuned_model_already_loaded = True
268
  else:
269
  load_fine_tuned_model = True
270
+ elif st.session_state.button_label == "Reload Model" and st.session_state['method'] != \
271
+ st.session_state['previous_state']['method']: # check if the model size have changed
272
  force_reload_full_model = True
273
+ elif (self.is_model_loaded and st.session_state.kbvqa.detection_model !=
274
+ st.session_state['detection_model']):
275
  reload_detection_model = True
276
  if nested_col12.button("Force Reload", on_click=self.disable_widgets,
277
  disabled=self.is_widget_disabled):
 
297
  st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
298
  st.session_state['loading_in_progress'] = False
299
  st.session_state['model_loaded'] = True
300
+
301
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
302
  self.col1.warning(
303
  f'Model using {st.session_state.method} is desgined but requires large scale data and multiple '
 
307
  st.write(st.session_state['previous_state']['method'])
308
  if st.session_state['kbvqa'] is not None:
309
  st.write(st.session_state['kbvqa'].kbvqa_model_name)
310
+
311
  if self.is_model_loaded:
312
  free_gpu_resources()
313
  st.session_state['loading_in_progress'] = False