radames commited on
Commit
5564afe
1 Parent(s): 6846772

filter models

Browse files
Files changed (1) hide show
  1. visualizer_drag_gradio.py +18 -8
visualizer_drag_gradio.py CHANGED
@@ -30,6 +30,7 @@ args = parser.parse_args()
30
  cache_dir = args.cache_dir
31
 
32
  device = 'cuda'
 
33
 
34
 
35
  def reverse_point_pairs(points):
@@ -154,17 +155,22 @@ def preprocess_mask_info(global_state, image):
154
  return global_state
155
 
156
 
 
 
 
 
 
 
 
157
  valid_checkpoints_dict = {
158
- f.split('/')[-1].split('.')[0]: osp.join(cache_dir, f)
159
- for f in os.listdir(cache_dir)
160
- if (f.endswith('pkl') and osp.exists(osp.join(cache_dir, f)))
161
  }
162
- print(f'File under cache_dir ({cache_dir}):')
163
- print(os.listdir(cache_dir))
164
  print('Valid checkpoint file:')
165
  print(valid_checkpoints_dict)
166
 
167
- init_pkl = 'stylegan_human_v2_512'
168
 
169
  with gr.Blocks() as app:
170
  gr.Markdown("""
@@ -241,9 +247,13 @@ with gr.Blocks() as app:
241
  gr.Markdown(value='Latent', show_label=False)
242
 
243
  with gr.Column(scale=4, min_width=10):
244
- form_seed_number = gr.Number(
 
 
 
245
  value=global_state.value['params']['seed'],
246
  interactive=True,
 
247
  label="Seed",
248
  )
249
  form_lr_number = gr.Number(
@@ -865,5 +875,5 @@ with gr.Blocks() as app:
865
 
866
  print("SHAReD: Start app", parser.parse_args())
867
  gr.close_all()
868
- app.queue(concurrency_count=2, max_size=20, api_open=False)
869
  app.launch(share=args.share, show_api=False)
 
30
  cache_dir = args.cache_dir
31
 
32
  device = 'cuda'
33
+ IS_SPACE = "radames/DragGan" in os.environ.get('SPACE_ID', '')
34
 
35
 
36
  def reverse_point_pairs(points):
 
155
  return global_state
156
 
157
 
158
+ # filter large models running on SPACES
159
+ if IS_SPACE:
160
+ blocked_checkpoints = ["stylegan_human_v2_512.pkl",
161
+ "stylegan2_dogs_1024_pytorch.pkl"]
162
+ else:
163
+ blocked_checkpoints = []
164
+
165
  valid_checkpoints_dict = {
166
+ f.name.split('.')[0]: str(f)
167
+ for f in Path(cache_dir).glob('*.pkl')
168
+ if f.name not in blocked_checkpoints
169
  }
 
 
170
  print('Valid checkpoint file:')
171
  print(valid_checkpoints_dict)
172
 
173
+ init_pkl = 'stylegan2_lions_512_pytorch'
174
 
175
  with gr.Blocks() as app:
176
  gr.Markdown("""
 
247
  gr.Markdown(value='Latent', show_label=False)
248
 
249
  with gr.Column(scale=4, min_width=10):
250
+ form_seed_number = gr.Slider(
251
+ mininium=0,
252
+ maximum=2**32-1,
253
+ step=1,
254
  value=global_state.value['params']['seed'],
255
  interactive=True,
256
+ randomize=True,
257
  label="Seed",
258
  )
259
  form_lr_number = gr.Number(
 
875
 
876
  print("SHAReD: Start app", parser.parse_args())
877
  gr.close_all()
878
+ app.queue(concurrency_count=5, max_size=20, api_open=False)
879
  app.launch(share=args.share, show_api=False)