Update my_model/state_manager.py
Browse files- my_model/state_manager.py +15 -1
my_model/state_manager.py
CHANGED
@@ -24,7 +24,7 @@ class StateManager:
|
|
24 |
def set_up_widgets(self):
|
25 |
|
26 |
# Create two columns with different widths
|
27 |
-
col1, col2 = st.columns([0.
|
28 |
with col1:
|
29 |
st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
30 |
detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
@@ -37,15 +37,19 @@ class StateManager:
|
|
37 |
if show_model_settings:
|
38 |
self.display_model_settings()
|
39 |
|
|
|
|
|
40 |
|
41 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
|
42 |
|
43 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
|
|
44 |
|
45 |
@property
|
46 |
def settings_changed(self):
|
47 |
return self.has_state_changed()
|
48 |
|
|
|
49 |
def display_model_settings(self):
|
50 |
st.write("#### Current Model Settings:")
|
51 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
@@ -53,11 +57,13 @@ class StateManager:
|
|
53 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
54 |
st.table(styled_df)
|
55 |
|
|
|
56 |
def display_session_state(self):
|
57 |
st.write("Current Model:")
|
58 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
59 |
df = pd.DataFrame(data)
|
60 |
st.table(df)
|
|
|
61 |
|
62 |
def load_model(self):
|
63 |
"""Load the KBVQA model with specified settings."""
|
@@ -76,6 +82,7 @@ class StateManager:
|
|
76 |
except Exception as e:
|
77 |
st.error(f"Error loading model: {e}")
|
78 |
|
|
|
79 |
# Function to check if any session state values have changed
|
80 |
def has_state_changed(self):
|
81 |
for key in st.session_state['previous_state']:
|
@@ -83,13 +90,16 @@ class StateManager:
|
|
83 |
return True # Found a change
|
84 |
else: return False # No changes found
|
85 |
|
|
|
86 |
def get_model(self):
|
87 |
"""Retrieve the KBVQA model from the session state."""
|
88 |
return st.session_state.get('kbvqa', None)
|
89 |
|
|
|
90 |
def is_model_loaded(self):
|
91 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
92 |
|
|
|
93 |
def reload_detection_model(self):
|
94 |
try:
|
95 |
free_gpu_resources()
|
@@ -112,6 +122,7 @@ class StateManager:
|
|
112 |
'analysis_done': False
|
113 |
}
|
114 |
|
|
|
115 |
def analyze_image(self, image, kbvqa):
|
116 |
img = copy.deepcopy(image)
|
117 |
st.text("Analyzing the image .. ")
|
@@ -119,13 +130,16 @@ class StateManager:
|
|
119 |
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
|
120 |
return caption, detected_objects_str, image_with_boxes
|
121 |
|
|
|
122 |
def add_to_qa_history(self, image_key, question, answer):
|
123 |
if image_key in st.session_state['images_data']:
|
124 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
125 |
|
|
|
126 |
def get_images_data(self):
|
127 |
return st.session_state['images_data']
|
128 |
|
|
|
129 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
130 |
if image_key in st.session_state['images_data']:
|
131 |
st.session_state['images_data'][image_key].update({
|
|
|
24 |
def set_up_widgets(self):
|
25 |
|
26 |
# Create two columns with different widths
|
27 |
+
col1, col2, col3 = st.columns([0.2, 0.6, 0.2]) # Adjust the ratio as needed
|
28 |
with col1:
|
29 |
st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
30 |
detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
|
|
37 |
if show_model_settings:
|
38 |
self.display_model_settings()
|
39 |
|
40 |
+
col3.header("COL3")
|
41 |
+
|
42 |
|
43 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
|
44 |
|
45 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
46 |
+
|
47 |
|
48 |
@property
|
49 |
def settings_changed(self):
|
50 |
return self.has_state_changed()
|
51 |
|
52 |
+
|
53 |
def display_model_settings(self):
|
54 |
st.write("#### Current Model Settings:")
|
55 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
|
|
57 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
58 |
st.table(styled_df)
|
59 |
|
60 |
+
|
61 |
def display_session_state(self):
|
62 |
st.write("Current Model:")
|
63 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
64 |
df = pd.DataFrame(data)
|
65 |
st.table(df)
|
66 |
+
|
67 |
|
68 |
def load_model(self):
|
69 |
"""Load the KBVQA model with specified settings."""
|
|
|
82 |
except Exception as e:
|
83 |
st.error(f"Error loading model: {e}")
|
84 |
|
85 |
+
|
86 |
# Function to check if any session state values have changed
|
87 |
def has_state_changed(self):
|
88 |
for key in st.session_state['previous_state']:
|
|
|
90 |
return True # Found a change
|
91 |
else: return False # No changes found
|
92 |
|
93 |
+
|
94 |
def get_model(self):
|
95 |
"""Retrieve the KBVQA model from the session state."""
|
96 |
return st.session_state.get('kbvqa', None)
|
97 |
|
98 |
+
|
99 |
def is_model_loaded(self):
|
100 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
101 |
|
102 |
+
|
103 |
def reload_detection_model(self):
|
104 |
try:
|
105 |
free_gpu_resources()
|
|
|
122 |
'analysis_done': False
|
123 |
}
|
124 |
|
125 |
+
|
126 |
def analyze_image(self, image, kbvqa):
|
127 |
img = copy.deepcopy(image)
|
128 |
st.text("Analyzing the image .. ")
|
|
|
130 |
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
|
131 |
return caption, detected_objects_str, image_with_boxes
|
132 |
|
133 |
+
|
134 |
def add_to_qa_history(self, image_key, question, answer):
|
135 |
if image_key in st.session_state['images_data']:
|
136 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
137 |
|
138 |
+
|
139 |
def get_images_data(self):
|
140 |
return st.session_state['images_data']
|
141 |
|
142 |
+
|
143 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
144 |
if image_key in st.session_state['images_data']:
|
145 |
st.session_state['images_data'][image_key].update({
|