m7mdal7aj commited on
Commit
e57843e
1 Parent(s): 7391509

Update my_model/KBVQA.py

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +20 -1
my_model/KBVQA.py CHANGED
@@ -21,6 +21,7 @@ class KBVQA():
21
  self.kbvqa_tokenizer = None
22
  self.captioner = None
23
  self.detector = None
 
24
  self.kbvqa_model = None
25
  self.access_token = os.getenv("HUGGINGFACE_TOKEN")
26
  # self.kbvqa_model_loaded = self.all_models_loaded()
@@ -87,6 +88,22 @@ class KBVQA():
87
  def all_models_loaded(self):
88
  return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
  def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
@@ -144,11 +161,13 @@ class KBVQA():
144
  def prepare_kbvqa_model(detection_model):
145
  free_gpu_resources()
146
  kbvqa = KBVQA()
 
147
  # Progress bar for model loading
148
  with st.spinner('Loading model...'):
149
 
150
  progress_bar = st.progress(0)
151
- kbvqa.load_detector(detection_model)
 
152
  progress_bar.progress(33)
153
  kbvqa.load_caption_model()
154
  free_gpu_resources()
 
21
  self.kbvqa_tokenizer = None
22
  self.captioner = None
23
  self.detector = None
24
+ sel.detection_model = None
25
  self.kbvqa_model = None
26
  self.access_token = os.getenv("HUGGINGFACE_TOKEN")
27
  # self.kbvqa_model_loaded = self.all_models_loaded()
 
88
  def all_models_loaded(self):
89
  return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
90
 
91
+ def force_reload_model(self):
92
+ free_gpu_resources()
93
+ if self.kbvqa_model is not None:
94
+ del self.kbvqa_model
95
+ if self.captioner is not None:
96
+ del self.captioner
97
+ if self.detector is not None:
98
+ del self.detector
99
+
100
+ free_gpu_resources()
101
+
102
+
103
+
104
+
105
+
106
+
107
 
108
 
109
  def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
 
161
  def prepare_kbvqa_model(detection_model):
162
  free_gpu_resources()
163
  kbvqa = KBVQA()
164
+ kbvqa.detection_model = detection_model
165
  # Progress bar for model loading
166
  with st.spinner('Loading model...'):
167
 
168
  progress_bar = st.progress(0)
169
+
170
+ kbvqa.load_detector(kbvqa.detection_model)
171
  progress_bar.progress(33)
172
  kbvqa.load_caption_model()
173
  free_gpu_resources()