hysts HF staff commited on
Commit
db1e5fb
1 Parent(s): b083a19

Update to Diffusers

Browse files
Files changed (16) hide show
  1. .gitignore +2 -1
  2. .gitmodules +0 -3
  3. .pre-commit-config.yaml +2 -0
  4. README.md +4 -3
  5. app.py +11 -228
  6. app_inference.py +150 -0
  7. app_training.py +128 -0
  8. app_upload.py +95 -0
  9. constants.py +6 -0
  10. inference.py +45 -48
  11. lora +0 -1
  12. requirements.txt +9 -7
  13. train_dreambooth_lora.py +1018 -0
  14. trainer.py +67 -54
  15. uploader.py +35 -16
  16. utils.py +18 -0
.gitignore CHANGED
@@ -1,5 +1,6 @@
1
  training_data/
2
- results/
 
3
 
4
 
5
  # Byte-compiled / optimized / DLL files
1
  training_data/
2
+ experiments/
3
+ wandb/
4
 
5
 
6
  # Byte-compiled / optimized / DLL files
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "lora"]
2
- path = lora
3
- url = https://github.com/cloneofsimo/lora
 
 
 
.pre-commit-config.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  repos:
2
  - repo: https://github.com/pre-commit/pre-commit-hooks
3
  rev: v4.2.0
@@ -28,6 +29,7 @@ repos:
28
  hooks:
29
  - id: mypy
30
  args: ['--ignore-missing-imports']
 
31
  - repo: https://github.com/google/yapf
32
  rev: v0.32.0
33
  hooks:
1
+ exclude: train_dreambooth_lora.py
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
  - repo: https://github.com/google/yapf
34
  rev: v0.32.0
35
  hooks:
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: LoRA + SD Training
3
- emoji: 🏢
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.12.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
1
  ---
2
+ title: LoRA DreamBooth Training UI
3
+ emoji:
4
  colorFrom: red
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
11
  license: mit
app.py CHANGED
@@ -1,27 +1,21 @@
1
  #!/usr/bin/env python
2
- """Unofficial demo app for https://github.com/cloneofsimo/lora.
3
-
4
- The code in this repo is partly adapted from the following repository:
5
- https://huggingface.co/spaces/multimodalart/dreambooth-training/tree/a00184917aa273c6d8adab08d5deb9b39b997938
6
- The license of the original code is MIT, which is specified in the README.md.
7
- """
8
 
9
  from __future__ import annotations
10
 
11
  import os
12
- import pathlib
13
 
14
  import gradio as gr
15
  import torch
16
 
 
 
 
17
  from inference import InferencePipeline
18
  from trainer import Trainer
19
- from uploader import upload
20
 
21
- TITLE = '# LoRA + StableDiffusion Training UI'
22
- DESCRIPTION = 'This is an unofficial demo for [https://github.com/cloneofsimo/lora](https://github.com/cloneofsimo/lora).'
23
 
24
- ORIGINAL_SPACE_ID = 'hysts/LoRA-SD-training'
25
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
26
  SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
27
 
@@ -29,7 +23,6 @@ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI.
29
  '''
30
  if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
31
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
32
-
33
  else:
34
  SETTINGS = 'Settings'
35
  CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
@@ -39,6 +32,8 @@ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
39
  </center>
40
  '''
41
 
 
 
42
 
43
  def show_warning(warning_text: str) -> gr.Blocks:
44
  with gr.Blocks() as demo:
@@ -47,217 +42,7 @@ def show_warning(warning_text: str) -> gr.Blocks:
47
  return demo
48
 
49
 
50
- def update_output_files() -> dict:
51
- paths = sorted(pathlib.Path('results').glob('*.pt'))
52
- paths = [path.as_posix() for path in paths] # type: ignore
53
- return gr.update(value=paths or None)
54
-
55
-
56
- def create_training_demo(trainer: Trainer,
57
- pipe: InferencePipeline) -> gr.Blocks:
58
- with gr.Blocks() as demo:
59
- base_model = gr.Dropdown(
60
- choices=['stabilityai/stable-diffusion-2-1-base'],
61
- value='stabilityai/stable-diffusion-2-1-base',
62
- label='Base Model',
63
- visible=False)
64
- resolution = gr.Dropdown(choices=['512'],
65
- value='512',
66
- label='Resolution',
67
- visible=False)
68
-
69
- with gr.Row():
70
- with gr.Box():
71
- gr.Markdown('Training Data')
72
- concept_images = gr.Files(label='Images for your concept')
73
- concept_prompt = gr.Textbox(label='Concept Prompt',
74
- max_lines=1)
75
- gr.Markdown('''
76
- - Upload images of the style you are planning on training on.
77
- - For a concept prompt, use a unique, made up word to avoid collisions.
78
- ''')
79
- with gr.Box():
80
- gr.Markdown('Training Parameters')
81
- num_training_steps = gr.Number(
82
- label='Number of Training Steps', value=1000, precision=0)
83
- learning_rate = gr.Number(label='Learning Rate', value=0.0001)
84
- train_text_encoder = gr.Checkbox(label='Train Text Encoder',
85
- value=True)
86
- learning_rate_text = gr.Number(
87
- label='Learning Rate for Text Encoder', value=0.00005)
88
- gradient_accumulation = gr.Number(
89
- label='Number of Gradient Accumulation',
90
- value=1,
91
- precision=0)
92
- fp16 = gr.Checkbox(label='FP16', value=True)
93
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
94
- gr.Markdown('''
95
- - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
96
- - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
97
- - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
98
- ''')
99
-
100
- run_button = gr.Button('Start Training')
101
- with gr.Box():
102
- with gr.Row():
103
- check_status_button = gr.Button('Check Training Status')
104
- with gr.Column():
105
- with gr.Box():
106
- gr.Markdown('Message')
107
- training_status = gr.Markdown()
108
- output_files = gr.Files(label='Trained Weight Files')
109
-
110
- run_button.click(fn=pipe.clear)
111
- run_button.click(fn=trainer.run,
112
- inputs=[
113
- base_model,
114
- resolution,
115
- concept_images,
116
- concept_prompt,
117
- num_training_steps,
118
- learning_rate,
119
- train_text_encoder,
120
- learning_rate_text,
121
- gradient_accumulation,
122
- fp16,
123
- use_8bit_adam,
124
- ],
125
- outputs=[
126
- training_status,
127
- output_files,
128
- ],
129
- queue=False)
130
- check_status_button.click(fn=trainer.check_if_running,
131
- inputs=None,
132
- outputs=training_status,
133
- queue=False)
134
- check_status_button.click(fn=update_output_files,
135
- inputs=None,
136
- outputs=output_files,
137
- queue=False)
138
- return demo
139
-
140
-
141
- def find_weight_files() -> list[str]:
142
- curr_dir = pathlib.Path(__file__).parent
143
- paths = sorted(curr_dir.rglob('*.pt'))
144
- paths = [path for path in paths if not path.stem.endswith('.text_encoder')]
145
- return [path.relative_to(curr_dir).as_posix() for path in paths]
146
-
147
-
148
- def reload_lora_weight_list() -> dict:
149
- return gr.update(choices=find_weight_files())
150
-
151
-
152
- def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
153
- with gr.Blocks() as demo:
154
- with gr.Row():
155
- with gr.Column():
156
- base_model = gr.Dropdown(
157
- choices=['stabilityai/stable-diffusion-2-1-base'],
158
- value='stabilityai/stable-diffusion-2-1-base',
159
- label='Base Model',
160
- visible=False)
161
- reload_button = gr.Button('Reload Weight List')
162
- lora_weight_name = gr.Dropdown(choices=find_weight_files(),
163
- value='lora/lora_disney.pt',
164
- label='LoRA Weight File')
165
- prompt = gr.Textbox(
166
- label='Prompt',
167
- max_lines=1,
168
- placeholder='Example: "style of sks, baby lion"')
169
- alpha = gr.Slider(label='Alpha',
170
- minimum=0,
171
- maximum=2,
172
- step=0.05,
173
- value=1)
174
- alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
175
- minimum=0,
176
- maximum=2,
177
- step=0.05,
178
- value=1)
179
- seed = gr.Slider(label='Seed',
180
- minimum=0,
181
- maximum=100000,
182
- step=1,
183
- value=1)
184
- with gr.Accordion('Other Parameters', open=False):
185
- num_steps = gr.Slider(label='Number of Steps',
186
- minimum=0,
187
- maximum=100,
188
- step=1,
189
- value=50)
190
- guidance_scale = gr.Slider(label='CFG Scale',
191
- minimum=0,
192
- maximum=50,
193
- step=0.1,
194
- value=7)
195
-
196
- run_button = gr.Button('Generate')
197
-
198
- gr.Markdown('''
199
- - Models with names starting with "lora/" are the pretrained models provided in the [original repo](https://github.com/cloneofsimo/lora), and the ones with names starting with "results/" are your trained models.
200
- - After training, you can press "Reload Weight List" button to load your trained model names.
201
- - The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
202
- - The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
203
- ''')
204
- with gr.Column():
205
- result = gr.Image(label='Result')
206
-
207
- reload_button.click(fn=reload_lora_weight_list,
208
- inputs=None,
209
- outputs=lora_weight_name)
210
- prompt.submit(fn=pipe.run,
211
- inputs=[
212
- base_model,
213
- lora_weight_name,
214
- prompt,
215
- alpha,
216
- alpha_for_text,
217
- seed,
218
- num_steps,
219
- guidance_scale,
220
- ],
221
- outputs=result,
222
- queue=False)
223
- run_button.click(fn=pipe.run,
224
- inputs=[
225
- base_model,
226
- lora_weight_name,
227
- prompt,
228
- alpha,
229
- alpha_for_text,
230
- seed,
231
- num_steps,
232
- guidance_scale,
233
- ],
234
- outputs=result,
235
- queue=False)
236
- return demo
237
-
238
-
239
- def create_upload_demo() -> gr.Blocks:
240
- with gr.Blocks() as demo:
241
- model_name = gr.Textbox(label='Model Name')
242
- hf_token = gr.Textbox(
243
- label='Hugging Face Token (with write permission)')
244
- upload_button = gr.Button('Upload')
245
- with gr.Box():
246
- gr.Markdown('Message')
247
- result = gr.Markdown()
248
- gr.Markdown('''
249
- - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
250
- - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
251
- ''')
252
-
253
- upload_button.click(fn=upload,
254
- inputs=[model_name, hf_token],
255
- outputs=result)
256
-
257
- return demo
258
-
259
-
260
- pipe = InferencePipeline()
261
  trainer = Trainer()
262
 
263
  with gr.Blocks(css='style.css') as demo:
@@ -267,14 +52,12 @@ with gr.Blocks(css='style.css') as demo:
267
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
268
 
269
  gr.Markdown(TITLE)
270
- gr.Markdown(DESCRIPTION)
271
-
272
  with gr.Tabs():
273
  with gr.TabItem('Train'):
274
  create_training_demo(trainer, pipe)
275
  with gr.TabItem('Test'):
276
- create_inference_demo(pipe)
277
  with gr.TabItem('Upload'):
278
- create_upload_demo()
279
 
280
- demo.queue(default_enabled=False).launch(share=False)
1
  #!/usr/bin/env python
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  import os
 
6
 
7
  import gradio as gr
8
  import torch
9
 
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
  from inference import InferencePipeline
14
  from trainer import Trainer
 
15
 
16
+ TITLE = '# LoRA DreamBooth Training UI'
 
17
 
18
+ ORIGINAL_SPACE_ID = 'hysts/LoRA-DreamBooth-Training-UI'
19
  SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
  SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
 
23
  '''
24
  if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
25
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
 
26
  else:
27
  SETTINGS = 'Settings'
28
  CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
32
  </center>
33
  '''
34
 
35
+ HF_TOKEN = os.getenv('HF_TOKEN')
36
+
37
 
38
  def show_warning(warning_text: str) -> gr.Blocks:
39
  with gr.Blocks() as demo:
42
  return demo
43
 
44
 
45
+ pipe = InferencePipeline(HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  trainer = Trainer()
47
 
48
  with gr.Blocks(css='style.css') as demo:
52
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
53
 
54
  gr.Markdown(TITLE)
 
 
55
  with gr.Tabs():
56
  with gr.TabItem('Train'):
57
  create_training_demo(trainer, pipe)
58
  with gr.TabItem('Test'):
59
+ create_inference_demo(pipe, HF_TOKEN)
60
  with gr.TabItem('Upload'):
61
+ create_upload_demo(HF_TOKEN)
62
 
63
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = ['patrickvonplaten/lora_dreambooth_dog_example']
14
+
15
+
16
+ class ModelSource(enum.Enum):
17
+ SAMPLE = 'Sample'
18
+ HUB_LIB = 'Hub (lora-library)'
19
+ LOCAL = 'Local'
20
+
21
+
22
+ class InferenceUtil:
23
+ def __init__(self, hf_token: str | None):
24
+ self.hf_token = hf_token
25
+
26
+ @staticmethod
27
+ def load_sample_lora_model_list():
28
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
29
+
30
+ def load_hub_lora_model_list(self) -> dict:
31
+ api = HfApi(token=self.hf_token)
32
+ choices = [
33
+ info.modelId for info in api.list_models(author='lora-library')
34
+ ]
35
+ return gr.update(choices=choices,
36
+ value=choices[0] if choices else None)
37
+
38
+ @staticmethod
39
+ def load_local_lora_model_list() -> dict:
40
+ choices = find_exp_dirs()
41
+ return gr.update(choices=choices,
42
+ value=choices[0] if choices else None)
43
+
44
+ def reload_lora_model_list(self, model_source: str) -> dict:
45
+ if model_source == ModelSource.SAMPLE.value:
46
+ return self.load_sample_lora_model_list()
47
+ elif model_source == ModelSource.HUB_LIB.value:
48
+ return self.load_hub_lora_model_list()
49
+ elif model_source == ModelSource.LOCAL.value:
50
+ return self.load_local_lora_model_list()
51
+ else:
52
+ raise ValueError
53
+
54
+ def load_model_info(self, lora_model_id: str) -> tuple[str, str]:
55
+ try:
56
+ card = InferencePipeline.get_model_card(lora_model_id,
57
+ self.hf_token)
58
+ except Exception:
59
+ return '', ''
60
+ base_model = getattr(card.data, 'base_model', '')
61
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
62
+ return base_model, instance_prompt
63
+
64
+
65
+ def create_inference_demo(pipe: InferencePipeline,
66
+ hf_token: str | None = None) -> gr.Blocks:
67
+ app = InferenceUtil(hf_token)
68
+
69
+ with gr.Blocks() as demo:
70
+ with gr.Row():
71
+ with gr.Column():
72
+ with gr.Box():
73
+ model_source = gr.Radio(
74
+ label='Model Source',
75
+ choices=[_.value for _ in ModelSource],
76
+ value=ModelSource.SAMPLE.value)
77
+ reload_button = gr.Button('Reload Model List')
78
+ lora_model_id = gr.Dropdown(label='LoRA Model ID',
79
+ choices=SAMPLE_MODEL_IDS,
80
+ value=SAMPLE_MODEL_IDS[0])
81
+ with gr.Accordion(
82
+ label=
83
+ 'Model info (Base model and instance prompt used for training)',
84
+ open=False):
85
+ with gr.Row():
86
+ base_model_used_for_training = gr.Text(
87
+ label='Base model', interactive=False)
88
+ instance_prompt_used_for_training = gr.Text(
89
+ label='Instance prompt', interactive=False)
90
+ prompt = gr.Textbox(
91
+ label='Prompt',
92
+ max_lines=1,
93
+ placeholder='Example: "A picture of a sks dog in a bucket"'
94
+ )
95
+ seed = gr.Slider(label='Seed',
96
+ minimum=0,
97
+ maximum=100000,
98
+ step=1,
99
+ value=0)
100
+ with gr.Accordion('Other Parameters', open=False):
101
+ num_steps = gr.Slider(label='Number of Steps',
102
+ minimum=0,
103
+ maximum=100,
104
+ step=1,
105
+ value=25)
106
+ guidance_scale = gr.Slider(label='CFG Scale',
107
+ minimum=0,
108
+ maximum=50,
109
+ step=0.1,
110
+ value=7.5)
111
+
112
+ run_button = gr.Button('Generate')
113
+
114
+ gr.Markdown('''
115
+ - After training, you can press "Reload Model List" button to load your trained model names.
116
+ ''')
117
+ with gr.Column():
118
+ result = gr.Image(label='Result')
119
+
120
+ model_source.change(fn=app.reload_lora_model_list,
121
+ inputs=model_source,
122
+ outputs=lora_model_id)
123
+ reload_button.click(fn=app.reload_lora_model_list,
124
+ inputs=model_source,
125
+ outputs=lora_model_id)
126
+ lora_model_id.change(fn=app.load_model_info,
127
+ inputs=lora_model_id,
128
+ outputs=[
129
+ base_model_used_for_training,
130
+ instance_prompt_used_for_training,
131
+ ])
132
+ inputs = [
133
+ lora_model_id,
134
+ prompt,
135
+ seed,
136
+ num_steps,
137
+ guidance_scale,
138
+ ]
139
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
140
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
141
+ return demo
142
+
143
+
144
+ if __name__ == '__main__':
145
+ import os
146
+
147
+ hf_token = os.getenv('HF_TOKEN')
148
+ pipe = InferencePipeline(hf_token)
149
+ demo = create_inference_demo(pipe, hf_token)
150
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ from constants import UploadTarget
8
+ from inference import InferencePipeline
9
+ from trainer import Trainer
10
+
11
+
12
+ def create_training_demo(trainer: Trainer,
13
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
14
+ with gr.Blocks() as demo:
15
+ with gr.Row():
16
+ with gr.Column():
17
+ with gr.Box():
18
+ gr.Markdown('Training Data')
19
+ instance_images = gr.Files(label='Instance images')
20
+ instance_prompt = gr.Textbox(label='Instance prompt',
21
+ max_lines=1)
22
+ gr.Markdown('''
23
+ - Upload images of the style you are planning on training on.
24
+ - For an instance prompt, use a unique, made up word to avoid collisions.
25
+ ''')
26
+ with gr.Box():
27
+ gr.Markdown('Output Model')
28
+ output_model_name = gr.Text(label='Name of your model',
29
+ max_lines=1)
30
+ delete_existing_model = gr.Checkbox(
31
+ label='Delete existing model of the same name',
32
+ value=False)
33
+ validation_prompt = gr.Text(label='Validation Prompt')
34
+ with gr.Box():
35
+ gr.Markdown('Upload Settings')
36
+ with gr.Row():
37
+ upload_to_hub = gr.Checkbox(
38
+ label='Upload model to Hub', value=False)
39
+ use_private_repo = gr.Checkbox(label='Private',
40
+ value=False)
41
+ delete_existing_repo = gr.Checkbox(
42
+ label='Delete existing repo of the same name',
43
+ value=False)
44
+ upload_to = gr.Radio(
45
+ label='Upload to',
46
+ choices=[_.value for _ in UploadTarget],
47
+ value=UploadTarget.PERSONAL_PROFILE.value)
48
+
49
+ with gr.Box():
50
+ gr.Markdown('Training Parameters')
51
+ with gr.Row():
52
+ base_model = gr.Text(
53
+ label='Base Model',
54
+ value='stabilityai/stable-diffusion-2-1-base',
55
+ max_lines=1)
56
+ resolution = gr.Dropdown(choices=['512', '768'],
57
+ value='512',
58
+ label='Resolution')
59
+ num_training_steps = gr.Number(
60
+ label='Number of Training Steps', value=1000, precision=0)
61
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
62
+ gradient_accumulation = gr.Number(
63
+ label='Number of Gradient Accumulation',
64
+ value=1,
65
+ precision=0)
66
+ seed = gr.Slider(label='Seed',
67
+ minimum=0,
68
+ maximum=100000,
69
+ step=1,
70
+ value=0)
71
+ fp16 = gr.Checkbox(label='FP16', value=True)
72
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
73
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
74
+ value=100,
75
+ precision=0)
76
+ use_wandb = gr.Checkbox(label='Use W&B', value=False)
77
+ validation_epochs = gr.Number(label='Validation Epochs',
78
+ value=100,
79
+ precision=0)
80
+ gr.Markdown('''
81
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
82
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
83
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use W&B. See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
84
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
85
+ ''')
86
+
87
+ # TODO currently disabled
88
+ remove_gpu_after_training = gr.Checkbox(
89
+ label='Remove GPU after training', value=False, interactive=False)
90
+ run_button = gr.Button('Start Training')
91
+
92
+ with gr.Box():
93
+ gr.Markdown('Message')
94
+ message = gr.Markdown()
95
+
96
+ if pipe is not None:
97
+ run_button.click(fn=pipe.clear)
98
+ run_button.click(fn=trainer.run,
99
+ inputs=[
100
+ instance_images,
101
+ instance_prompt,
102
+ output_model_name,
103
+ delete_existing_model,
104
+ validation_prompt,
105
+ base_model,
106
+ resolution,
107
+ num_training_steps,
108
+ learning_rate,
109
+ gradient_accumulation,
110
+ seed,
111
+ fp16,
112
+ use_8bit_adam,
113
+ checkpointing_steps,
114
+ use_wandb,
115
+ validation_epochs,
116
+ upload_to_hub,
117
+ use_private_repo,
118
+ delete_existing_repo,
119
+ upload_to,
120
+ ],
121
+ outputs=message)
122
+ return demo
123
+
124
+
125
+ if __name__ == '__main__':
126
+ trainer = Trainer()
127
+ demo = create_training_demo(trainer)
128
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class LoRAModelUploader(Uploader):
16
+ def upload_lora_model(self, folder_path: str, repo_name: str,
17
+ upload_to: str, private: bool,
18
+ delete_existing_repo: bool) -> str:
19
+ if not repo_name:
20
+ repo_name = pathlib.Path(folder_path).name
21
+ repo_name = slugify.slugify(repo_name)
22
+
23
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
24
+ organization = ''
25
+ elif upload_to == UploadTarget.LORA_LIBRARY.value:
26
+ organization = 'lora-library'
27
+ else:
28
+ raise ValueError
29
+
30
+ return self.upload(folder_path,
31
+ repo_name,
32
+ organization=organization,
33
+ private=private,
34
+ delete_existing_repo=delete_existing_repo)
35
+
36
+
37
+ def load_local_lora_model_list() -> dict:
38
+ choices = find_exp_dirs(ignore_repo=True)
39
+ return gr.update(choices=choices, value=choices[0] if choices else None)
40
+
41
+
42
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
43
+ uploader = LoRAModelUploader(hf_token)
44
+ model_dirs = find_exp_dirs(ignore_repo=True)
45
+
46
+ with gr.Blocks() as demo:
47
+ with gr.Box():
48
+ gr.Markdown('Local Models')
49
+ reload_button = gr.Button('Reload Model List')
50
+ model_dir = gr.Dropdown(
51
+ label='LoRA Model ID',
52
+ choices=model_dirs,
53
+ value=model_dirs[0] if model_dirs else None)
54
+ gr.Markdown(
55
+ '- Models uploaded in training time will not be shown here.')
56
+ with gr.Box():
57
+ gr.Markdown('Upload Settings')
58
+ with gr.Row():
59
+ use_private_repo = gr.Checkbox(label='Private', value=False)
60
+ delete_existing_repo = gr.Checkbox(
61
+ label='Delete existing repo of the same name', value=False)
62
+ upload_to = gr.Radio(label='Upload to',
63
+ choices=[_.value for _ in UploadTarget],
64
+ value=UploadTarget.PERSONAL_PROFILE.value)
65
+ model_name = gr.Textbox(label='Model Name')
66
+ upload_button = gr.Button('Upload')
67
+ gr.Markdown('''
68
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [LoRA Concepts Library](https://huggingface.co/lora-library) (i.e. https://huggingface.co/lora-library/{model_name}).
69
+ ''')
70
+ with gr.Box():
71
+ gr.Markdown('Message')
72
+ message = gr.Markdown()
73
+
74
+ reload_button.click(fn=load_local_lora_model_list,
75
+ inputs=None,
76
+ outputs=model_dir)
77
+ upload_button.click(fn=uploader.upload_lora_model,
78
+ inputs=[
79
+ model_dir,
80
+ model_name,
81
+ upload_to,
82
+ use_private_repo,
83
+ delete_existing_repo,
84
+ ],
85
+ outputs=message)
86
+
87
+ return demo
88
+
89
+
90
+ if __name__ == '__main__':
91
+ import os
92
+
93
+ hf_token = os.getenv('HF_TOKEN')
94
+ demo = create_upload_demo(hf_token)
95
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ LORA_LIBRARY = 'LoRA Library'
inference.py CHANGED
@@ -2,78 +2,77 @@ from __future__ import annotations
2
 
3
  import gc
4
  import pathlib
5
- import sys
6
 
7
  import gradio as gr
8
  import PIL.Image
9
  import torch
10
- from diffusers import StableDiffusionPipeline
11
-
12
- sys.path.insert(0, 'lora')
13
- from lora_diffusion import monkeypatch_lora, tune_lora_scale
14
 
15
 
16
  class InferencePipeline:
17
- def __init__(self):
 
18
  self.pipe = None
19
  self.device = torch.device(
20
  'cuda:0' if torch.cuda.is_available() else 'cpu')
21
- self.weight_path = None
 
22
 
23
  def clear(self) -> None:
24
- self.weight_path = None
 
25
  del self.pipe
26
  self.pipe = None
27
  torch.cuda.empty_cache()
28
  gc.collect()
29
 
30
  @staticmethod
31
- def get_lora_weight_path(name: str) -> pathlib.Path:
32
- curr_dir = pathlib.Path(__file__).parent
33
- return curr_dir / name
34
 
35
  @staticmethod
36
- def get_lora_text_encoder_weight_path(path: pathlib.Path) -> str:
37
- parent_dir = path.parent
38
- stem = path.stem
39
- text_encoder_filename = f'{stem}.text_encoder.pt'
40
- path = parent_dir / text_encoder_filename
41
- return path.as_posix() if path.exists() else ''
42
-
43
- def load_pipe(self, model_id: str, lora_filename: str) -> None:
44
- weight_path = self.get_lora_weight_path(lora_filename)
45
- if weight_path == self.weight_path:
46
- return
47
- self.weight_path = weight_path
48
- lora_weight = torch.load(self.weight_path, map_location=self.device)
49
-
50
- if self.device.type == 'cpu':
51
- pipe = StableDiffusionPipeline.from_pretrained(model_id)
52
  else:
53
- pipe = StableDiffusionPipeline.from_pretrained(
54
- model_id, torch_dtype=torch.float16)
55
- pipe = pipe.to(self.device)
56
 
57
- monkeypatch_lora(pipe.unet, lora_weight)
58
-
59
- lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
60
- weight_path)
61
- if lora_text_encoder_weight_path:
62
- lora_text_encoder_weight = torch.load(
63
- lora_text_encoder_weight_path, map_location=self.device)
64
- monkeypatch_lora(pipe.text_encoder,
65
- lora_text_encoder_weight,
66
- target_replace_module=['CLIPAttention'])
67
 
68
- self.pipe = pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def run(
71
  self,
72
- base_model: str,
73
- lora_weight_name: str,
74
  prompt: str,
75
- alpha: float,
76
- alpha_for_text: float,
77
  seed: int,
78
  n_steps: int,
79
  guidance_scale: float,
@@ -81,11 +80,9 @@ class InferencePipeline:
81
  if not torch.cuda.is_available():
82
  raise gr.Error('CUDA is not available.')
83
 
84
- self.load_pipe(base_model, lora_weight_name)
85
 
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
- tune_lora_scale(self.pipe.unet, alpha) # type: ignore
88
- tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
89
  out = self.pipe(prompt,
90
  num_inference_steps=n_steps,
91
  guidance_scale=guidance_scale,
2
 
3
  import gc
4
  import pathlib
 
5
 
6
  import gradio as gr
7
  import PIL.Image
8
  import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
 
 
11
 
12
 
13
  class InferencePipeline:
14
+ def __init__(self, hf_token: str | None = None):
15
+ self.hf_token = hf_token
16
  self.pipe = None
17
  self.device = torch.device(
18
  'cuda:0' if torch.cuda.is_available() else 'cpu')
19
+ self.lora_model_id = None
20
+ self.base_model_id = None
21
 
22
  def clear(self) -> None:
23
+ self.lora_model_id = None
24
+ self.base_model_id = None
25
  del self.pipe
26
  self.pipe = None
27
  torch.cuda.empty_cache()
28
  gc.collect()
29
 
30
  @staticmethod
31
+ def check_if_model_is_local(lora_model_id: str) -> bool:
32
+ return pathlib.Path(lora_model_id).exists()
 
33
 
34
  @staticmethod
35
+ def get_model_card(model_id: str,
36
+ hf_token: str | None = None) -> ModelCard:
37
+ if InferencePipeline.check_if_model_is_local(model_id):
38
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
 
 
 
 
 
 
 
 
 
 
 
 
39
  else:
40
+ card_path = model_id
41
+ return ModelCard.load(card_path, token=hf_token)
 
42
 
43
+ @staticmethod
44
+ def get_base_model_info(lora_model_id: str,
45
+ hf_token: str | None = None) -> str:
46
+ card = InferencePipeline.get_model_card(lora_model_id, hf_token)
47
+ return card.data.base_model
 
 
 
 
 
48
 
49
+ def load_pipe(self, lora_model_id: str) -> None:
50
+ if lora_model_id == self.lora_model_id:
51
+ return
52
+ base_model_id = self.get_base_model_info(lora_model_id, self.hf_token)
53
+ if base_model_id != self.base_model_id:
54
+ if self.device.type == 'cpu':
55
+ pipe = DiffusionPipeline.from_pretrained(
56
+ base_model_id, use_auth_token=self.hf_token)
57
+ else:
58
+ pipe = DiffusionPipeline.from_pretrained(
59
+ base_model_id,
60
+ torch_dtype=torch.float16,
61
+ use_auth_token=self.hf_token)
62
+ pipe = pipe.to(self.device)
63
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
64
+ pipe.scheduler.config)
65
+ self.pipe = pipe
66
+ self.pipe.unet.load_attn_procs( # type: ignore
67
+ lora_model_id, use_auth_token=self.hf_token)
68
+
69
+ self.lora_model_id = lora_model_id # type: ignore
70
+ self.base_model_id = base_model_id # type: ignore
71
 
72
  def run(
73
  self,
74
+ lora_model_id: str,
 
75
  prompt: str,
 
 
76
  seed: int,
77
  n_steps: int,
78
  guidance_scale: float,
80
  if not torch.cuda.is_available():
81
  raise gr.Error('CUDA is not available.')
82
 
83
+ self.load_pipe(lora_model_id)
84
 
85
  generator = torch.Generator(device=self.device).manual_seed(seed)
 
 
86
  out = self.pipe(prompt,
87
  num_inference_steps=n_steps,
88
  guidance_scale=guidance_scale,
lora DELETED
@@ -1 +0,0 @@
1
- Subproject commit 26787a09bff4ebcb08f0ad4e848b67bce4389a7a
 
requirements.txt CHANGED
@@ -1,10 +1,12 @@
1
  accelerate==0.15.0
2
- bitsandbytes==0.35.4
3
- diffusers==0.10.2
 
4
  ftfy==6.1.1
5
- Pillow==9.3.0
6
- torch==1.13.0
7
- torchvision==0.14.0
 
 
8
  transformers==4.25.1
9
- triton==2.0.0.dev20220701
10
- xformers==0.0.13
1
  accelerate==0.15.0
2
+ bitsandbytes==0.36.0.post2
3
+ datasets==2.8.0
4
+ git+https://github.com/huggingface/diffusers@a66f2baeb782e091dde4e1e6394e46f169e5ba58#egg=diffusers
5
  ftfy==6.1.1
6
+ gradio==3.14.0
7
+ Pillow==9.4.0
8
+ python-slugify==7.0.0
9
+ torch==1.13.1
10
+ torchvision==0.14.1
11
  transformers==4.25.1
12
+ wandb==0.13.9
 
train_dreambooth_lora.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # This file is adapted from https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/examples/dreambooth/train_dreambooth_lora.py
3
+ # The original license is as below.
4
+ #
5
+ # coding=utf-8
6
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+
19
+ import argparse
20
+ import hashlib
21
+ import logging
22
+ import math
23
+ import os
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import Optional
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ from torch.utils.data import Dataset
32
+
33
+ import datasets
34
+ import diffusers
35
+ import transformers
36
+ from accelerate import Accelerator
37
+ from accelerate.logging import get_logger
38
+ from accelerate.utils import set_seed
39
+ from diffusers import (
40
+ AutoencoderKL,
41
+ DDPMScheduler,
42
+ DiffusionPipeline,
43
+ DPMSolverMultistepScheduler,
44
+ UNet2DConditionModel,
45
+ )
46
+ from diffusers.loaders import AttnProcsLayers
47
+ from diffusers.models.cross_attention import LoRACrossAttnProcessor
48
+ from diffusers.optimization import get_scheduler
49
+ from diffusers.utils import check_min_version, is_wandb_available
50
+ from diffusers.utils.import_utils import is_xformers_available
51
+ from huggingface_hub import HfFolder, Repository, create_repo, delete_repo, whoami
52
+ from PIL import Image
53
+ from torchvision import transforms
54
+ from tqdm.auto import tqdm
55
+ from transformers import AutoTokenizer, PretrainedConfig
56
+
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.12.0.dev0")
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ def save_model_card(repo_name, base_model, instance_prompt, test_prompt="", images=None, repo_folder=""):
65
+ img_str = f"Test prompt: {test_prompt}\n" if test_prompt else ""
66
+ for i, image in enumerate(images or []):
67
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
68
+ img_str += f"![img_{i}](./image_{i}.png)\n"
69
+
70
+ yaml = f"""
71
+ ---
72
+ license: creativeml-openrail-m
73
+ base_model: {base_model}
74
+ instance_prompt: {instance_prompt}
75
+ tags:
76
+ - stable-diffusion
77
+ - stable-diffusion-diffusers
78
+ - text-to-image
79
+ - diffusers
80
+ inference: true
81
+ ---
82
+ """
83
+ model_card = f"""
84
+ # LoRA DreamBooth - {repo_name}
85
+
86
+ These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.\n
87
+ {img_str}
88
+ """
89
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
90
+ f.write(yaml + model_card)
91
+
92
+
93
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
94
+ text_encoder_config = PretrainedConfig.from_pretrained(
95
+ pretrained_model_name_or_path,
96
+ subfolder="text_encoder",
97
+ revision=revision,
98
+ )
99
+ model_class = text_encoder_config.architectures[0]
100
+
101
+ if model_class == "CLIPTextModel":
102
+ from transformers import CLIPTextModel
103
+
104
+ return CLIPTextModel
105
+ elif model_class == "RobertaSeriesModelWithTransformation":
106
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
107
+
108
+ return RobertaSeriesModelWithTransformation
109
+ else:
110
+ raise ValueError(f"{model_class} is not supported.")
111
+
112
+
113
+ def parse_args(input_args=None):
114
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
115
+ parser.add_argument(
116
+ "--pretrained_model_name_or_path",
117
+ type=str,
118
+ default=None,
119
+ required=True,
120
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
121
+ )
122
+ parser.add_argument(
123
+ "--revision",
124
+ type=str,
125
+ default=None,
126
+ required=False,
127
+ help="Revision of pretrained model identifier from huggingface.co/models.",
128
+ )
129
+ parser.add_argument(
130
+ "--tokenizer_name",
131
+ type=str,
132
+ default=None,
133
+ help="Pretrained tokenizer name or path if not the same as model_name",
134
+ )
135
+ parser.add_argument(
136
+ "--instance_data_dir",
137
+ type=str,
138
+ default=None,
139
+ required=True,
140
+ help="A folder containing the training data of instance images.",
141
+ )
142
+ parser.add_argument(
143
+ "--class_data_dir",
144
+ type=str,
145
+ default=None,
146
+ required=False,
147
+ help="A folder containing the training data of class images.",
148
+ )
149
+ parser.add_argument(
150
+ "--instance_prompt",
151
+ type=str,
152
+ default=None,
153
+ required=True,
154
+ help="The prompt with identifier specifying the instance",
155
+ )
156
+ parser.add_argument(
157
+ "--class_prompt",
158
+ type=str,
159
+ default=None,
160
+ help="The prompt to specify images in the same class as provided instance images.",
161
+ )
162
+ parser.add_argument(
163
+ "--validation_prompt",
164
+ type=str,
165
+ default=None,
166
+ help="A prompt that is used during validation to verify that the model is learning.",
167
+ )
168
+ parser.add_argument(
169
+ "--num_validation_images",
170
+ type=int,
171
+ default=4,
172
+ help="Number of images that should be generated during validation with `validation_prompt`.",
173
+ )
174
+ parser.add_argument(
175
+ "--validation_epochs",
176
+ type=int,
177
+ default=50,
178
+ help=(
179
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
180
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
181
+ ),
182
+ )
183
+ parser.add_argument(
184
+ "--with_prior_preservation",
185
+ default=False,
186
+ action="store_true",
187
+ help="Flag to add prior preservation loss.",
188
+ )
189
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
190
+ parser.add_argument(
191
+ "--num_class_images",
192
+ type=int,
193
+ default=100,
194
+ help=(
195
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
196
+ " class_data_dir, additional images will be sampled with class_prompt."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--output_dir",
201
+ type=str,
202
+ default="lora-dreambooth-model",
203
+ help="The output directory where the model predictions and checkpoints will be written.",
204
+ )
205
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
206
+ parser.add_argument(
207
+ "--resolution",
208
+ type=int,
209
+ default=512,
210
+ help=(
211
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
212
+ " resolution"
213
+ ),
214
+ )
215
+ parser.add_argument(
216
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
217
+ )
218
+ parser.add_argument(
219
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
220
+ )
221
+ parser.add_argument(
222
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
223
+ )
224
+ parser.add_argument("--num_train_epochs", type=int, default=1)
225
+ parser.add_argument(
226
+ "--max_train_steps",
227
+ type=int,
228
+ default=None,
229
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
230
+ )
231
+ parser.add_argument(
232
+ "--checkpointing_steps",
233
+ type=int,
234
+ default=500,
235
+ help=(
236
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
237
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
238
+ " training using `--resume_from_checkpoint`."
239
+ ),
240
+ )
241
+ parser.add_argument(
242
+ "--resume_from_checkpoint",
243
+ type=str,
244
+ default=None,
245
+ help=(
246
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
247
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
248
+ ),
249
+ )
250
+ parser.add_argument(
251
+ "--gradient_accumulation_steps",
252
+ type=int,
253
+ default=1,
254
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
255
+ )
256
+ parser.add_argument(
257
+ "--gradient_checkpointing",
258
+ action="store_true",
259
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
260
+ )
261
+ parser.add_argument(
262
+ "--learning_rate",
263
+ type=float,
264
+ default=5e-4,
265
+ help="Initial learning rate (after the potential warmup period) to use.",
266
+ )
267
+ parser.add_argument(
268
+ "--scale_lr",
269
+ action="store_true",
270
+ default=False,
271
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
272
+ )
273
+ parser.add_argument(
274
+ "--lr_scheduler",
275
+ type=str,
276
+ default="constant",
277
+ help=(
278
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
279
+ ' "constant", "constant_with_warmup"]'
280
+ ),
281
+ )
282
+ parser.add_argument(
283
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
284
+ )
285
+ parser.add_argument(
286
+ "--lr_num_cycles",
287
+ type=int,
288
+ default=1,
289
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
290
+ )
291
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
292
+ parser.add_argument(
293
+ "--dataloader_num_workers",
294
+ type=int,
295
+ default=0,
296
+ help=(
297
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
302
+ )
303
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
304
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
305
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
306
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
307
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
308
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
309
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
310
+ parser.add_argument(
311
+ "--hub_model_id",
312
+ type=str,
313
+ default=None,
314
+ help="The name of the repository to keep in sync with the local `output_dir`.",
315
+ )
316
+ parser.add_argument(
317
+ "--logging_dir",
318
+ type=str,
319
+ default="logs",
320
+ help=(
321
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
322
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
323
+ ),
324
+ )
325
+ parser.add_argument(
326
+ "--allow_tf32",
327
+ action="store_true",
328
+ help=(
329
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
330
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--report_to",
335
+ type=str,
336
+ default="tensorboard",
337
+ help=(
338
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
339
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
340
+ ),
341
+ )
342
+ parser.add_argument(
343
+ "--mixed_precision",
344
+ type=str,
345
+ default=None,
346
+ choices=["no", "fp16", "bf16"],
347
+ help=(
348
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
349
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
350
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
351
+ ),
352
+ )
353
+ parser.add_argument(
354
+ "--prior_generation_precision",
355
+ type=str,
356
+ default=None,
357
+ choices=["no", "fp32", "fp16", "bf16"],
358
+ help=(
359
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
360
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
361
+ ),
362
+ )
363
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
364
+ parser.add_argument(
365
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
366
+ )
367
+ parser.add_argument("--private_repo", action="store_true")
368
+ parser.add_argument("--delete_existing_repo", action="store_true")
369
+ parser.add_argument("--upload_to_lora_library", action="store_true")
370
+
371
+ if input_args is not None:
372
+ args = parser.parse_args(input_args)
373
+ else:
374
+ args = parser.parse_args()
375
+
376
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
377
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
378
+ args.local_rank = env_local_rank
379
+
380
+ if args.with_prior_preservation:
381
+ if args.class_data_dir is None:
382
+ raise ValueError("You must specify a data directory for class images.")
383
+ if args.class_prompt is None:
384
+ raise ValueError("You must specify prompt for class images.")
385
+ else:
386
+ # logger is not available yet
387
+ if args.class_data_dir is not None:
388
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
389
+ if args.class_prompt is not None:
390
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
391
+
392
+ return args
393
+
394
+
395
+ class DreamBoothDataset(Dataset):
396
+ """
397
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
398
+ It pre-processes the images and the tokenizes prompts.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ instance_data_root,
404
+ instance_prompt,
405
+ tokenizer,
406
+ class_data_root=None,
407
+ class_prompt=None,
408
+ size=512,
409
+ center_crop=False,
410
+ ):
411
+ self.size = size
412
+ self.center_crop = center_crop
413
+ self.tokenizer = tokenizer
414
+
415
+ self.instance_data_root = Path(instance_data_root)
416
+ if not self.instance_data_root.exists():
417
+ raise ValueError("Instance images root doesn't exists.")
418
+
419
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
420
+ self.num_instance_images = len(self.instance_images_path)
421
+ self.instance_prompt = instance_prompt
422
+ self._length = self.num_instance_images
423
+
424
+ if class_data_root is not None:
425
+ self.class_data_root = Path(class_data_root)
426
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
427
+ self.class_images_path = list(self.class_data_root.iterdir())
428
+ self.num_class_images = len(self.class_images_path)
429
+ self._length = max(self.num_class_images, self.num_instance_images)
430
+ self.class_prompt = class_prompt
431
+ else:
432
+ self.class_data_root = None
433
+
434
+ self.image_transforms = transforms.Compose(
435
+ [
436
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
437
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
438
+ transforms.ToTensor(),
439
+ transforms.Normalize([0.5], [0.5]),
440
+ ]
441
+ )
442
+
443
+ def __len__(self):
444
+ return self._length
445
+
446
+ def __getitem__(self, index):
447
+ example = {}
448
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
449
+ if not instance_image.mode == "RGB":
450
+ instance_image = instance_image.convert("RGB")
451
+ example["instance_images"] = self.image_transforms(instance_image)
452
+ example["instance_prompt_ids"] = self.tokenizer(
453
+ self.instance_prompt,
454
+ truncation=True,
455
+ padding="max_length",
456
+ max_length=self.tokenizer.model_max_length,
457
+ return_tensors="pt",
458
+ ).input_ids
459
+
460
+ if self.class_data_root:
461
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
462
+ if not class_image.mode == "RGB":
463
+ class_image = class_image.convert("RGB")
464
+ example["class_images"] = self.image_transforms(class_image)
465
+ example["class_prompt_ids"] = self.tokenizer(
466
+ self.class_prompt,
467
+ truncation=True,
468
+ padding="max_length",
469
+ max_length=self.tokenizer.model_max_length,
470
+ return_tensors="pt",
471
+ ).input_ids
472
+
473
+ return example
474
+
475
+
476
+ def collate_fn(examples, with_prior_preservation=False):
477
+ input_ids = [example["instance_prompt_ids"] for example in examples]
478
+ pixel_values = [example["instance_images"] for example in examples]
479
+
480
+ # Concat class and instance examples for prior preservation.
481
+ # We do this to avoid doing two forward passes.
482
+ if with_prior_preservation:
483
+ input_ids += [example["class_prompt_ids"] for example in examples]
484
+ pixel_values += [example["class_images"] for example in examples]
485
+
486
+ pixel_values = torch.stack(pixel_values)
487
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
488
+
489
+ input_ids = torch.cat(input_ids, dim=0)
490
+
491
+ batch = {
492
+ "input_ids": input_ids,
493
+ "pixel_values": pixel_values,
494
+ }
495
+ return batch
496
+
497
+
498
+ class PromptDataset(Dataset):
499
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
500
+
501
+ def __init__(self, prompt, num_samples):
502
+ self.prompt = prompt
503
+ self.num_samples = num_samples
504
+
505
+ def __len__(self):
506
+ return self.num_samples
507
+
508
+ def __getitem__(self, index):
509
+ example = {}
510
+ example["prompt"] = self.prompt
511
+ example["index"] = index
512
+ return example
513
+
514
+
515
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
516
+ if token is None:
517
+ token = HfFolder.get_token()
518
+ if organization is None:
519
+ username = whoami(token)["name"]
520
+ return f"{username}/{model_id}"
521
+ else:
522
+ return f"{organization}/{model_id}"
523
+
524
+
525
+ def main(args):
526
+ logging_dir = Path(args.output_dir, args.logging_dir)
527
+
528
+ accelerator = Accelerator(
529
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
530
+ mixed_precision=args.mixed_precision,
531
+ log_with=args.report_to,
532
+ logging_dir=logging_dir,
533
+ )
534
+
535
+ if args.report_to == "wandb":
536
+ if not is_wandb_available():
537
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
538
+ import wandb
539
+
540
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
541
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
542
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
543
+ # Make one log on every process with the configuration for debugging.
544
+ logging.basicConfig(
545
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
546
+ datefmt="%m/%d/%Y %H:%M:%S",
547
+ level=logging.INFO,
548
+ )
549
+ logger.info(accelerator.state, main_process_only=False)
550
+ if accelerator.is_local_main_process:
551
+ datasets.utils.logging.set_verbosity_warning()
552
+ transformers.utils.logging.set_verbosity_warning()
553
+ diffusers.utils.logging.set_verbosity_info()
554
+ else:
555
+ datasets.utils.logging.set_verbosity_error()
556
+ transformers.utils.logging.set_verbosity_error()
557
+ diffusers.utils.logging.set_verbosity_error()
558
+
559
+ # If passed along, set the training seed now.
560
+ if args.seed is not None:
561
+ set_seed(args.seed)
562
+
563
+ # Generate class images if prior preservation is enabled.
564
+ if args.with_prior_preservation:
565
+ class_images_dir = Path(args.class_data_dir)
566
+ if not class_images_dir.exists():
567
+ class_images_dir.mkdir(parents=True)
568
+ cur_class_images = len(list(class_images_dir.iterdir()))
569
+
570
+ if cur_class_images < args.num_class_images:
571
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
572
+ if args.prior_generation_precision == "fp32":
573
+ torch_dtype = torch.float32
574
+ elif args.prior_generation_precision == "fp16":
575
+ torch_dtype = torch.float16
576
+ elif args.prior_generation_precision == "bf16":
577
+ torch_dtype = torch.bfloat16
578
+ pipeline = DiffusionPipeline.from_pretrained(
579
+ args.pretrained_model_name_or_path,
580
+ torch_dtype=torch_dtype,
581
+ safety_checker=None,
582
+ revision=args.revision,
583
+ )
584
+ pipeline.set_progress_bar_config(disable=True)
585
+
586
+ num_new_images = args.num_class_images - cur_class_images
587
+ logger.info(f"Number of class images to sample: {num_new_images}.")
588
+
589
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
590
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
591
+
592
+ sample_dataloader = accelerator.prepare(sample_dataloader)
593
+ pipeline.to(accelerator.device)
594
+
595
+ for example in tqdm(
596
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
597
+ ):
598
+ images = pipeline(example["prompt"]).images
599
+
600
+ for i, image in enumerate(images):
601
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
602
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
603
+ image.save(image_filename)
604
+
605
+ del pipeline
606
+ if torch.cuda.is_available():
607
+ torch.cuda.empty_cache()
608
+
609
+ # Handle the repository creation
610
+ if accelerator.is_main_process:
611
+ if args.push_to_hub:
612
+ if args.hub_model_id is None:
613
+ organization = 'lora-library' if args.upload_to_lora_library else None
614
+ repo_name = get_full_repo_name(Path(args.output_dir).name, organization=organization, token=args.hub_token)
615
+ else:
616
+ repo_name = args.hub_model_id
617
+
618
+ if args.delete_existing_repo:
619
+ try:
620
+ delete_repo(repo_name, token=args.hub_token)
621
+ except Exception:
622
+ pass
623
+ create_repo(repo_name, token=args.hub_token, private=args.private_repo)
624
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
625
+
626
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
627
+ if "step_*" not in gitignore:
628
+ gitignore.write("step_*\n")
629
+ if "epoch_*" not in gitignore:
630
+ gitignore.write("epoch_*\n")
631
+ elif args.output_dir is not None:
632
+ os.makedirs(args.output_dir, exist_ok=True)
633
+
634
+ # Load the tokenizer
635
+ if args.tokenizer_name:
636
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
637
+ elif args.pretrained_model_name_or_path:
638
+ tokenizer = AutoTokenizer.from_pretrained(
639
+ args.pretrained_model_name_or_path,
640
+ subfolder="tokenizer",
641
+ revision=args.revision,
642
+ use_fast=False,
643
+ )
644
+
645
+ # import correct text encoder class
646
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
647
+
648
+ # Load scheduler and models
649
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
650
+ text_encoder = text_encoder_cls.from_pretrained(
651
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
652
+ )
653
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
654
+ unet = UNet2DConditionModel.from_pretrained(
655
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
656
+ )
657
+
658
+ # We only train the additional adapter LoRA layers
659
+ vae.requires_grad_(False)
660
+ text_encoder.requires_grad_(False)
661
+ unet.requires_grad_(False)
662
+
663
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
664
+ # as these models are only used for inference, keeping weights in full precision is not required.
665
+ weight_dtype = torch.float32
666
+ if accelerator.mixed_precision == "fp16":
667
+ weight_dtype = torch.float16
668
+ elif accelerator.mixed_precision == "bf16":
669
+ weight_dtype = torch.bfloat16
670
+
671
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
672
+ unet.to(accelerator.device, dtype=weight_dtype)
673
+ vae.to(accelerator.device, dtype=weight_dtype)
674
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
675
+
676
+ if args.enable_xformers_memory_efficient_attention:
677
+ if is_xformers_available():
678
+ unet.enable_xformers_memory_efficient_attention()
679
+ else:
680
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
681
+
682
+ # now we will add new LoRA weights to the attention layers
683
+ # It's important to realize here how many attention weights will be added and of which sizes
684
+ # The sizes of the attention layers consist only of two different variables:
685
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
686
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
687
+
688
+ # Let's first see how many attention processors we will have to set.
689
+ # For Stable Diffusion, it should be equal to:
690
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
691
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
692
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
693
+ # => 32 layers
694
+
695
+ # Set correct lora layers
696
+ lora_attn_procs = {}
697
+ for name in unet.attn_processors.keys():
698
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
699
+ if name.startswith("mid_block"):
700
+ hidden_size = unet.config.block_out_channels[-1]
701
+ elif name.startswith("up_blocks"):
702
+ block_id = int(name[len("up_blocks.")])
703
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
704
+ elif name.startswith("down_blocks"):
705
+ block_id = int(name[len("down_blocks.")])
706
+ hidden_size = unet.config.block_out_channels[block_id]
707
+
708
+ lora_attn_procs[name] = LoRACrossAttnProcessor(
709
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
710
+ )
711
+
712
+ unet.set_attn_processor(lora_attn_procs)
713
+ lora_layers = AttnProcsLayers(unet.attn_processors)
714
+
715
+ accelerator.register_for_checkpointing(lora_layers)
716
+
717
+ if args.scale_lr:
718
+ args.learning_rate = (
719
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
720
+ )
721
+
722
+ # Enable TF32 for faster training on Ampere GPUs,
723
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
724
+ if args.allow_tf32:
725
+ torch.backends.cuda.matmul.allow_tf32 = True
726
+
727
+ if args.scale_lr:
728
+ args.learning_rate = (
729
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
730
+ )
731
+
732
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
733
+ if args.use_8bit_adam:
734
+ try:
735
+ import bitsandbytes as bnb
736
+ except ImportError:
737
+ raise ImportError(
738
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
739
+ )
740
+
741
+ optimizer_class = bnb.optim.AdamW8bit
742
+ else:
743
+ optimizer_class = torch.optim.AdamW
744
+
745
+ # Optimizer creation
746
+ optimizer = optimizer_class(
747
+ lora_layers.parameters(),
748
+ lr=args.learning_rate,
749
+ betas=(args.adam_beta1, args.adam_beta2),
750
+ weight_decay=args.adam_weight_decay,
751
+ eps=args.adam_epsilon,
752
+ )
753
+
754
+ # Dataset and DataLoaders creation:
755
+ train_dataset = DreamBoothDataset(
756
+ instance_data_root=args.instance_data_dir,
757
+ instance_prompt=args.instance_prompt,
758
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
759
+ class_prompt=args.class_prompt,
760
+ tokenizer=tokenizer,
761
+ size=args.resolution,
762
+ center_crop=args.center_crop,
763
+ )
764
+
765
+ train_dataloader = torch.utils.data.DataLoader(
766
+ train_dataset,
767
+ batch_size=args.train_batch_size,
768
+ shuffle=True,
769
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
770
+ num_workers=args.dataloader_num_workers,
771
+ )
772
+
773
+ # Scheduler and math around the number of training steps.
774
+ overrode_max_train_steps = False
775
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
776
+ if args.max_train_steps is None:
777
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
778
+ overrode_max_train_steps = True
779
+
780
+ lr_scheduler = get_scheduler(
781
+ args.lr_scheduler,
782
+ optimizer=optimizer,
783
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
784
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
785
+ num_cycles=args.lr_num_cycles,
786
+ power=args.lr_power,
787
+ )
788
+
789
+ # Prepare everything with our `accelerator`.
790
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
791
+ lora_layers, optimizer, train_dataloader, lr_scheduler
792
+ )
793
+
794
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
795
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
796
+ if overrode_max_train_steps:
797
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
798
+ # Afterwards we recalculate our number of training epochs
799
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
800
+
801
+ # We need to initialize the trackers we use, and also store our configuration.
802
+ # The trackers initializes automatically on the main process.
803
+ if accelerator.is_main_process:
804
+ accelerator.init_trackers("dreambooth-lora", config=vars(args))
805
+
806
+ # Train!
807
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
808
+
809
+ logger.info("***** Running training *****")
810
+ logger.info(f" Num examples = {len(train_dataset)}")
811
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
812
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
813
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
814
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
815
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
816
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
817
+ global_step = 0
818
+ first_epoch = 0
819
+
820
+ # Potentially load in the weights and states from a previous save
821
+ if args.resume_from_checkpoint:
822
+ if args.resume_from_checkpoint != "latest":
823
+ path = os.path.basename(args.resume_from_checkpoint)
824
+ else:
825
+ # Get the mos recent checkpoint
826
+ dirs = os.listdir(args.output_dir)
827
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
828
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
829
+ path = dirs[-1]
830
+ accelerator.print(f"Resuming from checkpoint {path}")
831
+ accelerator.load_state(os.path.join(args.output_dir, path))
832
+ global_step = int(path.split("-")[1])
833
+
834
+ resume_global_step = global_step * args.gradient_accumulation_steps
835
+ first_epoch = resume_global_step // num_update_steps_per_epoch
836
+ resume_step = resume_global_step % num_update_steps_per_epoch
837
+
838
+ # Only show the progress bar once on each machine.
839
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
840
+ progress_bar.set_description("Steps")
841
+
842
+ for epoch in range(first_epoch, args.num_train_epochs):
843
+ unet.train()
844
+ for step, batch in enumerate(train_dataloader):
845
+ # Skip steps until we reach the resumed step
846
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
847
+ if step % args.gradient_accumulation_steps == 0:
848
+ progress_bar.update(1)
849
+ continue
850
+
851
+ with accelerator.accumulate(unet):
852
+ # Convert images to latent space
853
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
854
+ latents = latents * 0.18215
855
+
856
+ # Sample noise that we'll add to the latents
857
+ noise = torch.randn_like(latents)
858
+ bsz = latents.shape[0]
859
+ # Sample a random timestep for each image
860
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
861
+ timesteps = timesteps.long()
862
+
863
+ # Add noise to the latents according to the noise magnitude at each timestep
864
+ # (this is the forward diffusion process)
865
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
866
+
867
+ # Get the text embedding for conditioning
868
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
869
+
870
+ # Predict the noise residual
871
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
872
+
873
+ # Get the target for loss depending on the prediction type
874
+ if noise_scheduler.config.prediction_type == "epsilon":
875
+ target = noise
876
+ elif noise_scheduler.config.prediction_type == "v_prediction":
877
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
878
+ else:
879
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
880
+
881
+ if args.with_prior_preservation:
882
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
883
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
884
+ target, target_prior = torch.chunk(target, 2, dim=0)
885
+
886
+ # Compute instance loss
887
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
888
+
889
+ # Compute prior loss
890
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
891
+
892
+ # Add the prior loss to the instance loss.
893
+ loss = loss + args.prior_loss_weight * prior_loss
894
+ else:
895
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
896
+
897
+ accelerator.backward(loss)
898
+ if accelerator.sync_gradients:
899
+ params_to_clip = lora_layers.parameters()
900
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
901
+ optimizer.step()
902
+ lr_scheduler.step()
903
+ optimizer.zero_grad()
904
+
905
+ # Checks if the accelerator has performed an optimization step behind the scenes
906
+ if accelerator.sync_gradients:
907
+ progress_bar.update(1)
908
+ global_step += 1
909
+
910
+ if global_step % args.checkpointing_steps == 0:
911
+ if accelerator.is_main_process:
912
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
913
+ accelerator.save_state(save_path)
914
+ logger.info(f"Saved state to {save_path}")
915
+
916
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
917
+ progress_bar.set_postfix(**logs)
918
+ accelerator.log(logs, step=global_step)
919
+
920
+ if global_step >= args.max_train_steps:
921
+ break
922
+
923
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
924
+ logger.info(
925
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
926
+ f" {args.validation_prompt}."
927
+ )
928
+ # create pipeline
929
+ pipeline = DiffusionPipeline.from_pretrained(
930
+ args.pretrained_model_name_or_path,
931
+ unet=accelerator.unwrap_model(unet),
932
+ text_encoder=accelerator.unwrap_model(text_encoder),
933
+ revision=args.revision,
934
+ torch_dtype=weight_dtype,
935
+ )
936
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
937
+ pipeline = pipeline.to(accelerator.device)
938
+ pipeline.set_progress_bar_config(disable=True)
939
+
940
+ # run inference
941
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
942
+ prompt = args.num_validation_images * [args.validation_prompt]
943
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
944
+
945
+ for tracker in accelerator.trackers:
946
+ if tracker.name == "wandb":
947
+ tracker.log(
948
+ {
949
+ "validation": [
950
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
951
+ for i, image in enumerate(images)
952
+ ]
953
+ }
954
+ )
955
+
956
+ del pipeline
957
+ torch.cuda.empty_cache()
958
+
959
+ # Save the lora layers
960
+ accelerator.wait_for_everyone()
961
+ if accelerator.is_main_process:
962
+ unet = unet.to(torch.float32)
963
+ unet.save_attn_procs(args.output_dir)
964
+
965
+ # Final inference
966
+ # Load previous pipeline
967
+ pipeline = DiffusionPipeline.from_pretrained(
968
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
969
+ )
970
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
971
+ pipeline = pipeline.to(accelerator.device)
972
+
973
+ # load attention processors
974
+ pipeline.unet.load_attn_procs(args.output_dir)
975
+
976
+ # run inference
977
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
978
+ prompt = args.num_validation_images * [args.validation_prompt]
979
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
980
+
981
+ for tracker in accelerator.trackers:
982
+ if tracker.name == "wandb":
983
+ tracker.log(
984
+ {
985
+ "test": [
986
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
987
+ for i, image in enumerate(images)
988
+ ]
989
+ }
990
+ )
991
+
992
+ if args.push_to_hub:
993
+ save_model_card(
994
+ repo_name,
995
+ base_model=args.pretrained_model_name_or_path,
996
+ instance_prompt=args.instance_prompt,
997
+ test_prompt=args.validation_prompt,
998
+ images=images,
999
+ repo_folder=args.output_dir,
1000
+ )
1001
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1002
+ else:
1003
+ repo_name = Path(args.output_dir).name
1004
+ save_model_card(
1005
+ repo_name,
1006
+ base_model=args.pretrained_model_name_or_path,
1007
+ instance_prompt=args.instance_prompt,
1008
+ test_prompt=args.validation_prompt,
1009
+ images=images,
1010
+ repo_folder=args.output_dir,
1011
+ )
1012
+
1013
+ accelerator.end_training()
1014
+
1015
+
1016
+ if __name__ == "__main__":
1017
+ args = parse_args()
1018
+ main(args)
trainer.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import os
4
  import pathlib
5
  import shlex
@@ -8,9 +9,10 @@ import subprocess
8
 
9
  import gradio as gr
10
  import PIL.Image
 
11
  import torch
12
 
13
- os.environ['PYTHONPATH'] = f'lora:{os.getenv("PYTHONPATH", "")}'
14
 
15
 
16
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -28,94 +30,105 @@ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
28
 
29
 
30
  class Trainer:
31
- def __init__(self):
32
- self.is_running = False
33
- self.is_running_message = 'Another training is in progress.'
34
-
35
- self.output_dir = pathlib.Path('results')
36
- self.instance_data_dir = self.output_dir / 'training_data'
37
-
38
- def check_if_running(self) -> dict:
39
- if self.is_running:
40
- return gr.update(value=self.is_running_message)
41
- else:
42
- return gr.update(value='No training is running.')
43
-
44
- def cleanup_dirs(self) -> None:
45
- shutil.rmtree(self.output_dir, ignore_errors=True)
46
-
47
- def prepare_dataset(self, concept_images: list, resolution: int) -> None:
48
- self.instance_data_dir.mkdir(parents=True)
49
- for i, temp_path in enumerate(concept_images):
50
  image = PIL.Image.open(temp_path.name)
51
  image = pad_image(image)
52
  image = image.resize((resolution, resolution))
53
  image = image.convert('RGB')
54
- out_path = self.instance_data_dir / f'{i:03d}.jpg'
55
  image.save(out_path, format='JPEG', quality=100)
56
 
57
  def run(
58
  self,
 
 
 
 
 
59
  base_model: str,
60
  resolution_s: str,
61
- concept_images: list | None,
62
- concept_prompt: str,
63
  n_steps: int,
64
  learning_rate: float,
65
- train_text_encoder: bool,
66
- learning_rate_text: float,
67
  gradient_accumulation: int,
 
68
  fp16: bool,
69
  use_8bit_adam: bool,
70
- ) -> tuple[dict, list[pathlib.Path]]:
 
 
 
 
 
 
 
71
  if not torch.cuda.is_available():
72
  raise gr.Error('CUDA is not available.')
73
-
74
- if self.is_running:
75
- return gr.update(value=self.is_running_message), []
76
-
77
- if concept_images is None:
78
  raise gr.Error('You need to upload images.')
79
- if not concept_prompt:
80
- raise gr.Error('The concept prompt is missing.')
 
 
81
 
82
  resolution = int(resolution_s)
83
 
84
- self.cleanup_dirs()
85
- self.prepare_dataset(concept_images, resolution)
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  command = f'''
88
- accelerate launch lora/train_lora_dreambooth.py \
89
  --pretrained_model_name_or_path={base_model} \
90
- --instance_data_dir={self.instance_data_dir} \
91
- --output_dir={self.output_dir} \
92
- --instance_prompt="{concept_prompt}" \
93
  --resolution={resolution} \
94
  --train_batch_size=1 \
95
  --gradient_accumulation_steps={gradient_accumulation} \
96
  --learning_rate={learning_rate} \
97
  --lr_scheduler=constant \
98
  --lr_warmup_steps=0 \
99
- --max_train_steps={n_steps}
 
 
 
 
100
  '''
101
  if fp16:
102
  command += ' --mixed_precision fp16'
103
  if use_8bit_adam:
104
  command += ' --use_8bit_adam'
105
- if train_text_encoder:
106
- command += f' --train_text_encoder --learning_rate_text={learning_rate_text} --color_jitter'
107
-
108
- with open(self.output_dir / 'train.sh', 'w') as f:
 
 
 
 
 
 
 
 
 
 
 
109
  command_s = ' '.join(command.split())
110
  f.write(command_s)
111
 
112
- self.is_running = True
113
- res = subprocess.run(shlex.split(command))
114
- self.is_running = False
115
-
116
- if res.returncode == 0:
117
- result_message = 'Training Completed!'
118
- else:
119
- result_message = 'Training Failed!'
120
- weight_paths = sorted(self.output_dir.glob('*.pt'))
121
- return gr.update(value=result_message), weight_paths
1
  from __future__ import annotations
2
 
3
+ import datetime
4
  import os
5
  import pathlib
6
  import shlex
9
 
10
  import gradio as gr
11
  import PIL.Image
12
+ import slugify
13
  import torch
14
 
15
+ from constants import UploadTarget
16
 
17
 
18
  def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
30
 
31
 
32
  class Trainer:
33
+ def prepare_dataset(self, instance_images: list, resolution: int,
34
+ instance_data_dir: pathlib.Path) -> None:
35
+ shutil.rmtree(instance_data_dir, ignore_errors=True)
36
+ instance_data_dir.mkdir(parents=True)
37
+ for i, temp_path in enumerate(instance_images):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  image = PIL.Image.open(temp_path.name)
39
  image = pad_image(image)
40
  image = image.resize((resolution, resolution))
41
  image = image.convert('RGB')
42
+ out_path = instance_data_dir / f'{i:03d}.jpg'
43
  image.save(out_path, format='JPEG', quality=100)
44
 
45
  def run(
46
  self,
47
+ instance_images: list | None,
48
+ instance_prompt: str,
49
+ output_model_name: str,
50
+ overwrite_existing_model: bool,
51
+ validation_prompt: str,
52
  base_model: str,
53
  resolution_s: str,
 
 
54
  n_steps: int,
55
  learning_rate: float,
 
 
56
  gradient_accumulation: int,
57
+ seed: int,
58
  fp16: bool,
59
  use_8bit_adam: bool,
60
+ checkpointing_steps: int,
61
+ use_wandb: bool,
62
+ validation_epochs: int,
63
+ upload_to_hub: bool,
64
+ use_private_repo: bool,
65
+ delete_existing_repo: bool,
66
+ upload_to: str,
67
+ ) -> str:
68
  if not torch.cuda.is_available():
69
  raise gr.Error('CUDA is not available.')
70
+ if instance_images is None:
 
 
 
 
71
  raise gr.Error('You need to upload images.')
72
+ if not instance_prompt:
73
+ raise gr.Error('The instance prompt is missing.')
74
+ if not validation_prompt:
75
+ raise gr.Error('The validation prompt is missing.')
76
 
77
  resolution = int(resolution_s)
78
 
79
+ if not output_model_name:
80
+ output_model_name = datetime.datetime.now().strftime(
81
+ '%Y-%m-%d-%H-%M-%S')
82
+ output_model_name = slugify.slugify(output_model_name)
83
+
84
+ repo_dir = pathlib.Path(__file__).parent
85
+ output_dir = repo_dir / 'experiments' / output_model_name
86
+ if overwrite_existing_model or upload_to_hub:
87
+ shutil.rmtree(output_dir, ignore_errors=True)
88
+ if not upload_to_hub:
89
+ output_dir.mkdir(parents=True)
90
+
91
+ instance_data_dir = repo_dir / 'training_data' / output_model_name
92
+ self.prepare_dataset(instance_images, resolution, instance_data_dir)
93
 
94
  command = f'''
95
+ accelerate launch train_dreambooth_lora.py \
96
  --pretrained_model_name_or_path={base_model} \
97
+ --instance_data_dir={instance_data_dir} \
98
+ --output_dir={output_dir} \
99
+ --instance_prompt="{instance_prompt}" \
100
  --resolution={resolution} \
101
  --train_batch_size=1 \
102
  --gradient_accumulation_steps={gradient_accumulation} \
103
  --learning_rate={learning_rate} \
104
  --lr_scheduler=constant \
105
  --lr_warmup_steps=0 \
106
+ --max_train_steps={n_steps} \
107
+ --checkpointing_steps={checkpointing_steps} \
108
+ --validation_prompt="{validation_prompt}" \
109
+ --validation_epochs={validation_epochs} \
110
+ --seed={seed}
111
  '''
112
  if fp16:
113
  command += ' --mixed_precision fp16'
114
  if use_8bit_adam:
115
  command += ' --use_8bit_adam'
116
+ if use_wandb:
117
+ command += ' --report_to wandb'
118
+ if upload_to_hub:
119
+ hf_token = os.getenv('HF_TOKEN')
120
+ command += f' --push_to_hub --hub_token {hf_token}'
121
+ if use_private_repo:
122
+ command += ' --private_repo'
123
+ if delete_existing_repo:
124
+ command += ' --delete_existing_repo'
125
+ if upload_to == UploadTarget.LORA_LIBRARY.value:
126
+ command += ' --upload_to_lora_library'
127
+
128
+ subprocess.run(shlex.split(command))
129
+
130
+ with open(output_dir / 'train.sh', 'w') as f:
131
  command_s = ' '.join(command.split())
132
  f.write(command_s)
133
 
134
+ return 'Training completed!'
 
 
 
 
 
 
 
 
 
uploader.py CHANGED
@@ -1,20 +1,39 @@
1
- import gradio as gr
 
2
  from huggingface_hub import HfApi
3
 
4
 
5
- def upload(model_name: str, hf_token: str) -> None:
6
- api = HfApi(token=hf_token)
7
- user_name = api.whoami()['name']
8
- model_id = f'{user_name}/{model_name}'
9
- try:
10
- api.create_repo(model_id, repo_type='model', private=True)
11
- api.upload_folder(repo_id=model_id,
12
- folder_path='results',
13
- path_in_repo='results',
14
- repo_type='model')
15
- url = f'https://huggingface.co/{model_id}'
16
- message = f'Your model was successfully uploaded to [{url}]({url}).'
17
- except Exception as e:
18
- message = str(e)
19
 
20
- return gr.update(value=message, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
  from huggingface_hub import HfApi
4
 
5
 
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
 
 
 
 
 
 
 
 
12
 
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not organization:
21
+ organization = self.get_username()
22
+ repo_id = f'{organization}/{repo_name}'
23
+ if delete_existing_repo:
24
+ try:
25
+ self.api.delete_repo(repo_id, repo_type=repo_type)
26
+ except Exception:
27
+ pass
28
+ try:
29
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
30
+ self.api.upload_folder(repo_id=repo_id,
31
+ folder_path=folder_path,
32
+ path_in_repo='.',
33
+ repo_type=repo_type)
34
+ url = f'https://huggingface.co/{repo_id}'
35
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
36
+ except Exception as e:
37
+ message = str(e)
38
+ print(message)
39
+ return message
utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+
4
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
5
+ repo_dir = pathlib.Path(__file__).parent
6
+ exp_root_dir = repo_dir / 'experiments'
7
+ if not exp_root_dir.exists():
8
+ return []
9
+ exp_dirs = sorted(exp_root_dir.glob('*'))
10
+ exp_dirs = [
11
+ exp_dir for exp_dir in exp_dirs
12
+ if (exp_dir / 'pytorch_lora_weights.bin').exists()
13
+ ]
14
+ if ignore_repo:
15
+ exp_dirs = [
16
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
17
+ ]
18
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]