hysts commited on
Commit
700bc3e
1 Parent(s): 3e6e590

Enable to remove GPU after model training is done

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. app_training.py +6 -3
  3. requirements.txt +1 -1
  4. trainer.py +13 -2
app.py CHANGED
@@ -43,7 +43,7 @@ def show_warning(warning_text: str) -> gr.Blocks:
43
 
44
 
45
  pipe = InferencePipeline(HF_TOKEN)
46
- trainer = Trainer()
47
 
48
  with gr.Blocks(css='style.css') as demo:
49
  if os.getenv('IS_SHARED_UI'):
 
43
 
44
 
45
  pipe = InferencePipeline(HF_TOKEN)
46
+ trainer = Trainer(HF_TOKEN)
47
 
48
  with gr.Blocks(css='style.css') as demo:
49
  if os.getenv('IS_SHARED_UI'):
app_training.py CHANGED
@@ -92,9 +92,10 @@ def create_training_demo(trainer: Trainer,
92
  - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
93
  ''')
94
 
95
- # TODO currently disabled
96
  remove_gpu_after_training = gr.Checkbox(
97
- label='Remove GPU after training', value=False, interactive=False)
 
 
98
  run_button = gr.Button('Start Training')
99
 
100
  with gr.Box():
@@ -125,12 +126,14 @@ def create_training_demo(trainer: Trainer,
125
  use_private_repo,
126
  delete_existing_repo,
127
  upload_to,
 
128
  ],
129
  outputs=output_message)
130
  return demo
131
 
132
 
133
  if __name__ == '__main__':
134
- trainer = Trainer()
 
135
  demo = create_training_demo(trainer)
136
  demo.queue(max_size=1).launch(share=False)
 
92
  - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
93
  ''')
94
 
 
95
  remove_gpu_after_training = gr.Checkbox(
96
+ label='Remove GPU after training',
97
+ value=False,
98
+ interactive=bool(os.getenv('SPACE_ID')))
99
  run_button = gr.Button('Start Training')
100
 
101
  with gr.Box():
 
126
  use_private_repo,
127
  delete_existing_repo,
128
  upload_to,
129
+ remove_gpu_after_training,
130
  ],
131
  outputs=output_message)
132
  return demo
133
 
134
 
135
  if __name__ == '__main__':
136
+ hf_token = os.getenv('HF_TOKEN')
137
+ trainer = Trainer(hf_token)
138
  demo = create_training_demo(trainer)
139
  demo.queue(max_size=1).launch(share=False)
requirements.txt CHANGED
@@ -4,7 +4,7 @@ datasets==2.8.0
4
  git+https://github.com/huggingface/diffusers@febaf863026bd014b7a14349336544fc109d0f57#egg=diffusers
5
  ftfy==6.1.1
6
  gradio==3.14.0
7
- huggingface-hub==0.11.1
8
  Pillow==9.4.0
9
  python-slugify==7.0.0
10
  tensorboard==2.11.2
 
4
  git+https://github.com/huggingface/diffusers@febaf863026bd014b7a14349336544fc109d0f57#egg=diffusers
5
  ftfy==6.1.1
6
  gradio==3.14.0
7
+ git+https://github.com/huggingface/huggingface_hub@bdb9d06b5e67269d702860ca60e1cdb106a66c91#egg=huggingface-hub
8
  Pillow==9.4.0
9
  python-slugify==7.0.0
10
  tensorboard==2.11.2
trainer.py CHANGED
@@ -11,6 +11,7 @@ import gradio as gr
11
  import PIL.Image
12
  import slugify
13
  import torch
 
14
 
15
  from constants import UploadTarget
16
 
@@ -30,6 +31,10 @@ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
30
 
31
 
32
  class Trainer:
 
 
 
 
33
  def prepare_dataset(self, instance_images: list, resolution: int,
34
  instance_data_dir: pathlib.Path) -> None:
35
  shutil.rmtree(instance_data_dir, ignore_errors=True)
@@ -64,6 +69,7 @@ class Trainer:
64
  use_private_repo: bool,
65
  delete_existing_repo: bool,
66
  upload_to: str,
 
67
  ) -> str:
68
  if not torch.cuda.is_available():
69
  raise gr.Error('CUDA is not available.')
@@ -116,8 +122,7 @@ class Trainer:
116
  if use_wandb:
117
  command += ' --report_to wandb'
118
  if upload_to_hub:
119
- hf_token = os.getenv('HF_TOKEN')
120
- command += f' --push_to_hub --hub_token {hf_token}'
121
  if use_private_repo:
122
  command += ' --private_repo'
123
  if delete_existing_repo:
@@ -127,6 +132,12 @@ class Trainer:
127
 
128
  subprocess.run(shlex.split(command))
129
 
 
 
 
 
 
 
130
  with open(output_dir / 'train.sh', 'w') as f:
131
  command_s = ' '.join(command.split())
132
  f.write(command_s)
 
11
  import PIL.Image
12
  import slugify
13
  import torch
14
+ from huggingface_hub import HfApi
15
 
16
  from constants import UploadTarget
17
 
 
31
 
32
 
33
  class Trainer:
34
+ def __init__(self, hf_token: str | None = None):
35
+ self.hf_token = hf_token
36
+ self.api = HfApi(token=hf_token)
37
+
38
  def prepare_dataset(self, instance_images: list, resolution: int,
39
  instance_data_dir: pathlib.Path) -> None:
40
  shutil.rmtree(instance_data_dir, ignore_errors=True)
 
69
  use_private_repo: bool,
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
+ remove_gpu_after_training: bool,
73
  ) -> str:
74
  if not torch.cuda.is_available():
75
  raise gr.Error('CUDA is not available.')
 
122
  if use_wandb:
123
  command += ' --report_to wandb'
124
  if upload_to_hub:
125
+ command += f' --push_to_hub --hub_token {self.hf_token}'
 
126
  if use_private_repo:
127
  command += ' --private_repo'
128
  if delete_existing_repo:
 
132
 
133
  subprocess.run(shlex.split(command))
134
 
135
+ if remove_gpu_after_training:
136
+ space_id = os.getenv('SPACE_ID')
137
+ if space_id:
138
+ self.api.request_space_hardware(repo_id=space_id,
139
+ hardware='cpu-basic')
140
+
141
  with open(output_dir / 'train.sh', 'w') as f:
142
  command_s = ' '.join(command.split())
143
  f.write(command_s)