m7mdal7aj commited on
Commit
40d77a8
·
verified ·
1 Parent(s): 9d9c49e

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +30 -20
my_model/tabs/run_inference.py CHANGED
@@ -12,6 +12,7 @@ from my_model.object_detection import detect_and_draw_objects
12
  from my_model.captioner.image_captioning import get_caption
13
  from my_model.utilities.gen_utilities import free_gpu_resources
14
  from my_model.state_manager import StateManager
 
15
 
16
 
17
  class InferenceRunner(StateManager):
@@ -19,8 +20,8 @@ class InferenceRunner(StateManager):
19
 
20
  super().__init__()
21
  self.initialize_state()
22
- self.sample_images = [
23
- "Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg"]
24
 
25
  def answer_question(self, caption, detected_objects_str, question, model):
26
  free_gpu_resources()
@@ -76,16 +77,32 @@ class InferenceRunner(StateManager):
76
 
77
  if image_data['analysis_done']:
78
  st.session_state['loading_in_progress'] = False
79
- question = nested_col22.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
80
- if nested_col22.button('Get Answer', key=f'answer_{image_key}', on_click=self.disable_widgets, disabled=self.is_widget_disabled):
81
-
82
- if question not in [q for q, _ in qa_history]:
83
- answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
84
- st.session_state['loading_in_progress'] = False
85
- self.add_to_qa_history(image_key, question, answer)
86
- else: nested_col22.warning("This questions has already been answered.")
87
-
88
- st.session_state['loading_in_progress'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Display Q&A history for each image
91
  for num, (q, a) in enumerate(qa_history):
@@ -95,16 +112,9 @@ class InferenceRunner(StateManager):
95
  pass
96
 
97
 
98
-
99
-
100
-
101
-
102
-
103
-
104
 
105
  def run_inference(self):
106
- if st.session_state['loading_in_progress']:
107
- st.rerun()
108
  self.set_up_widgets()
109
 
110
  load_fine_tuned_model = False
 
12
  from my_model.captioner.image_captioning import get_caption
13
  from my_model.utilities.gen_utilities import free_gpu_resources
14
  from my_model.state_manager import StateManager
15
+ from my_model.config import inference_config as config
16
 
17
 
18
  class InferenceRunner(StateManager):
 
20
 
21
  super().__init__()
22
  self.initialize_state()
23
+ self.sample_images = config.SAMPLE_IMAGES
24
+
25
 
26
  def answer_question(self, caption, detected_objects_str, question, model):
27
  free_gpu_resources()
 
77
 
78
  if image_data['analysis_done']:
79
  st.session_state['loading_in_progress'] = False
80
+ sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
81
+ selected_question = nested_col22.selectbox(
82
+ "Select a sample question or type your own:",
83
+ ["Custom question..."] + sample_questions,
84
+ key=f'sample_question_{image_key}')
85
+
86
+ # Text input for custom question
87
+ custom_question = nested_col22.text_input(
88
+ "Or ask your own question:",
89
+ key=f'custom_question_{image_key}')
90
+ # Use the selected sample question or the custom question
91
+ question = custom_question if selected_question == "Custom question..." else selected_question
92
+
93
+
94
+ if not question:
95
+ nested_col22.warning("Please select or enter a question.")
96
+ else:
97
+ if question in [q for q, _ in qa_history]:
98
+ nested_col22.warning("This question has already been answered.")
99
+ else:
100
+ if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
101
+
102
+ answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
103
+ st.session_state['loading_in_progress'] = False
104
+ self.add_to_qa_history(image_key, question, answer)
105
+
106
 
107
  # Display Q&A history for each image
108
  for num, (q, a) in enumerate(qa_history):
 
112
  pass
113
 
114
 
 
 
 
 
 
 
115
 
116
  def run_inference(self):
117
+
 
118
  self.set_up_widgets()
119
 
120
  load_fine_tuned_model = False