m7mdal7aj commited on
Commit
1948116
·
verified ·
1 Parent(s): 247b080

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +82 -92
my_model/tabs/run_inference.py CHANGED
@@ -10,99 +10,89 @@ import pandas as pd
10
  from my_model.object_detection import detect_and_draw_objects
11
  from my_model.captioner.image_captioning import get_caption
12
  from my_model.gen_utilities import free_gpu_resources
13
- from my_model.KBVQA import KBVQA, prepare_kbvqa_model
14
  from my_model.state_manager import StateManager
15
 
16
  state_manager = StateManager()
17
 
18
- def answer_question(caption, detected_objects_str, question, model):
19
- free_gpu_resources()
20
- answer = model.generate_answer(question, caption, detected_objects_str)
21
- free_gpu_resources()
22
- return answer
23
-
24
-
25
- # Sample images (assuming these are paths to your sample images)
26
- sample_images = ["Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
27
- "Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
28
- "Files/sample7.jpg"]
29
-
30
-
31
-
32
-
33
- def image_qa_app(kbvqa):
34
- # Display sample images as clickable thumbnails
35
- st.write("Choose from sample images:")
36
- cols = st.columns(len(sample_images))
37
- for idx, sample_image_path in enumerate(sample_images):
38
- with cols[idx]:
39
- image = Image.open(sample_image_path)
40
- st.image(image, use_column_width=True)
41
- if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
42
- state_manager.process_new_image(sample_image_path, image, kbvqa)
43
-
44
- # Image uploader
45
- uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
46
- if uploaded_image is not None:
47
- state_manager.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
48
-
49
- # Display and interact with each uploaded/selected image
50
- for image_key, image_data in state_manager.get_images_data().items():
51
- st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
52
- if not image_data['analysis_done']:
53
- st.text("Cool image, please click 'Analyze Image'..")
54
- if st.button('Analyze Image', key=f'analyze_{image_key}'):
55
- caption, detected_objects_str, image_with_boxes = state_manager.analyze_image(image_data['image'], kbvqa)
56
- state_manager.update_image_data(image_key, caption, detected_objects_str, True)
57
-
58
- # Initialize qa_history for each image
59
- qa_history = image_data.get('qa_history', [])
60
-
61
- if image_data['analysis_done']:
62
- question = st.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
63
- if st.button('Get Answer', key=f'answer_{image_key}'):
64
- if question not in [q for q, _ in qa_history]:
65
- answer = answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
66
- state_manager.add_to_qa_history(image_key, question, answer)
67
-
68
- # Display Q&A history for each image
69
- for q, a in qa_history:
70
- st.text(f"Q: {q}\nA: {a}\n")
71
-
72
-
73
-
74
-
75
-
76
- def run_inference():
77
-
78
- st.title("Run Inference")
79
- state_manager.initialize_state()
80
- state_manager.set_up_widgets()
81
- st.session_state['settings_changed'] = state_manager.has_state_changed()
82
- if st.session_state['settings_changed']:
83
- st.warning("Model settings have changed, please reload the model, this will take a second .. ")
84
-
85
- st.session_state.button_label = "Reload Model" if state_manager.is_model_loaded() and state_manager.settings_changed else "Load Model"
86
- # state_manager.display_session_state()
87
-
88
-
89
- if st.session_state.method == "Fine-Tuned Model":
90
- if st.button(st.session_state.button_label):
91
- if st.session_state.button_label == "Load Model":
92
- if state_manager.is_model_loaded():
93
- st.text("Model already loaded and no settings were changed:)")
94
-
95
- else: state_manager.load_model()
96
-
97
- else:
98
- state_manager.reload_detection_model()
99
-
100
-
101
- if state_manager.is_model_loaded() and st.session_state.kbvqa.all_models_loaded:
102
- image_qa_app(state_manager.get_model())
103
-
104
- else:
105
- st.write(f'Model using {st.session_state.method} is not deplyed yet, will be ready later.')
106
-
107
-
108
-
 
10
  from my_model.object_detection import detect_and_draw_objects
11
  from my_model.captioner.image_captioning import get_caption
12
  from my_model.gen_utilities import free_gpu_resources
13
+ #from my_model.KBVQA import KBVQA, prepare_kbvqa_model
14
  from my_model.state_manager import StateManager
15
 
16
  state_manager = StateManager()
17
 
18
+ class InferenceRunner(StateManager):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.sample_images = [
22
+ "Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
23
+ "Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
24
+ "Files/sample7.jpg"
25
+ ]
26
+
27
+ def answer_question(self, caption, detected_objects_str, question, model):
28
+ free_gpu_resources()
29
+ answer = model.generate_answer(question, caption, detected_objects_str)
30
+ free_gpu_resources()
31
+ return answer
32
+
33
+
34
+ def image_qa_app(self, kbvqa):
35
+ # Display sample images as clickable thumbnails
36
+ st.write("Choose from sample images:")
37
+ cols = st.columns(len(self.sample_images))
38
+ for idx, sample_image_path in enumerate(self.sample_images):
39
+ with cols[idx]:
40
+ image = Image.open(sample_image_path)
41
+ st.image(image, use_column_width=True)
42
+ if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
43
+ self.process_new_image(sample_image_path, image, kbvqa)
44
+
45
+ # Image uploader
46
+ uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
47
+ if uploaded_image is not None:
48
+ self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
49
+
50
+ # Display and interact with each uploaded/selected image
51
+ for image_key, image_data in self.get_images_data().items():
52
+ st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
53
+ if not image_data['analysis_done']:
54
+ st.text("Cool image, please click 'Analyze Image'..")
55
+ if st.button('Analyze Image', key=f'analyze_{image_key}'):
56
+ caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
57
+ self.update_image_data(image_key, caption, detected_objects_str, True)
58
+
59
+ # Initialize qa_history for each image
60
+ qa_history = image_data.get('qa_history', [])
61
+
62
+ if image_data['analysis_done']:
63
+ question = st.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
64
+ if st.button('Get Answer', key=f'answer_{image_key}'):
65
+ if question not in [q for q, _ in qa_history]:
66
+ answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
67
+ self.add_to_qa_history(image_key, question, answer)
68
+
69
+ # Display Q&A history for each image
70
+ for q, a in qa_history:
71
+ st.text(f"Q: {q}\nA: {a}\n")
72
+
73
+
74
+
75
+ def run_inference(self):
76
+ st.title("Run Inference")
77
+ self.initialize_state()
78
+ self.set_up_widgets()
79
+ st.session_state['settings_changed'] = self.has_state_changed()
80
+ if st.session_state['settings_changed']:
81
+ st.warning("Model settings have changed, please reload the model, this will take a second .. ")
82
+
83
+ st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
84
+
85
+ if st.session_state.method == "Fine-Tuned Model":
86
+ if st.button(st.session_state.button_label):
87
+ if st.session_state.button_label == "Load Model":
88
+ if self.is_model_loaded():
89
+ st.text("Model already loaded and no settings were changed:)")
90
+ else:
91
+ self.load_model()
92
+ else:
93
+ self.reload_detection_model()
94
+
95
+ if self.is_model_loaded() and st.session_state.kbvqa.all_models_loaded:
96
+ self.image_qa_app(self.get_model())
97
+ else:
98
+ st.write(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')