m7mdal7aj commited on
Commit
df65239
1 Parent(s): 09f5cd2

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +10 -12
my_model/tabs/run_inference.py CHANGED
@@ -33,7 +33,7 @@ class InferenceRunner(StateManager):
33
  # self.initialize_state()
34
 
35
 
36
- def answer_question(self, caption, detected_objects_str, question, model):
37
  """
38
  Generates an answer to a given question based on the image's caption and detected objects.
39
 
@@ -41,27 +41,25 @@ class InferenceRunner(StateManager):
41
  caption (str): The caption generated for the image.
42
  detected_objects_str (str): String representation of objects detected in the image.
43
  question (str): The user's question about the image.
44
- model (KBVQA): The loaded KBVQA model used for generating the answer.
45
 
46
  Returns:
47
  str: The generated answer to the question.
48
  """
49
  free_gpu_resources()
50
- answer = model.generate_answer(question, caption, detected_objects_str)
51
  prompt_length = model.current_prompt_length
52
  free_gpu_resources()
53
  return answer, prompt_length
54
 
55
 
56
- def image_qa_app(self, kbvqa):
57
  """
58
  Main application interface for image-based question answering. It handles displaying
59
  of sample images, uploading of new images, and facilitates the QA process.
60
-
61
- Args:
62
- kbvqa (KBVQA): The loaded KBVQA model used for image analysis and question answering.
63
  """
64
 
 
65
  # Display sample images as clickable thumbnails
66
  self.col1.write("Choose from sample images:")
67
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
@@ -71,12 +69,12 @@ class InferenceRunner(StateManager):
71
  image_for_display = self.resize_image(sample_image_path, 80, 80)
72
  st.image(image_for_display)
73
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
74
- self.process_new_image(sample_image_path, image, kbvqa)
75
 
76
  # Image uploader
77
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
78
  if uploaded_image is not None:
79
- self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
80
 
81
  # Display and interact with each uploaded/selected image
82
  self.display_session_state()
@@ -91,7 +89,7 @@ class InferenceRunner(StateManager):
91
  with nested_col22:
92
  if st.button('Analyze Image', key=f'analyze_{image_key}', on_click=self.disable_widgets, disabled=self.is_widget_disabled):
93
 
94
- caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
95
  self.update_image_data(image_key, caption, detected_objects_str, True)
96
  st.session_state['loading_in_progress'] = False
97
 
@@ -121,7 +119,7 @@ class InferenceRunner(StateManager):
121
  else:
122
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
123
 
124
- answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
125
  st.session_state['loading_in_progress'] = False
126
  self.add_to_qa_history(image_key, question, answer, prompt_length)
127
 
@@ -198,6 +196,6 @@ class InferenceRunner(StateManager):
198
  if self.is_model_loaded:
199
  free_gpu_resources()
200
  st.session_state['loading_in_progress'] = False
201
- self.image_qa_app(self.get_model())
202
 
203
 
 
33
  # self.initialize_state()
34
 
35
 
36
+ def answer_question(self, caption, detected_objects_str, question):
37
  """
38
  Generates an answer to a given question based on the image's caption and detected objects.
39
 
 
41
  caption (str): The caption generated for the image.
42
  detected_objects_str (str): String representation of objects detected in the image.
43
  question (str): The user's question about the image.
44
+
45
 
46
  Returns:
47
  str: The generated answer to the question.
48
  """
49
  free_gpu_resources()
50
+ answer = self.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
51
  prompt_length = model.current_prompt_length
52
  free_gpu_resources()
53
  return answer, prompt_length
54
 
55
 
56
+ def image_qa_app(self):
57
  """
58
  Main application interface for image-based question answering. It handles displaying
59
  of sample images, uploading of new images, and facilitates the QA process.
 
 
 
60
  """
61
 
62
+
63
  # Display sample images as clickable thumbnails
64
  self.col1.write("Choose from sample images:")
65
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
 
69
  image_for_display = self.resize_image(sample_image_path, 80, 80)
70
  st.image(image_for_display)
71
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
72
+ self.process_new_image(sample_image_path, image)
73
 
74
  # Image uploader
75
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
76
  if uploaded_image is not None:
77
+ self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
78
 
79
  # Display and interact with each uploaded/selected image
80
  self.display_session_state()
 
89
  with nested_col22:
90
  if st.button('Analyze Image', key=f'analyze_{image_key}', on_click=self.disable_widgets, disabled=self.is_widget_disabled):
91
 
92
+ caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
93
  self.update_image_data(image_key, caption, detected_objects_str, True)
94
  st.session_state['loading_in_progress'] = False
95
 
 
119
  else:
120
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
121
 
122
+ answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
123
  st.session_state['loading_in_progress'] = False
124
  self.add_to_qa_history(image_key, question, answer, prompt_length)
125
 
 
196
  if self.is_model_loaded:
197
  free_gpu_resources()
198
  st.session_state['loading_in_progress'] = False
199
+ self.image_qa_app()
200
 
201