lev1 commited on
Commit
e882c67
·
verified ·
1 Parent(s): 9c17006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -29,6 +29,7 @@ args = parser.parse_args()
29
  Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
30
  result_fol = Path(args.where_to_log).absolute()
31
  device = args.device
 
32
 
33
 
34
  # --------------------------
@@ -40,10 +41,10 @@ cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-
40
  # --------------------------
41
  # ----- Initialization -----
42
  # --------------------------
43
- ms_model = init_modelscope(device)
44
  # # zs_model = init_zeroscope(device)
45
- ad_model = init_animatediff(device)
46
- svd_model = init_svd(device)
47
  sdxl_model = init_sdxl(device)
48
 
49
  ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute()
@@ -73,14 +74,20 @@ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, se
73
  inference_generator = torch.Generator(device="cuda").manual_seed(seed)
74
 
75
  if model_name_stage1 == "ModelScopeT2V (text to video)":
 
76
  short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
 
77
  elif model_name_stage1 == "AnimateDiff (text to video)":
 
78
  short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
 
79
  elif model_name_stage1 == "SVD (image to video)":
80
  # For cached examples
81
  if isinstance(image, dict):
82
  image = image["path"]
 
83
  short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
 
84
 
85
  stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
86
  video_path = opj(where_to_log, name+".mp4")
 
29
  Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
30
  result_fol = Path(args.where_to_log).absolute()
31
  device = args.device
32
+ device_cpu = "cpu"
33
 
34
 
35
  # --------------------------
 
41
  # --------------------------
42
  # ----- Initialization -----
43
  # --------------------------
44
+ ms_model = init_modelscope(device_cpu)
45
  # # zs_model = init_zeroscope(device)
46
+ ad_model = init_animatediff(device_cpu)
47
+ svd_model = init_svd(device_cpu)
48
  sdxl_model = init_sdxl(device)
49
 
50
  ckpt_file_streaming_t2v = Path("t2v_enhanced/checkpoints/streaming_t2v.ckpt").absolute()
 
74
  inference_generator = torch.Generator(device="cuda").manual_seed(seed)
75
 
76
  if model_name_stage1 == "ModelScopeT2V (text to video)":
77
+ ms_model.to(device)
78
  short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
79
+ ms_model.to(device_cpu)
80
  elif model_name_stage1 == "AnimateDiff (text to video)":
81
+ ad_model.to(device)
82
  short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
83
+ ad_model.to(device_cpu)
84
  elif model_name_stage1 == "SVD (image to video)":
85
  # For cached examples
86
  if isinstance(image, dict):
87
  image = image["path"]
88
+ svd_model.to(device)
89
  short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
90
+ svd_model.to(device_cpu)
91
 
92
  stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
93
  video_path = opj(where_to_log, name+".mp4")