hpoghos commited on
Commit
2c5b700
1 Parent(s): 501ae46
Files changed (2) hide show
  1. app.py +3 -2
  2. t2v_enhanced/model_init.py +1 -0
app.py CHANGED
@@ -76,16 +76,17 @@ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, se
76
 
77
  n_autoreg_gen = (num_frames-8)//8
78
 
79
- inference_generator = torch.Generator(device="cuda").manual_seed(seed)
80
-
81
  if model_name_stage1 == "ModelScopeT2V (text to video)":
 
82
  short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
83
  elif model_name_stage1 == "AnimateDiff (text to video)":
 
84
  short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
85
  elif model_name_stage1 == "SVD (image to video)":
86
  # For cached examples
87
  if isinstance(image, dict):
88
  image = image["path"]
 
89
  short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
90
 
91
  stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
 
76
 
77
  n_autoreg_gen = (num_frames-8)//8
78
 
 
 
79
  if model_name_stage1 == "ModelScopeT2V (text to video)":
80
+ inference_generator = torch.Generator(device=ms_model.device).manual_seed(seed)
81
  short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
82
  elif model_name_stage1 == "AnimateDiff (text to video)":
83
+ inference_generator = torch.Generator(device=ad_model.device).manual_seed(seed)
84
  short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
85
  elif model_name_stage1 == "SVD (image to video)":
86
  # For cached examples
87
  if isinstance(image, dict):
88
  image = image["path"]
89
+ inference_generator = torch.Generator(device=svd_model.device).manual_seed(seed)
90
  short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
91
 
92
  stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
t2v_enhanced/model_init.py CHANGED
@@ -119,6 +119,7 @@ def init_v2v_model(cfg, device):
119
 
120
  pipe_enhance.model.autoencoder = pipe_enhance.model.autoencoder.to(device)
121
  pipe_enhance.model.generator = pipe_enhance.model.generator.to(device)
 
122
  pipe_enhance.model.negative_y = pipe_enhance.model.negative_y.to(device)
123
  pipe_enhance.model.cfg.max_frames = 10000
124
  return pipe_enhance
 
119
 
120
  pipe_enhance.model.autoencoder = pipe_enhance.model.autoencoder.to(device)
121
  pipe_enhance.model.generator = pipe_enhance.model.generator.to(device)
122
+ pipe_enhance.model.generator = pipe_enhance.model.generator.half()
123
  pipe_enhance.model.negative_y = pipe_enhance.model.negative_y.to(device)
124
  pipe_enhance.model.cfg.max_frames = 10000
125
  return pipe_enhance