Haoxin Chen commited on
Commit
15190a9
1 Parent(s): a8f3a29

fix ckpt path

Browse files
Files changed (2) hide show
  1. i2v_test.py +1 -1
  2. t2v_test.py +8 -8
i2v_test.py CHANGED
@@ -68,7 +68,7 @@ class Image2Video():
68
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
69
 
70
  def download_model(self):
71
- REPO_ID = 'VideoCrafter/Image2Video-512-v1.0'
72
  filename_list = ['model.ckpt']
73
  if not os.path.exists('./checkpoints/i2v_512_v1/'):
74
  os.makedirs('./checkpoints/i2v_512_v1/')
 
68
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
69
 
70
  def download_model(self):
71
+ REPO_ID = 'VideoCrafter/Image2Video-512'
72
  filename_list = ['model.ckpt']
73
  if not os.path.exists('./checkpoints/i2v_512_v1/'):
74
  os.makedirs('./checkpoints/i2v_512_v1/')
t2v_test.py CHANGED
@@ -12,8 +12,8 @@ class Text2Video():
12
  self.result_dir = result_dir
13
  if not os.path.exists(self.result_dir):
14
  os.mkdir(self.result_dir)
15
- ckpt_path='checkpoints/base_512_v1/model.ckpt'
16
- config_file='configs/inference_t2v_512_v1.0.yaml'
17
  config = OmegaConf.load(config_file)
18
  model_config = config.pop("model", OmegaConf.create())
19
  model_config['params']['unet_config']['params']['use_checkpoint']=False
@@ -39,7 +39,7 @@ class Text2Video():
39
  batch_size=1
40
  channels = model.model.diffusion_model.in_channels
41
  frames = model.temporal_length
42
- h, w = 320 // 8, 512 // 8
43
  noise_shape = [batch_size, channels, frames, h, w]
44
 
45
  #prompts = batch_size * [""]
@@ -59,15 +59,15 @@ class Text2Video():
59
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
60
 
61
  def download_model(self):
62
- REPO_ID = 'VideoCrafter/Text2Video-512-v1'
63
  filename_list = ['model.ckpt']
64
- if not os.path.exists('./checkpoints/base_512_v1/'):
65
- os.makedirs('./checkpoints/base_512_v1/')
66
  for filename in filename_list:
67
- local_file = os.path.join('./checkpoints/base_512_v1/', filename)
68
 
69
  if not os.path.exists(local_file):
70
- hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_512_v1/', local_dir_use_symlinks=False)
71
 
72
 
73
  if __name__ == '__main__':
 
12
  self.result_dir = result_dir
13
  if not os.path.exists(self.result_dir):
14
  os.mkdir(self.result_dir)
15
+ ckpt_path='checkpoints/base_1024_v1/model.ckpt'
16
+ config_file='configs/inference_t2v_1024_v1.0.yaml'
17
  config = OmegaConf.load(config_file)
18
  model_config = config.pop("model", OmegaConf.create())
19
  model_config['params']['unet_config']['params']['use_checkpoint']=False
 
39
  batch_size=1
40
  channels = model.model.diffusion_model.in_channels
41
  frames = model.temporal_length
42
+ h, w = 576 // 8, 1024 // 8
43
  noise_shape = [batch_size, channels, frames, h, w]
44
 
45
  #prompts = batch_size * [""]
 
59
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
60
 
61
  def download_model(self):
62
+ REPO_ID = 'VideoCrafter/Text2Video-1024'
63
  filename_list = ['model.ckpt']
64
+ if not os.path.exists('./checkpoints/base_1024_v1/'):
65
+ os.makedirs('./checkpoints/base_1024_v1/')
66
  for filename in filename_list:
67
+ local_file = os.path.join('./checkpoints/base_1024_v1/', filename)
68
 
69
  if not os.path.exists(local_file):
70
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_1024_v1/', local_dir_use_symlinks=False)
71
 
72
 
73
  if __name__ == '__main__':