Linoy Tsaban commited on
Commit
e1cb97f
1 Parent(s): d91ace1

Update inversion_utils.py

Browse files
Files changed (1) hide show
  1. inversion_utils.py +9 -6
inversion_utils.py CHANGED
@@ -7,6 +7,7 @@ import torchvision.transforms as T
7
  import os
8
  import yaml
9
  import numpy as np
 
10
 
11
 
12
  def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
@@ -129,10 +130,11 @@ def get_variance(model, timestep): #, prev_timestep):
129
 
130
  def inversion_forward_process(model, x0,
131
  etas = None,
132
- prog_bar = False,
133
  prompt = "",
134
  cfg_scale = 3.5,
135
- num_inference_steps=50, eps = None):
 
136
 
137
  if not prompt=="":
138
  text_embeddings = encode_text(model, prompt)
@@ -155,7 +157,7 @@ def inversion_forward_process(model, x0,
155
 
156
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
157
  xt = x0
158
- op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
159
 
160
  for t in op:
161
  idx = t_to_idx[int(t)]
@@ -241,10 +243,11 @@ def inversion_reverse_process(model,
241
  etas = 0,
242
  prompts = "",
243
  cfg_scales = None,
244
- prog_bar = False,
245
  zs = None,
246
  controller=None,
247
- asyrp = False):
 
248
 
249
  batch_size = len(prompts)
250
 
@@ -259,7 +262,7 @@ def inversion_reverse_process(model,
259
  timesteps = model.scheduler.timesteps.to(model.device)
260
 
261
  xt = xT.expand(batch_size, -1, -1, -1)
262
- op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
263
 
264
  t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
265
 
 
7
  import os
8
  import yaml
9
  import numpy as np
10
+ import gradio as gr
11
 
12
 
13
  def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
 
130
 
131
  def inversion_forward_process(model, x0,
132
  etas = None,
133
+ prog_bar = True,
134
  prompt = "",
135
  cfg_scale = 3.5,
136
+ num_inference_steps=50, eps = None
137
+ progress=gr.Progress()):
138
 
139
  if not prompt=="":
140
  text_embeddings = encode_text(model, prompt)
 
157
 
158
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
159
  xt = x0
160
+ op = progress.tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
161
 
162
  for t in op:
163
  idx = t_to_idx[int(t)]
 
243
  etas = 0,
244
  prompts = "",
245
  cfg_scales = None,
246
+ prog_bar = True,
247
  zs = None,
248
  controller=None,
249
+ asyrp = False,
250
+ progress=gr.Progress()):
251
 
252
  batch_size = len(prompts)
253
 
 
262
  timesteps = model.scheduler.timesteps.to(model.device)
263
 
264
  xt = xT.expand(batch_size, -1, -1, -1)
265
+ op = progress.tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
266
 
267
  t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
268