wangfuyun commited on
Commit
84b1cf9
1 Parent(s): 0cf1eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -149,7 +149,7 @@ class AnimateController:
149
  for key in f.keys():
150
  self.lora_model_state_dict[key] = f.get_tensor(key)
151
  return gr.Dropdown.update()
152
-
153
  def animate(
154
  self,
155
  lora_alpha_slider,
@@ -174,8 +174,8 @@ class AnimateController:
174
  **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
175
  ).to("cuda")
176
 
177
- pipeline.unet = convert_lcm_lora(copy.deepcopy(
178
- self.unet), self.lcm_lora_path, spatial_lora_slider)
179
 
180
  pipeline.to("cuda")
181
 
@@ -185,15 +185,19 @@ class AnimateController:
185
  torch.seed()
186
  seed = torch.initial_seed()
187
 
188
- sample = pipeline(
189
- prompt_textbox,
190
- negative_prompt=negative_prompt_textbox,
191
- num_inference_steps=sample_step_slider,
192
- guidance_scale=cfg_scale_slider,
193
- width=width_slider,
194
- height=height_slider,
195
- video_length=length_slider,
196
- ).videos
 
 
 
 
197
 
198
  save_sample_path = os.path.join(
199
  self.savedir_sample, f"{sample_idx}.mp4")
 
149
  for key in f.keys():
150
  self.lora_model_state_dict[key] = f.get_tensor(key)
151
  return gr.Dropdown.update()
152
+ @torch.no_grad()
153
  def animate(
154
  self,
155
  lora_alpha_slider,
 
174
  **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
175
  ).to("cuda")
176
 
177
+ original_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
178
+ pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider)
179
 
180
  pipeline.to("cuda")
181
 
 
185
  torch.seed()
186
  seed = torch.initial_seed()
187
 
188
+ with torch.autocast("cuda"):
189
+ sample = pipeline(
190
+ prompt_textbox,
191
+ negative_prompt=negative_prompt_textbox,
192
+ num_inference_steps=sample_step_slider,
193
+ guidance_scale=cfg_scale_slider,
194
+ width=width_slider,
195
+ height=height_slider,
196
+ video_length=length_slider,
197
+ ).videos
198
+
199
+ pipeline.unet.load(original_state_dict)
200
+ del original_state_dict
201
 
202
  save_sample_path = os.path.join(
203
  self.savedir_sample, f"{sample_idx}.mp4")