linoyts HF staff commited on
Commit
19efc84
1 Parent(s): c719a09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -10,6 +10,19 @@ import utils
10
  from inversion_utils import inversion_forward_process, inversion_reverse_process
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def randomize_seed_fn(seed, randomize_seed):
14
  if randomize_seed:
15
  seed = random.randint(0, np.iinfo(np.int32).max)
@@ -17,7 +30,7 @@ def randomize_seed_fn(seed, randomize_seed):
17
  return seed
18
 
19
 
20
- def invert(x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
21
  ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
22
 
23
  with inference_mode():
@@ -34,7 +47,7 @@ def invert(x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable)
34
 
35
 
36
 
37
- def sample(zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable):
38
  # reverse process (via Zs and wT)
39
  tstart = torch.tensor(tstart, dtype=torch.int)
40
  skip = steps - tstart
@@ -79,13 +92,22 @@ def edit(input_audio,
79
  t_start=90,
80
  randomize_seed=True):
81
 
82
- global ldm_stable, current_loaded_model
83
- print(f'current loaded model: {ldm_stable.model_id}')
84
- if model_id != current_loaded_model:
85
- print(f'Changing model to {model_id}...')
86
- current_loaded_model = model_id
87
- ldm_stable = None
88
- ldm_stable = load_model(model_id, device, steps)
 
 
 
 
 
 
 
 
 
89
 
90
  # If the inversion was done for a different model, we need to re-run the inversion
91
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
@@ -94,7 +116,7 @@ def edit(input_audio,
94
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
95
 
96
  if do_inversion or randomize_seed: # always re-run inversion
97
- zs_tensor, wts_tensor = invert(x0=x0, prompt_src=source_prompt,
98
  num_diffusion_steps=steps,
99
  cfg_scale_src=cfg_scale_src)
100
  wts = gr.State(value=wts_tensor)
@@ -105,16 +127,13 @@ def edit(input_audio,
105
  # make sure t_start is in the right limit
106
  t_start = change_tstart_range(t_start, steps)
107
 
108
- output = sample(zs.value, wts.value, steps, prompt_tar=target_prompt, tstart=t_start,
109
  cfg_scale_tar=cfg_scale_tar)
110
 
111
  return output, wts, zs, saved_inv_model, do_inversion
112
 
113
 
114
- current_loaded_model = "cvssp/audioldm2-music"
115
- # current_loaded_model = "cvssp/audioldm2-music"
116
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
- ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
118
 
119
 
120
  def get_example():
@@ -267,7 +286,7 @@ with gr.Blocks(css='style.css') as demo:
267
  input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
268
  src_prompt.change(fn=reset_do_inversion, outputs=[do_inversion])
269
  model_id.change(fn=reset_do_inversion, outputs=[do_inversion])
270
- # steps.change(fn=change_tstart_range, inputs=[steps], outputs=[t_start])
271
 
272
  gr.Examples(
273
  label="Examples",
 
10
  from inversion_utils import inversion_forward_process, inversion_reverse_process
11
 
12
 
13
+ # current_loaded_model = "cvssp/audioldm2-music"
14
+ # # current_loaded_model = "cvssp/audioldm2-music"
15
+
16
+ # ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
17
+ LDM2 = "cvssp/audioldm2"
18
+ MUSIC = "cvssp/audioldm2-music"
19
+ LDM2_LARGE = "cvssp/audioldm2-large"
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ ldm2 = load_model(model_id=LDM2, device=device)
22
+ ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
23
+ ldm2_music = load_model(model_id= MUSIC, device=device)
24
+
25
+
26
  def randomize_seed_fn(seed, randomize_seed):
27
  if randomize_seed:
28
  seed = random.randint(0, np.iinfo(np.int32).max)
 
30
  return seed
31
 
32
 
33
+ def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
34
  ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
35
 
36
  with inference_mode():
 
47
 
48
 
49
 
50
+ def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable):
51
  # reverse process (via Zs and wT)
52
  tstart = torch.tensor(tstart, dtype=torch.int)
53
  skip = steps - tstart
 
92
  t_start=90,
93
  randomize_seed=True):
94
 
95
+ # global ldm_stable, current_loaded_model
96
+ # print(f'current loaded model: {ldm_stable.model_id}')
97
+ # if model_id != current_loaded_model:
98
+ # print(f'Changing model to {model_id}...')
99
+ # current_loaded_model = model_id
100
+ # ldm_stable = None
101
+ # ldm_stable = load_model(model_id, device)
102
+ print(model_id)
103
+ if model_id == LDM2:
104
+ ldm_stable = ldm2
105
+ elif model_id == LDM2_LARGE:
106
+ ldm_stable = ldm2_large
107
+ else: # MUSIC
108
+ ldm_stable = ldm2_music
109
+
110
+
111
 
112
  # If the inversion was done for a different model, we need to re-run the inversion
113
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
 
116
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
117
 
118
  if do_inversion or randomize_seed: # always re-run inversion
119
+ zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
120
  num_diffusion_steps=steps,
121
  cfg_scale_src=cfg_scale_src)
122
  wts = gr.State(value=wts_tensor)
 
127
  # make sure t_start is in the right limit
128
  t_start = change_tstart_range(t_start, steps)
129
 
130
+ output = sample(ldm_stable, zs.value, wts.value, steps, prompt_tar=target_prompt, tstart=t_start,
131
  cfg_scale_tar=cfg_scale_tar)
132
 
133
  return output, wts, zs, saved_inv_model, do_inversion
134
 
135
 
136
+
 
 
 
137
 
138
 
139
  def get_example():
 
286
  input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
287
  src_prompt.change(fn=reset_do_inversion, outputs=[do_inversion])
288
  model_id.change(fn=reset_do_inversion, outputs=[do_inversion])
289
+ steps.change(fn=reset_do_inversion, outputs=[do_inversion])
290
 
291
  gr.Examples(
292
  label="Examples",