hysts HF staff commited on
Commit
307f13f
1 Parent(s): c11895e
Files changed (3) hide show
  1. app_upload.py +8 -3
  2. trainer.py +19 -17
  3. uploader.py +0 -1
app_upload.py CHANGED
@@ -13,9 +13,14 @@ from utils import find_exp_dirs
13
 
14
 
15
  class LoRAModelUploader(Uploader):
16
- def upload_lora_model(self, folder_path: str, repo_name: str,
17
- upload_to: str, private: bool,
18
- delete_existing_repo: bool) -> str:
 
 
 
 
 
19
  if not repo_name:
20
  repo_name = pathlib.Path(folder_path).name
21
  repo_name = slugify.slugify(repo_name)
 
13
 
14
 
15
  class LoRAModelUploader(Uploader):
16
+ def upload_lora_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
  if not repo_name:
25
  repo_name = pathlib.Path(folder_path).name
26
  repo_name = slugify.slugify(repo_name)
trainer.py CHANGED
@@ -13,7 +13,7 @@ import slugify
13
  import torch
14
  from huggingface_hub import HfApi
15
 
16
- from constants import UploadTarget
17
 
18
 
19
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -32,8 +32,8 @@ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
32
 
33
  class Trainer:
34
  def __init__(self, hf_token: str | None = None):
35
- self.hf_token = hf_token
36
  self.api = HfApi(token=hf_token)
 
37
 
38
  def prepare_dataset(self, instance_images: list, resolution: int,
39
  instance_data_dir: pathlib.Path) -> None:
@@ -91,8 +91,7 @@ class Trainer:
91
  output_dir = repo_dir / 'experiments' / output_model_name
92
  if overwrite_existing_model or upload_to_hub:
93
  shutil.rmtree(output_dir, ignore_errors=True)
94
- if not upload_to_hub:
95
- output_dir.mkdir(parents=True)
96
 
97
  instance_data_dir = repo_dir / 'training_data' / output_model_name
98
  self.prepare_dataset(instance_images, resolution, instance_data_dir)
@@ -121,16 +120,23 @@ class Trainer:
121
  command += ' --use_8bit_adam'
122
  if use_wandb:
123
  command += ' --report_to wandb'
124
- if upload_to_hub:
125
- command += f' --push_to_hub --hub_token {self.hf_token}'
126
- if use_private_repo:
127
- command += ' --private_repo'
128
- if delete_existing_repo:
129
- command += ' --delete_existing_repo'
130
- if upload_to == UploadTarget.LORA_LIBRARY.value:
131
- command += ' --upload_to_lora_library'
132
 
 
 
 
133
  subprocess.run(shlex.split(command))
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if remove_gpu_after_training:
136
  space_id = os.getenv('SPACE_ID')
@@ -138,8 +144,4 @@ class Trainer:
138
  self.api.request_space_hardware(repo_id=space_id,
139
  hardware='cpu-basic')
140
 
141
- with open(output_dir / 'train.sh', 'w') as f:
142
- command_s = ' '.join(command.split())
143
- f.write(command_s)
144
-
145
- return 'Training completed!'
 
13
  import torch
14
  from huggingface_hub import HfApi
15
 
16
+ from app_upload import LoRAModelUploader
17
 
18
 
19
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
 
32
 
33
  class Trainer:
34
  def __init__(self, hf_token: str | None = None):
 
35
  self.api = HfApi(token=hf_token)
36
+ self.model_uploader = LoRAModelUploader(hf_token)
37
 
38
  def prepare_dataset(self, instance_images: list, resolution: int,
39
  instance_data_dir: pathlib.Path) -> None:
 
91
  output_dir = repo_dir / 'experiments' / output_model_name
92
  if overwrite_existing_model or upload_to_hub:
93
  shutil.rmtree(output_dir, ignore_errors=True)
94
+ output_dir.mkdir(parents=True)
 
95
 
96
  instance_data_dir = repo_dir / 'training_data' / output_model_name
97
  self.prepare_dataset(instance_images, resolution, instance_data_dir)
 
120
  command += ' --use_8bit_adam'
121
  if use_wandb:
122
  command += ' --report_to wandb'
 
 
 
 
 
 
 
 
123
 
124
+ with open(output_dir / 'train.sh', 'w') as f:
125
+ command_s = ' '.join(command.split())
126
+ f.write(command_s)
127
  subprocess.run(shlex.split(command))
128
+ message = 'Training completed!'
129
+ print(message)
130
+
131
+ if upload_to_hub:
132
+ upload_message = self.model_uploader.upload_lora_model(
133
+ folder_path=output_dir.as_posix(),
134
+ repo_name=output_model_name,
135
+ upload_to=upload_to,
136
+ private=use_private_repo,
137
+ delete_existing_repo=delete_existing_repo)
138
+ print(upload_message)
139
+ message = message + '\n' + upload_message
140
 
141
  if remove_gpu_after_training:
142
  space_id = os.getenv('SPACE_ID')
 
144
  self.api.request_space_hardware(repo_id=space_id,
145
  hardware='cpu-basic')
146
 
147
+ return message
 
 
 
 
uploader.py CHANGED
@@ -35,5 +35,4 @@ class Uploader:
35
  message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
36
  except Exception as e:
37
  message = str(e)
38
- print(message)
39
  return message
 
35
  message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
36
  except Exception as e:
37
  message = str(e)
 
38
  return message