PeterL1n commited on
Commit
f3a1f2e
1 Parent(s): 87d5fe9
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -11,53 +11,66 @@ from safetensors.torch import load_file
11
  from PIL import Image
12
 
13
  # Constants
14
- base = "frankjoshua/toonyou_beta6"
15
- loaded = None
 
 
 
 
16
 
17
  # Ensure model and scheduler are initialized in GPU-enabled function
18
- if torch.cuda.is_available():
19
- device = "cuda"
20
- dtype = torch.float16
21
- adapter = MotionAdapter().to(device, dtype)
22
- pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
23
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
24
- else:
25
  raise NotImplementedError("No GPU detected!")
26
 
 
 
 
 
 
27
  # Function
28
  @spaces.GPU(enable_queue=True)
29
- def generate_image(prompt, step):
30
  global loaded
31
  print(prompt, step)
32
 
33
- if loaded != step:
34
  repo = "ByteDance/AnimateDiff-Lightning"
35
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
36
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
37
- loaded = step
 
 
 
 
38
 
39
  output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=step)
40
-
41
  name = str(uuid.uuid4()).replace("-", "")
42
  path = f"/tmp/{name}.mp4"
43
  export_to_video(output.frames[0], path, fps=10)
44
  return path
45
 
46
 
47
-
48
  # Gradio Interface
49
-
50
  with gr.Blocks(css="style.css") as demo:
51
  gr.HTML("<h1><center>AnimateDiff-Lightning ⚡</center></h1>")
52
  gr.HTML("<p><center>Lightning-fast text-to-video generation</center></p><p><center><a href='https://huggingface.co/ByteDance/AnimateDiff-Lightning'>https://huggingface.co/ByteDance/AnimateDiff-Lightning</a></center></p>")
53
  with gr.Group():
54
  with gr.Row():
55
  prompt = gr.Textbox(
56
- label='Enter your prompt (English)',
57
  scale=8
58
  )
59
- ckpt = gr.Dropdown(
60
- label='Select inference steps',
 
 
 
 
 
 
 
 
 
61
  choices=[
62
  ('1-Step', 1),
63
  ('2-Step', 2),
@@ -77,12 +90,12 @@ with gr.Blocks(css="style.css") as demo:
77
 
78
  prompt.submit(
79
  fn=generate_image,
80
- inputs=[prompt, ckpt],
81
  outputs=video,
82
  )
83
  submit.click(
84
  fn=generate_image,
85
- inputs=[prompt, ckpt],
86
  outputs=video,
87
  )
88
 
 
11
  from PIL import Image
12
 
13
  # Constants
14
+ bases = {
15
+ "ToonYou": "frankjoshua/toonyou_beta6",
16
+ "epiCRealism": "emilianJR/epiCRealism"
17
+ }
18
+ step_loaded = None
19
+ base_loaded = "ToonYou"
20
 
21
  # Ensure model and scheduler are initialized in GPU-enabled function
22
+ if not torch.cuda.is_available():
 
 
 
 
 
 
23
  raise NotImplementedError("No GPU detected!")
24
 
25
+ device = "cuda"
26
+ dtype = torch.float16
27
+ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
28
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
29
+
30
  # Function
31
  @spaces.GPU(enable_queue=True)
32
+ def generate_image(prompt, base, step):
33
  global loaded
34
  print(prompt, step)
35
 
36
+ if step_loaded != step:
37
  repo = "ByteDance/AnimateDiff-Lightning"
38
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
39
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
40
+ step_loaded = step
41
+
42
+ if base_loaded != base:
43
+ pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
44
+ base_loaded = base
45
 
46
  output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=step)
 
47
  name = str(uuid.uuid4()).replace("-", "")
48
  path = f"/tmp/{name}.mp4"
49
  export_to_video(output.frames[0], path, fps=10)
50
  return path
51
 
52
 
 
53
  # Gradio Interface
 
54
  with gr.Blocks(css="style.css") as demo:
55
  gr.HTML("<h1><center>AnimateDiff-Lightning ⚡</center></h1>")
56
  gr.HTML("<p><center>Lightning-fast text-to-video generation</center></p><p><center><a href='https://huggingface.co/ByteDance/AnimateDiff-Lightning'>https://huggingface.co/ByteDance/AnimateDiff-Lightning</a></center></p>")
57
  with gr.Group():
58
  with gr.Row():
59
  prompt = gr.Textbox(
60
+ label='Prompt (English)',
61
  scale=8
62
  )
63
+ select_base = gr.Dropdown(
64
+ label='Base model',
65
+ choices=[
66
+ "ToonYou",
67
+ "epiCRealism",
68
+ ],
69
+ value=base_loaded,
70
+ interactive=True
71
+ )
72
+ select_step = gr.Dropdown(
73
+ label='Inference steps',
74
  choices=[
75
  ('1-Step', 1),
76
  ('2-Step', 2),
 
90
 
91
  prompt.submit(
92
  fn=generate_image,
93
+ inputs=[prompt, select_base, select_step],
94
  outputs=video,
95
  )
96
  submit.click(
97
  fn=generate_image,
98
+ inputs=[prompt, select_base, select_step],
99
  outputs=video,
100
  )
101