m7mdal7aj commited on
Commit
7138ab3
1 Parent(s): 44cee00

Update my_model/state_manager.py

Browse files
Files changed (1) hide show
  1. my_model/state_manager.py +119 -49
my_model/state_manager.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import pandas as pd
2
  import copy
3
  import time
@@ -8,25 +12,25 @@ from my_model.utilities.gen_utilities import free_gpu_resources
8
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
9
 
10
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- class StateManager:
14
 
15
- # Hints for methods
16
- # initialize_state: Initializes default values for session state.
17
- # set_up_widgets: Creates UI elements for model selection and settings.
18
- # set_slider_value: Generates a slider widget for numerical input.
19
- # is_widget_disabled: Returns True if UI elements should be disabled.
20
- # disable_widgets: Disables interactive UI elements during processing.
21
- # settings_changed: Checks if any model settings have changed.
22
- # confidance_change: Determines if the confidence level setting has changed.
23
- # display_model_settings: Shows current model settings in the UI.
24
- # display_session_state: Displays the current state of the application.
25
- # update_prev_state: Updates the record of the previous application state.
26
- # force_reload_model: Reloads the model, clearing and resetting necessary states.
27
-
28
-
29
- def __init__(self):
30
  """
31
  Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
32
  """
@@ -34,10 +38,12 @@ class StateManager:
34
  # Create three columns with different widths
35
  self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
36
 
37
- def initialize_state(self):
 
38
  """
39
  Initializes the Streamlit session state with default values for various keys.
40
  """
 
41
  if "previous_state" not in st.session_state:
42
  st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
43
  if 'images_data' not in st.session_state:
@@ -59,9 +65,13 @@ class StateManager:
59
  if 'model_loaded' not in st.session_state:
60
  st.session_state['model_loaded'] = self.is_model_loaded
61
 
 
62
  def set_up_widgets(self) -> None:
63
  """
64
  Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
 
 
 
65
  """
66
 
67
  self.col1.selectbox("Choose a model:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"], index=1, key='method', disabled=self.is_widget_disabled)
@@ -70,14 +80,11 @@ class StateManager:
70
  self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
71
 
72
  # Conditional display of model settings
73
-
74
-
75
  show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
76
  if show_model_settings:
77
  self.display_model_settings
78
 
79
-
80
-
81
  def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
82
  """
83
  Creates a slider widget with the specified parameters, optionally placing it in a specific column.
@@ -90,6 +97,9 @@ class StateManager:
90
  step (float): Step size for the slider.
91
  slider_key_name (str): Unique key for the slider.
92
  col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
 
 
 
93
  """
94
 
95
  if col is None:
@@ -99,30 +109,41 @@ class StateManager:
99
 
100
 
101
  @property
102
- def is_widget_disabled(self):
 
 
 
 
 
 
 
103
  return st.session_state['loading_in_progress']
104
 
105
- def disable_widgets(self):
106
  """
107
  Disables widgets by setting the 'loading_in_progress' state to True.
 
 
 
108
  """
109
 
110
  st.session_state['loading_in_progress'] = True
111
 
112
 
113
  @property
114
- def settings_changed(self):
115
  """
116
  Checks if any model settings have changed compared to the previous state.
117
 
118
  Returns:
119
  bool: True if any setting has changed, False otherwise.
120
  """
 
121
  return self.has_state_changed()
122
 
123
 
124
  @property
125
- def confidance_change(self):
126
  """
127
  Checks if the confidence level setting has changed compared to the previous state.
128
 
@@ -133,9 +154,12 @@ class StateManager:
133
  return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
134
 
135
 
136
- def update_prev_state(self):
137
  """
138
  Updates the 'previous_state' in the session state with the current state values.
 
 
 
139
  """
140
 
141
  for key in st.session_state['previous_state']:
@@ -151,6 +175,9 @@ class StateManager:
151
  - Sets the detection confidence level on the model object.
152
  - Updates previous state with current settings for change detection.
153
  - Updates the button label to "Reload Model".
 
 
 
154
  """
155
 
156
  try:
@@ -166,6 +193,7 @@ class StateManager:
166
 
167
  except Exception as e:
168
  st.error(f"Error loading model: {e}")
 
169
 
170
  def force_reload_model(self) -> None:
171
  """
@@ -175,9 +203,11 @@ class StateManager:
175
  - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
176
  - Updates the detection confidence level on the model object.
177
  - Displays a success message if the model is reloaded successfully.
 
 
 
178
  """
179
 
180
-
181
  try:
182
  self.delete_model()
183
  free_gpu_resources()
@@ -191,10 +221,14 @@ class StateManager:
191
  except Exception as e:
192
  st.error(f"Error reloading model: {e}")
193
  free_gpu_resources()
 
194
 
195
  def delete_model(self) -> None:
196
  """
197
  This method deletes the current models and calls `free_gpu_resources`.
 
 
 
198
  """
199
 
200
  free_gpu_resources()
@@ -210,11 +244,10 @@ class StateManager:
210
  pass
211
 
212
 
213
- # Function to check if any session state values have changed
214
  def has_state_changed(self) -> bool:
215
  """
216
  Compares current session state with the previous state to identify changes.
217
-
218
  Returns:
219
  bool: True if any change is found, False otherwise.
220
  """
@@ -227,14 +260,17 @@ class StateManager:
227
  else: return False # No changes found
228
 
229
 
230
- def get_model(self) -> KBVQA:
231
  """
232
- Retrieve the KBVQA model from the session state.
233
-
234
- Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
 
235
  """
 
236
  return st.session_state.get('kbvqa', None)
237
 
 
238
  @property
239
  def is_model_loaded(self) -> bool:
240
  """
@@ -243,6 +279,7 @@ class StateManager:
243
  Returns:
244
  bool: True if the model is loaded, False otherwise.
245
  """
 
246
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
247
  st.session_state.kbvqa.all_models_loaded \
248
  and (st.session_state['previous_state']['method'] is not None
@@ -258,6 +295,9 @@ class StateManager:
258
  - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
259
  - Updates detection confidence level on the model object.
260
  - Displays a success message if model is reloaded successfully.
 
 
 
261
  """
262
 
263
  try:
@@ -288,7 +328,9 @@ class StateManager:
288
  Args:
289
  image_key (str): Unique key for the image.
290
  image (obj): The uploaded image data.
291
-
 
 
292
  """
293
 
294
  if image_key not in st.session_state['images_data']:
@@ -300,6 +342,7 @@ class StateManager:
300
  'analysis_done': False
301
  }
302
 
 
303
 
304
  def analyze_image(self, image) -> Tuple[str, str, object]:
305
  """
@@ -313,7 +356,6 @@ class StateManager:
313
  Args:
314
  image (obj): The image data to analyze.
315
 
316
-
317
  Returns:
318
  tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
319
  """
@@ -329,28 +371,34 @@ class StateManager:
329
 
330
  def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
331
  """
332
- Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
333
 
334
  Args:
335
  image_key (str): Unique key for the image.
336
  question (str): The question asked about the image.
337
  answer (str): The answer generated by the KBVQA model.
 
 
 
 
338
  """
 
339
  if image_key in st.session_state['images_data']:
340
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
341
 
342
 
343
- def get_images_data(self):
344
  """
345
  Returns the dictionary containing processed image data from the session state.
346
 
347
  Returns:
348
  dict: The dictionary storing information about processed images.
349
  """
 
350
  return st.session_state['images_data']
351
 
352
 
353
- def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
354
  """
355
  Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
356
 
@@ -359,6 +407,9 @@ class StateManager:
359
  caption (str): The generated caption for the image.
360
  detected_objects_str (str): String representation of detected objects.
361
  analysis_done (bool): Flag indicating if analysis of the image is complete.
 
 
 
362
  """
363
  if image_key in st.session_state['images_data']:
364
  st.session_state['images_data'][image_key].update({
@@ -368,18 +419,18 @@ class StateManager:
368
  })
369
 
370
 
371
- def resize_image(self, image_input, new_width=None, new_height=None):
372
  """
373
- Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
374
  If both new_width and new_height are provided, the image is resized to those dimensions.
375
 
376
  Args:
377
- image (PIL.Image.Image): The image to resize.
378
- new_width (int, optional): The target width of the image.
379
- new_height (int, optional): The target height of the image.
380
 
381
  Returns:
382
- PIL.Image.Image: The resized image.
383
  """
384
 
385
  img = copy.deepcopy(image_input)
@@ -411,7 +462,18 @@ class StateManager:
411
 
412
 
413
 
414
- def display_message(self, message, message_type):
 
 
 
 
 
 
 
 
 
 
 
415
  if message_type == "warning":
416
  st.warning(message)
417
  elif message_type == "text":
@@ -424,10 +486,12 @@ class StateManager:
424
 
425
 
426
  @property
427
- def display_model_settings(self):
428
  """
429
  Displays a table of current model settings in the third column.
430
-
 
 
431
  """
432
  self.col3.write("##### Current Model Settings:")
433
  data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
@@ -435,9 +499,15 @@ class StateManager:
435
  return self.col3.write(df)
436
 
437
 
438
- def display_session_state(self, col):
439
  """
440
- Displays a table of the complete application state..
 
 
 
 
 
 
441
  """
442
 
443
  col.write("Current Model:")
 
1
+ # This module contains the StateManager class.
2
+ # The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, and test the models.
3
+
4
+
5
  import pandas as pd
6
  import copy
7
  import time
 
12
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
13
 
14
 
15
+ class StateManager:
16
+ """
17
+ Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
18
 
19
+ This class includes methods to initialize the session state, set up various UI widgets for model selection and settings,
20
+ manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images.
21
+ It tracks changes to the application's state to ensure the correct configuration is maintained.
22
+ Additionally, it provides methods to display the current model settings and the complete application state within the Streamlit interface.
23
+
24
+ The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, and test the models.
25
+
26
+ Attributes:
27
+ col1 (streamlit.columns): The first column in the Streamlit layout.
28
+ col2 (streamlit.columns): The second column in the Streamlit layout.
29
+ col3 (streamlit.columns): The third column in the Streamlit layout.
30
+ """
31
 
 
32
 
33
+ def __init__(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  """
35
  Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
36
  """
 
38
  # Create three columns with different widths
39
  self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
40
 
41
+
42
+ def initialize_state(self) -> None:
43
  """
44
  Initializes the Streamlit session state with default values for various keys.
45
  """
46
+
47
  if "previous_state" not in st.session_state:
48
  st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
49
  if 'images_data' not in st.session_state:
 
65
  if 'model_loaded' not in st.session_state:
66
  st.session_state['model_loaded'] = self.is_model_loaded
67
 
68
+
69
  def set_up_widgets(self) -> None:
70
  """
71
  Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
72
+
73
+ Returns:
74
+ None
75
  """
76
 
77
  self.col1.selectbox("Choose a model:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"], index=1, key='method', disabled=self.is_widget_disabled)
 
80
  self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
81
 
82
  # Conditional display of model settings
 
 
83
  show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
84
  if show_model_settings:
85
  self.display_model_settings
86
 
87
+
 
88
  def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
89
  """
90
  Creates a slider widget with the specified parameters, optionally placing it in a specific column.
 
97
  step (float): Step size for the slider.
98
  slider_key_name (str): Unique key for the slider.
99
  col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
100
+
101
+ Returns:
102
+ None
103
  """
104
 
105
  if col is None:
 
109
 
110
 
111
  @property
112
+ def is_widget_disabled(self) -> bool:
113
+ """
114
+ Checks if widgets should be disabled based on the 'loading_in_progress' state.
115
+
116
+ Returns:
117
+ bool: True if widgets should be disabled, False otherwise.
118
+ """
119
+
120
  return st.session_state['loading_in_progress']
121
 
122
+ def disable_widgets(self) -> None:
123
  """
124
  Disables widgets by setting the 'loading_in_progress' state to True.
125
+
126
+ Returns:
127
+ None
128
  """
129
 
130
  st.session_state['loading_in_progress'] = True
131
 
132
 
133
  @property
134
+ def settings_changed(self) -> bool:
135
  """
136
  Checks if any model settings have changed compared to the previous state.
137
 
138
  Returns:
139
  bool: True if any setting has changed, False otherwise.
140
  """
141
+
142
  return self.has_state_changed()
143
 
144
 
145
  @property
146
+ def confidance_change(self) -> bool:
147
  """
148
  Checks if the confidence level setting has changed compared to the previous state.
149
 
 
154
  return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
155
 
156
 
157
+ def update_prev_state(self) -> None:
158
  """
159
  Updates the 'previous_state' in the session state with the current state values.
160
+
161
+ Returns:
162
+ None
163
  """
164
 
165
  for key in st.session_state['previous_state']:
 
175
  - Sets the detection confidence level on the model object.
176
  - Updates previous state with current settings for change detection.
177
  - Updates the button label to "Reload Model".
178
+
179
+ Returns:
180
+ None
181
  """
182
 
183
  try:
 
193
 
194
  except Exception as e:
195
  st.error(f"Error loading model: {e}")
196
+
197
 
198
  def force_reload_model(self) -> None:
199
  """
 
203
  - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
204
  - Updates the detection confidence level on the model object.
205
  - Displays a success message if the model is reloaded successfully.
206
+
207
+ Returns:
208
+ None
209
  """
210
 
 
211
  try:
212
  self.delete_model()
213
  free_gpu_resources()
 
221
  except Exception as e:
222
  st.error(f"Error reloading model: {e}")
223
  free_gpu_resources()
224
+
225
 
226
  def delete_model(self) -> None:
227
  """
228
  This method deletes the current models and calls `free_gpu_resources`.
229
+
230
+ Returns:
231
+ None
232
  """
233
 
234
  free_gpu_resources()
 
244
  pass
245
 
246
 
 
247
  def has_state_changed(self) -> bool:
248
  """
249
  Compares current session state with the previous state to identify changes.
250
+
251
  Returns:
252
  bool: True if any change is found, False otherwise.
253
  """
 
260
  else: return False # No changes found
261
 
262
 
263
+ def get_model(self) -> KBVQA.KBVQA():
264
  """
265
+ Retrieves the KBVQA model from the session state.
266
+
267
+ Returns:
268
+ KBVQA: The loaded KBVQA model, or None if not loaded.
269
  """
270
+
271
  return st.session_state.get('kbvqa', None)
272
 
273
+
274
  @property
275
  def is_model_loaded(self) -> bool:
276
  """
 
279
  Returns:
280
  bool: True if the model is loaded, False otherwise.
281
  """
282
+
283
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
284
  st.session_state.kbvqa.all_models_loaded \
285
  and (st.session_state['previous_state']['method'] is not None
 
295
  - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
296
  - Updates detection confidence level on the model object.
297
  - Displays a success message if model is reloaded successfully.
298
+
299
+ Returns:
300
+ None
301
  """
302
 
303
  try:
 
328
  Args:
329
  image_key (str): Unique key for the image.
330
  image (obj): The uploaded image data.
331
+
332
+ Returns:
333
+ None
334
  """
335
 
336
  if image_key not in st.session_state['images_data']:
 
342
  'analysis_done': False
343
  }
344
 
345
+
346
 
347
  def analyze_image(self, image) -> Tuple[str, str, object]:
348
  """
 
356
  Args:
357
  image (obj): The image data to analyze.
358
 
 
359
  Returns:
360
  tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
361
  """
 
371
 
372
  def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
373
  """
374
+ Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.
375
 
376
  Args:
377
  image_key (str): Unique key for the image.
378
  question (str): The question asked about the image.
379
  answer (str): The answer generated by the KBVQA model.
380
+ prompt_length (int): The length of the prompt used for generating the answer.
381
+
382
+ Returns:
383
+ None
384
  """
385
+
386
  if image_key in st.session_state['images_data']:
387
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
388
 
389
 
390
+ def get_images_data(self) -> Dict:
391
  """
392
  Returns the dictionary containing processed image data from the session state.
393
 
394
  Returns:
395
  dict: The dictionary storing information about processed images.
396
  """
397
+
398
  return st.session_state['images_data']
399
 
400
 
401
+ def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
402
  """
403
  Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
404
 
 
407
  caption (str): The generated caption for the image.
408
  detected_objects_str (str): String representation of detected objects.
409
  analysis_done (bool): Flag indicating if analysis of the image is complete.
410
+
411
+ Returns:
412
+ None
413
  """
414
  if image_key in st.session_state['images_data']:
415
  st.session_state['images_data'][image_key].update({
 
419
  })
420
 
421
 
422
+ def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
423
  """
424
+ Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
425
  If both new_width and new_height are provided, the image is resized to those dimensions.
426
 
427
  Args:
428
+ image_input (PIL.Image.Image): The image to resize.
429
+ new_width (int, optional): The target width of the image.
430
+ new_height (int, optional): The target height of the image.
431
 
432
  Returns:
433
+ PIL.Image.Image: The resized image.
434
  """
435
 
436
  img = copy.deepcopy(image_input)
 
462
 
463
 
464
 
465
+ def display_message(self, message: str, message_type: str) -> None:
466
+ """
467
+ Displays a message in the Streamlit interface based on the specified message type.
468
+
469
+ Args:
470
+ message (str): The message to display.
471
+ message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').
472
+
473
+ Returns:
474
+ None
475
+ """
476
+
477
  if message_type == "warning":
478
  st.warning(message)
479
  elif message_type == "text":
 
486
 
487
 
488
  @property
489
+ def display_model_settings(self) -> None:
490
  """
491
  Displays a table of current model settings in the third column.
492
+
493
+ Returns:
494
+ None
495
  """
496
  self.col3.write("##### Current Model Settings:")
497
  data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
 
499
  return self.col3.write(df)
500
 
501
 
502
+ def display_session_state(self, col) -> None:
503
  """
504
+ Displays a table of the complete application state in the specified column.
505
+
506
+ Args:
507
+ col (streamlit.columns.Column): The Streamlit column to display the session state.
508
+
509
+ Returns:
510
+ None
511
  """
512
 
513
  col.write("Current Model:")