m7mdal7aj commited on
Commit
571fea7
1 Parent(s): 1d707c1

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +95 -27
my_model/tabs/run_inference.py CHANGED
@@ -18,13 +18,12 @@ from my_model.config import inference_config as config
18
  class InferenceRunner(StateManager):
19
 
20
  """
21
- InferenceRunner manages the user interface and interactions for a Streamlit-based
22
- Knowledge-Based Visual Question Answering (KBVQA) application. It handles image uploads,
23
- displays sample images, and facilitates the question-answering process using the KBVQA model.
24
- it inherits the StateManager class.
25
  """
26
 
27
- def __init__(self):
28
  """
29
  Initializes the InferenceRunner instance, setting up the necessary state.
30
  """
@@ -32,16 +31,17 @@ class InferenceRunner(StateManager):
32
  super().__init__()
33
 
34
 
35
- def answer_question(self, caption, detected_objects_str, question):
36
  """
37
- Generates an answer to a given question based on the image's caption and detected objects.
38
 
39
  Args:
40
- caption (str): The caption generated for the image.
41
- detected_objects_str (str): String representation of objects detected in the image.
42
- question (str): The user's question about the image.
 
43
  Returns:
44
- str: The generated answer to the question.
45
  """
46
  free_gpu_resources()
47
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
@@ -50,7 +50,11 @@ class InferenceRunner(StateManager):
50
  return answer, prompt_length
51
 
52
 
53
- def display_sample_images(self):
 
 
 
 
54
  self.col1.write("Choose from sample images:")
55
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
56
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
@@ -61,18 +65,39 @@ class InferenceRunner(StateManager):
61
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
62
  self.process_new_image(sample_image_path, image)
63
 
64
- def handle_image_upload(self):
 
 
 
65
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
66
  if uploaded_image is not None:
67
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
68
 
69
- def display_image_and_analysis(self, image_key, image_data, nested_col21, nested_col22):
 
 
 
 
 
 
 
 
 
70
 
71
  image_for_display = self.resize_image(image_data['image'], 600)
72
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
73
  self.handle_analysis_button(image_key, image_data, nested_col22)
74
 
75
- def handle_analysis_button(self, image_key, image_data, nested_col22):
 
 
 
 
 
 
 
 
 
76
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
77
  nested_col22.text("Please click 'Analyze Image'..")
78
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}'
@@ -81,29 +106,63 @@ class InferenceRunner(StateManager):
81
  self.update_image_data(image_key, caption, detected_objects_str, True)
82
  st.session_state['loading_in_progress'] = False
83
 
84
- def handle_question_answering(self, image_key, image_data, nested_col22):
85
- # Initialize qa_history for each image
86
- #qa_history = image_data.get('qa_history', [])
 
 
 
 
 
 
 
87
  if image_data['analysis_done']:
88
  self.display_question_answering_interface(image_key, image_data, nested_col22)
89
 
90
  if self.settings_changed or self.confidance_change:
91
  nested_col22.warning("Confidence level changed, please click 'Analyze Image'.")
92
 
93
- def display_question_answering_interface(self, image_key, image_data, nested_col22):
 
 
 
 
 
 
 
 
94
 
95
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
96
  selected_question = nested_col22.selectbox("Select a sample question or type your own:", ["Custom question..."] + sample_questions, key=f'sample_question_{image_key}')
97
- custom_question = nested_col22.text_input("Or ask your own question:", key=f'custom_question_{image_key}')
98
- question = custom_question if selected_question == "Custom question..." else selected_question
 
 
 
 
 
99
  self.process_question(image_key, question, image_data, nested_col22)
100
-
101
  qa_history = image_data.get('qa_history', [])
102
  for num, (q, a, p) in enumerate(qa_history):
103
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
104
 
 
105
 
106
- def process_question(self, image_key, question, image_data, nested_col22):
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  qa_history = image_data.get('qa_history', [])
108
  if question and (question not in [q for q, _, _ in qa_history] or self.settings_changed or self.confidance_change):
109
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
@@ -111,7 +170,14 @@ class InferenceRunner(StateManager):
111
  self.add_to_qa_history(image_key, question, answer, prompt_length)
112
  # nested_col22.text(f"Q: {question}\nA: {answer}\nPrompt Length: {prompt_length}")
113
 
114
- def image_qa_app(self):
 
 
 
 
 
 
 
115
  self.display_sample_images()
116
  self.handle_image_upload()
117
  self.display_session_state()
@@ -126,9 +192,10 @@ class InferenceRunner(StateManager):
126
 
127
  def run_inference(self):
128
  """
129
- Sets up the widgets and manages the inference process. This method handles model loading,
130
- reloading, and the overall flow of the inference process based on user interactions.
131
 
 
132
  """
133
 
134
  self.set_up_widgets()
@@ -195,6 +262,7 @@ class InferenceRunner(StateManager):
195
  if self.is_model_loaded:
196
  free_gpu_resources()
197
  st.session_state['loading_in_progress'] = False
198
- self.image_qa_app()
 
199
 
200
 
 
18
  class InferenceRunner(StateManager):
19
 
20
  """
21
+ Manages the user interface and interactions for a Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
22
+ This class handles image uploads, displays sample images, and facilitates the question-answering process using the KBVQA model.
23
+ Inherits from the StateManager class.
 
24
  """
25
 
26
+ def __init__(self) -> None:
27
  """
28
  Initializes the InferenceRunner instance, setting up the necessary state.
29
  """
 
31
  super().__init__()
32
 
33
 
34
+ def answer_question(self, caption: str, detected_objects_str: str, question: str) -> Tuple[str, int]:
35
  """
36
+ Generates an answer to a user's question based on the image's caption and detected objects.
37
 
38
  Args:
39
+ caption (str): Caption generated for the image.
40
+ detected_objects_str (str): String representation of detected objects in the image.
41
+ question (str): User's question about the image.
42
+
43
  Returns:
44
+ tuple: A tuple containing the answer to the question and the prompt length.
45
  """
46
  free_gpu_resources()
47
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
 
50
  return answer, prompt_length
51
 
52
 
53
+ def display_sample_images(self) -> None:
54
+ """
55
+ Displays sample images as clickable thumbnails for the user to select.
56
+ """
57
+
58
  self.col1.write("Choose from sample images:")
59
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
60
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
 
65
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
66
  self.process_new_image(sample_image_path, image)
67
 
68
+ def handle_image_upload(self) -> None:
69
+ """
70
+ Provides an image uploader widget for the user to upload their own images.
71
+ """
72
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
73
  if uploaded_image is not None:
74
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
75
 
76
+ def display_image_and_analysis(self, image_key: str, image_data: dict, nested_col21, nested_col22) -> None:
77
+ """
78
+ Displays the uploaded or selected image and provides an option to analyze the image.
79
+
80
+ Args:
81
+ image_key (str): Unique key identifying the image.
82
+ image_data (dict): Data associated with the image.
83
+ nested_col21 (streamlit column): Column for displaying the image.
84
+ nested_col22 (streamlit column): Column for displaying the analysis button.
85
+ """
86
 
87
  image_for_display = self.resize_image(image_data['image'], 600)
88
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
89
  self.handle_analysis_button(image_key, image_data, nested_col22)
90
 
91
+ def handle_analysis_button(self, image_key: str, image_data: dict, nested_col22) -> None:
92
+ """
93
+ Provides an 'Analyze Image' button and processes the image analysis upon click.
94
+
95
+ Args:
96
+ image_key (str): Unique key identifying the image.
97
+ image_data (dict): Data associated with the image.
98
+ nested_col22 (streamlit column): Column for displaying the analysis button.
99
+ """
100
+
101
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
102
  nested_col22.text("Please click 'Analyze Image'..")
103
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}'
 
106
  self.update_image_data(image_key, caption, detected_objects_str, True)
107
  st.session_state['loading_in_progress'] = False
108
 
109
+ def handle_question_answering(self, image_key: str, image_data: dict, nested_col22) -> None:
110
+ """
111
+ Manages the question-answering interface for each image.
112
+
113
+ Args:
114
+ image_key (str): Unique key identifying the image.
115
+ image_data (dict): Data associated with the image.
116
+ nested_col22 (streamlit column): Column for displaying the question-answering interface.
117
+ """
118
+
119
  if image_data['analysis_done']:
120
  self.display_question_answering_interface(image_key, image_data, nested_col22)
121
 
122
  if self.settings_changed or self.confidance_change:
123
  nested_col22.warning("Confidence level changed, please click 'Analyze Image'.")
124
 
125
+ def display_question_answering_interface(self, image_key: str, image_data: Dict, nested_col22: st.columns) -> None:
126
+ """
127
+ Displays the interface for question answering, including sample questions and a custom question input.
128
+
129
+ Args:
130
+ image_key (str): Unique key identifying the image.
131
+ image_data (dict): Data associated with the image.
132
+ nested_col22 (streamlit column): The column where the interface will be displayed.
133
+ """
134
 
135
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
136
  selected_question = nested_col22.selectbox("Select a sample question or type your own:", ["Custom question..."] + sample_questions, key=f'sample_question_{image_key}')
137
+
138
+ # Display custom question input only if "Custom question..." is selected
139
+ question = selected_question
140
+ if selected_question == "Custom question...":
141
+ custom_question = nested_col22.text_input("Or ask your own question:", key=f'custom_question_{image_key}')
142
+ question = custom_question
143
+
144
  self.process_question(image_key, question, image_data, nested_col22)
145
+
146
  qa_history = image_data.get('qa_history', [])
147
  for num, (q, a, p) in enumerate(qa_history):
148
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
149
 
150
+
151
 
152
+ def process_question(self, image_key: str, question: str, image_data: Dict, nested_col22: st.columns) -> None:
153
+ """
154
+ Processes the user's question, generates an answer, and updates the question-answer history.
155
+
156
+ Args:
157
+ image_key (str): Unique key identifying the image.
158
+ question (str): The question asked by the user.
159
+ image_data (Dict): Data associated with the image.
160
+ nested_col22 (streamlit column): The column where the answer will be displayed.
161
+
162
+ This method checks if the question is new or if settings have changed, and if so, generates an answer using the KBVQA model.
163
+ It then updates the question-answer history for the image.
164
+ """
165
+
166
  qa_history = image_data.get('qa_history', [])
167
  if question and (question not in [q for q, _, _ in qa_history] or self.settings_changed or self.confidance_change):
168
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
 
170
  self.add_to_qa_history(image_key, question, answer, prompt_length)
171
  # nested_col22.text(f"Q: {question}\nA: {answer}\nPrompt Length: {prompt_length}")
172
 
173
+ def image_qa_app(self) -> None:
174
+ """
175
+ Main application interface for image-based question answering.
176
+
177
+ This method orchestrates the display of sample images, handles image uploads, and facilitates the question-answering process.
178
+ It iterates through each image in the session state, displaying the image and providing interfaces for image analysis and question answering.
179
+ """
180
+
181
  self.display_sample_images()
182
  self.handle_image_upload()
183
  self.display_session_state()
 
192
 
193
  def run_inference(self):
194
  """
195
+ Sets up widgets and manages the inference process, including model loading and reloading,
196
+ based on user interactions.
197
 
198
+ This method orchestrates the overall flow of the inference process.
199
  """
200
 
201
  self.set_up_widgets()
 
262
  if self.is_model_loaded:
263
  free_gpu_resources()
264
  st.session_state['loading_in_progress'] = False
265
+
266
+ self.image_qa_app() # this is the main Q/A Application
267
 
268