thuanz123 commited on
Commit
2815e7c
1 Parent(s): 6d0ef31

Upload 14 files

Browse files
Files changed (14) hide show
  1. LICENSE +21 -0
  2. README.md +5 -4
  3. app.py +75 -0
  4. app_inference.py +164 -0
  5. app_training.py +145 -0
  6. app_upload.py +100 -0
  7. constants.py +6 -0
  8. inference.py +80 -0
  9. requirements.txt +14 -0
  10. style.css +3 -0
  11. train_dreambooth.py +1020 -0
  12. trainer.py +173 -0
  13. uploader.py +42 -0
  14. utils.py +59 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: DreamBooth Training UI
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.50.2
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: DreamBooth Training UI
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
11
  license: mit
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = '# DreamBooth Training UI'
17
+
18
+ ORIGINAL_SPACE_ID = 'dreambooth-library/DreamBooth-Training-UI'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
22
+ '''
23
+
24
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
25
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
26
+ else:
27
+ SETTINGS = 'Settings'
28
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
29
+ <center>
30
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
31
+ "T4 small" is sufficient to run this demo.
32
+ </center>
33
+ '''
34
+
35
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
36
+ <center>
37
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
38
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
39
+ </center>
40
+ '''
41
+
42
+ HF_TOKEN = os.getenv('HF_TOKEN')
43
+
44
+
45
+ def show_warning(warning_text: str) -> gr.Blocks:
46
+ with gr.Blocks() as demo:
47
+ with gr.Box():
48
+ gr.Markdown(warning_text)
49
+ return demo
50
+
51
+
52
+ pipe = InferencePipeline(HF_TOKEN)
53
+ trainer = Trainer(HF_TOKEN)
54
+
55
+ with gr.Blocks(css='style.css') as demo:
56
+ if os.getenv('IS_SHARED_UI'):
57
+ show_warning(SHARED_UI_WARNING)
58
+ if not torch.cuda.is_available():
59
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
60
+ if not HF_TOKEN:
61
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
62
+
63
+ gr.Markdown(TITLE)
64
+ with gr.Tabs():
65
+ with gr.TabItem('Train'):
66
+ create_training_demo(trainer, pipe)
67
+ with gr.TabItem('Test'):
68
+ create_inference_demo(pipe, HF_TOKEN)
69
+ with gr.TabItem('Upload'):
70
+ gr.Markdown('''
71
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
72
+ ''')
73
+ create_upload_demo(HF_TOKEN)
74
+
75
+ demo.queue(max_size=1).launch(share=True)
app_inference.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = [
14
+ 'koala2/dreambooth-dog-v2',
15
+ 'lambdalabs/dreambooth-avatar',
16
+ ]
17
+
18
+
19
+ class ModelSource(enum.Enum):
20
+ SAMPLE = 'Sample'
21
+ HUB_LIB = 'Hub (dreambooth-library)'
22
+ LOCAL = 'Local'
23
+
24
+
25
+ class InferenceUtil:
26
+ def __init__(self, hf_token: str | None):
27
+ self.hf_token = hf_token
28
+
29
+ @staticmethod
30
+ def load_sample_model_list():
31
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
+
33
+ def load_hub_model_list(self) -> dict:
34
+ api = HfApi(token=self.hf_token)
35
+ choices = [
36
+ info.modelId for info in api.list_models(author='dreambooth-library')
37
+ ]
38
+ return gr.update(choices=choices,
39
+ value=choices[0] if choices else None)
40
+
41
+ @staticmethod
42
+ def load_local_model_list() -> dict:
43
+ choices = find_exp_dirs()
44
+ return gr.update(choices=choices,
45
+ value=choices[0] if choices else None)
46
+
47
+ def reload_model_list(self, model_source: str) -> dict:
48
+ if model_source == ModelSource.SAMPLE.value:
49
+ return self.load_sample_model_list()
50
+ elif model_source == ModelSource.HUB_LIB.value:
51
+ return self.load_hub_model_list()
52
+ elif model_source == ModelSource.LOCAL.value:
53
+ return self.load_local_model_list()
54
+ else:
55
+ raise ValueError
56
+
57
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
58
+ try:
59
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
60
+ except Exception:
61
+ return ''
62
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
63
+ return instance_prompt
64
+
65
+ def reload_model_list_and_update_model_info(
66
+ self, model_source: str
67
+ ) -> tuple[dict, str, str]:
68
+ model_list_update = self.reload_model_list(model_source)
69
+ model_list = model_list_update['choices']
70
+ model_info = self.load_model_info(model_list[0] if model_list else '')
71
+ return model_list_update, *model_info
72
+
73
+
74
+ def create_inference_demo(pipe: InferencePipeline,
75
+ hf_token: str | None = None) -> gr.Blocks:
76
+ app = InferenceUtil(hf_token)
77
+
78
+ with gr.Blocks() as demo:
79
+ with gr.Row():
80
+ with gr.Column():
81
+ with gr.Box():
82
+ model_source = gr.Radio(
83
+ label='Model Source',
84
+ choices=[_.value for _ in ModelSource],
85
+ value=ModelSource.SAMPLE.value)
86
+ reload_button = gr.Button('Reload Model List')
87
+ model_id = gr.Dropdown(label='Model ID',
88
+ choices=SAMPLE_MODEL_IDS,
89
+ value=SAMPLE_MODEL_IDS[0])
90
+ with gr.Accordion(
91
+ label=
92
+ 'Model info (Base model and instance prompt used for training)',
93
+ open=False):
94
+ with gr.Row():
95
+ instance_prompt_used_for_training = gr.Text(
96
+ label='Instance prompt', interactive=False)
97
+ prompt = gr.Textbox(
98
+ label='Prompt',
99
+ max_lines=1,
100
+ placeholder='Example: "A picture of a {}dog in a bucket"'
101
+ )
102
+ seed = gr.Slider(label='Seed',
103
+ minimum=0,
104
+ maximum=100000,
105
+ step=1,
106
+ value=0)
107
+ with gr.Accordion('Other Parameters', open=False):
108
+ num_steps = gr.Slider(label='Number of Steps',
109
+ minimum=0,
110
+ maximum=100,
111
+ step=1,
112
+ value=25)
113
+ guidance_scale = gr.Slider(label='CFG Scale',
114
+ minimum=0,
115
+ maximum=50,
116
+ step=0.1,
117
+ value=7.5)
118
+
119
+ run_button = gr.Button('Generate')
120
+
121
+ gr.Markdown('''
122
+ - After training, you can press "Reload Model List" button to load your trained model names.
123
+ ''')
124
+ with gr.Column():
125
+ result = gr.Image(label='Result')
126
+
127
+ model_source.change(
128
+ fn=app.reload_model_list_and_update_model_info,
129
+ inputs=model_source,
130
+ outputs=[
131
+ model_id,
132
+ instance_prompt_used_for_training,
133
+ ])
134
+ reload_button.click(
135
+ fn=app.reload_model_list_and_update_model_info,
136
+ inputs=model_source,
137
+ outputs=[
138
+ model_id,
139
+ instance_prompt_used_for_training,
140
+ ])
141
+ model_id.change(fn=app.load_model_info,
142
+ inputs=model_id,
143
+ outputs=[
144
+ instance_prompt_used_for_training,
145
+ ])
146
+ inputs = [
147
+ model_id,
148
+ prompt,
149
+ seed,
150
+ num_steps,
151
+ guidance_scale,
152
+ ]
153
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
154
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
155
+ return demo
156
+
157
+
158
+ if __name__ == '__main__':
159
+ import os
160
+
161
+ hf_token = os.getenv('HF_TOKEN')
162
+ pipe = InferencePipeline(hf_token)
163
+ demo = create_inference_demo(pipe, hf_token)
164
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ instance_images = gr.Files(label='Instance images')
22
+ training_prompt = gr.Textbox(label='Training prompt', max_lines=1,
23
+ placeholder='Example: "A photo of a {}dog"')
24
+ gr.Markdown('''
25
+ - Upload images of the style you are planning on training on.
26
+ - For a training prompt, please refer to the example.
27
+ ''')
28
+ with gr.Box():
29
+ gr.Markdown('Output Model')
30
+ output_model_name = gr.Text(label='Name of your model',
31
+ max_lines=1)
32
+ delete_existing_model = gr.Checkbox(
33
+ label='Delete existing model of the same name',
34
+ value=False)
35
+ validation_prompt = gr.Text(label='Validation Prompt',
36
+ placeholder='Example: "A photo of a {}dog in the bucket"')
37
+ with gr.Box():
38
+ gr.Markdown('Upload Settings')
39
+ with gr.Row():
40
+ upload_to_hub = gr.Checkbox(
41
+ label='Upload model to Hub', value=True)
42
+ use_private_repo = gr.Checkbox(label='Private',
43
+ value=True)
44
+ delete_existing_repo = gr.Checkbox(
45
+ label='Delete existing repo of the same name',
46
+ value=False)
47
+ upload_to = gr.Radio(
48
+ label='Upload to',
49
+ choices=[_.value for _ in UploadTarget],
50
+ value=UploadTarget.DREAMBOOTH_LIBRARY.value)
51
+ gr.Markdown('''
52
+ - By default, trained models will be uploaded to [DreamBooth Library](https://huggingface.co/dreambooth-library).
53
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
54
+ ''')
55
+
56
+ with gr.Box():
57
+ gr.Markdown('Training Parameters')
58
+ with gr.Row():
59
+ base_model = gr.Text(
60
+ label='Base Model',
61
+ value='stabilityai/stable-diffusion-2-1-base',
62
+ max_lines=1)
63
+ resolution = gr.Dropdown(choices=['512', '768'],
64
+ value='512',
65
+ label='Resolution')
66
+ num_training_steps = gr.Number(
67
+ label='Number of Training Steps', value=1000, precision=0)
68
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
69
+ gradient_accumulation = gr.Number(
70
+ label='Number of Gradient Accumulation',
71
+ value=1,
72
+ precision=0)
73
+ seed = gr.Slider(label='Seed',
74
+ minimum=0,
75
+ maximum=100000,
76
+ step=1,
77
+ value=0)
78
+ fp16 = gr.Checkbox(label='FP16', value=True)
79
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
80
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
81
+ value=100,
82
+ precision=0)
83
+ use_wandb = gr.Checkbox(label='Use W&B',
84
+ value=False,
85
+ interactive=bool(
86
+ os.getenv('WANDB_API_KEY')))
87
+ validation_epochs = gr.Number(label='Validation Epochs',
88
+ value=100,
89
+ precision=0)
90
+ gr.Markdown('''
91
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
92
+ - It takes a few minutes to download the base model first.
93
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
94
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
95
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
96
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
97
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
98
+ ''')
99
+
100
+ remove_gpu_after_training = gr.Checkbox(
101
+ label='Remove GPU after training',
102
+ value=False,
103
+ interactive=bool(os.getenv('SPACE_ID')),
104
+ visible=False)
105
+ run_button = gr.Button('Start Training')
106
+
107
+ with gr.Box():
108
+ gr.Markdown('Output message')
109
+ output_message = gr.Markdown()
110
+
111
+ if pipe is not None:
112
+ run_button.click(fn=pipe.clear)
113
+ run_button.click(fn=trainer.run,
114
+ inputs=[
115
+ instance_images,
116
+ training_prompt,
117
+ output_model_name,
118
+ delete_existing_model,
119
+ validation_prompt,
120
+ base_model,
121
+ resolution,
122
+ num_training_steps,
123
+ learning_rate,
124
+ gradient_accumulation,
125
+ seed,
126
+ fp16,
127
+ use_8bit_adam,
128
+ checkpointing_steps,
129
+ use_wandb,
130
+ validation_epochs,
131
+ upload_to_hub,
132
+ use_private_repo,
133
+ delete_existing_repo,
134
+ upload_to,
135
+ remove_gpu_after_training,
136
+ ],
137
+ outputs=output_message)
138
+ return demo
139
+
140
+
141
+ if __name__ == '__main__':
142
+ hf_token = os.getenv('HF_TOKEN')
143
+ trainer = Trainer(hf_token)
144
+ demo = create_training_demo(trainer)
145
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.DREAMBOOTH_LIBRARY.value:
33
+ organization = 'dreambooth-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.DREAMBOOTH_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [DreamBooth Concepts Library](https://huggingface.co/dreambooth-library) (i.e. https://huggingface.co/dreambooth-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ DREAMBOOTH_LIBRARY = 'DreamBooth Library'
inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+
12
+
13
+ class InferencePipeline:
14
+ def __init__(self, hf_token: str | None = None):
15
+ self.hf_token = hf_token
16
+ self.pipe = None
17
+ self.device = torch.device(
18
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
19
+ self.model_id = None
20
+
21
+ def clear(self) -> None:
22
+ self.model_id = None
23
+ del self.pipe
24
+ self.pipe = None
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ @staticmethod
29
+ def check_if_model_is_local(model_id: str) -> bool:
30
+ return pathlib.Path(model_id).exists()
31
+
32
+ @staticmethod
33
+ def get_model_card(model_id: str,
34
+ hf_token: str | None = None) -> ModelCard:
35
+ if InferencePipeline.check_if_model_is_local(model_id):
36
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
37
+ else:
38
+ card_path = model_id
39
+ return ModelCard.load(card_path, token=hf_token)
40
+
41
+ def load_pipe(self, model_id: str) -> None:
42
+ if model_id == self.model_id:
43
+ return
44
+
45
+ if self.device.type == 'cpu':
46
+ pipe = DiffusionPipeline.from_pretrained(
47
+ model_id, use_auth_token=self.hf_token)
48
+ else:
49
+ pipe = DiffusionPipeline.from_pretrained(
50
+ model_id, torch_dtype=torch.float16,
51
+ use_auth_token=self.hf_token)
52
+ pipe = pipe.to(self.device)
53
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
54
+ pipe.scheduler.config)
55
+ self.pipe = pipe
56
+
57
+ pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
58
+ self.model_id = model_id # type: ignore
59
+
60
+ def run(
61
+ self,
62
+ model_id: str,
63
+ prompt: str,
64
+ seed: int,
65
+ n_steps: int,
66
+ guidance_scale: float,
67
+ ) -> PIL.Image.Image:
68
+ if not torch.cuda.is_available():
69
+ raise gr.Error('CUDA is not available.')
70
+
71
+ self.load_pipe(model_id)
72
+
73
+ generator = torch.Generator(device=self.device).manual_seed(seed)
74
+ out = self.pipe(
75
+ prompt.format("sks "),
76
+ num_inference_steps=n_steps,
77
+ guidance_scale=guidance_scale,
78
+ generator=generator,
79
+ ) # type: ignore
80
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.36.0.post2
3
+ datasets==2.8.0
4
+ git+https://github.com/huggingface/diffusers@31be42209ddfdb69d9640a777b32e9b5c6259bf0#egg=diffusers
5
+ ftfy==6.1.1
6
+ gradio==3.16.2
7
+ huggingface-hub==0.12.0
8
+ Pillow==9.4.0
9
+ python-slugify==7.0.0
10
+ tensorboard==2.11.2
11
+ torch==1.13.1
12
+ torchvision==0.14.1
13
+ transformers==4.26.0
14
+ wandb==0.13.9
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_dreambooth.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ #
4
+ # This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth.py
5
+ # The original license is as below:
6
+ #
7
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+
20
+ import argparse
21
+ import hashlib
22
+ import itertools
23
+ import logging
24
+ import math
25
+ import os
26
+ import warnings
27
+ from pathlib import Path
28
+ from typing import Optional
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+ from torch.utils.data import Dataset
35
+
36
+ import diffusers
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import set_seed
41
+ from diffusers import (
42
+ AutoencoderKL,
43
+ DDPMScheduler,
44
+ DiffusionPipeline,
45
+ DPMSolverMultistepScheduler,
46
+ UNet2DConditionModel,
47
+ )
48
+ from diffusers.optimization import get_scheduler
49
+ from diffusers.utils import check_min_version, is_wandb_available
50
+ from diffusers.utils.import_utils import is_xformers_available
51
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
52
+ from PIL import Image
53
+ from torchvision import transforms
54
+ from tqdm.auto import tqdm
55
+ from transformers import AutoTokenizer, PretrainedConfig
56
+
57
+
58
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
+ check_min_version("0.12.0.dev0")
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
65
+ img_str = ""
66
+ for i, image in enumerate(images):
67
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
68
+ img_str += f"![img_{i}](./image_{i}.png)\n"
69
+
70
+ yaml = f"""
71
+ ---
72
+ license: creativeml-openrail-m
73
+ base_model: {base_model}
74
+ tags:
75
+ - stable-diffusion
76
+ - stable-diffusion-diffusers
77
+ - text-to-image
78
+ - diffusers
79
+ - dreambooth
80
+ inference: true
81
+ ---
82
+ """
83
+ model_card = f"""
84
+ # DreamBooth - {repo_name}
85
+
86
+ These are DreamBooth weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
87
+ {img_str}
88
+ """
89
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
90
+ f.write(yaml + model_card)
91
+
92
+
93
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
94
+ text_encoder_config = PretrainedConfig.from_pretrained(
95
+ pretrained_model_name_or_path,
96
+ subfolder="text_encoder",
97
+ revision=revision,
98
+ )
99
+ model_class = text_encoder_config.architectures[0]
100
+
101
+ if model_class == "CLIPTextModel":
102
+ from transformers import CLIPTextModel
103
+
104
+ return CLIPTextModel
105
+ elif model_class == "RobertaSeriesModelWithTransformation":
106
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
107
+
108
+ return RobertaSeriesModelWithTransformation
109
+ else:
110
+ raise ValueError(f"{model_class} is not supported.")
111
+
112
+
113
+ def parse_args(input_args=None):
114
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
115
+ parser.add_argument(
116
+ "--pretrained_model_name_or_path",
117
+ type=str,
118
+ default=None,
119
+ required=True,
120
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
121
+ )
122
+ parser.add_argument(
123
+ "--revision",
124
+ type=str,
125
+ default=None,
126
+ required=False,
127
+ help="Revision of pretrained model identifier from huggingface.co/models.",
128
+ )
129
+ parser.add_argument(
130
+ "--tokenizer_name",
131
+ type=str,
132
+ default=None,
133
+ help="Pretrained tokenizer name or path if not the same as model_name",
134
+ )
135
+ parser.add_argument(
136
+ "--instance_data_dir",
137
+ type=str,
138
+ default=None,
139
+ required=True,
140
+ help="A folder containing the training data of instance images.",
141
+ )
142
+ parser.add_argument(
143
+ "--class_data_dir",
144
+ type=str,
145
+ default=None,
146
+ required=False,
147
+ help="A folder containing the training data of class images.",
148
+ )
149
+ parser.add_argument(
150
+ "--instance_prompt",
151
+ type=str,
152
+ default=None,
153
+ required=True,
154
+ help="The prompt with identifier specifying the instance",
155
+ )
156
+ parser.add_argument(
157
+ "--class_prompt",
158
+ type=str,
159
+ default=None,
160
+ help="The prompt to specify images in the same class as provided instance images.",
161
+ )
162
+ parser.add_argument(
163
+ "--validation_prompt",
164
+ type=str,
165
+ default=None,
166
+ help="A prompt that is used during validation to verify that the model is learning.",
167
+ )
168
+ parser.add_argument(
169
+ "--num_validation_images",
170
+ type=int,
171
+ default=4,
172
+ help="Number of images that should be generated during validation with `validation_prompt`.",
173
+ )
174
+ parser.add_argument(
175
+ "--validation_epochs",
176
+ type=int,
177
+ default=50,
178
+ help=(
179
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
180
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
181
+ ),
182
+ )
183
+ parser.add_argument(
184
+ "--with_prior_preservation",
185
+ default=False,
186
+ action="store_true",
187
+ help="Flag to add prior preservation loss.",
188
+ )
189
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
190
+ parser.add_argument(
191
+ "--num_class_images",
192
+ type=int,
193
+ default=100,
194
+ help=(
195
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
196
+ " class_data_dir, additional images will be sampled with class_prompt."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--output_dir",
201
+ type=str,
202
+ default="dreambooth-model",
203
+ help="The output directory where the model predictions and checkpoints will be written.",
204
+ )
205
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
206
+ parser.add_argument(
207
+ "--resolution",
208
+ type=int,
209
+ default=512,
210
+ help=(
211
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
212
+ " resolution"
213
+ ),
214
+ )
215
+ parser.add_argument(
216
+ "--center_crop",
217
+ default=False,
218
+ action="store_true",
219
+ help=(
220
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
221
+ " cropped. The images will be resized to the resolution first before cropping."
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--train_text_encoder",
226
+ action="store_true",
227
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
228
+ )
229
+ parser.add_argument(
230
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
231
+ )
232
+ parser.add_argument(
233
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
234
+ )
235
+ parser.add_argument("--num_train_epochs", type=int, default=1)
236
+ parser.add_argument(
237
+ "--max_train_steps",
238
+ type=int,
239
+ default=None,
240
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
241
+ )
242
+ parser.add_argument(
243
+ "--checkpointing_steps",
244
+ type=int,
245
+ default=500,
246
+ help=(
247
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
248
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
249
+ " training using `--resume_from_checkpoint`."
250
+ ),
251
+ )
252
+ parser.add_argument(
253
+ "--resume_from_checkpoint",
254
+ type=str,
255
+ default=None,
256
+ help=(
257
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
258
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
259
+ ),
260
+ )
261
+ parser.add_argument(
262
+ "--gradient_accumulation_steps",
263
+ type=int,
264
+ default=1,
265
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
266
+ )
267
+ parser.add_argument(
268
+ "--gradient_checkpointing",
269
+ action="store_true",
270
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
271
+ )
272
+ parser.add_argument(
273
+ "--learning_rate",
274
+ type=float,
275
+ default=5e-4,
276
+ help="Initial learning rate (after the potential warmup period) to use.",
277
+ )
278
+ parser.add_argument(
279
+ "--scale_lr",
280
+ action="store_true",
281
+ default=False,
282
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
283
+ )
284
+ parser.add_argument(
285
+ "--lr_scheduler",
286
+ type=str,
287
+ default="constant",
288
+ help=(
289
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
290
+ ' "constant", "constant_with_warmup"]'
291
+ ),
292
+ )
293
+ parser.add_argument(
294
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
295
+ )
296
+ parser.add_argument(
297
+ "--lr_num_cycles",
298
+ type=int,
299
+ default=1,
300
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
301
+ )
302
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
303
+ parser.add_argument(
304
+ "--dataloader_num_workers",
305
+ type=int,
306
+ default=0,
307
+ help=(
308
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
309
+ ),
310
+ )
311
+ parser.add_argument(
312
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
313
+ )
314
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
315
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
316
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
317
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
318
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
319
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
320
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
321
+ parser.add_argument(
322
+ "--hub_model_id",
323
+ type=str,
324
+ default=None,
325
+ help="The name of the repository to keep in sync with the local `output_dir`.",
326
+ )
327
+ parser.add_argument(
328
+ "--logging_dir",
329
+ type=str,
330
+ default="logs",
331
+ help=(
332
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
333
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
334
+ ),
335
+ )
336
+ parser.add_argument(
337
+ "--allow_tf32",
338
+ action="store_true",
339
+ help=(
340
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
341
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
342
+ ),
343
+ )
344
+ parser.add_argument(
345
+ "--report_to",
346
+ type=str,
347
+ default="tensorboard",
348
+ help=(
349
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
350
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
351
+ ),
352
+ )
353
+ parser.add_argument(
354
+ "--mixed_precision",
355
+ type=str,
356
+ default=None,
357
+ choices=["no", "fp16", "bf16"],
358
+ help=(
359
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
360
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
361
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
362
+ ),
363
+ )
364
+ parser.add_argument(
365
+ "--prior_generation_precision",
366
+ type=str,
367
+ default=None,
368
+ choices=["no", "fp32", "fp16", "bf16"],
369
+ help=(
370
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
371
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
372
+ ),
373
+ )
374
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
375
+ parser.add_argument(
376
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
377
+ )
378
+ parser.add_argument(
379
+ "--set_grads_to_none",
380
+ action="store_true",
381
+ help=(
382
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
383
+ " behaviors, so disable this argument if it causes any problems. More info:"
384
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
385
+ ),
386
+ )
387
+
388
+ if input_args is not None:
389
+ args = parser.parse_args(input_args)
390
+ else:
391
+ args = parser.parse_args()
392
+
393
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
394
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
395
+ args.local_rank = env_local_rank
396
+
397
+ if args.with_prior_preservation:
398
+ if args.class_data_dir is None:
399
+ raise ValueError("You must specify a data directory for class images.")
400
+ if args.class_prompt is None:
401
+ raise ValueError("You must specify prompt for class images.")
402
+ else:
403
+ # logger is not available yet
404
+ if args.class_data_dir is not None:
405
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
406
+ if args.class_prompt is not None:
407
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
408
+
409
+ return args
410
+
411
+
412
+ class DreamBoothDataset(Dataset):
413
+ """
414
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
415
+ It pre-processes the images and the tokenizes prompts.
416
+ """
417
+
418
+ def __init__(
419
+ self,
420
+ instance_data_root,
421
+ instance_prompt,
422
+ tokenizer,
423
+ class_data_root=None,
424
+ class_prompt=None,
425
+ size=512,
426
+ center_crop=False,
427
+ ):
428
+ self.size = size
429
+ self.center_crop = center_crop
430
+ self.tokenizer = tokenizer
431
+
432
+ self.instance_data_root = Path(instance_data_root)
433
+ if not self.instance_data_root.exists():
434
+ raise ValueError("Instance images root doesn't exists.")
435
+
436
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
437
+ self.num_instance_images = len(self.instance_images_path)
438
+ self.instance_prompt = instance_prompt
439
+ self._length = self.num_instance_images
440
+
441
+ if class_data_root is not None:
442
+ self.class_data_root = Path(class_data_root)
443
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
444
+ self.class_images_path = list(self.class_data_root.iterdir())
445
+ self.num_class_images = len(self.class_images_path)
446
+ self._length = max(self.num_class_images, self.num_instance_images)
447
+ self.class_prompt = class_prompt
448
+ else:
449
+ self.class_data_root = None
450
+
451
+ self.image_transforms = transforms.Compose(
452
+ [
453
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
454
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
455
+ transforms.ToTensor(),
456
+ transforms.Normalize([0.5], [0.5]),
457
+ ]
458
+ )
459
+
460
+ def __len__(self):
461
+ return self._length
462
+
463
+ def __getitem__(self, index):
464
+ example = {}
465
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
466
+ if not instance_image.mode == "RGB":
467
+ instance_image = instance_image.convert("RGB")
468
+ example["instance_images"] = self.image_transforms(instance_image)
469
+ example["instance_prompt_ids"] = self.tokenizer(
470
+ self.instance_prompt,
471
+ truncation=True,
472
+ padding="max_length",
473
+ max_length=self.tokenizer.model_max_length,
474
+ return_tensors="pt",
475
+ ).input_ids
476
+
477
+ if self.class_data_root:
478
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
479
+ if not class_image.mode == "RGB":
480
+ class_image = class_image.convert("RGB")
481
+ example["class_images"] = self.image_transforms(class_image)
482
+ example["class_prompt_ids"] = self.tokenizer(
483
+ self.class_prompt,
484
+ truncation=True,
485
+ padding="max_length",
486
+ max_length=self.tokenizer.model_max_length,
487
+ return_tensors="pt",
488
+ ).input_ids
489
+
490
+ return example
491
+
492
+
493
+ def collate_fn(examples, with_prior_preservation=False):
494
+ input_ids = [example["instance_prompt_ids"] for example in examples]
495
+ pixel_values = [example["instance_images"] for example in examples]
496
+
497
+ # Concat class and instance examples for prior preservation.
498
+ # We do this to avoid doing two forward passes.
499
+ if with_prior_preservation:
500
+ input_ids += [example["class_prompt_ids"] for example in examples]
501
+ pixel_values += [example["class_images"] for example in examples]
502
+
503
+ pixel_values = torch.stack(pixel_values)
504
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
505
+
506
+ input_ids = torch.cat(input_ids, dim=0)
507
+
508
+ batch = {
509
+ "input_ids": input_ids,
510
+ "pixel_values": pixel_values,
511
+ }
512
+ return batch
513
+
514
+
515
+ class PromptDataset(Dataset):
516
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
517
+
518
+ def __init__(self, prompt, num_samples):
519
+ self.prompt = prompt
520
+ self.num_samples = num_samples
521
+
522
+ def __len__(self):
523
+ return self.num_samples
524
+
525
+ def __getitem__(self, index):
526
+ example = {}
527
+ example["prompt"] = self.prompt
528
+ example["index"] = index
529
+ return example
530
+
531
+
532
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
533
+ if token is None:
534
+ token = HfFolder.get_token()
535
+ if organization is None:
536
+ username = whoami(token)["name"]
537
+ return f"{username}/{model_id}"
538
+ else:
539
+ return f"{organization}/{model_id}"
540
+
541
+
542
+ def main(args):
543
+ logging_dir = Path(args.output_dir, args.logging_dir)
544
+
545
+ accelerator = Accelerator(
546
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
547
+ mixed_precision=args.mixed_precision,
548
+ log_with=args.report_to,
549
+ logging_dir=logging_dir,
550
+ )
551
+
552
+ if args.report_to == "wandb":
553
+ if not is_wandb_available():
554
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
555
+ import wandb
556
+
557
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
558
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
559
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
560
+ # Make one log on every process with the configuration for debugging.
561
+ logging.basicConfig(
562
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
563
+ datefmt="%m/%d/%Y %H:%M:%S",
564
+ level=logging.INFO,
565
+ )
566
+ logger.info(accelerator.state, main_process_only=False)
567
+ if accelerator.is_local_main_process:
568
+ transformers.utils.logging.set_verbosity_warning()
569
+ diffusers.utils.logging.set_verbosity_info()
570
+ else:
571
+ transformers.utils.logging.set_verbosity_error()
572
+ diffusers.utils.logging.set_verbosity_error()
573
+
574
+ # If passed along, set the training seed now.
575
+ if args.seed is not None:
576
+ set_seed(args.seed)
577
+
578
+ # Generate class images if prior preservation is enabled.
579
+ if args.with_prior_preservation:
580
+ class_images_dir = Path(args.class_data_dir)
581
+ if not class_images_dir.exists():
582
+ class_images_dir.mkdir(parents=True)
583
+ cur_class_images = len(list(class_images_dir.iterdir()))
584
+
585
+ if cur_class_images < args.num_class_images:
586
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
587
+ if args.prior_generation_precision == "fp32":
588
+ torch_dtype = torch.float32
589
+ elif args.prior_generation_precision == "fp16":
590
+ torch_dtype = torch.float16
591
+ elif args.prior_generation_precision == "bf16":
592
+ torch_dtype = torch.bfloat16
593
+ pipeline = DiffusionPipeline.from_pretrained(
594
+ args.pretrained_model_name_or_path,
595
+ torch_dtype=torch_dtype,
596
+ safety_checker=None,
597
+ revision=args.revision,
598
+ )
599
+ pipeline.set_progress_bar_config(disable=True)
600
+
601
+ num_new_images = args.num_class_images - cur_class_images
602
+ logger.info(f"Number of class images to sample: {num_new_images}.")
603
+
604
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
605
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
606
+
607
+ sample_dataloader = accelerator.prepare(sample_dataloader)
608
+ pipeline.to(accelerator.device)
609
+
610
+ for example in tqdm(
611
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
612
+ ):
613
+ images = pipeline(example["prompt"]).images
614
+
615
+ for i, image in enumerate(images):
616
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
617
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
618
+ image.save(image_filename)
619
+
620
+ del pipeline
621
+ if torch.cuda.is_available():
622
+ torch.cuda.empty_cache()
623
+
624
+ # Handle the repository creation
625
+ if accelerator.is_main_process:
626
+ if args.push_to_hub:
627
+ if args.hub_model_id is None:
628
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
629
+ else:
630
+ repo_name = args.hub_model_id
631
+
632
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
633
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
634
+
635
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
636
+ if "step_*" not in gitignore:
637
+ gitignore.write("step_*\n")
638
+ if "epoch_*" not in gitignore:
639
+ gitignore.write("epoch_*\n")
640
+ elif args.output_dir is not None:
641
+ os.makedirs(args.output_dir, exist_ok=True)
642
+
643
+ # Load the tokenizer
644
+ if args.tokenizer_name:
645
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
646
+ elif args.pretrained_model_name_or_path:
647
+ tokenizer = AutoTokenizer.from_pretrained(
648
+ args.pretrained_model_name_or_path,
649
+ subfolder="tokenizer",
650
+ revision=args.revision,
651
+ use_fast=False,
652
+ )
653
+
654
+ # import correct text encoder class
655
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
656
+
657
+ # Load scheduler and models
658
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
659
+ text_encoder = text_encoder_cls.from_pretrained(
660
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
661
+ )
662
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
663
+ unet = UNet2DConditionModel.from_pretrained(
664
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
665
+ )
666
+
667
+ # Freeze vae and (optional) text_encoder
668
+ vae.requires_grad_(False)
669
+ if not args.train_text_encoder:
670
+ text_encoder.requires_grad_(False)
671
+
672
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
673
+ # as these models are only used for inference, keeping weights in full precision is not required.
674
+ weight_dtype = torch.float32
675
+ if accelerator.mixed_precision == "fp16":
676
+ weight_dtype = torch.float16
677
+ elif accelerator.mixed_precision == "bf16":
678
+ weight_dtype = torch.bfloat16
679
+
680
+ # Move vae and (optional) text_encoder to device and cast to weight_dtype
681
+ vae.to(accelerator.device, dtype=weight_dtype)
682
+ if not args.train_text_encoder:
683
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
684
+
685
+ if args.enable_xformers_memory_efficient_attention:
686
+ if is_xformers_available():
687
+ unet.enable_xformers_memory_efficient_attention()
688
+ else:
689
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
690
+
691
+ if args.gradient_checkpointing:
692
+ unet.enable_gradient_checkpointing()
693
+ if args.train_text_encoder:
694
+ text_encoder.gradient_checkpointing_enable()
695
+
696
+ if args.scale_lr:
697
+ args.learning_rate = (
698
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
699
+ )
700
+
701
+ # Enable TF32 for faster training on Ampere GPUs,
702
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
703
+ if args.allow_tf32:
704
+ torch.backends.cuda.matmul.allow_tf32 = True
705
+
706
+ if args.scale_lr:
707
+ args.learning_rate = (
708
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
709
+ )
710
+
711
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
712
+ if args.use_8bit_adam:
713
+ try:
714
+ import bitsandbytes as bnb
715
+ except ImportError:
716
+ raise ImportError(
717
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
718
+ )
719
+
720
+ optimizer_class = bnb.optim.AdamW8bit
721
+ else:
722
+ optimizer_class = torch.optim.AdamW
723
+
724
+ # Optimizer creation
725
+ params_to_optimize = (
726
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
727
+ )
728
+ optimizer = optimizer_class(
729
+ params_to_optimize,
730
+ lr=args.learning_rate,
731
+ betas=(args.adam_beta1, args.adam_beta2),
732
+ weight_decay=args.adam_weight_decay,
733
+ eps=args.adam_epsilon,
734
+ )
735
+
736
+ # Dataset and DataLoaders creation:
737
+ train_dataset = DreamBoothDataset(
738
+ instance_data_root=args.instance_data_dir,
739
+ instance_prompt=args.instance_prompt,
740
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
741
+ class_prompt=args.class_prompt,
742
+ tokenizer=tokenizer,
743
+ size=args.resolution,
744
+ center_crop=args.center_crop,
745
+ )
746
+
747
+ train_dataloader = torch.utils.data.DataLoader(
748
+ train_dataset,
749
+ batch_size=args.train_batch_size,
750
+ shuffle=True,
751
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
752
+ num_workers=args.dataloader_num_workers,
753
+ )
754
+
755
+ # Scheduler and math around the number of training steps.
756
+ overrode_max_train_steps = False
757
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
758
+ if args.max_train_steps is None:
759
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
760
+ overrode_max_train_steps = True
761
+
762
+ lr_scheduler = get_scheduler(
763
+ args.lr_scheduler,
764
+ optimizer=optimizer,
765
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
766
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
767
+ num_cycles=args.lr_num_cycles,
768
+ power=args.lr_power,
769
+ )
770
+
771
+ # Prepare everything with our `accelerator`.
772
+ if args.train_text_encoder:
773
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
774
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
775
+ )
776
+ else:
777
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
778
+ unet, optimizer, train_dataloader, lr_scheduler
779
+ )
780
+
781
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
782
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
783
+ if overrode_max_train_steps:
784
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
785
+ # Afterwards we recalculate our number of training epochs
786
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
787
+
788
+ # We need to initialize the trackers we use, and also store our configuration.
789
+ # The trackers initializes automatically on the main process.
790
+ if accelerator.is_main_process:
791
+ accelerator.init_trackers("dreambooth", config=vars(args))
792
+
793
+ # Train!
794
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
795
+
796
+ logger.info("***** Running training *****")
797
+ logger.info(f" Num examples = {len(train_dataset)}")
798
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
799
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
800
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
801
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
802
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
803
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
804
+ global_step = 0
805
+ first_epoch = 0
806
+
807
+ # Potentially load in the weights and states from a previous save
808
+ if args.resume_from_checkpoint:
809
+ if args.resume_from_checkpoint != "latest":
810
+ path = os.path.basename(args.resume_from_checkpoint)
811
+ else:
812
+ # Get the mos recent checkpoint
813
+ dirs = os.listdir(args.output_dir)
814
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
815
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
816
+ path = dirs[-1] if len(dirs) > 0 else None
817
+
818
+ if path is None:
819
+ accelerator.print(
820
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
821
+ )
822
+ args.resume_from_checkpoint = None
823
+ else:
824
+ accelerator.print(f"Resuming from checkpoint {path}")
825
+ accelerator.load_state(os.path.join(args.output_dir, path))
826
+ global_step = int(path.split("-")[1])
827
+
828
+ resume_global_step = global_step * args.gradient_accumulation_steps
829
+ first_epoch = global_step // num_update_steps_per_epoch
830
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
831
+
832
+ # Only show the progress bar once on each machine.
833
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
834
+ progress_bar.set_description("Steps")
835
+
836
+ for epoch in range(first_epoch, args.num_train_epochs):
837
+ unet.train()
838
+ if args.train_text_encoder:
839
+ text_encoder.train()
840
+ for step, batch in enumerate(train_dataloader):
841
+ # Skip steps until we reach the resumed step
842
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
843
+ if step % args.gradient_accumulation_steps == 0:
844
+ progress_bar.update(1)
845
+ continue
846
+
847
+ with accelerator.accumulate(unet):
848
+ # Convert images to latent space
849
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
850
+ latents = latents * 0.18215
851
+
852
+ # Sample noise that we'll add to the latents
853
+ noise = torch.randn_like(latents)
854
+ bsz = latents.shape[0]
855
+ # Sample a random timestep for each image
856
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
857
+ timesteps = timesteps.long()
858
+
859
+ # Add noise to the latents according to the noise magnitude at each timestep
860
+ # (this is the forward diffusion process)
861
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
862
+
863
+ # Get the text embedding for conditioning
864
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
865
+
866
+ # Predict the noise residual
867
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
868
+
869
+ # Get the target for loss depending on the prediction type
870
+ if noise_scheduler.config.prediction_type == "epsilon":
871
+ target = noise
872
+ elif noise_scheduler.config.prediction_type == "v_prediction":
873
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
874
+ else:
875
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
876
+
877
+ if args.with_prior_preservation:
878
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
879
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
880
+ target, target_prior = torch.chunk(target, 2, dim=0)
881
+
882
+ # Compute instance loss
883
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
884
+
885
+ # Compute prior loss
886
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
887
+
888
+ # Add the prior loss to the instance loss.
889
+ loss = loss + args.prior_loss_weight * prior_loss
890
+ else:
891
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
892
+
893
+ accelerator.backward(loss)
894
+ if accelerator.sync_gradients:
895
+ params_to_clip = (
896
+ itertools.chain(unet.parameters(), text_encoder.parameters())
897
+ if args.train_text_encoder
898
+ else unet.parameters()
899
+ )
900
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
901
+ optimizer.step()
902
+ lr_scheduler.step()
903
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
904
+
905
+ # Checks if the accelerator has performed an optimization step behind the scenes
906
+ if accelerator.sync_gradients:
907
+ progress_bar.update(1)
908
+ global_step += 1
909
+
910
+ if global_step % args.checkpointing_steps == 0:
911
+ if accelerator.is_main_process:
912
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
913
+ accelerator.save_state(save_path)
914
+ logger.info(f"Saved state to {save_path}")
915
+
916
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
917
+ progress_bar.set_postfix(**logs)
918
+ accelerator.log(logs, step=global_step)
919
+
920
+ if global_step >= args.max_train_steps:
921
+ break
922
+
923
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
924
+ logger.info(
925
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
926
+ f" {args.validation_prompt}."
927
+ )
928
+ # create pipeline
929
+ pipeline = DiffusionPipeline.from_pretrained(
930
+ args.pretrained_model_name_or_path,
931
+ unet=accelerator.unwrap_model(unet),
932
+ text_encoder=accelerator.unwrap_model(text_encoder),
933
+ revision=args.revision,
934
+ torch_dtype=weight_dtype,
935
+ )
936
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
937
+ pipeline = pipeline.to(accelerator.device)
938
+ pipeline.set_progress_bar_config(disable=True)
939
+
940
+ # run inference
941
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
942
+ prompt = args.num_validation_images * [args.validation_prompt]
943
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
944
+
945
+ for tracker in accelerator.trackers:
946
+ if tracker.name == "tensorboard":
947
+ np_images = np.stack([np.asarray(img) for img in images])
948
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
949
+ if tracker.name == "wandb":
950
+ tracker.log(
951
+ {
952
+ "validation": [
953
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
954
+ for i, image in enumerate(images)
955
+ ]
956
+ }
957
+ )
958
+
959
+ del pipeline
960
+ torch.cuda.empty_cache()
961
+
962
+ # Save the dreambooth model
963
+ accelerator.wait_for_everyone()
964
+ if accelerator.is_main_process:
965
+ # Final inference
966
+ # Load previous pipeline
967
+ pipeline = DiffusionPipeline.from_pretrained(
968
+ args.pretrained_model_name_or_path,
969
+ unet=accelerator.unwrap_model(unet),
970
+ text_encoder=accelerator.unwrap_model(text_encoder),
971
+ revision=args.revision,
972
+ torch_dtype=weight_dtype,
973
+ )
974
+ pipeline.save_pretrained(args.output_dir)
975
+
976
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
977
+ pipeline = pipeline.to(accelerator.device)
978
+
979
+ # run inference
980
+ if args.validation_prompt and args.num_validation_images > 0:
981
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
982
+ prompt = args.num_validation_images * [args.validation_prompt]
983
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
984
+
985
+ test_image_dir = Path(args.output_dir) / 'test_images'
986
+ test_image_dir.mkdir()
987
+ for i, image in enumerate(images):
988
+ out_path = test_image_dir / f'image_{i}.png'
989
+ image.save(out_path)
990
+
991
+ for tracker in accelerator.trackers:
992
+ if tracker.name == "tensorboard":
993
+ np_images = np.stack([np.asarray(img) for img in images])
994
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
995
+ if tracker.name == "wandb":
996
+ tracker.log(
997
+ {
998
+ "test": [
999
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1000
+ for i, image in enumerate(images)
1001
+ ]
1002
+ }
1003
+ )
1004
+
1005
+ if args.push_to_hub:
1006
+ save_model_card(
1007
+ repo_name,
1008
+ images=images,
1009
+ base_model=args.pretrained_model_name_or_path,
1010
+ prompt=args.instance_prompt,
1011
+ repo_folder=args.output_dir,
1012
+ )
1013
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1014
+
1015
+ accelerator.end_training()
1016
+
1017
+
1018
+ if __name__ == "__main__":
1019
+ args = parse_args()
1020
+ main(args)
trainer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+ import PIL.Image
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+
16
+ from app_upload import ModelUploader
17
+ from utils import save_model_card
18
+
19
+ URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/dreambooth-library/share/RIODZCCDCvwZLCSxdsNkRYXSfeuTGBgKqp'
20
+
21
+
22
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
23
+ w, h = image.size
24
+ if w == h:
25
+ return image
26
+ elif w > h:
27
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
28
+ new_image.paste(image, (0, (w - h) // 2))
29
+ return new_image
30
+ else:
31
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
32
+ new_image.paste(image, ((h - w) // 2, 0))
33
+ return new_image
34
+
35
+
36
+ class Trainer:
37
+ def __init__(self, hf_token: str | None = None):
38
+ self.hf_token = hf_token
39
+ self.api = HfApi(token=hf_token)
40
+ self.model_uploader = ModelUploader(hf_token)
41
+
42
+ def prepare_dataset(self, instance_images: list, resolution: int,
43
+ instance_data_dir: pathlib.Path) -> None:
44
+ shutil.rmtree(instance_data_dir, ignore_errors=True)
45
+ instance_data_dir.mkdir(parents=True)
46
+ for i, temp_path in enumerate(instance_images):
47
+ image = PIL.Image.open(temp_path.name)
48
+ image = pad_image(image)
49
+ image = image.resize((resolution, resolution))
50
+ image = image.convert('RGB')
51
+ out_path = instance_data_dir / f'{i:03d}.jpg'
52
+ image.save(out_path, format='JPEG', quality=100)
53
+
54
+ def join_library_org(self) -> None:
55
+ subprocess.run(
56
+ shlex.split(
57
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}'
58
+ ))
59
+
60
+ def run(
61
+ self,
62
+ instance_images: list | None,
63
+ training_prompt: str,
64
+ output_model_name: str,
65
+ overwrite_existing_model: bool,
66
+ validation_prompt: str,
67
+ base_model: str,
68
+ resolution_s: str,
69
+ n_steps: int,
70
+ learning_rate: float,
71
+ gradient_accumulation: int,
72
+ seed: int,
73
+ fp16: bool,
74
+ use_8bit_adam: bool,
75
+ checkpointing_steps: int,
76
+ use_wandb: bool,
77
+ validation_epochs: int,
78
+ upload_to_hub: bool,
79
+ use_private_repo: bool,
80
+ delete_existing_repo: bool,
81
+ upload_to: str,
82
+ remove_gpu_after_training: bool,
83
+ ) -> str:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+ if instance_images is None:
87
+ raise gr.Error('You need to upload images.')
88
+ if not training_prompt:
89
+ raise gr.Error('The instance prompt is missing.')
90
+ if not validation_prompt:
91
+ raise gr.Error('The validation prompt is missing.')
92
+
93
+ resolution = int(resolution_s)
94
+
95
+ if not output_model_name:
96
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
97
+ output_model_name = f'dreambooth-{timestamp}'
98
+ output_model_name = slugify.slugify(output_model_name)
99
+
100
+ repo_dir = pathlib.Path(__file__).parent
101
+ output_dir = repo_dir / 'experiments' / output_model_name
102
+ if overwrite_existing_model or upload_to_hub:
103
+ shutil.rmtree(output_dir, ignore_errors=True)
104
+ output_dir.mkdir(parents=True)
105
+
106
+ instance_data_dir = repo_dir / 'training_data' / output_model_name
107
+ class_data_dir = repo_dir / 'class_data' / output_model_name
108
+ self.prepare_dataset(instance_images, resolution, instance_data_dir)
109
+
110
+ if upload_to_hub:
111
+ self.join_library_org()
112
+
113
+ command = f'''
114
+ python train_dreambooth.py \
115
+ --pretrained_model_name_or_path={base_model} \
116
+ --train_text_encoder \
117
+ --instance_data_dir={instance_data_dir} \
118
+ --class_data_dir={class_data_dir} \
119
+ --output_dir={output_dir} \
120
+ --with_prior_preservation --prior_loss_weight=1.0 \
121
+ --instance_prompt="{training_prompt.format("sks ")}" \
122
+ --class_prompt="{training_prompt.format("")}" \
123
+ --resolution={resolution} \
124
+ --train_batch_size=1 \
125
+ --gradient_accumulation_steps={gradient_accumulation} --gradient_checkpointing \
126
+ --learning_rate={learning_rate} \
127
+ --lr_scheduler=constant \
128
+ --lr_warmup_steps=0 \
129
+ --set_grads_to_none \
130
+ --num_class_images=200 \
131
+ --max_train_steps={n_steps} \
132
+ --checkpointing_steps={checkpointing_steps} \
133
+ --validation_prompt="{validation_prompt.format("sks ")}" \
134
+ --validation_epochs={validation_epochs} \
135
+ --seed={seed}
136
+ '''
137
+ if fp16:
138
+ command += ' --mixed_precision fp16'
139
+ if use_8bit_adam:
140
+ command += ' --use_8bit_adam'
141
+ if use_wandb:
142
+ command += ' --report_to wandb'
143
+
144
+ with open(output_dir / 'train.sh', 'w') as f:
145
+ command_s = ' '.join(command.split())
146
+ f.write(command_s)
147
+ subprocess.run(shlex.split(command))
148
+ save_model_card(save_dir=output_dir,
149
+ base_model=base_model,
150
+ instance_prompt=training_prompt.format("sks "),
151
+ test_prompt=validation_prompt.format("sks "),
152
+ test_image_dir='test_images')
153
+
154
+ message = 'Training completed!'
155
+ print(message)
156
+
157
+ if upload_to_hub:
158
+ upload_message = self.model_uploader.upload_model(
159
+ folder_path=output_dir.as_posix(),
160
+ repo_name=output_model_name,
161
+ upload_to=upload_to,
162
+ private=use_private_repo,
163
+ delete_existing_repo=delete_existing_repo)
164
+ print(upload_message)
165
+ message = message + '\n' + upload_message
166
+
167
+ if remove_gpu_after_training:
168
+ space_id = os.getenv('SPACE_ID')
169
+ if space_id:
170
+ self.api.request_space_hardware(repo_id=space_id,
171
+ hardware='cpu-basic')
172
+
173
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
+ if ignore_repo:
17
+ exp_dirs = [
18
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
+ ]
20
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ instance_prompt: str,
27
+ test_prompt: str = '',
28
+ test_image_dir: str = '',
29
+ ) -> None:
30
+ image_str = ''
31
+ if test_prompt and test_image_dir:
32
+ image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
+ if image_paths:
34
+ image_str = f'Test prompt: {test_prompt}\n'
35
+ for image_path in image_paths:
36
+ rel_path = image_path.relative_to(save_dir)
37
+ image_str += f'![{image_path.stem}]({rel_path})\n'
38
+
39
+ model_card = f'''---
40
+ license: creativeml-openrail-m
41
+ base_model: {base_model}
42
+ instance_prompt: {instance_prompt}
43
+ tags:
44
+ - stable-diffusion
45
+ - stable-diffusion-diffusers
46
+ - text-to-image
47
+ - diffusers
48
+ - dreambooth
49
+ inference: true
50
+ ---
51
+ # DreamBooth - {save_dir.name}
52
+
53
+ These are DreamBooth weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
54
+
55
+ {image_str}
56
+ '''
57
+
58
+ with open(save_dir / 'README.md', 'w') as f:
59
+ f.write(model_card)