hysts HF staff commited on
Commit
8b7a3d1
0 Parent(s):

Duplicate from lora-library/LoRA-DreamBooth-Training-UI

Browse files
Files changed (18) hide show
  1. .gitattributes +34 -0
  2. .gitignore +165 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. README.md +15 -0
  7. app.py +76 -0
  8. app_inference.py +176 -0
  9. app_training.py +144 -0
  10. app_upload.py +100 -0
  11. constants.py +6 -0
  12. inference.py +94 -0
  13. requirements.txt +14 -0
  14. style.css +3 -0
  15. train_dreambooth_lora.py +1026 -0
  16. trainer.py +166 -0
  17. uploader.py +42 -0
  18. utils.py +59 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ experiments/
3
+ wandb/
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: train_dreambooth_lora.py
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LoRA DreamBooth Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: lora-library/LoRA-DreamBooth-Training-UI
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = '# LoRA DreamBooth Training UI'
17
+
18
+ ORIGINAL_SPACE_ID = 'lora-library/LoRA-DreamBooth-Training-UI'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
+
22
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
23
+ '''
24
+
25
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
26
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
27
+ else:
28
+ SETTINGS = 'Settings'
29
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
30
+ <center>
31
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
+ "T4 small" is sufficient to run this demo.
33
+ </center>
34
+ '''
35
+
36
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
37
+ <center>
38
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
39
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
40
+ </center>
41
+ '''
42
+
43
+ HF_TOKEN = os.getenv('HF_TOKEN')
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ pipe = InferencePipeline(HF_TOKEN)
54
+ trainer = Trainer(HF_TOKEN)
55
+
56
+ with gr.Blocks(css='style.css') as demo:
57
+ if os.getenv('IS_SHARED_UI'):
58
+ show_warning(SHARED_UI_WARNING)
59
+ if not torch.cuda.is_available():
60
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
61
+ if not HF_TOKEN:
62
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
63
+
64
+ gr.Markdown(TITLE)
65
+ with gr.Tabs():
66
+ with gr.TabItem('Train'):
67
+ create_training_demo(trainer, pipe)
68
+ with gr.TabItem('Test'):
69
+ create_inference_demo(pipe, HF_TOKEN)
70
+ with gr.TabItem('Upload'):
71
+ gr.Markdown('''
72
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
73
+ ''')
74
+ create_upload_demo(HF_TOKEN)
75
+
76
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = [
14
+ 'patrickvonplaten/lora_dreambooth_dog_example',
15
+ 'sayakpaul/sd-model-finetuned-lora-t4',
16
+ ]
17
+
18
+
19
+ class ModelSource(enum.Enum):
20
+ SAMPLE = 'Sample'
21
+ HUB_LIB = 'Hub (lora-library)'
22
+ LOCAL = 'Local'
23
+
24
+
25
+ class InferenceUtil:
26
+ def __init__(self, hf_token: str | None):
27
+ self.hf_token = hf_token
28
+
29
+ @staticmethod
30
+ def load_sample_lora_model_list():
31
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
+
33
+ def load_hub_lora_model_list(self) -> dict:
34
+ api = HfApi(token=self.hf_token)
35
+ choices = [
36
+ info.modelId for info in api.list_models(author='lora-library')
37
+ ]
38
+ return gr.update(choices=choices,
39
+ value=choices[0] if choices else None)
40
+
41
+ @staticmethod
42
+ def load_local_lora_model_list() -> dict:
43
+ choices = find_exp_dirs()
44
+ return gr.update(choices=choices,
45
+ value=choices[0] if choices else None)
46
+
47
+ def reload_lora_model_list(self, model_source: str) -> dict:
48
+ if model_source == ModelSource.SAMPLE.value:
49
+ return self.load_sample_lora_model_list()
50
+ elif model_source == ModelSource.HUB_LIB.value:
51
+ return self.load_hub_lora_model_list()
52
+ elif model_source == ModelSource.LOCAL.value:
53
+ return self.load_local_lora_model_list()
54
+ else:
55
+ raise ValueError
56
+
57
+ def load_model_info(self, lora_model_id: str) -> tuple[str, str]:
58
+ try:
59
+ card = InferencePipeline.get_model_card(lora_model_id,
60
+ self.hf_token)
61
+ except Exception:
62
+ return '', ''
63
+ base_model = getattr(card.data, 'base_model', '')
64
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
65
+ return base_model, instance_prompt
66
+
67
+ def reload_lora_model_list_and_update_model_info(
68
+ self, model_source: str) -> tuple[dict, str, str]:
69
+ model_list_update = self.reload_lora_model_list(model_source)
70
+ model_list = model_list_update['choices']
71
+ model_info = self.load_model_info(model_list[0] if model_list else '')
72
+ return model_list_update, *model_info
73
+
74
+
75
+ def create_inference_demo(pipe: InferencePipeline,
76
+ hf_token: str | None = None) -> gr.Blocks:
77
+ app = InferenceUtil(hf_token)
78
+
79
+ with gr.Blocks() as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ with gr.Box():
83
+ model_source = gr.Radio(
84
+ label='Model Source',
85
+ choices=[_.value for _ in ModelSource],
86
+ value=ModelSource.SAMPLE.value)
87
+ reload_button = gr.Button('Reload Model List')
88
+ lora_model_id = gr.Dropdown(label='LoRA Model ID',
89
+ choices=SAMPLE_MODEL_IDS,
90
+ value=SAMPLE_MODEL_IDS[0])
91
+ with gr.Accordion(
92
+ label=
93
+ 'Model info (Base model and instance prompt used for training)',
94
+ open=False):
95
+ with gr.Row():
96
+ base_model_used_for_training = gr.Text(
97
+ label='Base model', interactive=False)
98
+ instance_prompt_used_for_training = gr.Text(
99
+ label='Instance prompt', interactive=False)
100
+ prompt = gr.Textbox(
101
+ label='Prompt',
102
+ max_lines=1,
103
+ placeholder='Example: "A picture of a sks dog in a bucket"'
104
+ )
105
+ alpha = gr.Slider(label='LoRA alpha',
106
+ minimum=0,
107
+ maximum=2,
108
+ step=0.05,
109
+ value=1)
110
+ seed = gr.Slider(label='Seed',
111
+ minimum=0,
112
+ maximum=100000,
113
+ step=1,
114
+ value=0)
115
+ with gr.Accordion('Other Parameters', open=False):
116
+ num_steps = gr.Slider(label='Number of Steps',
117
+ minimum=0,
118
+ maximum=100,
119
+ step=1,
120
+ value=25)
121
+ guidance_scale = gr.Slider(label='CFG Scale',
122
+ minimum=0,
123
+ maximum=50,
124
+ step=0.1,
125
+ value=7.5)
126
+
127
+ run_button = gr.Button('Generate')
128
+
129
+ gr.Markdown('''
130
+ - After training, you can press "Reload Model List" button to load your trained model names.
131
+ ''')
132
+ with gr.Column():
133
+ result = gr.Image(label='Result')
134
+
135
+ model_source.change(
136
+ fn=app.reload_lora_model_list_and_update_model_info,
137
+ inputs=model_source,
138
+ outputs=[
139
+ lora_model_id,
140
+ base_model_used_for_training,
141
+ instance_prompt_used_for_training,
142
+ ])
143
+ reload_button.click(
144
+ fn=app.reload_lora_model_list_and_update_model_info,
145
+ inputs=model_source,
146
+ outputs=[
147
+ lora_model_id,
148
+ base_model_used_for_training,
149
+ instance_prompt_used_for_training,
150
+ ])
151
+ lora_model_id.change(fn=app.load_model_info,
152
+ inputs=lora_model_id,
153
+ outputs=[
154
+ base_model_used_for_training,
155
+ instance_prompt_used_for_training,
156
+ ])
157
+ inputs = [
158
+ lora_model_id,
159
+ prompt,
160
+ alpha,
161
+ seed,
162
+ num_steps,
163
+ guidance_scale,
164
+ ]
165
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
166
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
167
+ return demo
168
+
169
+
170
+ if __name__ == '__main__':
171
+ import os
172
+
173
+ hf_token = os.getenv('HF_TOKEN')
174
+ pipe = InferencePipeline(hf_token)
175
+ demo = create_inference_demo(pipe, hf_token)
176
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ instance_images = gr.Files(label='Instance images')
22
+ instance_prompt = gr.Textbox(label='Instance prompt',
23
+ max_lines=1)
24
+ gr.Markdown('''
25
+ - Upload images of the style you are planning on training on.
26
+ - For an instance prompt, use a unique, made up word to avoid collisions.
27
+ ''')
28
+ with gr.Box():
29
+ gr.Markdown('Output Model')
30
+ output_model_name = gr.Text(label='Name of your model',
31
+ max_lines=1)
32
+ delete_existing_model = gr.Checkbox(
33
+ label='Delete existing model of the same name',
34
+ value=False)
35
+ validation_prompt = gr.Text(label='Validation Prompt')
36
+ with gr.Box():
37
+ gr.Markdown('Upload Settings')
38
+ with gr.Row():
39
+ upload_to_hub = gr.Checkbox(
40
+ label='Upload model to Hub', value=True)
41
+ use_private_repo = gr.Checkbox(label='Private',
42
+ value=True)
43
+ delete_existing_repo = gr.Checkbox(
44
+ label='Delete existing repo of the same name',
45
+ value=False)
46
+ upload_to = gr.Radio(
47
+ label='Upload to',
48
+ choices=[_.value for _ in UploadTarget],
49
+ value=UploadTarget.LORA_LIBRARY.value)
50
+ gr.Markdown('''
51
+ - By default, trained models will be uploaded to [LoRA Library](https://huggingface.co/lora-library) (see [this example model](https://huggingface.co/lora-library/lora-dreambooth-sample-dog)).
52
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
+ ''')
54
+
55
+ with gr.Box():
56
+ gr.Markdown('Training Parameters')
57
+ with gr.Row():
58
+ base_model = gr.Text(
59
+ label='Base Model',
60
+ value='stabilityai/stable-diffusion-2-1-base',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution')
65
+ num_training_steps = gr.Number(
66
+ label='Number of Training Steps', value=1000, precision=0)
67
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
68
+ gradient_accumulation = gr.Number(
69
+ label='Number of Gradient Accumulation',
70
+ value=1,
71
+ precision=0)
72
+ seed = gr.Slider(label='Seed',
73
+ minimum=0,
74
+ maximum=100000,
75
+ step=1,
76
+ value=0)
77
+ fp16 = gr.Checkbox(label='FP16', value=True)
78
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
79
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
80
+ value=100,
81
+ precision=0)
82
+ use_wandb = gr.Checkbox(label='Use W&B',
83
+ value=False,
84
+ interactive=bool(
85
+ os.getenv('WANDB_API_KEY')))
86
+ validation_epochs = gr.Number(label='Validation Epochs',
87
+ value=100,
88
+ precision=0)
89
+ gr.Markdown('''
90
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
91
+ - It takes a few minutes to download the base model first.
92
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
93
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
94
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
95
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
96
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
97
+ ''')
98
+
99
+ remove_gpu_after_training = gr.Checkbox(
100
+ label='Remove GPU after training',
101
+ value=False,
102
+ interactive=bool(os.getenv('SPACE_ID')),
103
+ visible=False)
104
+ run_button = gr.Button('Start Training')
105
+
106
+ with gr.Box():
107
+ gr.Markdown('Output message')
108
+ output_message = gr.Markdown()
109
+
110
+ if pipe is not None:
111
+ run_button.click(fn=pipe.clear)
112
+ run_button.click(fn=trainer.run,
113
+ inputs=[
114
+ instance_images,
115
+ instance_prompt,
116
+ output_model_name,
117
+ delete_existing_model,
118
+ validation_prompt,
119
+ base_model,
120
+ resolution,
121
+ num_training_steps,
122
+ learning_rate,
123
+ gradient_accumulation,
124
+ seed,
125
+ fp16,
126
+ use_8bit_adam,
127
+ checkpointing_steps,
128
+ use_wandb,
129
+ validation_epochs,
130
+ upload_to_hub,
131
+ use_private_repo,
132
+ delete_existing_repo,
133
+ upload_to,
134
+ remove_gpu_after_training,
135
+ ],
136
+ outputs=output_message)
137
+ return demo
138
+
139
+
140
+ if __name__ == '__main__':
141
+ hf_token = os.getenv('HF_TOKEN')
142
+ trainer = Trainer(hf_token)
143
+ demo = create_training_demo(trainer)
144
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class LoRAModelUploader(Uploader):
16
+ def upload_lora_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.LORA_LIBRARY.value:
33
+ organization = 'lora-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_lora_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = LoRAModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.LORA_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [LoRA Concepts Library](https://huggingface.co/lora-library) (i.e. https://huggingface.co/lora-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_lora_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_lora_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ LORA_LIBRARY = 'LoRA Library'
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+
12
+
13
+ class InferencePipeline:
14
+ def __init__(self, hf_token: str | None = None):
15
+ self.hf_token = hf_token
16
+ self.pipe = None
17
+ self.device = torch.device(
18
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
19
+ self.lora_model_id = None
20
+ self.base_model_id = None
21
+
22
+ def clear(self) -> None:
23
+ self.lora_model_id = None
24
+ self.base_model_id = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def check_if_model_is_local(lora_model_id: str) -> bool:
32
+ return pathlib.Path(lora_model_id).exists()
33
+
34
+ @staticmethod
35
+ def get_model_card(model_id: str,
36
+ hf_token: str | None = None) -> ModelCard:
37
+ if InferencePipeline.check_if_model_is_local(model_id):
38
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
39
+ else:
40
+ card_path = model_id
41
+ return ModelCard.load(card_path, token=hf_token)
42
+
43
+ @staticmethod
44
+ def get_base_model_info(lora_model_id: str,
45
+ hf_token: str | None = None) -> str:
46
+ card = InferencePipeline.get_model_card(lora_model_id, hf_token)
47
+ return card.data.base_model
48
+
49
+ def load_pipe(self, lora_model_id: str) -> None:
50
+ if lora_model_id == self.lora_model_id:
51
+ return
52
+ base_model_id = self.get_base_model_info(lora_model_id, self.hf_token)
53
+ if base_model_id != self.base_model_id:
54
+ if self.device.type == 'cpu':
55
+ pipe = DiffusionPipeline.from_pretrained(
56
+ base_model_id, use_auth_token=self.hf_token)
57
+ else:
58
+ pipe = DiffusionPipeline.from_pretrained(
59
+ base_model_id,
60
+ torch_dtype=torch.float16,
61
+ use_auth_token=self.hf_token)
62
+ pipe = pipe.to(self.device)
63
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
64
+ pipe.scheduler.config)
65
+ self.pipe = pipe
66
+ self.pipe.unet.load_attn_procs( # type: ignore
67
+ lora_model_id, use_auth_token=self.hf_token)
68
+
69
+ self.lora_model_id = lora_model_id # type: ignore
70
+ self.base_model_id = base_model_id # type: ignore
71
+
72
+ def run(
73
+ self,
74
+ lora_model_id: str,
75
+ prompt: str,
76
+ lora_scale: float,
77
+ seed: int,
78
+ n_steps: int,
79
+ guidance_scale: float,
80
+ ) -> PIL.Image.Image:
81
+ if not torch.cuda.is_available():
82
+ raise gr.Error('CUDA is not available.')
83
+
84
+ self.load_pipe(lora_model_id)
85
+
86
+ generator = torch.Generator(device=self.device).manual_seed(seed)
87
+ out = self.pipe(
88
+ prompt,
89
+ num_inference_steps=n_steps,
90
+ guidance_scale=guidance_scale,
91
+ generator=generator,
92
+ cross_attention_kwargs={'scale': lora_scale},
93
+ ) # type: ignore
94
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.36.0.post2
3
+ datasets==2.8.0
4
+ git+https://github.com/huggingface/diffusers@31be42209ddfdb69d9640a777b32e9b5c6259bf0#egg=diffusers
5
+ ftfy==6.1.1
6
+ gradio==3.16.2
7
+ huggingface-hub==0.12.0
8
+ Pillow==9.4.0
9
+ python-slugify==7.0.0
10
+ tensorboard==2.11.2
11
+ torch==1.13.1
12
+ torchvision==0.14.1
13
+ transformers==4.26.0
14
+ wandb==0.13.9
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_dreambooth_lora.py ADDED
@@ -0,0 +1,1026 @@