File size: 6,316 Bytes
18d1852
e46d486
18d1852
49e9e5b
1d51bf5
18d1852
11a17ef
6d4d5ac
18d1852
 
 
11a17ef
 
 
 
18d1852
 
 
 
 
12f08dc
 
4170a5f
74d450c
b7c642c
f4bcc28
09ff02d
2152f1f
2957e90
1b71503
bfdde42
11a17ef
 
 
 
 
18d1852
08655fb
 
11a17ef
 
 
 
3fdd1d7
d80fd56
bfdde42
 
 
 
 
 
3fdd1d7
d80fd56
f4bcc28
 
 
18d1852
3fdd1d7
18d1852
11a17ef
08655fb
d6a4897
5dd58f8
11a17ef
18d1852
3fdd1d7
18d1852
3689b26
18d1852
 
 
3fdd1d7
18d1852
d0e9fe6
753c201
 
 
cc825df
d9364fd
d6f8382
f72214b
 
6307a2e
12f08dc
6d4d5ac
 
753c201
 
f72214b
 
3fdd1d7
f72214b
 
 
 
 
87da9a2
753c201
3fdd1d7
18d1852
 
 
 
3fdd1d7
18d1852
 
 
3fdd1d7
cd3678b
18d1852
 
 
1fc0405
9a8f19a
11a17ef
18d1852
 
 
 
d6f8382
18d1852
 
 
 
 
 
 
 
 
 
3fdd1d7
18d1852
 
d182243
18d1852
 
 
 
3fdd1d7
18d1852
 
 
 
3fdd1d7
18d1852
 
 
3fdd1d7
18d1852
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import pandas as pd
import copy
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model




class StateManager:

    def __init__(self):
        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])  

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
            


    def set_up_widgets(self):

    
        self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings

 
        show_model_settings = self.col3.checkbox("Show Model Settings", False)
        if show_model_settings:
            self.display_model_settings()


        
    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name)

        
    @property
    def settings_changed(self):
        return self.has_state_changed()

    
    def display_model_settings(self):
        self.col3.write("##### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
        self.col3.table(styled_df)

    
    def display_session_state(self):
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)
        

    def load_model(self):
        """Load the KBVQA model with specified settings."""
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            
            st.session_state['button_label'] = "Reload Model"
            #st.text('button changed')
            #self.has_state_changed()
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

            
    # Function to check if any session state values have changed
    def has_state_changed(self):
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    
    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    
    def reload_detection_model(self):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    
    def get_images_data(self):
        return st.session_state['images_data']

    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })