hilamanor commited on
Commit
7e02fda
·
1 Parent(s): c533e68

fix temp files cache and move to ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +56 -40
app.py CHANGED
@@ -1,12 +1,18 @@
 
 
 
 
 
1
  import gradio as gr
2
  import random
3
  import torch
4
- import os
5
  from torch import inference_mode
6
- from tempfile import NamedTemporaryFile
 
7
  import numpy as np
8
  from models import load_model
9
  import utils
 
10
  from inversion_utils import inversion_forward_process, inversion_reverse_process
11
 
12
 
@@ -31,7 +37,7 @@ def randomize_seed_fn(seed, randomize_seed):
31
 
32
 
33
  def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
34
- ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
35
 
36
  with inference_mode():
37
  w0 = ldm_stable.vae_encode(x0)
@@ -67,21 +73,22 @@ def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # ,
67
 
68
  return (16000, audio.squeeze().cpu().numpy())
69
 
70
-
71
- def edit(cache_dir,
72
- input_audio,
73
- model_id: str,
74
- do_inversion: bool,
75
- wtszs_file: str,
76
- # wts: gr.State, zs: gr.State,
77
- saved_inv_model: str,
78
- source_prompt="",
79
- target_prompt="",
80
- steps=200,
81
- cfg_scale_src=3.5,
82
- cfg_scale_tar=12,
83
- t_start=45,
84
- randomize_seed=True):
 
85
 
86
  print(model_id)
87
  if model_id == LDM2:
@@ -89,7 +96,9 @@ def edit(cache_dir,
89
  elif model_id == LDM2_LARGE:
90
  ldm_stable = ldm2_large
91
  else: # MUSIC
92
- ldm_stable = ldm2_music
 
 
93
 
94
  # If the inversion was done for a different model, we need to re-run the inversion
95
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
@@ -99,29 +108,35 @@ def edit(cache_dir,
99
  raise gr.Error('Input audio missing!')
100
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
101
 
102
- if not (do_inversion or randomize_seed):
103
- if not os.path.exists(wtszs_file):
104
- do_inversion = True
105
  # Too much time has passed
 
 
106
 
107
  if do_inversion or randomize_seed: # always re-run inversion
108
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
109
  num_diffusion_steps=steps,
110
  cfg_scale_src=cfg_scale_src)
111
- f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False)
112
- torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name)
113
- wtszs_file = f.name
114
  # wtszs_file = gr.State(value=f.name)
115
  # wts = gr.State(value=wts_tensor)
 
 
116
  # zs = gr.State(value=zs_tensor)
117
  # demo.move_resource_to_block_cache(f.name)
118
  saved_inv_model = model_id
119
  do_inversion = False
120
  else:
121
- wtszs = torch.load(wtszs_file, map_location=device)
122
- # wtszs = torch.load(wtszs_file.f, map_location=device)
123
- wts_tensor = wtszs['wts']
124
- zs_tensor = wtszs['zs']
 
 
125
 
126
  # make sure t_start is in the right limit
127
  # t_start = change_tstart_range(t_start, steps)
@@ -129,7 +144,8 @@ def edit(cache_dir,
129
  output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
130
  tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
131
 
132
- return output, wtszs_file, saved_inv_model, do_inversion
 
133
 
134
 
135
  def get_example():
@@ -208,7 +224,7 @@ change <code style="display:inline; background-color: lightgrey; ">duration = mi
208
 
209
  """
210
 
211
- with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
212
  def reset_do_inversion(do_inversion_user, do_inversion):
213
  # do_inversion = gr.State(value=True)
214
  do_inversion = True
@@ -219,18 +235,18 @@ with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
219
  def clear_do_inversion_user(do_inversion_user):
220
  do_inversion_user = False
221
  return do_inversion_user
 
222
  def post_match_do_inversion(do_inversion_user, do_inversion):
223
  if do_inversion_user:
224
  do_inversion = True
225
  do_inversion_user = False
226
  return do_inversion_user, do_inversion
227
 
228
-
229
  gr.HTML(intro)
230
- # wts = gr.State()
231
- # zs = gr.State()
232
  wtszs = gr.State()
233
- cache_dir = gr.State(demo.GRADIO_CACHE)
234
  saved_inv_model = gr.State()
235
  # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
236
  # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
@@ -293,13 +309,13 @@ with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
293
  outputs=[seed], queue=False).then(
294
  fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]).then(
295
  fn=edit,
296
- inputs=[cache_dir,
297
  input_audio,
298
  model_id,
299
  do_inversion,
300
  # current_loaded_model, ldm_stable,
301
- # wts, zs,
302
- wtszs,
303
  saved_inv_model,
304
  src_prompt,
305
  tar_prompt,
@@ -309,7 +325,7 @@ with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
309
  t_start,
310
  randomize_seed
311
  ],
312
- outputs=[output_audio, wtszs,
313
  saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
314
  ).then(post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion]
315
  ).then(lambda x: (demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None),
@@ -332,4 +348,4 @@ with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
332
  )
333
 
334
  demo.queue()
335
- demo.launch()
 
1
+ # Will be fixed soon, but meanwhile:
2
+ import os
3
+ if os.getenv('SPACES_ZERO_GPU') == "true":
4
+ os.environ['SPACES_ZERO_GPU'] = "1"
5
+
6
  import gradio as gr
7
  import random
8
  import torch
 
9
  from torch import inference_mode
10
+ # from tempfile import NamedTemporaryFile
11
+ from typing import Optional
12
  import numpy as np
13
  from models import load_model
14
  import utils
15
+ import spaces
16
  from inversion_utils import inversion_forward_process, inversion_reverse_process
17
 
18
 
 
37
 
38
 
39
  def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
40
+ # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
41
 
42
  with inference_mode():
43
  w0 = ldm_stable.vae_encode(x0)
 
73
 
74
  return (16000, audio.squeeze().cpu().numpy())
75
 
76
+ @spaces.GPU
77
+ def edit(
78
+ # cache_dir,
79
+ input_audio,
80
+ model_id: str,
81
+ do_inversion: bool,
82
+ # wtszs_file: str,
83
+ wts: Optional[torch.Tensor], zs: Optional[torch.Tensor],
84
+ saved_inv_model: str,
85
+ source_prompt="",
86
+ target_prompt="",
87
+ steps=200,
88
+ cfg_scale_src=3.5,
89
+ cfg_scale_tar=12,
90
+ t_start=45,
91
+ randomize_seed=True):
92
 
93
  print(model_id)
94
  if model_id == LDM2:
 
96
  elif model_id == LDM2_LARGE:
97
  ldm_stable = ldm2_large
98
  else: # MUSIC
99
+ ldm_stable = ldm2_music
100
+
101
+ ldm_stable.model.scheduler.set_timesteps(steps, device=device)
102
 
103
  # If the inversion was done for a different model, we need to re-run the inversion
104
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
 
108
  raise gr.Error('Input audio missing!')
109
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
110
 
111
+ # if not (do_inversion or randomize_seed):
112
+ # if not os.path.exists(wtszs_file):
113
+ # do_inversion = True
114
  # Too much time has passed
115
+ if wts is None or zs is None:
116
+ do_inversion = True
117
 
118
  if do_inversion or randomize_seed: # always re-run inversion
119
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
120
  num_diffusion_steps=steps,
121
  cfg_scale_src=cfg_scale_src)
122
+ # f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False)
123
+ # torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name)
124
+ # wtszs_file = f.name
125
  # wtszs_file = gr.State(value=f.name)
126
  # wts = gr.State(value=wts_tensor)
127
+ wts = wts_tensor
128
+ zs = zs_tensor
129
  # zs = gr.State(value=zs_tensor)
130
  # demo.move_resource_to_block_cache(f.name)
131
  saved_inv_model = model_id
132
  do_inversion = False
133
  else:
134
+ # wtszs = torch.load(wtszs_file, map_location=device)
135
+ # # wtszs = torch.load(wtszs_file.f, map_location=device)
136
+ # wts_tensor = wtszs['wts']
137
+ # zs_tensor = wtszs['zs']
138
+ wts_tensor = wts.to(device)
139
+ zs_tensor = zs.to(device)
140
 
141
  # make sure t_start is in the right limit
142
  # t_start = change_tstart_range(t_start, steps)
 
144
  output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
145
  tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
146
 
147
+ return output, wts.cpu(), zs.cpu(), saved_inv_model, do_inversion
148
+ # return output, wtszs_file, saved_inv_model, do_inversion
149
 
150
 
151
  def get_example():
 
224
 
225
  """
226
 
227
+ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
228
  def reset_do_inversion(do_inversion_user, do_inversion):
229
  # do_inversion = gr.State(value=True)
230
  do_inversion = True
 
235
  def clear_do_inversion_user(do_inversion_user):
236
  do_inversion_user = False
237
  return do_inversion_user
238
+
239
  def post_match_do_inversion(do_inversion_user, do_inversion):
240
  if do_inversion_user:
241
  do_inversion = True
242
  do_inversion_user = False
243
  return do_inversion_user, do_inversion
244
 
 
245
  gr.HTML(intro)
246
+ wts = gr.State()
247
+ zs = gr.State()
248
  wtszs = gr.State()
249
+ # cache_dir = gr.State(demo.GRADIO_CACHE)
250
  saved_inv_model = gr.State()
251
  # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
252
  # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
 
309
  outputs=[seed], queue=False).then(
310
  fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]).then(
311
  fn=edit,
312
+ inputs=[#cache_dir,
313
  input_audio,
314
  model_id,
315
  do_inversion,
316
  # current_loaded_model, ldm_stable,
317
+ wts, zs,
318
+ # wtszs,
319
  saved_inv_model,
320
  src_prompt,
321
  tar_prompt,
 
325
  t_start,
326
  randomize_seed
327
  ],
328
+ outputs=[output_audio, wts, zs, # wtszs,
329
  saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
330
  ).then(post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion]
331
  ).then(lambda x: (demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None),
 
348
  )
349
 
350
  demo.queue()
351
+ demo.launch(state_session_capacity=15)