m7mdal7aj commited on
Commit
b5978b1
1 Parent(s): 990f892

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +16 -5
my_model/tabs/run_inference.py CHANGED
@@ -61,6 +61,7 @@ class InferenceRunner(StateManager):
61
 
62
 
63
  # Display sample images as clickable thumbnails
 
64
  self.col1.write("Choose from sample images:")
65
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
66
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
@@ -80,17 +81,20 @@ class InferenceRunner(StateManager):
80
  self.display_session_state()
81
  with self.col2:
82
  for image_key, image_data in self.get_images_data().items():
 
83
  with st.container():
84
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
85
  image_for_display = self.resize_image(image_data['image'], 600)
86
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
87
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change: # if not done analysis before or even done but settings changed, then we need to analyze again
88
-
89
  nested_col22.text("Please click 'Analyze Image'..")
90
  free_gpu_resources()
91
  with nested_col22:
 
92
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}' # unique key for each click
93
  st.write(analyze_button_key)
 
94
  if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
95
  st.text("AAAAAAAAAAAAAAAAAAAAA")
96
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
@@ -99,16 +103,18 @@ class InferenceRunner(StateManager):
99
  self.update_image_data(image_key, caption, detected_objects_str, True)
100
  st.session_state['loading_in_progress'] = False
101
  free_gpu_resources()
 
102
 
103
  # Initialize qa_history for each image
104
  qa_history = image_data.get('qa_history', [])
105
-
106
  if image_data['analysis_done']:
 
107
  free_gpu_resources()
108
  if self.confidance_change:
109
-
110
  nested_col22.warning("Confidence level changed, please click analyze again.")
111
- self.update_prev_state()
112
  st.session_state['loading_in_progress'] = False
113
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
114
  selected_question = nested_col22.selectbox(
@@ -131,6 +137,7 @@ class InferenceRunner(StateManager):
131
  else:
132
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
133
  free_gpu_resources()
 
134
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
135
  st.session_state['loading_in_progress'] = False
136
  self.add_to_qa_history(image_key, question, answer, prompt_length)
@@ -139,6 +146,8 @@ class InferenceRunner(StateManager):
139
  for num, (q, a, p) in enumerate(qa_history):
140
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
141
  free_gpu_resources()
 
 
142
 
143
 
144
  def run_inference(self):
@@ -149,6 +158,7 @@ class InferenceRunner(StateManager):
149
  """
150
 
151
  self.set_up_widgets()
 
152
  load_fine_tuned_model = False
153
  fine_tuned_model_already_loaded = False
154
  reload_detection_model = False
@@ -158,9 +168,10 @@ class InferenceRunner(StateManager):
158
  if self.is_model_loaded and self.settings_changed:
159
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
160
  self.update_prev_state()
 
161
 
162
  st.session_state.button_label = "Reload Model" if self.is_model_loaded and self.settings_changed else "Load Model"
163
-
164
  with self.col1:
165
  if st.session_state.method == "Fine-Tuned Model":
166
  with st.container():
 
61
 
62
 
63
  # Display sample images as clickable thumbnails
64
+ st.write("D")
65
  self.col1.write("Choose from sample images:")
66
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
67
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
 
81
  self.display_session_state()
82
  with self.col2:
83
  for image_key, image_data in self.get_images_data().items():
84
+ st.write("E")
85
  with st.container():
86
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
87
  image_for_display = self.resize_image(image_data['image'], 600)
88
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
89
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change: # if not done analysis before or even done but settings changed, then we need to analyze again
90
+ st.write("F")
91
  nested_col22.text("Please click 'Analyze Image'..")
92
  free_gpu_resources()
93
  with nested_col22:
94
+ st.write("G")
95
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}' # unique key for each click
96
  st.write(analyze_button_key)
97
+ st.write("H")
98
  if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
99
  st.text("AAAAAAAAAAAAAAAAAAAAA")
100
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
 
103
  self.update_image_data(image_key, caption, detected_objects_str, True)
104
  st.session_state['loading_in_progress'] = False
105
  free_gpu_resources()
106
+ st.write("II")
107
 
108
  # Initialize qa_history for each image
109
  qa_history = image_data.get('qa_history', [])
110
+ st.write("J")
111
  if image_data['analysis_done']:
112
+ st.write("K")
113
  free_gpu_resources()
114
  if self.confidance_change:
115
+ st.write("L")
116
  nested_col22.warning("Confidence level changed, please click analyze again.")
117
+
118
  st.session_state['loading_in_progress'] = False
119
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
120
  selected_question = nested_col22.selectbox(
 
137
  else:
138
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
139
  free_gpu_resources()
140
+ st.write("M")
141
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
142
  st.session_state['loading_in_progress'] = False
143
  self.add_to_qa_history(image_key, question, answer, prompt_length)
 
146
  for num, (q, a, p) in enumerate(qa_history):
147
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
148
  free_gpu_resources()
149
+ st.write("N")
150
+ self.update_prev_state()
151
 
152
 
153
  def run_inference(self):
 
158
  """
159
 
160
  self.set_up_widgets()
161
+ st.write("A")
162
  load_fine_tuned_model = False
163
  fine_tuned_model_already_loaded = False
164
  reload_detection_model = False
 
168
  if self.is_model_loaded and self.settings_changed:
169
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
170
  self.update_prev_state()
171
+ st.write("B")
172
 
173
  st.session_state.button_label = "Reload Model" if self.is_model_loaded and self.settings_changed else "Load Model"
174
+ st.write("C")
175
  with self.col1:
176
  if st.session_state.method == "Fine-Tuned Model":
177
  with st.container():