m7mdal7aj commited on
Commit
0f0c882
·
verified ·
1 Parent(s): 4382e6a

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +41 -71
my_model/tabs/run_inference.py CHANGED
@@ -51,7 +51,6 @@ class InferenceRunner(StateManager):
51
 
52
 
53
  def display_sample_images(self):
54
- # Display sample images as clickable thumbnails
55
  self.col1.write("Choose from sample images:")
56
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
57
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
@@ -62,83 +61,54 @@ class InferenceRunner(StateManager):
62
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
63
  self.process_new_image(sample_image_path, image)
64
 
65
- def image_qa_app(self):
66
- """
67
- Main application interface for image-based question answering. It handles displaying
68
- of sample images, uploading of new images, and facilitates the QA process.
69
- """
70
-
71
- self.display_sample_images()
72
- # Image uploader
73
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
74
  if uploaded_image is not None:
75
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
76
 
77
- # Display and interact with each uploaded/selected image
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  with self.col2:
80
  for image_key, image_data in self.get_images_data().items():
81
-
82
  with st.container():
83
- nested_col21, nested_col22 = st.columns([0.65, 0.35])
84
- image_for_display = self.resize_image(image_data['image'], 600)
85
- nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
86
-
87
- if not image_data['analysis_done'] or self.settings_changed or self.confidance_change: # if not done analysis before or even done but settings changed, then we need to analyze again
88
-
89
- nested_col22.text("Please click 'Analyze Image'..")
90
- free_gpu_resources()
91
- with nested_col22:
92
-
93
- analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}' # unique key for each click
94
-
95
- if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
96
- caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
97
-
98
- self.update_image_data(image_key, caption, detected_objects_str, True)
99
- st.session_state['loading_in_progress'] = False
100
- free_gpu_resources()
101
-
102
-
103
- # Initialize qa_history for each image
104
- qa_history = image_data.get('qa_history', [])
105
-
106
- if image_data['analysis_done']:
107
-
108
- free_gpu_resources()
109
- if self.confidance_change:
110
-
111
- nested_col22.warning("If you change the Confidence level, please click analyze again.")
112
-
113
- st.session_state['loading_in_progress'] = False
114
- sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
115
- selected_question = nested_col22.selectbox(
116
- "Select a sample question or type your own:",
117
- ["Custom question..."] + sample_questions,
118
- key=f'sample_question_{image_key}')
119
-
120
- # Text input for custom question
121
- custom_question = nested_col22.text_input(
122
- "Or ask your own question:",
123
- key=f'custom_question_{image_key}')
124
- # Use the selected sample question or the custom question
125
- question = custom_question if selected_question == "Custom question..." else selected_question
126
-
127
- if question in [q for q, _, _ in qa_history] and not self.settings_changed and not self.confidance_change:
128
- nested_col22.warning("This question has already been answered.")
129
- else:
130
- if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
131
- free_gpu_resources()
132
-
133
- answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
134
- st.session_state['loading_in_progress'] = False
135
- self.add_to_qa_history(image_key, question, answer, prompt_length)
136
-
137
- # Display Q&A history and prompts lengths for each image
138
- for num, (q, a, p) in enumerate(qa_history):
139
- nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
140
- free_gpu_resources()
141
-
142
 
143
 
144
 
 
51
 
52
 
53
  def display_sample_images(self):
 
54
  self.col1.write("Choose from sample images:")
55
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
56
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
 
61
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
62
  self.process_new_image(sample_image_path, image)
63
 
64
+ def handle_image_upload(self):
 
 
 
 
 
 
 
65
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
66
  if uploaded_image is not None:
67
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
68
 
69
+ def display_image_and_analysis(self, image_key, image_data):
70
+ nested_col21, nested_col22 = st.columns([0.65, 0.35])
71
+ image_for_display = self.resize_image(image_data['image'], 600)
72
+ nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
73
+ self.handle_analysis_button(image_key, image_data, nested_col22)
74
+
75
+ def handle_analysis_button(self, image_key, image_data, nested_col22):
76
+ if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
77
+ nested_col22.text("Please click 'Analyze Image'..")
78
+ analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}'
79
+ if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
80
+ caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
81
+ self.update_image_data(image_key, caption, detected_objects_str, True)
82
+ st.session_state['loading_in_progress'] = False
83
+
84
+ def handle_question_answering(self, image_key, image_data, nested_col22):
85
+ if image_data['analysis_done']:
86
+ self.display_question_answering_interface(image_key, image_data, nested_col22)
87
+
88
+ def display_question_answering_interface(self, image_key, image_data, nested_col22):
89
+ sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
90
+ selected_question = nested_col22.selectbox("Select a sample question or type your own:", ["Custom question..."] + sample_questions, key=f'sample_question_{image_key}')
91
+ custom_question = nested_col22.text_input("Or ask your own question:", key=f'custom_question_{image_key}')
92
+ question = custom_question if selected_question == "Custom question..." else selected_question
93
+ self.process_question(image_key, question, image_data, nested_col22)
94
+
95
+ def process_question(self, image_key, question, image_data, nested_col22):
96
+ qa_history = image_data.get('qa_history', [])
97
+ if question and (question not in [q for q, _, _ in qa_history] or self.settings_changed or self.confidance_change):
98
+ if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
99
+ answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
100
+ self.add_to_qa_history(image_key, question, answer, prompt_length)
101
+ nested_col22.text(f"Q: {question}\nA: {answer}\nPrompt Length: {prompt_length}")
102
+
103
+ def image_qa_app(self):
104
+ self.display_sample_images()
105
+ self.handle_image_upload()
106
+ self.display_session_state()
107
  with self.col2:
108
  for image_key, image_data in self.get_images_data().items():
 
109
  with st.container():
110
+ self.display_image_and_analysis(image_key, image_data)
111
+ self.handle_question_answering(image_key, image_data, nested_col22)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114