hpoghos commited on
Commit
38f7b32
1 Parent(s): da5ac73

Update t2v_enhanced/model_init.py

Browse files
Files changed (1) hide show
  1. t2v_enhanced/model_init.py +14 -3
t2v_enhanced/model_init.py CHANGED
@@ -106,7 +106,18 @@ def init_streamingt2v_model(ckpt_file, result_fol):
106
 
107
 
108
  # Initialize Stage-3 model.
109
- def init_v2v_model(cfg):
110
  model_id = cfg['model_id']
111
- pipe_enhance = pipeline(task="video-to-video", model=model_id, model_revision='v1.1.0', device='cuda')
112
- return pipe_enhance
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  # Initialize Stage-3 model.
109
+ def init_v2v_model(cfg, device):
110
  model_id = cfg['model_id']
111
+ pipe_enhance = pipeline(task="video-to-video", model=model_id, model_revision='v1.1.0', device='cpu')
112
+ pipe_enhance.device = device
113
+
114
+ pipe_enhance.model = pipe_enhance.model.to(device)
115
+ pipe_enhance.model.device = device
116
+
117
+ pipe_enhance.model.clip_encoder.model = pipe_enhance.model.clip_encoder.model.to(device)
118
+ pipe_enhance.model.clip_encoder.device = device
119
+
120
+ pipe_enhance.model.autoencoder = pipe_enhance.model.autoencoder.to(device)
121
+ pipe_enhance.model.generator = pipe_enhance.model.generator.to(device)
122
+ pipe_enhance.model.negative_y = pipe_enhance.model.negative_y.to(device)
123
+ return pipe_enhance