File size: 5,686 Bytes
18d1852
e46d486
18d1852
753c201
1d51bf5
18d1852
 
 
 
 
 
 
 
 
12f08dc
 
4170a5f
74d450c
b7c642c
f4bcc28
09ff02d
2152f1f
2957e90
1b71503
 
22d0357
e4eee9a
f72214b
18d1852
d80fd56
 
 
 
f4bcc28
 
 
18d1852
 
d182243
d0a09f4
d6a4897
5dd58f8
d182243
18d1852
 
3689b26
18d1852
 
 
 
d0e9fe6
753c201
 
 
cc825df
d9364fd
d0e9fe6
f72214b
 
6307a2e
12f08dc
6307a2e
f841c80
753c201
 
f72214b
 
 
 
 
 
 
87da9a2
753c201
18d1852
 
 
 
 
 
 
cd3678b
18d1852
 
 
1fc0405
18d1852
105e89e
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d182243
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pandas as pd
import copy
import streamlit as st
from my_model.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model


class StateManager:

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
            


    def set_up_widgets(self):
        st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
        detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level')

    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):

        return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        
    @property
    def settings_changed(self):
        return self.has_state_changed()

    def display_model_settings(self):
        st.write("#### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa']]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
        st.table(styled_df)

    def display_session_state(self):
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)

    def load_model(self):
        """Load the KBVQA model with specified settings."""
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            #self.update_model_settings(detection_model, confidence_level)
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            
            st.session_state['button_label'] = "Reload Model"
            st.text('button changed')
            self.has_state_changed()
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

    # Function to check if any session state values have changed
    def has_state_changed(self):
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    def reload_detection_model(self):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = confidence_level
                #self.update_model_settings(detection_model, confidence_level)
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    # New methods to be added
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    def get_images_data(self):
        return st.session_state['images_data']

    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })