hysts HF staff commited on
Commit
5eb79d4
1 Parent(s): bd420fb
Files changed (4) hide show
  1. app_training.py +4 -4
  2. app_upload.py +6 -6
  3. trainer.py +4 -5
  4. uploader.py +2 -2
app_training.py CHANGED
@@ -48,9 +48,9 @@ def create_training_demo(trainer: Trainer,
48
  label='Resolution',
49
  visible=False)
50
 
51
- input_token = gr.Text(label='Hugging Face Write Token',
52
- placeholder='',
53
- visible=False if hf_token else True)
54
  with gr.Accordion('Advanced settings', open=False):
55
  num_training_steps = gr.Number(
56
  label='Number of Training Steps',
@@ -150,7 +150,7 @@ def create_training_demo(trainer: Trainer,
150
  delete_existing_repo,
151
  upload_to,
152
  remove_gpu_after_training,
153
- input_token,
154
  ])
155
  return demo
156
 
 
48
  label='Resolution',
49
  visible=False)
50
 
51
+ input_hf_token = gr.Text(label='Hugging Face Write Token',
52
+ placeholder='',
53
+ visible=hf_token is None)
54
  with gr.Accordion('Advanced settings', open=False):
55
  num_training_steps = gr.Number(
56
  label='Number of Training Steps',
 
150
  delete_existing_repo,
151
  upload_to,
152
  remove_gpu_after_training,
153
+ input_hf_token,
154
  ])
155
  return demo
156
 
app_upload.py CHANGED
@@ -20,7 +20,7 @@ class ModelUploader(Uploader):
20
  upload_to: str,
21
  private: bool,
22
  delete_existing_repo: bool,
23
- input_token: str | None = None,
24
  ) -> str:
25
  if not folder_path:
26
  raise ValueError
@@ -40,7 +40,7 @@ class ModelUploader(Uploader):
40
  organization=organization,
41
  private=private,
42
  delete_existing_repo=delete_existing_repo,
43
- input_token=input_token)
44
 
45
 
46
  def load_local_model_list() -> dict:
@@ -70,9 +70,9 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
70
  choices=[_.value for _ in UploadTarget],
71
  value=UploadTarget.MODEL_LIBRARY.value)
72
  model_name = gr.Textbox(label='Model Name')
73
- input_token = gr.Text(label='Hugging Face Write Token',
74
- placeholder='',
75
- visible=False if hf_token else True)
76
  upload_button = gr.Button('Upload')
77
  gr.Markdown(f'''
78
  - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
@@ -91,7 +91,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
91
  upload_to,
92
  use_private_repo,
93
  delete_existing_repo,
94
- input_token,
95
  ],
96
  outputs=output_message)
97
 
 
20
  upload_to: str,
21
  private: bool,
22
  delete_existing_repo: bool,
23
+ input_hf_token: str | None = None,
24
  ) -> str:
25
  if not folder_path:
26
  raise ValueError
 
40
  organization=organization,
41
  private=private,
42
  delete_existing_repo=delete_existing_repo,
43
+ input_hf_token=input_hf_token)
44
 
45
 
46
  def load_local_model_list() -> dict:
 
70
  choices=[_.value for _ in UploadTarget],
71
  value=UploadTarget.MODEL_LIBRARY.value)
72
  model_name = gr.Textbox(label='Model Name')
73
+ input_hf_token = gr.Text(label='Hugging Face Write Token',
74
+ placeholder='',
75
+ visible=False if hf_token else True)
76
  upload_button = gr.Button('Upload')
77
  gr.Markdown(f'''
78
  - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
 
91
  upload_to,
92
  use_private_repo,
93
  delete_existing_repo,
94
+ input_hf_token,
95
  ],
96
  outputs=output_message)
97
 
trainer.py CHANGED
@@ -72,7 +72,7 @@ class Trainer:
72
  delete_existing_repo: bool,
73
  upload_to: str,
74
  remove_gpu_after_training: bool,
75
- input_token: str,
76
  ) -> None:
77
  if not torch.cuda.is_available():
78
  raise gr.Error('CUDA is not available.')
@@ -98,7 +98,7 @@ class Trainer:
98
 
99
  if upload_to_hub:
100
  self.join_model_library_org(
101
- self.hf_token if self.hf_token else input_token)
102
 
103
  config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
104
  config.pretrained_model_path = self.download_base_model(base_model)
@@ -152,14 +152,13 @@ class Trainer:
152
  upload_to=upload_to,
153
  private=use_private_repo,
154
  delete_existing_repo=delete_existing_repo,
155
- input_token=input_token)
156
  with open(self.log_file, 'a') as f:
157
  f.write(upload_message)
158
 
159
  if remove_gpu_after_training:
160
  space_id = os.getenv('SPACE_ID')
161
  if space_id:
162
- api = HfApi(
163
- token=self.hf_token if self.hf_token else input_token)
164
  api.request_space_hardware(repo_id=space_id,
165
  hardware='cpu-basic')
 
72
  delete_existing_repo: bool,
73
  upload_to: str,
74
  remove_gpu_after_training: bool,
75
+ input_hf_token: str,
76
  ) -> None:
77
  if not torch.cuda.is_available():
78
  raise gr.Error('CUDA is not available.')
 
98
 
99
  if upload_to_hub:
100
  self.join_model_library_org(
101
+ self.hf_token if self.hf_token else input_hf_token)
102
 
103
  config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
104
  config.pretrained_model_path = self.download_base_model(base_model)
 
152
  upload_to=upload_to,
153
  private=use_private_repo,
154
  delete_existing_repo=delete_existing_repo,
155
+ input_hf_token=input_hf_token)
156
  with open(self.log_file, 'a') as f:
157
  f.write(upload_message)
158
 
159
  if remove_gpu_after_training:
160
  space_id = os.getenv('SPACE_ID')
161
  if space_id:
162
+ api = HfApi(token=self.hf_token or input_hf_token)
 
163
  api.request_space_hardware(repo_id=space_id,
164
  hardware='cpu-basic')
uploader.py CHANGED
@@ -14,9 +14,9 @@ class Uploader:
14
  repo_type: str = 'model',
15
  private: bool = True,
16
  delete_existing_repo: bool = False,
17
- input_token: str | None = None) -> str:
18
 
19
- api = HfApi(token=self.hf_token if self.hf_token else input_token)
20
 
21
  if not folder_path:
22
  raise ValueError
 
14
  repo_type: str = 'model',
15
  private: bool = True,
16
  delete_existing_repo: bool = False,
17
+ input_hf_token: str | None = None) -> str:
18
 
19
+ api = HfApi(token=self.hf_token or input_hf_token)
20
 
21
  if not folder_path:
22
  raise ValueError