Files changed (1) hide show
  1. app.py +46 -44
app.py CHANGED
@@ -21,38 +21,44 @@ step_loaded = None
21
  base_loaded = "Realistic"
22
  motion_loaded = None
23
 
24
- # Ensure model and scheduler are initialized in GPU-enabled function
25
  if not torch.cuda.is_available():
26
  raise NotImplementedError("No GPU detected!")
27
 
28
  device = "cuda"
29
  dtype = torch.float16
 
 
 
30
  pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
31
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
 
32
 
33
  # Safety checkers
34
  from transformers import CLIPFeatureExtractor
35
-
36
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
37
 
38
- # Function
39
- @spaces.GPU(duration=30,queue=False)
40
  def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
41
  global step_loaded
42
  global base_loaded
43
  global motion_loaded
44
- print(prompt, base, step)
45
 
 
46
  if step_loaded != step:
47
  repo = "ByteDance/AnimateDiff-Lightning"
48
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
49
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
50
  step_loaded = step
51
 
 
52
  if base_loaded != base:
53
  pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
54
  base_loaded = base
55
 
 
56
  if motion_loaded != motion:
57
  pipe.unload_lora_weights()
58
  if motion != "":
@@ -60,37 +66,44 @@ def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Prog
60
  pipe.set_adapters(["motion"], [0.7])
61
  motion_loaded = motion
62
 
 
 
 
 
 
63
  progress((0, step))
64
- def progress_callback(i, t, z):
65
- progress((i+1, step))
66
 
67
- output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  name = str(uuid.uuid4()).replace("-", "")
70
  path = f"/tmp/{name}.mp4"
71
- export_to_video(output.frames[0], path, fps=10)
72
  return path
73
 
74
-
75
  # Gradio Interface
76
  with gr.Blocks(css="style.css") as demo:
77
- gr.HTML(
78
- "<h1><center>Textual Imagination : A Text To Video Synthesis</center></h1>"
79
- )
80
  with gr.Group():
81
  with gr.Row():
82
- prompt = gr.Textbox(
83
- label='Prompt'
84
- )
85
  with gr.Row():
86
  select_base = gr.Dropdown(
87
  label='Base model',
88
- choices=[
89
- "Cartoon",
90
- "Realistic",
91
- "3d",
92
- "Anime",
93
- ],
94
  value=base_loaded,
95
  interactive=True
96
  )
@@ -112,38 +125,27 @@ with gr.Blocks(css="style.css") as demo:
112
  )
113
  select_step = gr.Dropdown(
114
  label='Inference steps',
115
- choices=[
116
- ('1-Step', 1),
117
- ('2-Step', 2),
118
- ('4-Step', 4),
119
- ('8-Step', 8),
120
- ],
121
  value=4,
122
  interactive=True
123
  )
124
- submit = gr.Button(
125
- scale=1,
126
- variant='primary'
127
- )
128
  video = gr.Video(
129
- label='AnimateDiff-Lightning',
130
  autoplay=True,
131
  height=512,
132
  width=512,
133
  elem_id="video_output"
134
  )
135
 
136
- gr.on(triggers=[
137
- submit.click,
138
- prompt.submit
139
- ],
140
- fn = generate_image,
141
- inputs = [prompt, select_base, select_motion, select_step],
142
- outputs = [video],
143
- api_name = "instant_video",
144
- queue = False
145
  )
146
 
147
  demo.queue().launch()
148
-
149
- Translate
 
21
  base_loaded = "Realistic"
22
  motion_loaded = None
23
 
24
+ # Ensure GPU availability
25
  if not torch.cuda.is_available():
26
  raise NotImplementedError("No GPU detected!")
27
 
28
  device = "cuda"
29
  dtype = torch.float16
30
+
31
+ # Load initial pipeline
32
+ print("Loading AnimateDiff pipeline...")
33
  pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
34
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
35
+ print("Pipeline loaded successfully.")
36
 
37
  # Safety checkers
38
  from transformers import CLIPFeatureExtractor
 
39
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
40
 
41
+ # Video Generation Function
42
+ @spaces.GPU(duration=30, queue=False)
43
  def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
44
  global step_loaded
45
  global base_loaded
46
  global motion_loaded
47
+ print(f"Generating video for: Prompt='{prompt}', Base='{base}', Motion='{motion}', Steps='{step}'")
48
 
49
+ # Load step-specific model
50
  if step_loaded != step:
51
  repo = "ByteDance/AnimateDiff-Lightning"
52
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
53
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
54
  step_loaded = step
55
 
56
+ # Load base model
57
  if base_loaded != base:
58
  pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
59
  base_loaded = base
60
 
61
+ # Load motion adapter
62
  if motion_loaded != motion:
63
  pipe.unload_lora_weights()
64
  if motion != "":
 
66
  pipe.set_adapters(["motion"], [0.7])
67
  motion_loaded = motion
68
 
69
+ # Video parameters: 30-second duration
70
+ fps = 10
71
+ duration = 30 # seconds
72
+ total_frames = fps * duration # 300 frames for 30s at 10 FPS
73
+
74
  progress((0, step))
 
 
75
 
76
+ def progress_callback(i, t, z):
77
+ progress((i + 1, step))
78
+
79
+ # Generate video frames
80
+ output_frames = []
81
+ for frame in range(total_frames):
82
+ output = pipe(
83
+ prompt=prompt,
84
+ guidance_scale=1.2,
85
+ num_inference_steps=step,
86
+ callback=progress_callback,
87
+ callback_steps=1
88
+ )
89
+ output_frames.extend(output.frames[0]) # Collect frames
90
 
91
+ # Export to video
92
  name = str(uuid.uuid4()).replace("-", "")
93
  path = f"/tmp/{name}.mp4"
94
+ export_to_video(output_frames, path, fps=fps)
95
  return path
96
 
 
97
  # Gradio Interface
98
  with gr.Blocks(css="style.css") as demo:
99
+ gr.HTML("<h1><center>Textual Imagination: A Text To Video Synthesis</center></h1>")
 
 
100
  with gr.Group():
101
  with gr.Row():
102
+ prompt = gr.Textbox(label='Prompt', placeholder="Enter your video description here...")
 
 
103
  with gr.Row():
104
  select_base = gr.Dropdown(
105
  label='Base model',
106
+ choices=["Cartoon", "Realistic", "3d", "Anime"],
 
 
 
 
 
107
  value=base_loaded,
108
  interactive=True
109
  )
 
125
  )
126
  select_step = gr.Dropdown(
127
  label='Inference steps',
128
+ choices=[('1-Step', 1), ('2-Step', 2), ('4-Step', 4), ('8-Step', 8)],
 
 
 
 
 
129
  value=4,
130
  interactive=True
131
  )
132
+ submit = gr.Button(scale=1, variant='primary')
133
+
 
 
134
  video = gr.Video(
135
+ label='Generated Video',
136
  autoplay=True,
137
  height=512,
138
  width=512,
139
  elem_id="video_output"
140
  )
141
 
142
+ gr.on(
143
+ triggers=[submit.click, prompt.submit],
144
+ fn=generate_image,
145
+ inputs=[prompt, select_base, select_motion, select_step],
146
+ outputs=[video],
147
+ api_name="instant_video",
148
+ queue=False
 
 
149
  )
150
 
151
  demo.queue().launch()