File size: 6,311 Bytes
18d1852 e46d486 18d1852 49e9e5b 1d51bf5 18d1852 11a17ef 6d4d5ac 18d1852 11a17ef 18d1852 12f08dc 4170a5f 74d450c b7c642c f4bcc28 09ff02d 2152f1f 2957e90 1b71503 bfdde42 11a17ef 18d1852 08655fb 11a17ef 3fdd1d7 d80fd56 bfdde42 3fdd1d7 d80fd56 f4bcc28 18d1852 3fdd1d7 18d1852 11a17ef 08655fb d6a4897 5dd58f8 11a17ef 18d1852 3fdd1d7 18d1852 3689b26 18d1852 3fdd1d7 18d1852 d0e9fe6 753c201 cc825df d9364fd d6f8382 f72214b 6307a2e 12f08dc 6d4d5ac 753c201 f72214b 3fdd1d7 f72214b 87da9a2 753c201 3fdd1d7 18d1852 3fdd1d7 18d1852 3fdd1d7 cd3678b 18d1852 1fc0405 9a8f19a 11a17ef 18d1852 d6f8382 18d1852 3fdd1d7 18d1852 d182243 18d1852 3fdd1d7 18d1852 3fdd1d7 18d1852 3fdd1d7 18d1852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import pandas as pd
import copy
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
def __init__(self):
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self):
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {}
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
def set_up_widgets(self):
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", False)
if show_model_settings:
self.display_model_settings()
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name)
@property
def settings_changed(self):
return self.has_state_changed()
def display_model_settings(self):
self.col3.write("##### Current Model Settings:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
df = pd.DataFrame(data)
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
self.col3.table(styled_df)
def display_session_state(self):
st.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data)
st.table(df)
def load_model(self):
"""Load the KBVQA model with specified settings."""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
st.session_state['button_label'] = "Reload Model"
#st.text('button changed')
#self.has_state_changed()
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
# Function to check if any session state values have changed
def has_state_changed(self):
for key in st.session_state['previous_state']:
if st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else: return False # No changes found
def get_model(self):
"""Retrieve the KBVQA model from the session state."""
return st.session_state.get('kbvqa', None)
def is_model_loaded(self):
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
def reload_detection_model(self):
try:
free_gpu_resources()
if self.is_model_loaded():
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key, image, kbvqa):
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image, kbvqa):
img = copy.deepcopy(image)
st.text("Analyzing the image .. ")
caption = kbvqa.get_caption(img)
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key, question, answer):
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
def get_images_data(self):
return st.session_state['images_data']
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
|