File size: 12,218 Bytes
18d1852
e46d486
0a62769
18d1852
49e9e5b
1d51bf5
18d1852
11a17ef
6d4d5ac
18d1852
 
 
11a17ef
 
 
 
18d1852
 
 
 
 
12f08dc
 
4170a5f
74d450c
b7c642c
f4bcc28
09ff02d
2152f1f
2957e90
1b71503
cf8c147
 
 
bfdde42
11a17ef
 
 
 
18d1852
08655fb
 
11a17ef
 
 
 
3fdd1d7
d80fd56
bfdde42
 
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
bfdde42
 
 
 
3fdd1d7
d80fd56
f4bcc28
 
cf8c147
 
 
 
 
 
f4bcc28
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
 
11a17ef
08655fb
d6a4897
5dd58f8
11a17ef
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
3689b26
18d1852
 
 
3fdd1d7
18d1852
d0e9fe6
cf8c147
 
 
 
 
 
 
 
 
 
753c201
 
cc825df
d9364fd
d6f8382
f72214b
 
6307a2e
12f08dc
6d4d5ac
 
753c201
 
f72214b
 
3fdd1d7
f72214b
 
cf8c147
 
 
 
 
 
f72214b
 
 
87da9a2
753c201
3fdd1d7
18d1852
cf8c147
 
 
 
 
18d1852
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
 
3fdd1d7
cd3678b
cf8c147
 
 
 
 
 
 
 
 
 
18d1852
 
 
1fc0405
9a8f19a
11a17ef
18d1852
 
 
 
d6f8382
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
 
 
 
 
 
 
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
d182243
18d1852
 
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
18d1852
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
 
a9a7f39
 
cf8c147
5313b6c
cf8c147
 
5313b6c
 
 
cf8c147
 
5313b6c
cf8c147
6fe12a5
5313b6c
cf8c147
5313b6c
 
 
 
 
 
 
 
436f3b4
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
 
18d1852
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import pandas as pd
import copy
from PIL import Image
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model




class StateManager:

    def __init__(self):
        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])  

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
            


    def set_up_widgets(self):
        """
        Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
        """

        self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings

 
        show_model_settings = self.col3.checkbox("Show Model Settings", False)
        if show_model_settings:
            self.display_model_settings()


        
    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
        """
        Creates a slider widget with the specified parameters, optionally placing it in a specific column.

        Args:
            text (str): Text to display next to the slider.
            min_value (float): Minimum value for the slider.
            max_value (float): Maximum value for the slider.
            value (float): Initial value for the slider.
            step (float): Step size for the slider.
            slider_key_name (str): Unique key for the slider.
            col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
        """
        
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name)

        
    @property
    def settings_changed(self):
        """
        Checks if any model settings have changed compared to the previous state.
    
        Returns:
            bool: True if any setting has changed, False otherwise.
        """
        return self.has_state_changed()

    
    def display_model_settings(self):
        """
        Displays a table of current model settings in the third column.
    
        Uses formatted HTML to style the table for better readability.
        """
        self.col3.write("##### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
        self.col3.table(styled_df)

    
    def display_session_state(self):
        """
        Displays a table of the complete application state..
        """
  
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)
        

    def load_model(self):
        """
        Loads the KBVQA model based on the chosen method and settings.
    
        - Frees GPU resources before loading.
        - Calls `prepare_kbvqa_model` to create the model.
        - Sets the detection confidence level on the model object.
        - Updates previous state with current settings for change detection.
        - Updates the button label to "Reload Model".
        """
        
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            
            st.session_state['button_label'] = "Reload Model"
            #st.text('button changed')
            #self.has_state_changed()
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

            
    # Function to check if any session state values have changed
    def has_state_changed(self):
        """
        Compares current session state with the previous state to identify changes.
    
        Returns:
            bool: True if any change is found, False otherwise.
        """
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self):
        """
        Retrieve the KBVQA model from the session state.
           
        Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
        """
        return st.session_state.get('kbvqa', None)

    
    def is_model_loaded(self):
        """
        Checks if the KBVQA model is loaded in the session state.
    
        Returns:
            bool: True if the model is loaded, False otherwise.
        """
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    
    def reload_detection_model(self):
        """
        Reloads only the detection model of the KBVQA model with updated settings.
    
        - Frees GPU resources before reloading.
        - Checks if the model is already loaded.
        - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
        - Updates detection confidence level on the model object.
        - Displays a success message if model is reloaded successfully.
        """
        
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key, image, kbvqa):
        """
        Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
    
        This dictionary stores information about each processed image, including:
            - `image`: The original image data.
            - `caption`: Generated caption for the image.
            - `detected_objects_str`: String representation of detected objects.
            - `qa_history`: List of questions and answers related to the image.
            - `analysis_done`: Flag indicating if analysis is complete.
    
        Args:
            image_key (str): Unique key for the image.
            image (obj): The uploaded image data.
            kbvqa (KBVQA object): The loaded KBVQA model.
        """
        
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image, kbvqa):
        """
        Analyzes the image using the KBVQA model.

        - Creates a copy of the image to avoid modifying the original.
        - Displays a "Analyzing the image .." message.
        - Calls KBVQA methods to generate a caption and detect objects.
        - Returns the generated caption, detected objects string, and image with bounding boxes.
    
        Args:
            image (obj): The image data to analyze.
            kbvqa (KBVQA object): The loaded KBVQA model.
    
        Returns:
            tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
        """
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key, question, answer):
        """
        Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
    
        Args:
            image_key (str): Unique key for the image.
            question (str): The question asked about the image.
            answer (str): The answer generated by the KBVQA model.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    
    def get_images_data(self):
        """
        Returns the dictionary containing processed image data from the session state.
    
        Returns:
            dict: The dictionary storing information about processed images.
        """
        return st.session_state['images_data']

    
    def resize_image(self, image_input, new_width, new_height):
        """
        Resize an image given either a file path or a PIL Image object.
    
        Args:
        image_input (str or Image.Image): The image file path or a PIL Image object.
        new_width (int): The new width of the image.
        new_height (int): The new height of the image.
    
        Returns:
        Image.Image: The resized PIL Image object.
        """
        self.col2.write(type(image_input))
        if isinstance(image_input, str):
            # Open the image from a file path
            image = Image.open(image_input)
        elif isinstance(image_input, Image.Image):
            # Use the image directly if it's already a PIL Image object
            image = image_input
        else:
            raise ValueError("image_input must be a file path or a PIL Image object")
    
        resized_image = image.resize((new_width, new_height))
        return resized_image
    
    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        """
        Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
    
        Args:
            image_key (str): Unique key for the image.
            caption (str): The generated caption for the image.
            detected_objects_str (str): String representation of detected objects.
            analysis_done (bool): Flag indicating if analysis of the image is complete.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })