ameerazam08
commited on
Commit
•
6a6edcb
1
Parent(s):
5d25eca
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +12 -0
- LICENSE +21 -0
- README.md +4 -4
- WEIGHTS_LICENSE +44 -0
- __init__.py +0 -0
- app.py +35 -0
- configs/inference/controlnet_c_3b_canny.yaml +14 -0
- configs/inference/controlnet_c_3b_identity.yaml +17 -0
- configs/inference/controlnet_c_3b_inpainting.yaml +15 -0
- configs/inference/controlnet_c_3b_sr.yaml +15 -0
- configs/inference/lora_c_3b.yaml +15 -0
- configs/inference/stage_b_3b.yaml +13 -0
- configs/inference/stage_c_3b.yaml +7 -0
- configs/training/controlnet_c_3b_canny.yaml +45 -0
- configs/training/controlnet_c_3b_identity.yaml +48 -0
- configs/training/controlnet_c_3b_inpainting.yaml +46 -0
- configs/training/controlnet_c_3b_sr.yaml +46 -0
- configs/training/finetune_b_3b.yaml +36 -0
- configs/training/finetune_b_700m.yaml +36 -0
- configs/training/finetune_c_1b.yaml +35 -0
- configs/training/finetune_c_3b.yaml +35 -0
- configs/training/finetune_c_3b_lora.yaml +44 -0
- configs/training/finetune_c_3b_lowres.yaml +41 -0
- configs/training/finetune_c_3b_v.yaml +36 -0
- core/__init__.py +371 -0
- core/data/__init__.py +69 -0
- core/data/bucketeer.py +72 -0
- core/scripts/__init__.py +0 -0
- core/scripts/cli.py +41 -0
- core/templates/__init__.py +1 -0
- core/templates/diffusion.py +236 -0
- core/utils/__init__.py +9 -0
- core/utils/base_dto.py +56 -0
- core/utils/save_and_load.py +59 -0
- figures/collage_1.jpg +3 -0
- figures/collage_2.jpg +0 -0
- figures/collage_3.jpg +3 -0
- figures/collage_4.jpg +0 -0
- figures/comparison-inference-speed.jpg +0 -0
- figures/comparison.png +0 -0
- figures/controlnet-canny.jpg +0 -0
- figures/controlnet-face.jpg +0 -0
- figures/controlnet-paint.jpg +0 -0
- figures/controlnet-sr.jpg +3 -0
- figures/fernando.jpg +0 -0
- figures/fernando_original.jpg +0 -0
- figures/image-to-image-example-rodent.jpg +0 -0
- figures/image-variations-example-headset.jpg +0 -0
- figures/model-overview.jpg +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figures/collage_1.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figures/collage_3.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figures/controlnet-sr.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
inference/controlnet.ipynb filter=lfs diff=lfs merge=lfs -text
|
40 |
+
inference/reconstruct_images.ipynb filter=lfs diff=lfs merge=lfs -text
|
41 |
+
inference/text_to_image.ipynb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.yml
|
2 |
+
*.out
|
3 |
+
dist_file_*
|
4 |
+
__pycache__/*
|
5 |
+
*/__pycache__/*
|
6 |
+
*/**/__pycache__/*
|
7 |
+
*_latest_output.jpg
|
8 |
+
*_sample.jpg
|
9 |
+
jobs/*.sh
|
10 |
+
.ipynb_checkpoints
|
11 |
+
*.safetensors
|
12 |
+
*_test.yaml
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Stability AI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Stable Cascade
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Stable Cascade Upscale
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
WEIGHTS_LICENSE
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## THIS LICENSE IS FOR THE MODEL WEIGHTS ONLY
|
2 |
+
|
3 |
+
STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
|
4 |
+
Dated: November 28, 2023
|
5 |
+
|
6 |
+
By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.
|
7 |
+
|
8 |
+
"Agreement" means this Stable Non-Commercial Research Community License Agreement.
|
9 |
+
|
10 |
+
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
11 |
+
|
12 |
+
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
13 |
+
|
14 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
15 |
+
|
16 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
17 |
+
|
18 |
+
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
|
19 |
+
|
20 |
+
“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
|
21 |
+
|
22 |
+
"Stability AI" or "we" means Stability AI Ltd. and its affiliates.
|
23 |
+
|
24 |
+
|
25 |
+
"Software" means Stability AI’s proprietary software made available under this Agreement.
|
26 |
+
|
27 |
+
“Software Products” means the Models, Software and Documentation, individually or in any combination.
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
1. License Rights and Redistribution.
|
32 |
+
a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
|
33 |
+
b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
|
34 |
+
c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
35 |
+
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
36 |
+
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
37 |
+
4. Intellectual Property.
|
38 |
+
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
|
39 |
+
b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
|
40 |
+
c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
|
41 |
+
5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
|
42 |
+
|
43 |
+
6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
|
44 |
+
principles.
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from main import Upscale_CaseCade
|
4 |
+
import spaces
|
5 |
+
|
6 |
+
upscale_class=Upscale_CaseCade()
|
7 |
+
# scale_fator=7
|
8 |
+
# url = "https://cdn.discordapp.com/attachments/1121232062708457508/1205110687538479145/A_photograph_of_a_sunflower_with_sunglasses_on_in__3.jpg?ex=65d72dc9&is=65c4b8c9&hm=72172e774ce6cda618503b3778b844de05cd1208b61e185d8418db512fb2858a&"
|
9 |
+
# image_pil=Image.open("/home/rnd/Documents/Ameer/StableCascade/poster.png").convert("RGB")
|
10 |
+
@spaces.GPU
|
11 |
+
def scale_image(image_pil,scale_factor):
|
12 |
+
og,ups=upscale_class.upscale_image(image_pil=image_pil.convert("RGB"),scale_fator=scale_factor)
|
13 |
+
return [ups]
|
14 |
+
DESCRIPTION = "# Stable Cascade -> Super Resolution"
|
15 |
+
DESCRIPTION += "\n<p style=\"text-align: center\">Unofficial demo for Cascade-Super Resolution <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Upscale Cascade</a>, a new high resolution image-to-image model by Stability AI, - <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>non-commercial research license</a></p>"
|
16 |
+
# block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dracula_revamped').queue()
|
17 |
+
block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dark').queue()
|
18 |
+
|
19 |
+
with block:
|
20 |
+
with gr.Row():
|
21 |
+
gr.Markdown(DESCRIPTION)
|
22 |
+
with gr.Tabs():
|
23 |
+
with gr.Row():
|
24 |
+
with gr.Column():
|
25 |
+
image_pil = gr.Image(label="Describe the Image", type='pil')
|
26 |
+
scale_factor = gr.Slider(minimum=1,maximum=10,value=1, step=1, label="Scale Factor")
|
27 |
+
generate_button = gr.Button("Upscale Image")
|
28 |
+
with gr.Column():
|
29 |
+
generated_image = gr.Gallery(label="Generated Image",)
|
30 |
+
|
31 |
+
generate_button.click(fn=scale_image, inputs=[image_pil,scale_factor], outputs=[generated_image])
|
32 |
+
|
33 |
+
block.launch(show_api=False, server_port=8888, share=False, show_error=True, max_threads=1)
|
34 |
+
|
35 |
+
# pip install gradio==4.16.0 gradio_client==0.8.1
|
configs/inference/controlnet_c_3b_canny.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
7 |
+
controlnet_filter: CannyFilter
|
8 |
+
controlnet_filter_params:
|
9 |
+
resize: 224
|
10 |
+
|
11 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
12 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
13 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
14 |
+
controlnet_checkpoint_path: models/canny.safetensors
|
configs/inference/controlnet_c_3b_identity.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_bottleneck_mode: 'simple'
|
7 |
+
controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
8 |
+
controlnet_filter: IdentityFilter
|
9 |
+
controlnet_filter_params:
|
10 |
+
max_faces: 4
|
11 |
+
p_drop: 0.00
|
12 |
+
p_full: 0.0
|
13 |
+
|
14 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
15 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
16 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
17 |
+
controlnet_checkpoint_path:
|
configs/inference/controlnet_c_3b_inpainting.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
7 |
+
controlnet_filter: InpaintFilter
|
8 |
+
controlnet_filter_params:
|
9 |
+
thresold: [0.04, 0.4]
|
10 |
+
p_outpaint: 0.4
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
controlnet_checkpoint_path: models/inpainting.safetensors
|
configs/inference/controlnet_c_3b_sr.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# ControlNet specific
|
6 |
+
controlnet_bottleneck_mode: 'large'
|
7 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
8 |
+
controlnet_filter: SREffnetFilter
|
9 |
+
controlnet_filter_params:
|
10 |
+
scale_factor: 0.5
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
controlnet_checkpoint_path: models/super_resolution.safetensors
|
configs/inference/lora_c_3b.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# LoRA specific
|
6 |
+
module_filters: ['.attn']
|
7 |
+
rank: 4
|
8 |
+
train_tokens:
|
9 |
+
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
|
10 |
+
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
|
11 |
+
|
12 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
13 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
14 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
15 |
+
lora_checkpoint_path: models/lora_fernando_10k.safetensors
|
configs/inference/stage_b_3b.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
# For demonstration purposes in reconstruct_images.ipynb
|
6 |
+
webdataset_path: file:inference/imagenet_1024.tar
|
7 |
+
batch_size: 4
|
8 |
+
image_size: 1024
|
9 |
+
grad_accum_steps: 1
|
10 |
+
|
11 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
12 |
+
stage_a_checkpoint_path: models/stage_a.safetensors
|
13 |
+
generator_checkpoint_path: models/stage_b_bf16.safetensors
|
configs/inference/stage_c_3b.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
model_version: 3.6B
|
3 |
+
dtype: bfloat16
|
4 |
+
|
5 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
6 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
7 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/controlnet_c_3b_canny.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_controlnet_canny
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 256
|
14 |
+
image_size: 768
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 10000
|
18 |
+
backup_every: 2000
|
19 |
+
save_every: 1000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# ControlNet specific
|
24 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
25 |
+
controlnet_filter: CannyFilter
|
26 |
+
controlnet_filter_params:
|
27 |
+
resize: 224
|
28 |
+
# offset_noise: 0.1
|
29 |
+
|
30 |
+
# CUSTOM CAPTIONS GETTER & FILTERS
|
31 |
+
captions_getter: ['txt', identity]
|
32 |
+
dataset_filters:
|
33 |
+
- ['width', 'lambda w: w >= 768']
|
34 |
+
- ['height', 'lambda h: h >= 768']
|
35 |
+
|
36 |
+
# ema_start_iters: 5000
|
37 |
+
# ema_iters: 100
|
38 |
+
# ema_beta: 0.9
|
39 |
+
|
40 |
+
webdataset_path:
|
41 |
+
- s3://path/to/your/first/dataset/on/s3
|
42 |
+
- s3://path/to/your/second/dataset/on/s3
|
43 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
44 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
45 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/controlnet_c_3b_identity.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_controlnet_identity
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 256
|
14 |
+
image_size: 768
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 200000
|
18 |
+
backup_every: 2000
|
19 |
+
save_every: 1000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# ControlNet specific
|
24 |
+
controlnet_bottleneck_mode: 'simple'
|
25 |
+
controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
26 |
+
controlnet_filter: IdentityFilter
|
27 |
+
controlnet_filter_params:
|
28 |
+
max_faces: 4
|
29 |
+
p_drop: 0.05
|
30 |
+
p_full: 0.3
|
31 |
+
# offset_noise: 0.1
|
32 |
+
|
33 |
+
# CUSTOM CAPTIONS GETTER & FILTERS
|
34 |
+
captions_getter: ['txt', identity]
|
35 |
+
dataset_filters:
|
36 |
+
- ['width', 'lambda w: w >= 768']
|
37 |
+
- ['height', 'lambda h: h >= 768']
|
38 |
+
|
39 |
+
# ema_start_iters: 5000
|
40 |
+
# ema_iters: 100
|
41 |
+
# ema_beta: 0.9
|
42 |
+
|
43 |
+
webdataset_path:
|
44 |
+
- s3://path/to/your/first/dataset/on/s3
|
45 |
+
- s3://path/to/your/second/dataset/on/s3
|
46 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
47 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
48 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/controlnet_c_3b_inpainting.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_controlnet_inpainting
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 256
|
14 |
+
image_size: 768
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 10000
|
18 |
+
backup_every: 2000
|
19 |
+
save_every: 1000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# ControlNet specific
|
24 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
25 |
+
controlnet_filter: InpaintFilter
|
26 |
+
controlnet_filter_params:
|
27 |
+
thresold: [0.04, 0.4]
|
28 |
+
p_outpaint: 0.4
|
29 |
+
offset_noise: 0.1
|
30 |
+
|
31 |
+
# CUSTOM CAPTIONS GETTER & FILTERS
|
32 |
+
captions_getter: ['txt', identity]
|
33 |
+
dataset_filters:
|
34 |
+
- ['width', 'lambda w: w >= 768']
|
35 |
+
- ['height', 'lambda h: h >= 768']
|
36 |
+
|
37 |
+
# ema_start_iters: 5000
|
38 |
+
# ema_iters: 100
|
39 |
+
# ema_beta: 0.9
|
40 |
+
|
41 |
+
webdataset_path:
|
42 |
+
- s3://path/to/your/first/dataset/on/s3
|
43 |
+
- s3://path/to/your/second/dataset/on/s3
|
44 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
45 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
46 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/controlnet_c_3b_sr.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_controlnet_sr
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 256
|
14 |
+
image_size: 768
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 30000
|
18 |
+
backup_every: 5000
|
19 |
+
save_every: 1000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# ControlNet specific
|
24 |
+
controlnet_bottleneck_mode: 'large'
|
25 |
+
controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
|
26 |
+
controlnet_filter: SREffnetFilter
|
27 |
+
controlnet_filter_params:
|
28 |
+
scale_factor: 0.5
|
29 |
+
offset_noise: 0.1
|
30 |
+
|
31 |
+
# CUSTOM CAPTIONS GETTER & FILTERS
|
32 |
+
captions_getter: ['txt', identity]
|
33 |
+
dataset_filters:
|
34 |
+
- ['width', 'lambda w: w >= 768']
|
35 |
+
- ['height', 'lambda h: h >= 768']
|
36 |
+
|
37 |
+
# ema_start_iters: 5000
|
38 |
+
# ema_iters: 100
|
39 |
+
# ema_beta: 0.9
|
40 |
+
|
41 |
+
webdataset_path:
|
42 |
+
- s3://path/to/your/first/dataset/on/s3
|
43 |
+
- s3://path/to/your/second/dataset/on/s3
|
44 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
45 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
46 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/finetune_b_3b.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_b_3b_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 256
|
14 |
+
image_size: 1024
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
shift: 4
|
17 |
+
grad_accum_steps: 1
|
18 |
+
updates: 100000
|
19 |
+
backup_every: 20000
|
20 |
+
save_every: 1000
|
21 |
+
warmup_updates: 1
|
22 |
+
use_fsdp: True
|
23 |
+
|
24 |
+
# GDF
|
25 |
+
adaptive_loss_weight: True
|
26 |
+
|
27 |
+
# ema_start_iters: 5000
|
28 |
+
# ema_iters: 100
|
29 |
+
# ema_beta: 0.9
|
30 |
+
|
31 |
+
webdataset_path:
|
32 |
+
- s3://path/to/your/first/dataset/on/s3
|
33 |
+
- s3://path/to/your/second/dataset/on/s3
|
34 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
35 |
+
stage_a_checkpoint_path: models/stage_a.safetensors
|
36 |
+
generator_checkpoint_path: models/stage_b_bf16.safetensors
|
configs/training/finetune_b_700m.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_b_700m_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 700M
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 512
|
14 |
+
image_size: 1024
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
shift: 4
|
17 |
+
grad_accum_steps: 1
|
18 |
+
updates: 10000
|
19 |
+
backup_every: 20000
|
20 |
+
save_every: 2000
|
21 |
+
warmup_updates: 1
|
22 |
+
use_fsdp: True
|
23 |
+
|
24 |
+
# GDF
|
25 |
+
adaptive_loss_weight: True
|
26 |
+
|
27 |
+
# ema_start_iters: 5000
|
28 |
+
# ema_iters: 100
|
29 |
+
# ema_beta: 0.9
|
30 |
+
|
31 |
+
webdataset_path:
|
32 |
+
- s3://path/to/your/first/dataset/on/s3
|
33 |
+
- s3://path/to/your/second/dataset/on/s3
|
34 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
35 |
+
stage_a_checkpoint_path: models/stage_a.safetensors
|
36 |
+
generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
|
configs/training/finetune_c_1b.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_1b_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 1B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 1024
|
14 |
+
image_size: 768
|
15 |
+
# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 10000
|
18 |
+
backup_every: 20000
|
19 |
+
save_every: 2000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# GDF
|
24 |
+
# adaptive_loss_weight: True
|
25 |
+
|
26 |
+
# ema_start_iters: 5000
|
27 |
+
# ema_iters: 100
|
28 |
+
# ema_beta: 0.9
|
29 |
+
|
30 |
+
webdataset_path:
|
31 |
+
- s3://path/to/your/first/dataset/on/s3
|
32 |
+
- s3://path/to/your/second/dataset/on/s3
|
33 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
34 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
35 |
+
generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
|
configs/training/finetune_c_3b.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 512
|
14 |
+
image_size: 768
|
15 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 100000
|
18 |
+
backup_every: 20000
|
19 |
+
save_every: 2000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# GDF
|
24 |
+
adaptive_loss_weight: True
|
25 |
+
|
26 |
+
# ema_start_iters: 5000
|
27 |
+
# ema_iters: 100
|
28 |
+
# ema_beta: 0.9
|
29 |
+
|
30 |
+
webdataset_path:
|
31 |
+
- s3://path/to/your/first/dataset/on/s3
|
32 |
+
- s3://path/to/your/second/dataset/on/s3
|
33 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
34 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
35 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/finetune_c_3b_lora.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_lora
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 32
|
14 |
+
image_size: 768
|
15 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 4
|
17 |
+
updates: 10000
|
18 |
+
backup_every: 1000
|
19 |
+
save_every: 100
|
20 |
+
warmup_updates: 1
|
21 |
+
# use_fsdp: True -> FSDP doesn't work at the moment for LoRA
|
22 |
+
use_fsdp: False
|
23 |
+
|
24 |
+
# GDF
|
25 |
+
# adaptive_loss_weight: True
|
26 |
+
|
27 |
+
# LoRA specific
|
28 |
+
module_filters: ['.attn']
|
29 |
+
rank: 4
|
30 |
+
train_tokens:
|
31 |
+
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
|
32 |
+
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
|
33 |
+
|
34 |
+
|
35 |
+
# ema_start_iters: 5000
|
36 |
+
# ema_iters: 100
|
37 |
+
# ema_beta: 0.9
|
38 |
+
|
39 |
+
webdataset_path:
|
40 |
+
- s3://path/to/your/first/dataset/on/s3
|
41 |
+
- s3://path/to/your/second/dataset/on/s3
|
42 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
43 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
44 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/finetune_c_3b_lowres.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 1024
|
14 |
+
image_size: 384
|
15 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 100000
|
18 |
+
backup_every: 20000
|
19 |
+
save_every: 2000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# GDF
|
24 |
+
adaptive_loss_weight: True
|
25 |
+
|
26 |
+
# CUSTOM CAPTIONS GETTER & FILTERS
|
27 |
+
# captions_getter: ['json', captions_getter]
|
28 |
+
# dataset_filters:
|
29 |
+
# - ['normalized_score', 'lambda s: s > 9.0']
|
30 |
+
# - ['pgen_normalized_score', 'lambda s: s > 3.0']
|
31 |
+
|
32 |
+
# ema_start_iters: 5000
|
33 |
+
# ema_iters: 100
|
34 |
+
# ema_beta: 0.9
|
35 |
+
|
36 |
+
webdataset_path:
|
37 |
+
- s3://path/to/your/first/dataset/on/s3
|
38 |
+
- s3://path/to/your/second/dataset/on/s3
|
39 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
40 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
41 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
configs/training/finetune_c_3b_v.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLOBAL STUFF
|
2 |
+
experiment_id: stage_c_3b_finetuning
|
3 |
+
checkpoint_path: /path/to/checkpoint
|
4 |
+
output_path: /path/to/output
|
5 |
+
model_version: 3.6B
|
6 |
+
|
7 |
+
# WandB
|
8 |
+
wandb_project: StableCascade
|
9 |
+
wandb_entity: wandb_username
|
10 |
+
|
11 |
+
# TRAINING PARAMS
|
12 |
+
lr: 1.0e-4
|
13 |
+
batch_size: 512
|
14 |
+
image_size: 768
|
15 |
+
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
|
16 |
+
grad_accum_steps: 1
|
17 |
+
updates: 100000
|
18 |
+
backup_every: 20000
|
19 |
+
save_every: 2000
|
20 |
+
warmup_updates: 1
|
21 |
+
use_fsdp: True
|
22 |
+
|
23 |
+
# GDF
|
24 |
+
adaptive_loss_weight: True
|
25 |
+
edm_objective: True
|
26 |
+
|
27 |
+
# ema_start_iters: 5000
|
28 |
+
# ema_iters: 100
|
29 |
+
# ema_beta: 0.9
|
30 |
+
|
31 |
+
webdataset_path:
|
32 |
+
- s3://path/to/your/first/dataset/on/s3
|
33 |
+
- s3://path/to/your/second/dataset/on/s3
|
34 |
+
effnet_checkpoint_path: models/effnet_encoder.safetensors
|
35 |
+
previewer_checkpoint_path: models/previewer.safetensors
|
36 |
+
generator_checkpoint_path: models/stage_c_bf16.safetensors
|
core/__init__.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import wandb
|
6 |
+
import json
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
|
11 |
+
from torch.distributed import init_process_group, destroy_process_group, barrier
|
12 |
+
from torch.distributed.fsdp import (
|
13 |
+
FullyShardedDataParallel as FSDP,
|
14 |
+
FullStateDictConfig,
|
15 |
+
MixedPrecision,
|
16 |
+
ShardingStrategy,
|
17 |
+
StateDictType
|
18 |
+
)
|
19 |
+
|
20 |
+
from .utils import Base, EXPECTED, EXPECTED_TRAIN
|
21 |
+
from .utils import create_folder_if_necessary, safe_save, load_or_fail
|
22 |
+
|
23 |
+
# pylint: disable=unused-argument
|
24 |
+
class WarpCore(ABC):
|
25 |
+
@dataclass(frozen=True)
|
26 |
+
class Config(Base):
|
27 |
+
experiment_id: str = EXPECTED_TRAIN
|
28 |
+
checkpoint_path: str = EXPECTED_TRAIN
|
29 |
+
output_path: str = EXPECTED_TRAIN
|
30 |
+
checkpoint_extension: str = "safetensors"
|
31 |
+
dist_file_subfolder: str = ""
|
32 |
+
allow_tf32: bool = True
|
33 |
+
|
34 |
+
wandb_project: str = None
|
35 |
+
wandb_entity: str = None
|
36 |
+
|
37 |
+
@dataclass() # not frozen, means that fields are mutable
|
38 |
+
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
|
39 |
+
wandb_run_id: str = None
|
40 |
+
total_steps: int = 0
|
41 |
+
iter: int = 0
|
42 |
+
|
43 |
+
@dataclass(frozen=True)
|
44 |
+
class Data(Base):
|
45 |
+
dataset: Dataset = EXPECTED
|
46 |
+
dataloader: DataLoader = EXPECTED
|
47 |
+
iterator: any = EXPECTED
|
48 |
+
|
49 |
+
@dataclass(frozen=True)
|
50 |
+
class Models(Base):
|
51 |
+
pass
|
52 |
+
|
53 |
+
@dataclass(frozen=True)
|
54 |
+
class Optimizers(Base):
|
55 |
+
pass
|
56 |
+
|
57 |
+
@dataclass(frozen=True)
|
58 |
+
class Schedulers(Base):
|
59 |
+
pass
|
60 |
+
|
61 |
+
@dataclass(frozen=True)
|
62 |
+
class Extras(Base):
|
63 |
+
pass
|
64 |
+
# ---------------------------------------
|
65 |
+
info: Info
|
66 |
+
config: Config
|
67 |
+
|
68 |
+
# FSDP stuff
|
69 |
+
fsdp_defaults = {
|
70 |
+
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
|
71 |
+
"cpu_offload": None,
|
72 |
+
"mixed_precision": MixedPrecision(
|
73 |
+
param_dtype=torch.bfloat16,
|
74 |
+
reduce_dtype=torch.bfloat16,
|
75 |
+
buffer_dtype=torch.bfloat16,
|
76 |
+
),
|
77 |
+
"limit_all_gathers": True,
|
78 |
+
}
|
79 |
+
fsdp_fullstate_save_policy = FullStateDictConfig(
|
80 |
+
offload_to_cpu=True, rank0_only=True
|
81 |
+
)
|
82 |
+
# ------------
|
83 |
+
|
84 |
+
# OVERRIDEABLE METHODS
|
85 |
+
|
86 |
+
# [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
|
87 |
+
def setup_extras_pre(self) -> Extras:
|
88 |
+
return self.Extras()
|
89 |
+
|
90 |
+
# setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
|
91 |
+
@abstractmethod
|
92 |
+
def setup_data(self, extras: Extras) -> Data:
|
93 |
+
raise NotImplementedError("This method needs to be overriden")
|
94 |
+
|
95 |
+
# return a dict with all models that are going to be used in the training
|
96 |
+
@abstractmethod
|
97 |
+
def setup_models(self, extras: Extras) -> Models:
|
98 |
+
raise NotImplementedError("This method needs to be overriden")
|
99 |
+
|
100 |
+
# return a dict with all optimizers that are going to be used in the training
|
101 |
+
@abstractmethod
|
102 |
+
def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
|
103 |
+
raise NotImplementedError("This method needs to be overriden")
|
104 |
+
|
105 |
+
# [optionally] return a dict with all schedulers that are going to be used in the training
|
106 |
+
def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
|
107 |
+
return self.Schedulers()
|
108 |
+
|
109 |
+
# [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
|
110 |
+
def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
|
111 |
+
return self.Extras.from_dict(extras.to_dict())
|
112 |
+
|
113 |
+
# perform the training here
|
114 |
+
@abstractmethod
|
115 |
+
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
116 |
+
raise NotImplementedError("This method needs to be overriden")
|
117 |
+
# ------------
|
118 |
+
|
119 |
+
def setup_info(self, full_path=None) -> Info:
|
120 |
+
if full_path is None:
|
121 |
+
full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
|
122 |
+
info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
|
123 |
+
info_dto = self.Info(**info_dict)
|
124 |
+
if info_dto.total_steps > 0 and self.is_main_node:
|
125 |
+
print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
|
126 |
+
return info_dto
|
127 |
+
|
128 |
+
def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
|
129 |
+
if config_file_path is not None:
|
130 |
+
if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
|
131 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
132 |
+
loaded_config = yaml.safe_load(file)
|
133 |
+
elif config_file_path.endswith(".json"):
|
134 |
+
with open(config_file_path, "r", encoding="utf-8") as file:
|
135 |
+
loaded_config = json.load(file)
|
136 |
+
else:
|
137 |
+
raise ValueError("Config file must be either a .yml|.yaml or .json file")
|
138 |
+
return self.Config.from_dict({**loaded_config, 'training': training})
|
139 |
+
if config_dict is not None:
|
140 |
+
return self.Config.from_dict({**config_dict, 'training': training})
|
141 |
+
return self.Config(training=training)
|
142 |
+
|
143 |
+
def setup_ddp(self, experiment_id, single_gpu=False):
|
144 |
+
if not single_gpu:
|
145 |
+
local_rank = int(os.environ.get("SLURM_LOCALID"))
|
146 |
+
process_id = int(os.environ.get("SLURM_PROCID"))
|
147 |
+
world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
|
148 |
+
|
149 |
+
self.process_id = process_id
|
150 |
+
self.is_main_node = process_id == 0
|
151 |
+
self.device = torch.device(local_rank)
|
152 |
+
self.world_size = world_size
|
153 |
+
|
154 |
+
dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
|
155 |
+
# if os.path.exists(dist_file_path) and self.is_main_node:
|
156 |
+
# os.remove(dist_file_path)
|
157 |
+
|
158 |
+
torch.cuda.set_device(local_rank)
|
159 |
+
init_process_group(
|
160 |
+
backend="nccl",
|
161 |
+
rank=process_id,
|
162 |
+
world_size=world_size,
|
163 |
+
init_method=f"file://{dist_file_path}",
|
164 |
+
)
|
165 |
+
print(f"[GPU {process_id}] READY")
|
166 |
+
else:
|
167 |
+
print("Running in single thread, DDP not enabled.")
|
168 |
+
|
169 |
+
def setup_wandb(self):
|
170 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
171 |
+
self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
|
172 |
+
wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
|
173 |
+
|
174 |
+
if self.info.total_steps > 0:
|
175 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
|
176 |
+
else:
|
177 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
|
178 |
+
|
179 |
+
# LOAD UTILITIES ----------
|
180 |
+
def load_model(self, model, model_id=None, full_path=None, strict=True):
|
181 |
+
if model_id is not None and full_path is None:
|
182 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
183 |
+
elif full_path is None and model_id is None:
|
184 |
+
raise ValueError(
|
185 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
186 |
+
)
|
187 |
+
|
188 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
189 |
+
if checkpoint is not None:
|
190 |
+
model.load_state_dict(checkpoint, strict=strict)
|
191 |
+
del checkpoint
|
192 |
+
|
193 |
+
return model
|
194 |
+
|
195 |
+
def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
196 |
+
if optim_id is not None and full_path is None:
|
197 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
198 |
+
elif full_path is None and optim_id is None:
|
199 |
+
raise ValueError(
|
200 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
201 |
+
)
|
202 |
+
|
203 |
+
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
|
204 |
+
if checkpoint is not None:
|
205 |
+
try:
|
206 |
+
if fsdp_model is not None:
|
207 |
+
sharded_optimizer_state_dict = (
|
208 |
+
FSDP.scatter_full_optim_state_dict( # <---- FSDP
|
209 |
+
checkpoint
|
210 |
+
if (
|
211 |
+
self.is_main_node
|
212 |
+
or self.fsdp_defaults["sharding_strategy"]
|
213 |
+
== ShardingStrategy.NO_SHARD
|
214 |
+
)
|
215 |
+
else None,
|
216 |
+
fsdp_model,
|
217 |
+
)
|
218 |
+
)
|
219 |
+
optim.load_state_dict(sharded_optimizer_state_dict)
|
220 |
+
del checkpoint, sharded_optimizer_state_dict
|
221 |
+
else:
|
222 |
+
optim.load_state_dict(checkpoint)
|
223 |
+
# pylint: disable=broad-except
|
224 |
+
except Exception as e:
|
225 |
+
print("!!! Failed loading optimizer, skipping... Exception:", e)
|
226 |
+
|
227 |
+
return optim
|
228 |
+
|
229 |
+
# SAVE UTILITIES ----------
|
230 |
+
def save_info(self, info, suffix=""):
|
231 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
|
232 |
+
create_folder_if_necessary(full_path)
|
233 |
+
if self.is_main_node:
|
234 |
+
safe_save(vars(self.info), full_path)
|
235 |
+
|
236 |
+
def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
|
237 |
+
if model_id is not None and full_path is None:
|
238 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
|
239 |
+
elif full_path is None and model_id is None:
|
240 |
+
raise ValueError(
|
241 |
+
"This method expects either 'model_id' or 'full_path' to be defined"
|
242 |
+
)
|
243 |
+
create_folder_if_necessary(full_path)
|
244 |
+
if is_fsdp:
|
245 |
+
with FSDP.summon_full_params(model):
|
246 |
+
pass
|
247 |
+
with FSDP.state_dict_type(
|
248 |
+
model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
|
249 |
+
):
|
250 |
+
checkpoint = model.state_dict()
|
251 |
+
if self.is_main_node:
|
252 |
+
safe_save(checkpoint, full_path)
|
253 |
+
del checkpoint
|
254 |
+
else:
|
255 |
+
if self.is_main_node:
|
256 |
+
checkpoint = model.state_dict()
|
257 |
+
safe_save(checkpoint, full_path)
|
258 |
+
del checkpoint
|
259 |
+
|
260 |
+
def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
|
261 |
+
if optim_id is not None and full_path is None:
|
262 |
+
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
|
263 |
+
elif full_path is None and optim_id is None:
|
264 |
+
raise ValueError(
|
265 |
+
"This method expects either 'optim_id' or 'full_path' to be defined"
|
266 |
+
)
|
267 |
+
create_folder_if_necessary(full_path)
|
268 |
+
if fsdp_model is not None:
|
269 |
+
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
|
270 |
+
if self.is_main_node:
|
271 |
+
safe_save(optim_statedict, full_path)
|
272 |
+
del optim_statedict
|
273 |
+
else:
|
274 |
+
if self.is_main_node:
|
275 |
+
checkpoint = optim.state_dict()
|
276 |
+
safe_save(checkpoint, full_path)
|
277 |
+
del checkpoint
|
278 |
+
# -----
|
279 |
+
|
280 |
+
def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
|
281 |
+
# Temporary setup, will be overriden by setup_ddp if required
|
282 |
+
self.device = device
|
283 |
+
self.process_id = 0
|
284 |
+
self.is_main_node = True
|
285 |
+
self.world_size = 1
|
286 |
+
# ----
|
287 |
+
|
288 |
+
self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
|
289 |
+
self.info: self.Info = self.setup_info()
|
290 |
+
|
291 |
+
def __call__(self, single_gpu=False):
|
292 |
+
self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
|
293 |
+
self.setup_wandb()
|
294 |
+
if self.config.allow_tf32:
|
295 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
296 |
+
torch.backends.cudnn.allow_tf32 = True
|
297 |
+
|
298 |
+
if self.is_main_node:
|
299 |
+
print()
|
300 |
+
print("**STARTIG JOB WITH CONFIG:**")
|
301 |
+
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
|
302 |
+
print("------------------------------------")
|
303 |
+
print()
|
304 |
+
print("**INFO:**")
|
305 |
+
print(yaml.dump(vars(self.info), default_flow_style=False))
|
306 |
+
print("------------------------------------")
|
307 |
+
print()
|
308 |
+
|
309 |
+
# SETUP STUFF
|
310 |
+
extras = self.setup_extras_pre()
|
311 |
+
assert extras is not None, "setup_extras_pre() must return a DTO"
|
312 |
+
|
313 |
+
data = self.setup_data(extras)
|
314 |
+
assert data is not None, "setup_data() must return a DTO"
|
315 |
+
if self.is_main_node:
|
316 |
+
print("**DATA:**")
|
317 |
+
print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
|
318 |
+
print("------------------------------------")
|
319 |
+
print()
|
320 |
+
|
321 |
+
models = self.setup_models(extras)
|
322 |
+
assert models is not None, "setup_models() must return a DTO"
|
323 |
+
if self.is_main_node:
|
324 |
+
print("**MODELS:**")
|
325 |
+
print(yaml.dump({
|
326 |
+
k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
|
327 |
+
}, default_flow_style=False))
|
328 |
+
print("------------------------------------")
|
329 |
+
print()
|
330 |
+
|
331 |
+
optimizers = self.setup_optimizers(extras, models)
|
332 |
+
assert optimizers is not None, "setup_optimizers() must return a DTO"
|
333 |
+
if self.is_main_node:
|
334 |
+
print("**OPTIMIZERS:**")
|
335 |
+
print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
|
336 |
+
print("------------------------------------")
|
337 |
+
print()
|
338 |
+
|
339 |
+
schedulers = self.setup_schedulers(extras, models, optimizers)
|
340 |
+
assert schedulers is not None, "setup_schedulers() must return a DTO"
|
341 |
+
if self.is_main_node:
|
342 |
+
print("**SCHEDULERS:**")
|
343 |
+
print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
|
344 |
+
print("------------------------------------")
|
345 |
+
print()
|
346 |
+
|
347 |
+
post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
|
348 |
+
assert post_extras is not None, "setup_extras_post() must return a DTO"
|
349 |
+
extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
|
350 |
+
if self.is_main_node:
|
351 |
+
print("**EXTRAS:**")
|
352 |
+
print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
|
353 |
+
print("------------------------------------")
|
354 |
+
print()
|
355 |
+
# -------
|
356 |
+
|
357 |
+
# TRAIN
|
358 |
+
if self.is_main_node:
|
359 |
+
print("**TRAINING STARTING...**")
|
360 |
+
self.train(data, extras, models, optimizers, schedulers)
|
361 |
+
|
362 |
+
if single_gpu is False:
|
363 |
+
barrier()
|
364 |
+
destroy_process_group()
|
365 |
+
if self.is_main_node:
|
366 |
+
print()
|
367 |
+
print("------------------------------------")
|
368 |
+
print()
|
369 |
+
print("**TRAINING COMPLETE**")
|
370 |
+
if self.config.wandb_project is not None:
|
371 |
+
wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
|
core/data/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import subprocess
|
3 |
+
import yaml
|
4 |
+
import os
|
5 |
+
from .bucketeer import Bucketeer
|
6 |
+
|
7 |
+
class MultiFilter():
|
8 |
+
def __init__(self, rules, default=False):
|
9 |
+
self.rules = rules
|
10 |
+
self.default = default
|
11 |
+
|
12 |
+
def __call__(self, x):
|
13 |
+
try:
|
14 |
+
x_json = x['json']
|
15 |
+
if isinstance(x_json, bytes):
|
16 |
+
x_json = json.loads(x_json)
|
17 |
+
validations = []
|
18 |
+
for k, r in self.rules.items():
|
19 |
+
if isinstance(k, tuple):
|
20 |
+
v = r(*[x_json[kv] for kv in k])
|
21 |
+
else:
|
22 |
+
v = r(x_json[k])
|
23 |
+
validations.append(v)
|
24 |
+
return all(validations)
|
25 |
+
except Exception:
|
26 |
+
return False
|
27 |
+
|
28 |
+
class MultiGetter():
|
29 |
+
def __init__(self, rules):
|
30 |
+
self.rules = rules
|
31 |
+
|
32 |
+
def __call__(self, x_json):
|
33 |
+
if isinstance(x_json, bytes):
|
34 |
+
x_json = json.loads(x_json)
|
35 |
+
outputs = []
|
36 |
+
for k, r in self.rules.items():
|
37 |
+
if isinstance(k, tuple):
|
38 |
+
v = r(*[x_json[kv] for kv in k])
|
39 |
+
else:
|
40 |
+
v = r(x_json[k])
|
41 |
+
outputs.append(v)
|
42 |
+
if len(outputs) == 1:
|
43 |
+
outputs = outputs[0]
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
def setup_webdataset_path(paths, cache_path=None):
|
47 |
+
if cache_path is None or not os.path.exists(cache_path):
|
48 |
+
tar_paths = []
|
49 |
+
if isinstance(paths, str):
|
50 |
+
paths = [paths]
|
51 |
+
for path in paths:
|
52 |
+
if path.strip().endswith(".tar"):
|
53 |
+
# Avoid looking up s3 if we already have a tar file
|
54 |
+
tar_paths.append(path)
|
55 |
+
continue
|
56 |
+
bucket = "/".join(path.split("/")[:3])
|
57 |
+
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
|
58 |
+
files = result.stdout.decode('utf-8').split()
|
59 |
+
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
|
60 |
+
tar_paths += files
|
61 |
+
|
62 |
+
with open(cache_path, 'w', encoding='utf-8') as outfile:
|
63 |
+
yaml.dump(tar_paths, outfile, default_flow_style=False)
|
64 |
+
else:
|
65 |
+
with open(cache_path, 'r', encoding='utf-8') as file:
|
66 |
+
tar_paths = yaml.safe_load(file)
|
67 |
+
|
68 |
+
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
|
69 |
+
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
|
core/data/bucketeer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import numpy as np
|
4 |
+
from torchtools.transforms import SmartCrop
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Bucketeer():
|
8 |
+
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
|
9 |
+
assert crop_mode in ['center', 'random', 'smart']
|
10 |
+
self.crop_mode = crop_mode
|
11 |
+
self.ratios = ratios
|
12 |
+
if reverse_list:
|
13 |
+
for r in list(ratios):
|
14 |
+
if 1/r not in self.ratios:
|
15 |
+
self.ratios.append(1/r)
|
16 |
+
self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios]
|
17 |
+
self.batch_size = dataloader.batch_size
|
18 |
+
self.iterator = iter(dataloader)
|
19 |
+
self.buckets = {s: [] for s in self.sizes}
|
20 |
+
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
|
21 |
+
self.p_random_ratio = p_random_ratio
|
22 |
+
self.interpolate_nearest = interpolate_nearest
|
23 |
+
|
24 |
+
def get_available_batch(self):
|
25 |
+
for b in self.buckets:
|
26 |
+
if len(self.buckets[b]) >= self.batch_size:
|
27 |
+
batch = self.buckets[b][:self.batch_size]
|
28 |
+
self.buckets[b] = self.buckets[b][self.batch_size:]
|
29 |
+
return batch
|
30 |
+
return None
|
31 |
+
|
32 |
+
def get_closest_size(self, x):
|
33 |
+
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
|
34 |
+
best_size_idx = np.random.randint(len(self.ratios))
|
35 |
+
else:
|
36 |
+
w, h = x.size(-1), x.size(-2)
|
37 |
+
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
|
38 |
+
return self.sizes[best_size_idx]
|
39 |
+
|
40 |
+
def get_resize_size(self, orig_size, tgt_size):
|
41 |
+
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
|
42 |
+
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
|
43 |
+
resize_size = max(alt_min, min(tgt_size))
|
44 |
+
else:
|
45 |
+
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
|
46 |
+
resize_size = max(alt_max, max(tgt_size))
|
47 |
+
return resize_size
|
48 |
+
|
49 |
+
def __next__(self):
|
50 |
+
batch = self.get_available_batch()
|
51 |
+
while batch is None:
|
52 |
+
elements = next(self.iterator)
|
53 |
+
for dct in elements:
|
54 |
+
img = dct['images']
|
55 |
+
size = self.get_closest_size(img)
|
56 |
+
resize_size = self.get_resize_size(img.shape[-2:], size)
|
57 |
+
if self.interpolate_nearest:
|
58 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
|
59 |
+
else:
|
60 |
+
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
|
61 |
+
if self.crop_mode == 'center':
|
62 |
+
img = torchvision.transforms.functional.center_crop(img, size)
|
63 |
+
elif self.crop_mode == 'random':
|
64 |
+
img = torchvision.transforms.RandomCrop(size)(img)
|
65 |
+
elif self.crop_mode == 'smart':
|
66 |
+
self.smartcrop.output_size = size
|
67 |
+
img = self.smartcrop(img)
|
68 |
+
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
|
69 |
+
batch = self.get_available_batch()
|
70 |
+
|
71 |
+
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
|
72 |
+
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
|
core/scripts/__init__.py
ADDED
File without changes
|
core/scripts/cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
from .. import WarpCore
|
4 |
+
from .. import templates
|
5 |
+
|
6 |
+
|
7 |
+
def template_init(args):
|
8 |
+
return ''''
|
9 |
+
|
10 |
+
|
11 |
+
'''.strip()
|
12 |
+
|
13 |
+
|
14 |
+
def init_template(args):
|
15 |
+
parser = argparse.ArgumentParser(description='WarpCore template init tool')
|
16 |
+
parser.add_argument('-t', '--template', type=str, default='WarpCore')
|
17 |
+
args = parser.parse_args(args)
|
18 |
+
|
19 |
+
if args.template == 'WarpCore':
|
20 |
+
template_cls = WarpCore
|
21 |
+
else:
|
22 |
+
try:
|
23 |
+
template_cls = __import__(args.template)
|
24 |
+
except ModuleNotFoundError:
|
25 |
+
template_cls = getattr(templates, args.template)
|
26 |
+
print(template_cls)
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
if len(sys.argv) < 2:
|
31 |
+
print('Usage: core <command>')
|
32 |
+
sys.exit(1)
|
33 |
+
if sys.argv[1] == 'init':
|
34 |
+
init_template(sys.argv[2:])
|
35 |
+
else:
|
36 |
+
print('Unknown command')
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
core/templates/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion import DiffusionCore
|
core/templates/diffusion.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .. import WarpCore
|
2 |
+
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
|
3 |
+
from abc import abstractmethod
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from gdf import GDF
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import wandb
|
12 |
+
|
13 |
+
import webdataset as wds
|
14 |
+
from webdataset.handlers import warn_and_continue
|
15 |
+
from torch.distributed import barrier
|
16 |
+
from enum import Enum
|
17 |
+
|
18 |
+
class TargetReparametrization(Enum):
|
19 |
+
EPSILON = 'epsilon'
|
20 |
+
X0 = 'x0'
|
21 |
+
|
22 |
+
class DiffusionCore(WarpCore):
|
23 |
+
@dataclass(frozen=True)
|
24 |
+
class Config(WarpCore.Config):
|
25 |
+
# TRAINING PARAMS
|
26 |
+
lr: float = EXPECTED_TRAIN
|
27 |
+
grad_accum_steps: int = EXPECTED_TRAIN
|
28 |
+
batch_size: int = EXPECTED_TRAIN
|
29 |
+
updates: int = EXPECTED_TRAIN
|
30 |
+
warmup_updates: int = EXPECTED_TRAIN
|
31 |
+
save_every: int = 500
|
32 |
+
backup_every: int = 20000
|
33 |
+
use_fsdp: bool = True
|
34 |
+
|
35 |
+
# EMA UPDATE
|
36 |
+
ema_start_iters: int = None
|
37 |
+
ema_iters: int = None
|
38 |
+
ema_beta: float = None
|
39 |
+
|
40 |
+
# GDF setting
|
41 |
+
gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
|
42 |
+
|
43 |
+
@dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
|
44 |
+
class Info(WarpCore.Info):
|
45 |
+
ema_loss: float = None
|
46 |
+
|
47 |
+
@dataclass(frozen=True)
|
48 |
+
class Models(WarpCore.Models):
|
49 |
+
generator : nn.Module = EXPECTED
|
50 |
+
generator_ema : nn.Module = None # optional
|
51 |
+
|
52 |
+
@dataclass(frozen=True)
|
53 |
+
class Optimizers(WarpCore.Optimizers):
|
54 |
+
generator : any = EXPECTED
|
55 |
+
|
56 |
+
@dataclass(frozen=True)
|
57 |
+
class Schedulers(WarpCore.Schedulers):
|
58 |
+
generator: any = None
|
59 |
+
|
60 |
+
@dataclass(frozen=True)
|
61 |
+
class Extras(WarpCore.Extras):
|
62 |
+
gdf: GDF = EXPECTED
|
63 |
+
sampling_configs: dict = EXPECTED
|
64 |
+
|
65 |
+
# --------------------------------------------
|
66 |
+
info: Info
|
67 |
+
config: Config
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
71 |
+
raise NotImplementedError("This method needs to be overriden")
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
|
75 |
+
raise NotImplementedError("This method needs to be overriden")
|
76 |
+
|
77 |
+
@abstractmethod
|
78 |
+
def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
|
79 |
+
raise NotImplementedError("This method needs to be overriden")
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def webdataset_path(self, extras: Extras):
|
83 |
+
raise NotImplementedError("This method needs to be overriden")
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def webdataset_filters(self, extras: Extras):
|
87 |
+
raise NotImplementedError("This method needs to be overriden")
|
88 |
+
|
89 |
+
@abstractmethod
|
90 |
+
def webdataset_preprocessors(self, extras: Extras):
|
91 |
+
raise NotImplementedError("This method needs to be overriden")
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
|
95 |
+
raise NotImplementedError("This method needs to be overriden")
|
96 |
+
# -------------
|
97 |
+
|
98 |
+
def setup_data(self, extras: Extras) -> WarpCore.Data:
|
99 |
+
# SETUP DATASET
|
100 |
+
dataset_path = self.webdataset_path(extras)
|
101 |
+
preprocessors = self.webdataset_preprocessors(extras)
|
102 |
+
filters = self.webdataset_filters(extras)
|
103 |
+
|
104 |
+
handler = warn_and_continue # None
|
105 |
+
# handler = None
|
106 |
+
dataset = wds.WebDataset(
|
107 |
+
dataset_path, resampled=True, handler=handler
|
108 |
+
).select(filters).shuffle(690, handler=handler).decode(
|
109 |
+
"pilrgb", handler=handler
|
110 |
+
).to_tuple(
|
111 |
+
*[p[0] for p in preprocessors], handler=handler
|
112 |
+
).map_tuple(
|
113 |
+
*[p[1] for p in preprocessors], handler=handler
|
114 |
+
).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
|
115 |
+
|
116 |
+
# SETUP DATALOADER
|
117 |
+
real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
|
118 |
+
dataloader = DataLoader(
|
119 |
+
dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
|
120 |
+
)
|
121 |
+
|
122 |
+
return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
|
123 |
+
|
124 |
+
def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
|
125 |
+
batch = next(data.iterator)
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
conditions = self.get_conditions(batch, models, extras)
|
129 |
+
latents = self.encode_latents(batch, models, extras)
|
130 |
+
noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
|
131 |
+
|
132 |
+
# FORWARD PASS
|
133 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
134 |
+
pred = models.generator(noised, noise_cond, **conditions)
|
135 |
+
if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
|
136 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
|
137 |
+
target = noise
|
138 |
+
elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
|
139 |
+
pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
|
140 |
+
target = latents
|
141 |
+
loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
|
142 |
+
loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
|
143 |
+
|
144 |
+
return loss, loss_adjusted
|
145 |
+
|
146 |
+
def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
|
147 |
+
start_iter = self.info.iter+1
|
148 |
+
max_iters = self.config.updates * self.config.grad_accum_steps
|
149 |
+
if self.is_main_node:
|
150 |
+
print(f"STARTING AT STEP: {start_iter}/{max_iters}")
|
151 |
+
|
152 |
+
pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
|
153 |
+
models.generator.train()
|
154 |
+
for i in pbar:
|
155 |
+
# FORWARD PASS
|
156 |
+
loss, loss_adjusted = self.forward_pass(data, extras, models)
|
157 |
+
|
158 |
+
# BACKWARD PASS
|
159 |
+
if i % self.config.grad_accum_steps == 0 or i == max_iters:
|
160 |
+
loss_adjusted.backward()
|
161 |
+
grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
|
162 |
+
optimizers_dict = optimizers.to_dict()
|
163 |
+
for k in optimizers_dict:
|
164 |
+
optimizers_dict[k].step()
|
165 |
+
schedulers_dict = schedulers.to_dict()
|
166 |
+
for k in schedulers_dict:
|
167 |
+
schedulers_dict[k].step()
|
168 |
+
models.generator.zero_grad(set_to_none=True)
|
169 |
+
self.info.total_steps += 1
|
170 |
+
else:
|
171 |
+
with models.generator.no_sync():
|
172 |
+
loss_adjusted.backward()
|
173 |
+
self.info.iter = i
|
174 |
+
|
175 |
+
# UPDATE EMA
|
176 |
+
if models.generator_ema is not None and i % self.config.ema_iters == 0:
|
177 |
+
update_weights_ema(
|
178 |
+
models.generator_ema, models.generator,
|
179 |
+
beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
|
180 |
+
)
|
181 |
+
|
182 |
+
# UPDATE LOSS METRICS
|
183 |
+
self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
|
184 |
+
|
185 |
+
if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
|
186 |
+
wandb.alert(
|
187 |
+
title=f"NaN value encountered in training run {self.info.wandb_run_id}",
|
188 |
+
text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
|
189 |
+
wait_duration=60*30
|
190 |
+
)
|
191 |
+
|
192 |
+
if self.is_main_node:
|
193 |
+
logs = {
|
194 |
+
'loss': self.info.ema_loss,
|
195 |
+
'raw_loss': loss.mean().item(),
|
196 |
+
'grad_norm': grad_norm.item(),
|
197 |
+
'lr': optimizers.generator.param_groups[0]['lr'],
|
198 |
+
'total_steps': self.info.total_steps,
|
199 |
+
}
|
200 |
+
|
201 |
+
pbar.set_postfix(logs)
|
202 |
+
if self.config.wandb_project is not None:
|
203 |
+
wandb.log(logs)
|
204 |
+
|
205 |
+
if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
|
206 |
+
# SAVE AND CHECKPOINT STUFF
|
207 |
+
if np.isnan(loss.mean().item()):
|
208 |
+
if self.is_main_node and self.config.wandb_project is not None:
|
209 |
+
tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
|
210 |
+
wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
|
211 |
+
else:
|
212 |
+
self.save_checkpoints(models, optimizers)
|
213 |
+
if self.is_main_node:
|
214 |
+
create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
|
215 |
+
self.sample(models, data, extras)
|
216 |
+
|
217 |
+
def models_to_save(self):
|
218 |
+
return ['generator', 'generator_ema']
|
219 |
+
|
220 |
+
def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
|
221 |
+
barrier()
|
222 |
+
suffix = '' if suffix is None else suffix
|
223 |
+
self.save_info(self.info, suffix=suffix)
|
224 |
+
models_dict = models.to_dict()
|
225 |
+
optimizers_dict = optimizers.to_dict()
|
226 |
+
for key in self.models_to_save():
|
227 |
+
model = models_dict[key]
|
228 |
+
if model is not None:
|
229 |
+
self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
|
230 |
+
for key in optimizers_dict:
|
231 |
+
optimizer = optimizers_dict[key]
|
232 |
+
if optimizer is not None:
|
233 |
+
self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
|
234 |
+
if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
|
235 |
+
self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
|
236 |
+
torch.cuda.empty_cache()
|
core/utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
|
2 |
+
from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
|
3 |
+
|
4 |
+
# MOVE IT SOMERWHERE ELSE
|
5 |
+
def update_weights_ema(tgt_model, src_model, beta=0.999):
|
6 |
+
for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
|
7 |
+
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
|
8 |
+
for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
|
9 |
+
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
|
core/utils/base_dto.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from dataclasses import dataclass, _MISSING_TYPE
|
3 |
+
from munch import Munch
|
4 |
+
|
5 |
+
EXPECTED = "___REQUIRED___"
|
6 |
+
EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
|
7 |
+
|
8 |
+
# pylint: disable=invalid-field-call
|
9 |
+
def nested_dto(x, raw=False):
|
10 |
+
return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
|
11 |
+
|
12 |
+
@dataclass(frozen=True)
|
13 |
+
class Base:
|
14 |
+
training: bool = None
|
15 |
+
def __new__(cls, **kwargs):
|
16 |
+
training = kwargs.get('training', True)
|
17 |
+
setteable_fields = cls.setteable_fields(**kwargs)
|
18 |
+
mandatory_fields = cls.mandatory_fields(**kwargs)
|
19 |
+
invalid_kwargs = [
|
20 |
+
{k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
|
21 |
+
]
|
22 |
+
print(mandatory_fields)
|
23 |
+
assert (
|
24 |
+
len(invalid_kwargs) == 0
|
25 |
+
), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
|
26 |
+
missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
|
27 |
+
assert (
|
28 |
+
len(missing_kwargs) == 0
|
29 |
+
), f"Required fields missing initializing this DTO: {missing_kwargs}."
|
30 |
+
return object.__new__(cls)
|
31 |
+
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def setteable_fields(cls, **kwargs):
|
35 |
+
return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def mandatory_fields(cls, **kwargs):
|
39 |
+
training = kwargs.get('training', True)
|
40 |
+
return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_dict(cls, kwargs):
|
44 |
+
for k in kwargs:
|
45 |
+
if isinstance(kwargs[k], (dict, list, tuple)):
|
46 |
+
kwargs[k] = Munch.fromDict(kwargs[k])
|
47 |
+
return cls(**kwargs)
|
48 |
+
|
49 |
+
def to_dict(self):
|
50 |
+
# selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
|
51 |
+
selfdict = {}
|
52 |
+
for k in dataclasses.fields(self):
|
53 |
+
selfdict[k.name] = getattr(self, k.name)
|
54 |
+
if isinstance(selfdict[k.name], Munch):
|
55 |
+
selfdict[k.name] = selfdict[k.name].toDict()
|
56 |
+
return selfdict
|
core/utils/save_and_load.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
import safetensors
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
|
9 |
+
def create_folder_if_necessary(path):
|
10 |
+
path = "/".join(path.split("/")[:-1])
|
11 |
+
Path(path).mkdir(parents=True, exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
def safe_save(ckpt, path):
|
15 |
+
try:
|
16 |
+
os.remove(f"{path}.bak")
|
17 |
+
except OSError:
|
18 |
+
pass
|
19 |
+
try:
|
20 |
+
os.rename(path, f"{path}.bak")
|
21 |
+
except OSError:
|
22 |
+
pass
|
23 |
+
if path.endswith(".pt") or path.endswith(".ckpt"):
|
24 |
+
torch.save(ckpt, path)
|
25 |
+
elif path.endswith(".json"):
|
26 |
+
with open(path, "w", encoding="utf-8") as f:
|
27 |
+
json.dump(ckpt, f, indent=4)
|
28 |
+
elif path.endswith(".safetensors"):
|
29 |
+
safetensors.torch.save_file(ckpt, path)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"File extension not supported: {path}")
|
32 |
+
|
33 |
+
|
34 |
+
def load_or_fail(path, wandb_run_id=None):
|
35 |
+
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
|
36 |
+
try:
|
37 |
+
assert any(
|
38 |
+
[path.endswith(ext) for ext in accepted_extensions]
|
39 |
+
), f"Automatic loading not supported for this extension: {path}"
|
40 |
+
if not os.path.exists(path):
|
41 |
+
checkpoint = None
|
42 |
+
elif path.endswith(".pt") or path.endswith(".ckpt"):
|
43 |
+
checkpoint = torch.load(path, map_location="cpu")
|
44 |
+
elif path.endswith(".json"):
|
45 |
+
with open(path, "r", encoding="utf-8") as f:
|
46 |
+
checkpoint = json.load(f)
|
47 |
+
elif path.endswith(".safetensors"):
|
48 |
+
checkpoint = {}
|
49 |
+
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
50 |
+
for key in f.keys():
|
51 |
+
checkpoint[key] = f.get_tensor(key)
|
52 |
+
return checkpoint
|
53 |
+
except Exception as e:
|
54 |
+
if wandb_run_id is not None:
|
55 |
+
wandb.alert(
|
56 |
+
title=f"Corrupt checkpoint for run {wandb_run_id}",
|
57 |
+
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
|
58 |
+
)
|
59 |
+
raise e
|
figures/collage_1.jpg
ADDED
Git LFS Details
|
figures/collage_2.jpg
ADDED
figures/collage_3.jpg
ADDED
Git LFS Details
|
figures/collage_4.jpg
ADDED
figures/comparison-inference-speed.jpg
ADDED
figures/comparison.png
ADDED
figures/controlnet-canny.jpg
ADDED
figures/controlnet-face.jpg
ADDED
figures/controlnet-paint.jpg
ADDED
figures/controlnet-sr.jpg
ADDED
Git LFS Details
|
figures/fernando.jpg
ADDED
figures/fernando_original.jpg
ADDED
figures/image-to-image-example-rodent.jpg
ADDED
figures/image-variations-example-headset.jpg
ADDED
figures/model-overview.jpg
ADDED