ShaoTengLiu commited on
Commit
e3712a5
1 Parent(s): f5c12d4
Files changed (1) hide show
  1. trainer.py +10 -5
trainer.py CHANGED
@@ -11,6 +11,7 @@ import sys
11
  import gradio as gr
12
  import slugify
13
  import torch
 
14
  from huggingface_hub import HfApi
15
  from omegaconf import OmegaConf
16
 
@@ -33,16 +34,20 @@ class Trainer:
33
  self.checkpoint_dir = pathlib.Path('checkpoints')
34
  self.checkpoint_dir.mkdir(exist_ok=True)
35
 
36
- def download_base_model(self, base_model_id: str) -> str:
37
  model_dir = self.checkpoint_dir / base_model_id
38
  if not model_dir.exists():
39
  org_name = base_model_id.split('/')[0]
40
  org_dir = self.checkpoint_dir / org_name
41
  org_dir.mkdir(exist_ok=True)
42
  print(f'https://huggingface.co/{base_model_id}')
43
- subprocess.run(shlex.split(
44
- f'git clone https://huggingface.co/{base_model_id}'),
45
- cwd=org_dir)
 
 
 
 
46
  return model_dir.as_posix()
47
 
48
  def join_model_library_org(self, token: str) -> None:
@@ -241,7 +246,7 @@ class Trainer:
241
  self.hf_token if self.hf_token else input_token)
242
 
243
  config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
244
- config.pretrained_model_path = self.download_base_model(tuned_model)
245
  config.output_dir = output_dir.as_posix()
246
  config.train_data.video_path = training_video.name # type: ignore
247
  config.train_data.prompt = training_prompt
 
11
  import gradio as gr
12
  import slugify
13
  import torch
14
+ import huggingface_hub
15
  from huggingface_hub import HfApi
16
  from omegaconf import OmegaConf
17
 
 
34
  self.checkpoint_dir = pathlib.Path('checkpoints')
35
  self.checkpoint_dir.mkdir(exist_ok=True)
36
 
37
+ def download_base_model(self, base_model_id: str, token=None) -> str:
38
  model_dir = self.checkpoint_dir / base_model_id
39
  if not model_dir.exists():
40
  org_name = base_model_id.split('/')[0]
41
  org_dir = self.checkpoint_dir / org_name
42
  org_dir.mkdir(exist_ok=True)
43
  print(f'https://huggingface.co/{base_model_id}')
44
+ try:
45
+ subprocess.run(shlex.split(
46
+ f'git clone https://huggingface.co/{base_model_id}'),
47
+ cwd=org_dir)
48
+ except:
49
+ temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token)
50
+ subprocess.run(shlex.split(f'mv {temp_path} {org_dir}'))
51
  return model_dir.as_posix()
52
 
53
  def join_model_library_org(self, token: str) -> None:
 
246
  self.hf_token if self.hf_token else input_token)
247
 
248
  config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
249
+ config.pretrained_model_path = self.download_base_model(tuned_model, token=input_token)
250
  config.output_dir = output_dir.as_posix()
251
  config.train_data.video_path = training_video.name # type: ignore
252
  config.train_data.prompt = training_prompt