hilamanor commited on
Commit
fc577e0
1 Parent(s): 7e02fda
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -96,7 +96,7 @@ def edit(
96
  elif model_id == LDM2_LARGE:
97
  ldm_stable = ldm2_large
98
  else: # MUSIC
99
- ldm_stable = ldm2_music
100
 
101
  ldm_stable.model.scheduler.set_timesteps(steps, device=device)
102
 
@@ -113,7 +113,7 @@ def edit(
113
  # do_inversion = True
114
  # Too much time has passed
115
  if wts is None or zs is None:
116
- do_inversion = True
117
 
118
  if do_inversion or randomize_seed: # always re-run inversion
119
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
@@ -131,10 +131,10 @@ def edit(
131
  saved_inv_model = model_id
132
  do_inversion = False
133
  else:
134
- # wtszs = torch.load(wtszs_file, map_location=device)
135
- # # wtszs = torch.load(wtszs_file.f, map_location=device)
136
- # wts_tensor = wtszs['wts']
137
- # zs_tensor = wtszs['zs']
138
  wts_tensor = wts.to(device)
139
  zs_tensor = zs.to(device)
140
 
@@ -224,7 +224,7 @@ change <code style="display:inline; background-color: lightgrey; ">duration = mi
224
 
225
  """
226
 
227
- with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
228
  def reset_do_inversion(do_inversion_user, do_inversion):
229
  # do_inversion = gr.State(value=True)
230
  do_inversion = True
@@ -235,13 +235,13 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
235
  def clear_do_inversion_user(do_inversion_user):
236
  do_inversion_user = False
237
  return do_inversion_user
238
-
239
  def post_match_do_inversion(do_inversion_user, do_inversion):
240
  if do_inversion_user:
241
  do_inversion = True
242
  do_inversion_user = False
243
  return do_inversion_user, do_inversion
244
-
245
  gr.HTML(intro)
246
  wts = gr.State()
247
  zs = gr.State()
@@ -329,7 +329,7 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
329
  saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
330
  ).then(post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion]
331
  ).then(lambda x: (demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None),
332
- inputs=wtszs)
333
 
334
  # demo.move_resource_to_block_cache(wtszs.value)
335
 
 
96
  elif model_id == LDM2_LARGE:
97
  ldm_stable = ldm2_large
98
  else: # MUSIC
99
+ ldm_stable = ldm2_music
100
 
101
  ldm_stable.model.scheduler.set_timesteps(steps, device=device)
102
 
 
113
  # do_inversion = True
114
  # Too much time has passed
115
  if wts is None or zs is None:
116
+ do_inversion = True
117
 
118
  if do_inversion or randomize_seed: # always re-run inversion
119
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
 
131
  saved_inv_model = model_id
132
  do_inversion = False
133
  else:
134
+ # wtszs = torch.load(wtszs_file, map_location=device)
135
+ # # wtszs = torch.load(wtszs_file.f, map_location=device)
136
+ # wts_tensor = wtszs['wts']
137
+ # zs_tensor = wtszs['zs']
138
  wts_tensor = wts.to(device)
139
  zs_tensor = zs.to(device)
140
 
 
224
 
225
  """
226
 
227
+ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
228
  def reset_do_inversion(do_inversion_user, do_inversion):
229
  # do_inversion = gr.State(value=True)
230
  do_inversion = True
 
235
  def clear_do_inversion_user(do_inversion_user):
236
  do_inversion_user = False
237
  return do_inversion_user
238
+
239
  def post_match_do_inversion(do_inversion_user, do_inversion):
240
  if do_inversion_user:
241
  do_inversion = True
242
  do_inversion_user = False
243
  return do_inversion_user, do_inversion
244
+
245
  gr.HTML(intro)
246
  wts = gr.State()
247
  zs = gr.State()
 
329
  saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
330
  ).then(post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion]
331
  ).then(lambda x: (demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None),
332
+ inputs=wtszs)
333
 
334
  # demo.move_resource_to_block_cache(wtszs.value)
335