multimodalart HF staff commited on
Commit
cfadc6d
1 Parent(s): c9feef4

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +12 -7
trainer.py CHANGED
@@ -20,12 +20,12 @@ from utils import save_model_card
20
  sys.path.append('Tune-A-Video')
21
 
22
  URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
-
 
24
 
25
  class Trainer:
26
  def __init__(self, hf_token: str | None = None):
27
  self.hf_token = hf_token
28
- self.api = HfApi(token=hf_token)
29
  self.model_uploader = ModelUploader(hf_token)
30
 
31
  self.checkpoint_dir = pathlib.Path('checkpoints')
@@ -42,10 +42,10 @@ class Trainer:
42
  cwd=org_dir)
43
  return model_dir.as_posix()
44
 
45
- def join_model_library_org(self) -> None:
46
  subprocess.run(
47
  shlex.split(
48
- f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
  ))
50
 
51
  def run(
@@ -70,7 +70,10 @@ class Trainer:
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
  remove_gpu_after_training: bool,
 
73
  ) -> str:
 
 
74
  if not torch.cuda.is_available():
75
  raise gr.Error('CUDA is not available.')
76
  if training_video is None:
@@ -143,14 +146,16 @@ class Trainer:
143
  repo_name=output_model_name,
144
  upload_to=upload_to,
145
  private=use_private_repo,
146
- delete_existing_repo=delete_existing_repo)
 
147
  print(upload_message)
148
  message = message + '\n' + upload_message
149
 
150
  if remove_gpu_after_training:
151
  space_id = os.getenv('SPACE_ID')
152
  if space_id:
153
- self.api.request_space_hardware(repo_id=space_id,
154
- hardware='cpu-basic')
 
155
 
156
  return message
 
20
  sys.path.append('Tune-A-Video')
21
 
22
  URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
24
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
 
26
  class Trainer:
27
  def __init__(self, hf_token: str | None = None):
28
  self.hf_token = hf_token
 
29
  self.model_uploader = ModelUploader(hf_token)
30
 
31
  self.checkpoint_dir = pathlib.Path('checkpoints')
 
42
  cwd=org_dir)
43
  return model_dir.as_posix()
44
 
45
+ def join_model_library_org(token) -> None:
46
  subprocess.run(
47
  shlex.split(
48
+ f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
  ))
50
 
51
  def run(
 
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
  remove_gpu_after_training: bool,
73
+ input_token: str,
74
  ) -> str:
75
+ if SPACE_ID != ORIGINAL_SPACE_ID:
76
+ raise gr.Error('This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU')
77
  if not torch.cuda.is_available():
78
  raise gr.Error('CUDA is not available.')
79
  if training_video is None:
 
146
  repo_name=output_model_name,
147
  upload_to=upload_to,
148
  private=use_private_repo,
149
+ delete_existing_repo=delete_existing_repo,
150
+ input_token=input_token)
151
  print(upload_message)
152
  message = message + '\n' + upload_message
153
 
154
  if remove_gpu_after_training:
155
  space_id = os.getenv('SPACE_ID')
156
  if space_id:
157
+ api = HfApi(token=hf_token if self.hf_token else input_token)
158
+ api.request_space_hardware(repo_id=space_id,
159
+ hardware='cpu-basic')
160
 
161
  return message