m7mdal7aj commited on
Commit
7e54217
1 Parent(s): 69b926c

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +10 -2
my_model/tabs/run_inference.py CHANGED
@@ -4,6 +4,7 @@ import bitsandbytes
4
  import accelerate
5
  import scipy
6
  import copy
 
7
  from PIL import Image
8
  import torch.nn as nn
9
  import pandas as pd
@@ -32,6 +33,7 @@ class InferenceRunner(StateManager):
32
  # Display sample images as clickable thumbnails
33
  self.col1.write("Choose from sample images:")
34
  cols = self.col1.columns(len(self.sample_images))
 
35
  for idx, sample_image_path in enumerate(self.sample_images):
36
  with cols[idx]:
37
  image = Image.open(sample_image_path)
@@ -108,7 +110,7 @@ class InferenceRunner(StateManager):
108
  with st.container():
109
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
110
  if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
111
-
112
  if st.session_state.button_label == "Load Model":
113
  if self.is_model_loaded():
114
  free_gpu_resources()
@@ -121,10 +123,12 @@ class InferenceRunner(StateManager):
121
 
122
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
123
  force_reload_full_model = True
 
124
 
125
  if load_fine_tuned_model:
126
  free_gpu_resources()
127
  self.load_model()
 
128
  st.session_state['loading_in_progress'] = False
129
 
130
  elif fine_tuned_model_already_loaded:
@@ -139,8 +143,11 @@ class InferenceRunner(StateManager):
139
 
140
  elif force_reload_full_model:
141
  free_gpu_resources()
 
142
  self.force_reload_model()
 
143
  st.session_state['loading_in_progress'] = False
 
144
 
145
  elif st.session_state.method == "In-Context Learning (n-shots)":
146
  self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
@@ -148,8 +155,9 @@ class InferenceRunner(StateManager):
148
 
149
 
150
  if self.is_model_loaded():
151
- st.session_state['loading_in_progress'] = False
152
  free_gpu_resources()
 
153
  self.image_qa_app(self.get_model())
154
  st.write(st.session_state['loading_in_progress'])
155
 
 
4
  import accelerate
5
  import scipy
6
  import copy
7
+ import time
8
  from PIL import Image
9
  import torch.nn as nn
10
  import pandas as pd
 
33
  # Display sample images as clickable thumbnails
34
  self.col1.write("Choose from sample images:")
35
  cols = self.col1.columns(len(self.sample_images))
36
+ st.write(st.session_state['loading_in_progress'])
37
  for idx, sample_image_path in enumerate(self.sample_images):
38
  with cols[idx]:
39
  image = Image.open(sample_image_path)
 
110
  with st.container():
111
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
112
  if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
113
+ t1=time.time()
114
  if st.session_state.button_label == "Load Model":
115
  if self.is_model_loaded():
116
  free_gpu_resources()
 
123
 
124
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
125
  force_reload_full_model = True
126
+ t1=time.time()
127
 
128
  if load_fine_tuned_model:
129
  free_gpu_resources()
130
  self.load_model()
131
+
132
  st.session_state['loading_in_progress'] = False
133
 
134
  elif fine_tuned_model_already_loaded:
 
143
 
144
  elif force_reload_full_model:
145
  free_gpu_resources()
146
+
147
  self.force_reload_model()
148
+
149
  st.session_state['loading_in_progress'] = False
150
+ st.session_state['model_loaded'] = True
151
 
152
  elif st.session_state.method == "In-Context Learning (n-shots)":
153
  self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
 
155
 
156
 
157
  if self.is_model_loaded():
158
+ st.session_state['time_taken_to_load_model'] = time.time()-t1
159
  free_gpu_resources()
160
+ st.session_state['loading_in_progress'] = False
161
  self.image_qa_app(self.get_model())
162
  st.write(st.session_state['loading_in_progress'])
163