fffiloni commited on
Commit
c02f410
1 Parent(s): 05436b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from diffusers import CogVideoXImageToVideoPipeline
5
  from diffusers.utils import export_to_video, load_image
6
  from datetime import datetime
7
 
@@ -22,9 +22,14 @@ hf_hub_download(
22
  local_dir="checkpoints"
23
  )
24
 
25
- pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
 
 
 
 
 
26
 
27
- def infer(prompt, image_path, orbit_type):
28
  lora_path = None
29
  if orbit_type == "Left":
30
  lora_path = "checkpoints/orbit_left_lora_weights.safetensors"
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from diffusers import CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
5
  from diffusers.utils import export_to_video, load_image
6
  from datetime import datetime
7
 
 
22
  local_dir="checkpoints"
23
  )
24
 
25
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
26
+ "THUDM/CogVideoX-5b-I2V",
27
+ transformer=CogVideoXTransformer3DModel.from_pretrained(
28
+ "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
29
+ ),
30
+ torch_dtype=torch.bfloat16)
31
 
32
+ def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True)):
33
  lora_path = None
34
  if orbit_type == "Left":
35
  lora_path = "checkpoints/orbit_left_lora_weights.safetensors"