m7mdal7aj commited on
Commit
3fdd1d7
1 Parent(s): 502ab4a

Update my_model/state_manager.py

Browse files
Files changed (1) hide show
  1. my_model/state_manager.py +15 -1
my_model/state_manager.py CHANGED
@@ -24,7 +24,7 @@ class StateManager:
24
  def set_up_widgets(self):
25
 
26
  # Create two columns with different widths
27
- col1, col2 = st.columns([0.8, 0.2]) # Adjust the ratio as needed
28
  with col1:
29
  st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
30
  detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
@@ -37,15 +37,19 @@ class StateManager:
37
  if show_model_settings:
38
  self.display_model_settings()
39
 
 
 
40
 
41
  def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
42
 
43
  return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
 
44
 
45
  @property
46
  def settings_changed(self):
47
  return self.has_state_changed()
48
 
 
49
  def display_model_settings(self):
50
  st.write("#### Current Model Settings:")
51
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
@@ -53,11 +57,13 @@ class StateManager:
53
  styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
54
  st.table(styled_df)
55
 
 
56
  def display_session_state(self):
57
  st.write("Current Model:")
58
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
59
  df = pd.DataFrame(data)
60
  st.table(df)
 
61
 
62
  def load_model(self):
63
  """Load the KBVQA model with specified settings."""
@@ -76,6 +82,7 @@ class StateManager:
76
  except Exception as e:
77
  st.error(f"Error loading model: {e}")
78
 
 
79
  # Function to check if any session state values have changed
80
  def has_state_changed(self):
81
  for key in st.session_state['previous_state']:
@@ -83,13 +90,16 @@ class StateManager:
83
  return True # Found a change
84
  else: return False # No changes found
85
 
 
86
  def get_model(self):
87
  """Retrieve the KBVQA model from the session state."""
88
  return st.session_state.get('kbvqa', None)
89
 
 
90
  def is_model_loaded(self):
91
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
92
 
 
93
  def reload_detection_model(self):
94
  try:
95
  free_gpu_resources()
@@ -112,6 +122,7 @@ class StateManager:
112
  'analysis_done': False
113
  }
114
 
 
115
  def analyze_image(self, image, kbvqa):
116
  img = copy.deepcopy(image)
117
  st.text("Analyzing the image .. ")
@@ -119,13 +130,16 @@ class StateManager:
119
  image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
120
  return caption, detected_objects_str, image_with_boxes
121
 
 
122
  def add_to_qa_history(self, image_key, question, answer):
123
  if image_key in st.session_state['images_data']:
124
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
125
 
 
126
  def get_images_data(self):
127
  return st.session_state['images_data']
128
 
 
129
  def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
130
  if image_key in st.session_state['images_data']:
131
  st.session_state['images_data'][image_key].update({
 
24
  def set_up_widgets(self):
25
 
26
  # Create two columns with different widths
27
+ col1, col2, col3 = st.columns([0.2, 0.6, 0.2]) # Adjust the ratio as needed
28
  with col1:
29
  st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
30
  detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
 
37
  if show_model_settings:
38
  self.display_model_settings()
39
 
40
+ col3.header("COL3")
41
+
42
 
43
  def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
44
 
45
  return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
46
+
47
 
48
  @property
49
  def settings_changed(self):
50
  return self.has_state_changed()
51
 
52
+
53
  def display_model_settings(self):
54
  st.write("#### Current Model Settings:")
55
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
 
57
  styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
58
  st.table(styled_df)
59
 
60
+
61
  def display_session_state(self):
62
  st.write("Current Model:")
63
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
64
  df = pd.DataFrame(data)
65
  st.table(df)
66
+
67
 
68
  def load_model(self):
69
  """Load the KBVQA model with specified settings."""
 
82
  except Exception as e:
83
  st.error(f"Error loading model: {e}")
84
 
85
+
86
  # Function to check if any session state values have changed
87
  def has_state_changed(self):
88
  for key in st.session_state['previous_state']:
 
90
  return True # Found a change
91
  else: return False # No changes found
92
 
93
+
94
  def get_model(self):
95
  """Retrieve the KBVQA model from the session state."""
96
  return st.session_state.get('kbvqa', None)
97
 
98
+
99
  def is_model_loaded(self):
100
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
101
 
102
+
103
  def reload_detection_model(self):
104
  try:
105
  free_gpu_resources()
 
122
  'analysis_done': False
123
  }
124
 
125
+
126
  def analyze_image(self, image, kbvqa):
127
  img = copy.deepcopy(image)
128
  st.text("Analyzing the image .. ")
 
130
  image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
131
  return caption, detected_objects_str, image_with_boxes
132
 
133
+
134
  def add_to_qa_history(self, image_key, question, answer):
135
  if image_key in st.session_state['images_data']:
136
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
137
 
138
+
139
  def get_images_data(self):
140
  return st.session_state['images_data']
141
 
142
+
143
  def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
144
  if image_key in st.session_state['images_data']:
145
  st.session_state['images_data'][image_key].update({