ihsanvp commited on
Commit
998bf52
1 Parent(s): 7acc91c

fix: progress updater

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +24 -6
  3. utils.py +7 -0
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .env
2
  __pycache__/
3
  *.mp4
4
- *.jpg
 
 
1
  .env
2
  __pycache__/
3
  *.mp4
4
+ *.jpg
5
+ test.py
app.py CHANGED
@@ -4,13 +4,16 @@ import torchvision
4
  from diffusers import I2VGenXLPipeline, DiffusionPipeline
5
  from torchvision.transforms.functional import to_tensor
6
  from PIL import Image
 
7
 
8
  if gr.NO_RELOAD:
9
- n_steps = 50
 
10
  high_noise_frac = 0.8
11
  negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
12
  generator = torch.manual_seed(8888)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
  print("Device:", device)
15
 
16
  base = DiffusionPipeline.from_pretrained(
@@ -41,17 +44,27 @@ def generate(prompt: str, progress=gr.Progress()):
41
  progress((0, 100), desc="Starting..")
42
  image = base(
43
  prompt=prompt,
44
- num_inference_steps=n_steps,
45
  denoising_end=high_noise_frac,
46
  output_type="latent",
47
- callback_on_step_end=lambda p, s, t, d: progress((s, 100), desc="Generating first frame..."),
 
 
 
 
 
48
  ).images[0]
49
  image = refiner(
50
  prompt=prompt,
51
- num_inference_steps=n_steps,
52
  denoising_start=high_noise_frac,
53
  image=image,
54
- callback_on_step_end=lambda p, s, t, d: progress((s+40, 100), desc="Refining first frame..."),
 
 
 
 
 
55
  ).images[0]
56
  image = to_tensor(image)
57
  frames: list[Image.Image] = pipeline(
@@ -62,7 +75,12 @@ def generate(prompt: str, progress=gr.Progress()):
62
  guidance_scale=9.0,
63
  generator=generator,
64
  decode_chunk_size=10,
65
- callback_on_step_end=lambda p, s, t, d: progress((s+50, 100), desc="Generating video..."),
 
 
 
 
 
66
  ).frames[0]
67
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
68
  frames = torch.stack(frames)
 
4
  from diffusers import I2VGenXLPipeline, DiffusionPipeline
5
  from torchvision.transforms.functional import to_tensor
6
  from PIL import Image
7
+ from utils import create_progress_updater
8
 
9
  if gr.NO_RELOAD:
10
+ n_sdxl_steps = 50
11
+ n_i2v_steps = 50
12
  high_noise_frac = 0.8
13
  negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
14
  generator = torch.manual_seed(8888)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ total_steps = n_sdxl_steps + n_i2v_steps
17
  print("Device:", device)
18
 
19
  base = DiffusionPipeline.from_pretrained(
 
44
  progress((0, 100), desc="Starting..")
45
  image = base(
46
  prompt=prompt,
47
+ num_inference_steps=n_sdxl_steps,
48
  denoising_end=high_noise_frac,
49
  output_type="latent",
50
+ callback_on_step_end=create_progress_updater(
51
+ start=0,
52
+ total=total_steps,
53
+ desc="Generating first frame...",
54
+ progress=progress,
55
+ ),
56
  ).images[0]
57
  image = refiner(
58
  prompt=prompt,
59
+ num_inference_steps=n_sdxl_steps,
60
  denoising_start=high_noise_frac,
61
  image=image,
62
+ callback_on_step_end=create_progress_updater(
63
+ start=n_sdxl_steps * high_noise_frac,
64
+ total=total_steps,
65
+ desc="Refining first frame...",
66
+ progress=progress,
67
+ ),
68
  ).images[0]
69
  image = to_tensor(image)
70
  frames: list[Image.Image] = pipeline(
 
75
  guidance_scale=9.0,
76
  generator=generator,
77
  decode_chunk_size=10,
78
+ callback_on_step_end=create_progress_updater(
79
+ start=n_sdxl_steps,
80
+ total=total_steps,
81
+ desc="Generating video...",
82
+ progress=progress,
83
+ ),
84
  ).frames[0]
85
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
86
  frames = torch.stack(frames)
utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from gradio import Progress
2
+
3
+ def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
4
+ def updater(pipe, step, timestep, callback_kwargs):
5
+ progress((step + start, total), desc=desc)
6
+ return callback_kwargs
7
+ return updater