m7mdal7aj commited on
Commit
8286a8f
1 Parent(s): 7138ab3

Update my_model/state_manager.py

Browse files
Files changed (1) hide show
  1. my_model/state_manager.py +89 -97
my_model/state_manager.py CHANGED
@@ -1,5 +1,6 @@
1
  # This module contains the StateManager class.
2
- # The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, and test the models.
 
3
 
4
 
5
  import pandas as pd
@@ -14,14 +15,18 @@ from my_model.KBVQA import KBVQA, prepare_kbvqa_model
14
 
15
  class StateManager:
16
  """
17
- Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
 
18
 
19
- This class includes methods to initialize the session state, set up various UI widgets for model selection and settings,
 
20
  manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images.
21
  It tracks changes to the application's state to ensure the correct configuration is maintained.
22
- Additionally, it provides methods to display the current model settings and the complete application state within the Streamlit interface.
 
23
 
24
- The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, and test the models.
 
25
 
26
  Attributes:
27
  col1 (streamlit.columns): The first column in the Streamlit layout.
@@ -29,28 +34,26 @@ class StateManager:
29
  col3 (streamlit.columns): The third column in the Streamlit layout.
30
  """
31
 
32
-
33
  def __init__(self) -> None:
34
  """
35
  Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
36
  """
37
-
38
  # Create three columns with different widths
39
- self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
40
 
41
-
42
  def initialize_state(self) -> None:
43
  """
44
  Initializes the Streamlit session state with default values for various keys.
45
  """
46
-
47
  if "previous_state" not in st.session_state:
48
  st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
49
  if 'images_data' not in st.session_state:
50
  st.session_state['images_data'] = {}
51
  if 'kbvqa' not in st.session_state:
52
  st.session_state['kbvqa'] = None
53
- if "button_label" not in st.session_state:
54
  st.session_state['button_label'] = "Load Model"
55
  if 'loading_in_progress' not in st.session_state:
56
  st.session_state['loading_in_progress'] = False
@@ -65,7 +68,6 @@ class StateManager:
65
  if 'model_loaded' not in st.session_state:
66
  st.session_state['model_loaded'] = self.is_model_loaded
67
 
68
-
69
  def set_up_widgets(self) -> None:
70
  """
71
  Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
@@ -74,18 +76,22 @@ class StateManager:
74
  None
75
  """
76
 
77
- self.col1.selectbox("Choose a model:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"], index=1, key='method', disabled=self.is_widget_disabled)
78
- detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', disabled=self.is_widget_disabled)
 
 
 
79
  default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
80
- self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
 
81
 
82
  # Conditional display of model settings
83
- show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
84
  if show_model_settings:
85
  self.display_model_settings
86
 
87
-
88
- def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
89
  """
90
  Creates a slider widget with the specified parameters, optionally placing it in a specific column.
91
 
@@ -101,47 +107,46 @@ class StateManager:
101
  Returns:
102
  None
103
  """
104
-
105
  if col is None:
106
- return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
 
107
  else:
108
- return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)
 
109
 
110
-
111
  @property
112
  def is_widget_disabled(self) -> bool:
113
  """
114
  Checks if widgets should be disabled based on the 'loading_in_progress' state.
115
-
116
  Returns:
117
  bool: True if widgets should be disabled, False otherwise.
118
  """
119
-
120
  return st.session_state['loading_in_progress']
121
 
122
  def disable_widgets(self) -> None:
123
  """
124
  Disables widgets by setting the 'loading_in_progress' state to True.
125
-
126
  Returns:
127
  None
128
  """
129
-
130
  st.session_state['loading_in_progress'] = True
131
 
132
-
133
  @property
134
  def settings_changed(self) -> bool:
135
  """
136
  Checks if any model settings have changed compared to the previous state.
137
-
138
  Returns:
139
  bool: True if any setting has changed, False otherwise.
140
  """
141
-
142
  return self.has_state_changed()
143
 
144
-
145
  @property
146
  def confidance_change(self) -> bool:
147
  """
@@ -150,10 +155,9 @@ class StateManager:
150
  Returns:
151
  bool: True if the confidence level has changed, False otherwise.
152
  """
153
-
154
  return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
155
 
156
-
157
  def update_prev_state(self) -> None:
158
  """
159
  Updates the 'previous_state' in the session state with the current state values.
@@ -161,15 +165,14 @@ class StateManager:
161
  Returns:
162
  None
163
  """
164
-
165
  for key in st.session_state['previous_state']:
166
  st.session_state['previous_state'][key] = st.session_state[key]
167
 
168
-
169
  def load_model(self) -> None:
170
  """
171
  Loads the KBVQA model based on the chosen method and settings.
172
-
173
  - Frees GPU resources before loading.
174
  - Calls `prepare_kbvqa_model` to create the model.
175
  - Sets the detection confidence level on the model object.
@@ -179,7 +182,7 @@ class StateManager:
179
  Returns:
180
  None
181
  """
182
-
183
  try:
184
  free_gpu_resources()
185
  st.session_state['kbvqa'] = prepare_kbvqa_model()
@@ -190,14 +193,14 @@ class StateManager:
190
  st.session_state['button_label'] = "Reload Model"
191
  free_gpu_resources()
192
  free_gpu_resources()
193
-
194
  except Exception as e:
195
  st.error(f"Error loading model: {e}")
196
-
197
 
198
  def force_reload_model(self) -> None:
199
  """
200
- Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
 
201
 
202
  - Deletes the current KBVQA model from the session state.
203
  - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
@@ -207,7 +210,7 @@ class StateManager:
207
  Returns:
208
  None
209
  """
210
-
211
  try:
212
  self.delete_model()
213
  free_gpu_resources()
@@ -215,14 +218,13 @@ class StateManager:
215
  st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
216
  # Update the previous state with current session state values
217
  self.update_prev_state()
218
-
219
  st.session_state['model_loaded'] = True
220
  free_gpu_resources()
221
  except Exception as e:
222
  st.error(f"Error reloading model: {e}")
223
  free_gpu_resources()
224
 
225
-
226
  def delete_model(self) -> None:
227
  """
228
  This method deletes the current models and calls `free_gpu_resources`.
@@ -230,9 +232,9 @@ class StateManager:
230
  Returns:
231
  None
232
  """
233
-
234
  free_gpu_resources()
235
-
236
  if self.is_model_loaded:
237
  try:
238
  del st.session_state['kbvqa']
@@ -242,12 +244,11 @@ class StateManager:
242
  free_gpu_resources()
243
  free_gpu_resources()
244
  pass
245
-
246
-
247
  def has_state_changed(self) -> bool:
248
  """
249
  Compares current session state with the previous state to identify changes.
250
-
251
  Returns:
252
  bool: True if any change is found, False otherwise.
253
  """
@@ -255,41 +256,38 @@ class StateManager:
255
  if key == 'confidence_level':
256
  continue # confidence_level tracker is separate
257
  if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
258
-
259
  return True # Found a change
260
- else: return False # No changes found
 
261
 
262
-
263
  def get_model(self) -> KBVQA.KBVQA():
264
  """
265
  Retrieves the KBVQA model from the session state.
266
-
267
  Returns:
268
  KBVQA: The loaded KBVQA model, or None if not loaded.
269
  """
270
-
271
  return st.session_state.get('kbvqa', None)
272
 
273
-
274
  @property
275
  def is_model_loaded(self) -> bool:
276
  """
277
  Checks if the KBVQA model is loaded in the session state.
278
-
279
  Returns:
280
  bool: True if the model is loaded, False otherwise.
281
  """
282
-
283
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
284
- st.session_state.kbvqa.all_models_loaded \
285
- and (st.session_state['previous_state']['method'] is not None
286
- and st.session_state['method'] == st.session_state['previous_state']['method'])
287
 
288
-
289
  def reload_detection_model(self) -> None:
290
  """
291
  Reloads only the detection model of the KBVQA model with updated settings.
292
-
293
  - Frees GPU resources before reloading.
294
  - Checks if the model is already loaded.
295
  - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
@@ -299,7 +297,7 @@ class StateManager:
299
  Returns:
300
  None
301
  """
302
-
303
  try:
304
  free_gpu_resources()
305
  if self.is_model_loaded:
@@ -313,18 +311,18 @@ class StateManager:
313
  except Exception as e:
314
  st.error(f"Error reloading detection model: {e}")
315
 
316
-
317
  def process_new_image(self, image_key: str, image) -> None:
318
  """
319
- Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
320
-
 
321
  This dictionary stores information about each processed image, including:
322
  - `image`: The original image data.
323
  - `caption`: Generated caption for the image.
324
  - `detected_objects_str`: String representation of detected objects.
325
  - `qa_history`: List of questions and answers related to the image.
326
  - `analysis_done`: Flag indicating if analysis is complete.
327
-
328
  Args:
329
  image_key (str): Unique key for the image.
330
  image (obj): The uploaded image data.
@@ -332,7 +330,7 @@ class StateManager:
332
  Returns:
333
  None
334
  """
335
-
336
  if image_key not in st.session_state['images_data']:
337
  st.session_state['images_data'][image_key] = {
338
  'image': image,
@@ -342,8 +340,6 @@ class StateManager:
342
  'analysis_done': False
343
  }
344
 
345
-
346
-
347
  def analyze_image(self, image) -> Tuple[str, str, object]:
348
  """
349
  Analyzes the image using the KBVQA model.
@@ -352,7 +348,7 @@ class StateManager:
352
  - Displays a "Analyzing the image .." message.
353
  - Calls KBVQA methods to generate a caption and detect objects.
354
  - Returns the generated caption, detected objects string, and image with bounding boxes.
355
-
356
  Args:
357
  image (obj): The image data to analyze.
358
 
@@ -368,11 +364,10 @@ class StateManager:
368
  free_gpu_resources()
369
  return caption, detected_objects_str, image_with_boxes
370
 
371
-
372
  def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
373
  """
374
  Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.
375
-
376
  Args:
377
  image_key (str): Unique key for the image.
378
  question (str): The question asked about the image.
@@ -382,26 +377,25 @@ class StateManager:
382
  Returns:
383
  None
384
  """
385
-
386
  if image_key in st.session_state['images_data']:
387
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
388
 
389
-
390
  def get_images_data(self) -> Dict:
391
  """
392
  Returns the dictionary containing processed image data from the session state.
393
-
394
  Returns:
395
  dict: The dictionary storing information about processed images.
396
  """
397
-
398
  return st.session_state['images_data']
399
-
400
-
401
  def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
402
  """
403
- Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
404
-
 
405
  Args:
406
  image_key (str): Unique key for the image.
407
  caption (str): The generated caption for the image.
@@ -418,21 +412,20 @@ class StateManager:
418
  'analysis_done': analysis_done
419
  })
420
 
421
-
422
  def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
423
  """
424
  Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
425
  If both new_width and new_height are provided, the image is resized to those dimensions.
426
-
427
  Args:
428
  image_input (PIL.Image.Image): The image to resize.
429
  new_width (int, optional): The target width of the image.
430
  new_height (int, optional): The target height of the image.
431
-
432
  Returns:
433
  PIL.Image.Image: The resized image.
434
  """
435
-
436
  img = copy.deepcopy(image_input)
437
  if isinstance(img, str):
438
  # Open the image from a file path
@@ -442,7 +435,7 @@ class StateManager:
442
  image = img
443
  else:
444
  raise ValueError("image_input must be a file path or a PIL Image object")
445
-
446
  if new_width is not None and new_height is None:
447
  # Calculate new height to maintain aspect ratio
448
  original_width, original_height = image.size
@@ -455,17 +448,15 @@ class StateManager:
455
  new_width = int(original_width * ratio)
456
  elif new_width is None and new_height is None:
457
  raise ValueError("At least one of new_width or new_height must be provided")
458
-
459
  # Resize the image
460
  resized_image = image.resize((new_width, new_height))
461
  return resized_image
462
 
463
-
464
-
465
  def display_message(self, message: str, message_type: str) -> None:
466
  """
467
  Displays a message in the Streamlit interface based on the specified message type.
468
-
469
  Args:
470
  message (str): The message to display.
471
  message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').
@@ -473,18 +464,18 @@ class StateManager:
473
  Returns:
474
  None
475
  """
476
-
477
  if message_type == "warning":
478
  st.warning(message)
479
  elif message_type == "text":
480
  st.text(message)
481
  elif message_type == "success":
482
- st.success(messae)
483
  elif message_type == "write":
484
  st.write(message)
485
- else: st.error("Message type unknown")
486
-
487
-
488
  @property
489
  def display_model_settings(self) -> None:
490
  """
@@ -494,22 +485,23 @@ class StateManager:
494
  None
495
  """
496
  self.col3.write("##### Current Model Settings:")
497
- data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
 
 
498
  df = pd.DataFrame(data).reset_index(drop=True)
499
  return self.col3.write(df)
500
 
501
-
502
  def display_session_state(self, col) -> None:
503
  """
504
  Displays a table of the complete application state in the specified column.
505
-
506
  Args:
507
  col (streamlit.columns.Column): The Streamlit column to display the session state.
508
 
509
  Returns:
510
  None
511
  """
512
-
513
  col.write("Current Model:")
514
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
515
  df = pd.DataFrame(data).reset_index(drop=True)
 
1
  # This module contains the StateManager class.
2
+ # The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
3
+ # and test the models.
4
 
5
 
6
  import pandas as pd
 
15
 
16
  class StateManager:
17
  """
18
+ Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering
19
+ (KBVQA) application.
20
 
21
+ This class includes methods to initialize the session state, set up various UI widgets for model selection and
22
+ settings,
23
  manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images.
24
  It tracks changes to the application's state to ensure the correct configuration is maintained.
25
+ Additionally, it provides methods to display the current model settings and the complete application state within
26
+ the Streamlit interface.
27
 
28
+ The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
29
+ and test the models.
30
 
31
  Attributes:
32
  col1 (streamlit.columns): The first column in the Streamlit layout.
 
34
  col3 (streamlit.columns): The third column in the Streamlit layout.
35
  """
36
 
 
37
  def __init__(self) -> None:
38
  """
39
  Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
40
  """
41
+
42
  # Create three columns with different widths
43
+ self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
44
 
 
45
  def initialize_state(self) -> None:
46
  """
47
  Initializes the Streamlit session state with default values for various keys.
48
  """
49
+
50
  if "previous_state" not in st.session_state:
51
  st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
52
  if 'images_data' not in st.session_state:
53
  st.session_state['images_data'] = {}
54
  if 'kbvqa' not in st.session_state:
55
  st.session_state['kbvqa'] = None
56
+ if "button_label" not in st.session_state:
57
  st.session_state['button_label'] = "Load Model"
58
  if 'loading_in_progress' not in st.session_state:
59
  st.session_state['loading_in_progress'] = False
 
68
  if 'model_loaded' not in st.session_state:
69
  st.session_state['model_loaded'] = self.is_model_loaded
70
 
 
71
  def set_up_widgets(self) -> None:
72
  """
73
  Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
 
76
  None
77
  """
78
 
79
+ self.col1.selectbox("Choose a model:",
80
+ ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"],
81
+ index=1, key='method', disabled=self.is_widget_disabled)
82
+ detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1,
83
+ key='detection_model', disabled=self.is_widget_disabled)
84
  default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
85
+ self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9,
86
+ value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
87
 
88
  # Conditional display of model settings
89
+ show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
90
  if show_model_settings:
91
  self.display_model_settings
92
 
93
+ def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float,
94
+ slider_key_name: str, col=None) -> None:
95
  """
96
  Creates a slider widget with the specified parameters, optionally placing it in a specific column.
97
 
 
107
  Returns:
108
  None
109
  """
110
+
111
  if col is None:
112
+ return st.slider(text, min_value, max_value, value, step, key=slider_key_name,
113
+ disabled=self.is_widget_disabledd)
114
  else:
115
+ return col.slider(text, min_value, max_value, value, step, key=slider_key_name,
116
+ disabled=self.is_widget_disabled)
117
 
 
118
  @property
119
  def is_widget_disabled(self) -> bool:
120
  """
121
  Checks if widgets should be disabled based on the 'loading_in_progress' state.
122
+
123
  Returns:
124
  bool: True if widgets should be disabled, False otherwise.
125
  """
126
+
127
  return st.session_state['loading_in_progress']
128
 
129
  def disable_widgets(self) -> None:
130
  """
131
  Disables widgets by setting the 'loading_in_progress' state to True.
132
+
133
  Returns:
134
  None
135
  """
136
+
137
  st.session_state['loading_in_progress'] = True
138
 
 
139
  @property
140
  def settings_changed(self) -> bool:
141
  """
142
  Checks if any model settings have changed compared to the previous state.
143
+
144
  Returns:
145
  bool: True if any setting has changed, False otherwise.
146
  """
147
+
148
  return self.has_state_changed()
149
 
 
150
  @property
151
  def confidance_change(self) -> bool:
152
  """
 
155
  Returns:
156
  bool: True if the confidence level has changed, False otherwise.
157
  """
158
+
159
  return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
160
 
 
161
  def update_prev_state(self) -> None:
162
  """
163
  Updates the 'previous_state' in the session state with the current state values.
 
165
  Returns:
166
  None
167
  """
168
+
169
  for key in st.session_state['previous_state']:
170
  st.session_state['previous_state'][key] = st.session_state[key]
171
 
 
172
  def load_model(self) -> None:
173
  """
174
  Loads the KBVQA model based on the chosen method and settings.
175
+
176
  - Frees GPU resources before loading.
177
  - Calls `prepare_kbvqa_model` to create the model.
178
  - Sets the detection confidence level on the model object.
 
182
  Returns:
183
  None
184
  """
185
+
186
  try:
187
  free_gpu_resources()
188
  st.session_state['kbvqa'] = prepare_kbvqa_model()
 
193
  st.session_state['button_label'] = "Reload Model"
194
  free_gpu_resources()
195
  free_gpu_resources()
196
+
197
  except Exception as e:
198
  st.error(f"Error loading model: {e}")
 
199
 
200
  def force_reload_model(self) -> None:
201
  """
202
+ Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls
203
+ `free_gpu_resources`.
204
 
205
  - Deletes the current KBVQA model from the session state.
206
  - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
 
210
  Returns:
211
  None
212
  """
213
+
214
  try:
215
  self.delete_model()
216
  free_gpu_resources()
 
218
  st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
219
  # Update the previous state with current session state values
220
  self.update_prev_state()
221
+
222
  st.session_state['model_loaded'] = True
223
  free_gpu_resources()
224
  except Exception as e:
225
  st.error(f"Error reloading model: {e}")
226
  free_gpu_resources()
227
 
 
228
  def delete_model(self) -> None:
229
  """
230
  This method deletes the current models and calls `free_gpu_resources`.
 
232
  Returns:
233
  None
234
  """
235
+
236
  free_gpu_resources()
237
+
238
  if self.is_model_loaded:
239
  try:
240
  del st.session_state['kbvqa']
 
244
  free_gpu_resources()
245
  free_gpu_resources()
246
  pass
247
+
 
248
  def has_state_changed(self) -> bool:
249
  """
250
  Compares current session state with the previous state to identify changes.
251
+
252
  Returns:
253
  bool: True if any change is found, False otherwise.
254
  """
 
256
  if key == 'confidence_level':
257
  continue # confidence_level tracker is separate
258
  if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
 
259
  return True # Found a change
260
+ else:
261
+ return False # No changes found
262
 
 
263
  def get_model(self) -> KBVQA.KBVQA():
264
  """
265
  Retrieves the KBVQA model from the session state.
266
+
267
  Returns:
268
  KBVQA: The loaded KBVQA model, or None if not loaded.
269
  """
270
+
271
  return st.session_state.get('kbvqa', None)
272
 
 
273
  @property
274
  def is_model_loaded(self) -> bool:
275
  """
276
  Checks if the KBVQA model is loaded in the session state.
277
+
278
  Returns:
279
  bool: True if the model is loaded, False otherwise.
280
  """
281
+
282
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
283
+ st.session_state.kbvqa.all_models_loaded \
284
+ and (st.session_state['previous_state']['method'] is not None
285
+ and st.session_state['method'] == st.session_state['previous_state']['method'])
286
 
 
287
  def reload_detection_model(self) -> None:
288
  """
289
  Reloads only the detection model of the KBVQA model with updated settings.
290
+
291
  - Frees GPU resources before reloading.
292
  - Checks if the model is already loaded.
293
  - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
 
297
  Returns:
298
  None
299
  """
300
+
301
  try:
302
  free_gpu_resources()
303
  if self.is_model_loaded:
 
311
  except Exception as e:
312
  st.error(f"Error reloading detection model: {e}")
313
 
 
314
  def process_new_image(self, image_key: str, image) -> None:
315
  """
316
+ Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session
317
+ state.
318
+
319
  This dictionary stores information about each processed image, including:
320
  - `image`: The original image data.
321
  - `caption`: Generated caption for the image.
322
  - `detected_objects_str`: String representation of detected objects.
323
  - `qa_history`: List of questions and answers related to the image.
324
  - `analysis_done`: Flag indicating if analysis is complete.
325
+
326
  Args:
327
  image_key (str): Unique key for the image.
328
  image (obj): The uploaded image data.
 
330
  Returns:
331
  None
332
  """
333
+
334
  if image_key not in st.session_state['images_data']:
335
  st.session_state['images_data'][image_key] = {
336
  'image': image,
 
340
  'analysis_done': False
341
  }
342
 
 
 
343
  def analyze_image(self, image) -> Tuple[str, str, object]:
344
  """
345
  Analyzes the image using the KBVQA model.
 
348
  - Displays a "Analyzing the image .." message.
349
  - Calls KBVQA methods to generate a caption and detect objects.
350
  - Returns the generated caption, detected objects string, and image with bounding boxes.
351
+
352
  Args:
353
  image (obj): The image data to analyze.
354
 
 
364
  free_gpu_resources()
365
  return caption, detected_objects_str, image_with_boxes
366
 
 
367
  def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
368
  """
369
  Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.
370
+
371
  Args:
372
  image_key (str): Unique key for the image.
373
  question (str): The question asked about the image.
 
377
  Returns:
378
  None
379
  """
380
+
381
  if image_key in st.session_state['images_data']:
382
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
383
 
 
384
  def get_images_data(self) -> Dict:
385
  """
386
  Returns the dictionary containing processed image data from the session state.
387
+
388
  Returns:
389
  dict: The dictionary storing information about processed images.
390
  """
391
+
392
  return st.session_state['images_data']
393
+
 
394
  def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
395
  """
396
+ Updates the information stored for a specific image in the `images_data` dictionary in the application session
397
+ state.
398
+
399
  Args:
400
  image_key (str): Unique key for the image.
401
  caption (str): The generated caption for the image.
 
412
  'analysis_done': analysis_done
413
  })
414
 
 
415
  def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
416
  """
417
  Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
418
  If both new_width and new_height are provided, the image is resized to those dimensions.
419
+
420
  Args:
421
  image_input (PIL.Image.Image): The image to resize.
422
  new_width (int, optional): The target width of the image.
423
  new_height (int, optional): The target height of the image.
424
+
425
  Returns:
426
  PIL.Image.Image: The resized image.
427
  """
428
+
429
  img = copy.deepcopy(image_input)
430
  if isinstance(img, str):
431
  # Open the image from a file path
 
435
  image = img
436
  else:
437
  raise ValueError("image_input must be a file path or a PIL Image object")
438
+
439
  if new_width is not None and new_height is None:
440
  # Calculate new height to maintain aspect ratio
441
  original_width, original_height = image.size
 
448
  new_width = int(original_width * ratio)
449
  elif new_width is None and new_height is None:
450
  raise ValueError("At least one of new_width or new_height must be provided")
451
+
452
  # Resize the image
453
  resized_image = image.resize((new_width, new_height))
454
  return resized_image
455
 
 
 
456
  def display_message(self, message: str, message_type: str) -> None:
457
  """
458
  Displays a message in the Streamlit interface based on the specified message type.
459
+
460
  Args:
461
  message (str): The message to display.
462
  message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').
 
464
  Returns:
465
  None
466
  """
467
+
468
  if message_type == "warning":
469
  st.warning(message)
470
  elif message_type == "text":
471
  st.text(message)
472
  elif message_type == "success":
473
+ st.success(message)
474
  elif message_type == "write":
475
  st.write(message)
476
+ else:
477
+ st.error("Message type unknown")
478
+
479
  @property
480
  def display_model_settings(self) -> None:
481
  """
 
485
  None
486
  """
487
  self.col3.write("##### Current Model Settings:")
488
+ data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if
489
+ key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed',
490
+ 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data']]
491
  df = pd.DataFrame(data).reset_index(drop=True)
492
  return self.col3.write(df)
493
 
 
494
  def display_session_state(self, col) -> None:
495
  """
496
  Displays a table of the complete application state in the specified column.
497
+
498
  Args:
499
  col (streamlit.columns.Column): The Streamlit column to display the session state.
500
 
501
  Returns:
502
  None
503
  """
504
+
505
  col.write("Current Model:")
506
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
507
  df = pd.DataFrame(data).reset_index(drop=True)