hysts HF staff commited on
Commit
5a1aaa1
1 Parent(s): 07f8fd9

Pause Space

Browse files
Files changed (2) hide show
  1. app_training.py +3 -3
  2. trainer.py +4 -6
app_training.py CHANGED
@@ -105,8 +105,8 @@ def create_training_demo(trainer: Trainer,
105
  choices=[_.value for _ in UploadTarget],
106
  value=UploadTarget.MODEL_LIBRARY.value)
107
 
108
- remove_gpu_after_training = gr.Checkbox(
109
- label='Remove GPU after training',
110
  value=False,
111
  interactive=bool(os.getenv('SPACE_ID')),
112
  visible=False)
@@ -143,7 +143,7 @@ def create_training_demo(trainer: Trainer,
143
  use_private_repo,
144
  delete_existing_repo,
145
  upload_to,
146
- remove_gpu_after_training,
147
  hf_token,
148
  ])
149
  return demo
 
105
  choices=[_.value for _ in UploadTarget],
106
  value=UploadTarget.MODEL_LIBRARY.value)
107
 
108
+ pause_space_after_training = gr.Checkbox(
109
+ label='Pause this Space after training',
110
  value=False,
111
  interactive=bool(os.getenv('SPACE_ID')),
112
  visible=False)
 
143
  use_private_repo,
144
  delete_existing_repo,
145
  upload_to,
146
+ pause_space_after_training,
147
  hf_token,
148
  ])
149
  return demo
trainer.py CHANGED
@@ -60,7 +60,7 @@ class Trainer:
60
  use_private_repo: bool,
61
  delete_existing_repo: bool,
62
  upload_to: str,
63
- remove_gpu_after_training: bool,
64
  hf_token: str,
65
  ) -> None:
66
  if not torch.cuda.is_available():
@@ -140,9 +140,7 @@ class Trainer:
140
  with open(self.log_file, 'a') as f:
141
  f.write(upload_message)
142
 
143
- if remove_gpu_after_training:
144
- space_id = os.getenv('SPACE_ID')
145
- if space_id:
146
  api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
147
- api.request_space_hardware(repo_id=space_id,
148
- hardware='cpu-basic')
 
60
  use_private_repo: bool,
61
  delete_existing_repo: bool,
62
  upload_to: str,
63
+ pause_space_after_training: bool,
64
  hf_token: str,
65
  ) -> None:
66
  if not torch.cuda.is_available():
 
140
  with open(self.log_file, 'a') as f:
141
  f.write(upload_message)
142
 
143
+ if pause_space_after_training:
144
+ if space_id := os.getenv('SPACE_ID'):
 
145
  api = HfApi(token=os.getenv('HF_TOKEN') or hf_token)
146
+ api.pause_space(repo_id=space_id)