File size: 6,311 Bytes
18d1852
e46d486
18d1852
49e9e5b
1d51bf5
18d1852
 
 
 
 
 
 
 
 
12f08dc
 
4170a5f
74d450c
b7c642c
f4bcc28
09ff02d
2152f1f
2957e90
1b71503
d6f8382
b84905e
3fdd1d7
bfdde42
 
 
 
 
 
18d1852
08655fb
 
bfdde42
 
 
 
3fdd1d7
d80fd56
bfdde42
 
 
 
 
 
3fdd1d7
d80fd56
f4bcc28
 
 
18d1852
3fdd1d7
18d1852
bfdde42
08655fb
d6a4897
5dd58f8
8050435
18d1852
3fdd1d7
18d1852
3689b26
18d1852
 
 
3fdd1d7
18d1852
d0e9fe6
753c201
 
 
cc825df
d9364fd
d6f8382
f72214b
 
6307a2e
12f08dc
6307a2e
f841c80
753c201
 
f72214b
 
3fdd1d7
f72214b
 
 
 
 
87da9a2
753c201
3fdd1d7
18d1852
 
 
 
3fdd1d7
18d1852
 
 
3fdd1d7
cd3678b
18d1852
 
 
1fc0405
9a8f19a
 
18d1852
 
 
 
d6f8382
18d1852
 
 
 
 
 
 
 
 
 
3fdd1d7
18d1852
 
d182243
18d1852
 
 
 
3fdd1d7
18d1852
 
 
 
3fdd1d7
18d1852
 
 
3fdd1d7
18d1852
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import pandas as pd
import copy
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model


class StateManager:

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
            


    def set_up_widgets(self):
        
        # Create two columns with different widths
        col1, col2, col3 = st.columns([0.2, 0.6, 0.2])  # Adjust the ratio as needed

        with col1:
            st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
            detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
            default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
            self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level')

        # Conditional display of model settings

        with col3:
            show_model_settings = st.checkbox("Show Model Settings", False)
            if show_model_settings:
                self.display_model_settings()


        
    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name)

        
    @property
    def settings_changed(self):
        return self.has_state_changed()

    
    def display_model_settings(self):
        st.write("##### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
        col3.table(styled_df)

    
    def display_session_state(self):
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)
        

    def load_model(self):
        """Load the KBVQA model with specified settings."""
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            
            st.session_state['button_label'] = "Reload Model"
            st.text('button changed')
            self.has_state_changed()
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

            
    # Function to check if any session state values have changed
    def has_state_changed(self):
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    
    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    
    def reload_detection_model(self):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                st.success("Model reloaded with updated settings and ready for inference.")
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    
    def get_images_data(self):
        return st.session_state['images_data']

    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })