KB-VQA-E / my_model /state_manager.py
m7mdal7aj's picture
Update my_model/state_manager.py
31da5e4 verified
raw
history blame
No virus
18 kB
import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
# Hints for methods
# initialize_state: Initializes default values for session state.
# set_up_widgets: Creates UI elements for model selection and settings.
# set_slider_value: Generates a slider widget for numerical input.
# is_widget_disabled: Returns True if UI elements should be disabled.
# disable_widgets: Disables interactive UI elements during processing.
# settings_changed: Checks if any model settings have changed.
# confidance_change: Determines if the confidence level setting has changed.
# display_model_settings: Shows current model settings in the UI.
# display_session_state: Displays the current state of the application.
# update_prev_state: Updates the record of the previous application state.
# force_reload_model: Reloads the model, clearing and resetting necessary states.
def __init__(self):
"""
Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
"""
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self):
"""
Initializes the Streamlit session state with default values for various keys.
"""
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if 'loading_in_progress' not in st.session_state:
st.session_state['loading_in_progress'] = False
if 'load_button_clicked' not in st.session_state:
st.session_state['load_button_clicked'] = False
if 'force_reload_button_clicked' not in st.session_state:
st.session_state['force_reload_button_clicked'] = False
if 'time_taken_to_load_model' not in st.session_state:
st.session_state['time_taken_to_load_model'] = None
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
if 'model_loaded' not in st.session_state:
st.session_state['model_loaded'] = self.is_model_loaded
def set_up_widgets(self) -> None:
"""
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
"""
self.col1.selectbox("Choose a method:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Learning Visual Embeddings"], index=0, key='method', disabled=self.is_widget_disabled)
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', disabled=self.is_widget_disabled)
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
if show_model_settings:
self.display_model_settings
def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
"""
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
Args:
text (str): Text to display next to the slider.
min_value (float): Minimum value for the slider.
max_value (float): Maximum value for the slider.
value (float): Initial value for the slider.
step (float): Step size for the slider.
slider_key_name (str): Unique key for the slider.
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
"""
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)
@property
def is_widget_disabled(self):
return st.session_state['loading_in_progress']
def disable_widgets(self):
"""
Disables widgets by setting the 'loading_in_progress' state to True.
"""
st.session_state['loading_in_progress'] = True
@property
def settings_changed(self):
"""
Checks if any model settings have changed compared to the previous state.
Returns:
bool: True if any setting has changed, False otherwise.
"""
return self.has_state_changed()
@property
def confidance_change(self):
"""
Checks if the confidence level setting has changed compared to the previous state.
Returns:
bool: True if the confidence level has changed, False otherwise.
"""
return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
def update_prev_state(self):
"""
Updates the 'previous_state' in the session state with the current state values.
"""
for key in st.session_state['previous_state']:
st.session_state['previous_state'][key] = st.session_state[key]
def load_model(self) -> None:
"""
Loads the KBVQA model based on the chosen method and settings.
- Frees GPU resources before loading.
- Calls `prepare_kbvqa_model` to create the model.
- Sets the detection confidence level on the model object.
- Updates previous state with current settings for change detection.
- Updates the button label to "Reload Model".
"""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def force_reload_model(self) -> None:
"""
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
- Deletes the current KBVQA model from the session state.
- Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
- Updates the detection confidence level on the model object.
- Displays a success message if the model is reloaded successfully.
"""
try:
self.delete_model()
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading model: {e}")
free_gpu_resources()
def delete_model(self) -> None:
"""
This method deletes the current models and calls `free_gpu_resources`.
"""
free_gpu_resources()
if self.is_model_loaded:
try:
del st.session_state['kbvqa']
free_gpu_resources()
except:
free_gpu_resources()
pass
# Function to check if any session state values have changed
def has_state_changed(self) -> bool:
"""
Compares current session state with the previous state to identify changes.
Returns:
bool: True if any change is found, False otherwise.
"""
for key in st.session_state['previous_state']:
if key == 'confidence_level':
continue # confidence_level tracker is separate
if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else: return False # No changes found
def get_model(self) -> KBVQA:
"""
Retrieve the KBVQA model from the session state.
Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
"""
return st.session_state.get('kbvqa', None)
@property
def is_model_loaded(self) -> bool:
"""
Checks if the KBVQA model is loaded in the session state.
Returns:
bool: True if the model is loaded, False otherwise.
"""
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and st.session_state.kbvqa.all_models_loaded
def reload_detection_model(self) -> None:
"""
Reloads only the detection model of the KBVQA model with updated settings.
- Frees GPU resources before reloading.
- Checks if the model is already loaded.
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
- Updates detection confidence level on the model object.
- Displays a success message if model is reloaded successfully.
"""
try:
free_gpu_resources()
if self.is_model_loaded:
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
self.update_prev_state
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key: str, image) -> None:
"""
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
This dictionary stores information about each processed image, including:
- `image`: The original image data.
- `caption`: Generated caption for the image.
- `detected_objects_str`: String representation of detected objects.
- `qa_history`: List of questions and answers related to the image.
- `analysis_done`: Flag indicating if analysis is complete.
Args:
image_key (str): Unique key for the image.
image (obj): The uploaded image data.
"""
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image) -> Tuple[str, str, object]:
"""
Analyzes the image using the KBVQA model.
- Creates a copy of the image to avoid modifying the original.
- Displays a "Analyzing the image .." message.
- Calls KBVQA methods to generate a caption and detect objects.
- Returns the generated caption, detected objects string, and image with bounding boxes.
Args:
image (obj): The image data to analyze.
Returns:
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
"""
img = copy.deepcopy(image)
caption = st.session_state['kbvqa'].get_caption(img)
image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
free_gpu_resources()
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
"""
Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
Args:
image_key (str): Unique key for the image.
question (str): The question asked about the image.
answer (str): The answer generated by the KBVQA model.
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
def get_images_data(self):
"""
Returns the dictionary containing processed image data from the session state.
Returns:
dict: The dictionary storing information about processed images.
"""
return st.session_state['images_data']
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
"""
Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
Args:
image_key (str): Unique key for the image.
caption (str): The generated caption for the image.
detected_objects_str (str): String representation of detected objects.
analysis_done (bool): Flag indicating if analysis of the image is complete.
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
def resize_image(self, image_input, new_width=None, new_height=None):
"""
Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
If both new_width and new_height are provided, the image is resized to those dimensions.
Args:
image (PIL.Image.Image): The image to resize.
new_width (int, optional): The target width of the image.
new_height (int, optional): The target height of the image.
Returns:
PIL.Image.Image: The resized image.
"""
img = copy.deepcopy(image_input)
if isinstance(img, str):
# Open the image from a file path
image = Image.open(img)
elif isinstance(img, Image.Image):
# Use the image directly if it's already a PIL Image object
image = img
else:
raise ValueError("image_input must be a file path or a PIL Image object")
if new_width is not None and new_height is None:
# Calculate new height to maintain aspect ratio
original_width, original_height = image.size
ratio = new_width / original_width
new_height = int(original_height * ratio)
elif new_width is None and new_height is not None:
# Calculate new width to maintain aspect ratio
original_width, original_height = image.size
ratio = new_height / original_height
new_width = int(original_width * ratio)
elif new_width is None and new_height is None:
raise ValueError("At least one of new_width or new_height must be provided")
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
def display_message(self, message, message_type):
if message_type == "warning":
st.warning(message)
elif message_type == "text":
st.text(message)
elif message_type == "success":
st.success(messae)
elif message_type == "write":
st.write(message)
else: st.error("Message type unknown")
@property
def display_model_settings(self):
"""
Displays a table of current model settings in the third column.
"""
self.col3.write("##### Current Model Settings:")
data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
df = pd.DataFrame(data).reset_index(drop=True)
return self.col3.write(df)
def display_session_state(self, col):
"""
Displays a table of the complete application state..
"""
col.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data).reset_index(drop=True)
col.write(df)