ihsanvp commited on
Commit
0593b2c
1 Parent(s): 448a859

fix: out of memory error

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -10,6 +10,8 @@ if gr.NO_RELOAD:
10
  high_noise_frac = 0.8
11
  negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
12
  generator = torch.manual_seed(8888)
 
 
13
 
14
  base = DiffusionPipeline.from_pretrained(
15
  "stabilityai/stable-diffusion-xl-base-1.0",
@@ -35,18 +37,21 @@ if gr.NO_RELOAD:
35
  refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
36
  pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
37
 
38
- def generate(prompt: str):
 
39
  image = base(
40
  prompt=prompt,
41
  num_inference_steps=n_steps,
42
  denoising_end=high_noise_frac,
43
  output_type="latent",
 
44
  ).images[0]
45
  image = refiner(
46
  prompt=prompt,
47
  num_inference_steps=n_steps,
48
  denoising_start=high_noise_frac,
49
  image=image,
 
50
  ).images[0]
51
  image = to_tensor(image)
52
  frames: list[Image.Image] = pipeline(
@@ -56,6 +61,8 @@ def generate(prompt: str):
56
  negative_prompt=negative_prompt,
57
  guidance_scale=9.0,
58
  generator=generator,
 
 
59
  ).frames[0]
60
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
61
  frames = torch.stack(frames)
 
10
  high_noise_frac = 0.8
11
  negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
12
  generator = torch.manual_seed(8888)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print("Device:", device)
15
 
16
  base = DiffusionPipeline.from_pretrained(
17
  "stabilityai/stable-diffusion-xl-base-1.0",
 
37
  refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
38
  pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
39
 
40
+ def generate(prompt: str, progress=gr.Progress()):
41
+ progress((0, 100), desc="Starting..")
42
  image = base(
43
  prompt=prompt,
44
  num_inference_steps=n_steps,
45
  denoising_end=high_noise_frac,
46
  output_type="latent",
47
+ callback=lambda s, t, d: progress((s, 100), desc="Generating first frame..."),
48
  ).images[0]
49
  image = refiner(
50
  prompt=prompt,
51
  num_inference_steps=n_steps,
52
  denoising_start=high_noise_frac,
53
  image=image,
54
+ callback=lambda s, t, d: progress((s+40, 100), desc="Refining first frame..."),
55
  ).images[0]
56
  image = to_tensor(image)
57
  frames: list[Image.Image] = pipeline(
 
61
  negative_prompt=negative_prompt,
62
  guidance_scale=9.0,
63
  generator=generator,
64
+ decode_chunk_size=10,
65
+ callback=lambda s, t, d: progress((s+50, 100), desc="Generating video..."),
66
  ).frames[0]
67
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
68
  frames = torch.stack(frames)