ameerazam08 commited on
Commit
6a6edcb
1 Parent(s): 5d25eca

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +12 -0
  3. LICENSE +21 -0
  4. README.md +4 -4
  5. WEIGHTS_LICENSE +44 -0
  6. __init__.py +0 -0
  7. app.py +35 -0
  8. configs/inference/controlnet_c_3b_canny.yaml +14 -0
  9. configs/inference/controlnet_c_3b_identity.yaml +17 -0
  10. configs/inference/controlnet_c_3b_inpainting.yaml +15 -0
  11. configs/inference/controlnet_c_3b_sr.yaml +15 -0
  12. configs/inference/lora_c_3b.yaml +15 -0
  13. configs/inference/stage_b_3b.yaml +13 -0
  14. configs/inference/stage_c_3b.yaml +7 -0
  15. configs/training/controlnet_c_3b_canny.yaml +45 -0
  16. configs/training/controlnet_c_3b_identity.yaml +48 -0
  17. configs/training/controlnet_c_3b_inpainting.yaml +46 -0
  18. configs/training/controlnet_c_3b_sr.yaml +46 -0
  19. configs/training/finetune_b_3b.yaml +36 -0
  20. configs/training/finetune_b_700m.yaml +36 -0
  21. configs/training/finetune_c_1b.yaml +35 -0
  22. configs/training/finetune_c_3b.yaml +35 -0
  23. configs/training/finetune_c_3b_lora.yaml +44 -0
  24. configs/training/finetune_c_3b_lowres.yaml +41 -0
  25. configs/training/finetune_c_3b_v.yaml +36 -0
  26. core/__init__.py +371 -0
  27. core/data/__init__.py +69 -0
  28. core/data/bucketeer.py +72 -0
  29. core/scripts/__init__.py +0 -0
  30. core/scripts/cli.py +41 -0
  31. core/templates/__init__.py +1 -0
  32. core/templates/diffusion.py +236 -0
  33. core/utils/__init__.py +9 -0
  34. core/utils/base_dto.py +56 -0
  35. core/utils/save_and_load.py +59 -0
  36. figures/collage_1.jpg +3 -0
  37. figures/collage_2.jpg +0 -0
  38. figures/collage_3.jpg +3 -0
  39. figures/collage_4.jpg +0 -0
  40. figures/comparison-inference-speed.jpg +0 -0
  41. figures/comparison.png +0 -0
  42. figures/controlnet-canny.jpg +0 -0
  43. figures/controlnet-face.jpg +0 -0
  44. figures/controlnet-paint.jpg +0 -0
  45. figures/controlnet-sr.jpg +3 -0
  46. figures/fernando.jpg +0 -0
  47. figures/fernando_original.jpg +0 -0
  48. figures/image-to-image-example-rodent.jpg +0 -0
  49. figures/image-variations-example-headset.jpg +0 -0
  50. figures/model-overview.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/collage_1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ figures/collage_3.jpg filter=lfs diff=lfs merge=lfs -text
38
+ figures/controlnet-sr.jpg filter=lfs diff=lfs merge=lfs -text
39
+ inference/controlnet.ipynb filter=lfs diff=lfs merge=lfs -text
40
+ inference/reconstruct_images.ipynb filter=lfs diff=lfs merge=lfs -text
41
+ inference/text_to_image.ipynb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.yml
2
+ *.out
3
+ dist_file_*
4
+ __pycache__/*
5
+ */__pycache__/*
6
+ */**/__pycache__/*
7
+ *_latest_output.jpg
8
+ *_sample.jpg
9
+ jobs/*.sh
10
+ .ipynb_checkpoints
11
+ *.safetensors
12
+ *_test.yaml
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Stable Cascade SR
3
- emoji: 🌖
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.19.1
8
  app_file: app.py
 
1
  ---
2
+ title: Stable Cascade Upscale
3
+ emoji: 🏃
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.19.1
8
  app_file: app.py
WEIGHTS_LICENSE ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## THIS LICENSE IS FOR THE MODEL WEIGHTS ONLY
2
+
3
+ STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
4
+ Dated: November 28, 2023
5
+
6
+ By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.
7
+
8
+ "Agreement" means this Stable Non-Commercial Research Community License Agreement.
9
+
10
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
11
+
12
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
13
+
14
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
15
+
16
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
17
+
18
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
19
+
20
+ “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
21
+
22
+ "Stability AI" or "we" means Stability AI Ltd. and its affiliates.
23
+
24
+
25
+ "Software" means Stability AI’s proprietary software made available under this Agreement.
26
+
27
+ “Software Products” means the Models, Software and Documentation, individually or in any combination.
28
+
29
+
30
+
31
+ 1. License Rights and Redistribution.
32
+ a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
33
+ b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
34
+ c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
35
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ 4. Intellectual Property.
38
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
39
+ b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
40
+ c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
41
+ 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
42
+
43
+ 6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
44
+ principles.
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from main import Upscale_CaseCade
4
+ import spaces
5
+
6
+ upscale_class=Upscale_CaseCade()
7
+ # scale_fator=7
8
+ # url = "https://cdn.discordapp.com/attachments/1121232062708457508/1205110687538479145/A_photograph_of_a_sunflower_with_sunglasses_on_in__3.jpg?ex=65d72dc9&is=65c4b8c9&hm=72172e774ce6cda618503b3778b844de05cd1208b61e185d8418db512fb2858a&"
9
+ # image_pil=Image.open("/home/rnd/Documents/Ameer/StableCascade/poster.png").convert("RGB")
10
+ @spaces.GPU
11
+ def scale_image(image_pil,scale_factor):
12
+ og,ups=upscale_class.upscale_image(image_pil=image_pil.convert("RGB"),scale_fator=scale_factor)
13
+ return [ups]
14
+ DESCRIPTION = "# Stable Cascade -> Super Resolution"
15
+ DESCRIPTION += "\n<p style=\"text-align: center\">Unofficial demo for Cascade-Super Resolution <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Upscale Cascade</a>, a new high resolution image-to-image model by Stability AI, - <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>non-commercial research license</a></p>"
16
+ # block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dracula_revamped').queue()
17
+ block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dark').queue()
18
+
19
+ with block:
20
+ with gr.Row():
21
+ gr.Markdown(DESCRIPTION)
22
+ with gr.Tabs():
23
+ with gr.Row():
24
+ with gr.Column():
25
+ image_pil = gr.Image(label="Describe the Image", type='pil')
26
+ scale_factor = gr.Slider(minimum=1,maximum=10,value=1, step=1, label="Scale Factor")
27
+ generate_button = gr.Button("Upscale Image")
28
+ with gr.Column():
29
+ generated_image = gr.Gallery(label="Generated Image",)
30
+
31
+ generate_button.click(fn=scale_image, inputs=[image_pil,scale_factor], outputs=[generated_image])
32
+
33
+ block.launch(show_api=False, server_port=8888, share=False, show_error=True, max_threads=1)
34
+
35
+ # pip install gradio==4.16.0 gradio_client==0.8.1
configs/inference/controlnet_c_3b_canny.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
+ controlnet_filter: CannyFilter
8
+ controlnet_filter_params:
9
+ resize: 224
10
+
11
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
12
+ previewer_checkpoint_path: models/previewer.safetensors
13
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
14
+ controlnet_checkpoint_path: models/canny.safetensors
configs/inference/controlnet_c_3b_identity.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_bottleneck_mode: 'simple'
7
+ controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
8
+ controlnet_filter: IdentityFilter
9
+ controlnet_filter_params:
10
+ max_faces: 4
11
+ p_drop: 0.00
12
+ p_full: 0.0
13
+
14
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
15
+ previewer_checkpoint_path: models/previewer.safetensors
16
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
17
+ controlnet_checkpoint_path:
configs/inference/controlnet_c_3b_inpainting.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
+ controlnet_filter: InpaintFilter
8
+ controlnet_filter_params:
9
+ thresold: [0.04, 0.4]
10
+ p_outpaint: 0.4
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ controlnet_checkpoint_path: models/inpainting.safetensors
configs/inference/controlnet_c_3b_sr.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_bottleneck_mode: 'large'
7
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
8
+ controlnet_filter: SREffnetFilter
9
+ controlnet_filter_params:
10
+ scale_factor: 0.5
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ controlnet_checkpoint_path: models/super_resolution.safetensors
configs/inference/lora_c_3b.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # LoRA specific
6
+ module_filters: ['.attn']
7
+ rank: 4
8
+ train_tokens:
9
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
10
+ - ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ lora_checkpoint_path: models/lora_fernando_10k.safetensors
configs/inference/stage_b_3b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3B
3
+ dtype: bfloat16
4
+
5
+ # For demonstration purposes in reconstruct_images.ipynb
6
+ webdataset_path: file:inference/imagenet_1024.tar
7
+ batch_size: 4
8
+ image_size: 1024
9
+ grad_accum_steps: 1
10
+
11
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
12
+ stage_a_checkpoint_path: models/stage_a.safetensors
13
+ generator_checkpoint_path: models/stage_b_bf16.safetensors
configs/inference/stage_c_3b.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
6
+ previewer_checkpoint_path: models/previewer.safetensors
7
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/controlnet_c_3b_canny.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_controlnet_canny
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 256
14
+ image_size: 768
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 10000
18
+ backup_every: 2000
19
+ save_every: 1000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # ControlNet specific
24
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
25
+ controlnet_filter: CannyFilter
26
+ controlnet_filter_params:
27
+ resize: 224
28
+ # offset_noise: 0.1
29
+
30
+ # CUSTOM CAPTIONS GETTER & FILTERS
31
+ captions_getter: ['txt', identity]
32
+ dataset_filters:
33
+ - ['width', 'lambda w: w >= 768']
34
+ - ['height', 'lambda h: h >= 768']
35
+
36
+ # ema_start_iters: 5000
37
+ # ema_iters: 100
38
+ # ema_beta: 0.9
39
+
40
+ webdataset_path:
41
+ - s3://path/to/your/first/dataset/on/s3
42
+ - s3://path/to/your/second/dataset/on/s3
43
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
44
+ previewer_checkpoint_path: models/previewer.safetensors
45
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/controlnet_c_3b_identity.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_controlnet_identity
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 256
14
+ image_size: 768
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 200000
18
+ backup_every: 2000
19
+ save_every: 1000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # ControlNet specific
24
+ controlnet_bottleneck_mode: 'simple'
25
+ controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
26
+ controlnet_filter: IdentityFilter
27
+ controlnet_filter_params:
28
+ max_faces: 4
29
+ p_drop: 0.05
30
+ p_full: 0.3
31
+ # offset_noise: 0.1
32
+
33
+ # CUSTOM CAPTIONS GETTER & FILTERS
34
+ captions_getter: ['txt', identity]
35
+ dataset_filters:
36
+ - ['width', 'lambda w: w >= 768']
37
+ - ['height', 'lambda h: h >= 768']
38
+
39
+ # ema_start_iters: 5000
40
+ # ema_iters: 100
41
+ # ema_beta: 0.9
42
+
43
+ webdataset_path:
44
+ - s3://path/to/your/first/dataset/on/s3
45
+ - s3://path/to/your/second/dataset/on/s3
46
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
47
+ previewer_checkpoint_path: models/previewer.safetensors
48
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/controlnet_c_3b_inpainting.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_controlnet_inpainting
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 256
14
+ image_size: 768
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 10000
18
+ backup_every: 2000
19
+ save_every: 1000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # ControlNet specific
24
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
25
+ controlnet_filter: InpaintFilter
26
+ controlnet_filter_params:
27
+ thresold: [0.04, 0.4]
28
+ p_outpaint: 0.4
29
+ offset_noise: 0.1
30
+
31
+ # CUSTOM CAPTIONS GETTER & FILTERS
32
+ captions_getter: ['txt', identity]
33
+ dataset_filters:
34
+ - ['width', 'lambda w: w >= 768']
35
+ - ['height', 'lambda h: h >= 768']
36
+
37
+ # ema_start_iters: 5000
38
+ # ema_iters: 100
39
+ # ema_beta: 0.9
40
+
41
+ webdataset_path:
42
+ - s3://path/to/your/first/dataset/on/s3
43
+ - s3://path/to/your/second/dataset/on/s3
44
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
45
+ previewer_checkpoint_path: models/previewer.safetensors
46
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/controlnet_c_3b_sr.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_controlnet_sr
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 256
14
+ image_size: 768
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 30000
18
+ backup_every: 5000
19
+ save_every: 1000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # ControlNet specific
24
+ controlnet_bottleneck_mode: 'large'
25
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
26
+ controlnet_filter: SREffnetFilter
27
+ controlnet_filter_params:
28
+ scale_factor: 0.5
29
+ offset_noise: 0.1
30
+
31
+ # CUSTOM CAPTIONS GETTER & FILTERS
32
+ captions_getter: ['txt', identity]
33
+ dataset_filters:
34
+ - ['width', 'lambda w: w >= 768']
35
+ - ['height', 'lambda h: h >= 768']
36
+
37
+ # ema_start_iters: 5000
38
+ # ema_iters: 100
39
+ # ema_beta: 0.9
40
+
41
+ webdataset_path:
42
+ - s3://path/to/your/first/dataset/on/s3
43
+ - s3://path/to/your/second/dataset/on/s3
44
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
45
+ previewer_checkpoint_path: models/previewer.safetensors
46
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/finetune_b_3b.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_b_3b_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 256
14
+ image_size: 1024
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ shift: 4
17
+ grad_accum_steps: 1
18
+ updates: 100000
19
+ backup_every: 20000
20
+ save_every: 1000
21
+ warmup_updates: 1
22
+ use_fsdp: True
23
+
24
+ # GDF
25
+ adaptive_loss_weight: True
26
+
27
+ # ema_start_iters: 5000
28
+ # ema_iters: 100
29
+ # ema_beta: 0.9
30
+
31
+ webdataset_path:
32
+ - s3://path/to/your/first/dataset/on/s3
33
+ - s3://path/to/your/second/dataset/on/s3
34
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
35
+ stage_a_checkpoint_path: models/stage_a.safetensors
36
+ generator_checkpoint_path: models/stage_b_bf16.safetensors
configs/training/finetune_b_700m.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_b_700m_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 700M
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 512
14
+ image_size: 1024
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ shift: 4
17
+ grad_accum_steps: 1
18
+ updates: 10000
19
+ backup_every: 20000
20
+ save_every: 2000
21
+ warmup_updates: 1
22
+ use_fsdp: True
23
+
24
+ # GDF
25
+ adaptive_loss_weight: True
26
+
27
+ # ema_start_iters: 5000
28
+ # ema_iters: 100
29
+ # ema_beta: 0.9
30
+
31
+ webdataset_path:
32
+ - s3://path/to/your/first/dataset/on/s3
33
+ - s3://path/to/your/second/dataset/on/s3
34
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
35
+ stage_a_checkpoint_path: models/stage_a.safetensors
36
+ generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
configs/training/finetune_c_1b.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_1b_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 1B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 1024
14
+ image_size: 768
15
+ # multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 10000
18
+ backup_every: 20000
19
+ save_every: 2000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # GDF
24
+ # adaptive_loss_weight: True
25
+
26
+ # ema_start_iters: 5000
27
+ # ema_iters: 100
28
+ # ema_beta: 0.9
29
+
30
+ webdataset_path:
31
+ - s3://path/to/your/first/dataset/on/s3
32
+ - s3://path/to/your/second/dataset/on/s3
33
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
34
+ previewer_checkpoint_path: models/previewer.safetensors
35
+ generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
configs/training/finetune_c_3b.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 512
14
+ image_size: 768
15
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 100000
18
+ backup_every: 20000
19
+ save_every: 2000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # GDF
24
+ adaptive_loss_weight: True
25
+
26
+ # ema_start_iters: 5000
27
+ # ema_iters: 100
28
+ # ema_beta: 0.9
29
+
30
+ webdataset_path:
31
+ - s3://path/to/your/first/dataset/on/s3
32
+ - s3://path/to/your/second/dataset/on/s3
33
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
34
+ previewer_checkpoint_path: models/previewer.safetensors
35
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/finetune_c_3b_lora.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_lora
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 32
14
+ image_size: 768
15
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 4
17
+ updates: 10000
18
+ backup_every: 1000
19
+ save_every: 100
20
+ warmup_updates: 1
21
+ # use_fsdp: True -> FSDP doesn't work at the moment for LoRA
22
+ use_fsdp: False
23
+
24
+ # GDF
25
+ # adaptive_loss_weight: True
26
+
27
+ # LoRA specific
28
+ module_filters: ['.attn']
29
+ rank: 4
30
+ train_tokens:
31
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
32
+ - ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
33
+
34
+
35
+ # ema_start_iters: 5000
36
+ # ema_iters: 100
37
+ # ema_beta: 0.9
38
+
39
+ webdataset_path:
40
+ - s3://path/to/your/first/dataset/on/s3
41
+ - s3://path/to/your/second/dataset/on/s3
42
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
43
+ previewer_checkpoint_path: models/previewer.safetensors
44
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/finetune_c_3b_lowres.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 1024
14
+ image_size: 384
15
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 100000
18
+ backup_every: 20000
19
+ save_every: 2000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # GDF
24
+ adaptive_loss_weight: True
25
+
26
+ # CUSTOM CAPTIONS GETTER & FILTERS
27
+ # captions_getter: ['json', captions_getter]
28
+ # dataset_filters:
29
+ # - ['normalized_score', 'lambda s: s > 9.0']
30
+ # - ['pgen_normalized_score', 'lambda s: s > 3.0']
31
+
32
+ # ema_start_iters: 5000
33
+ # ema_iters: 100
34
+ # ema_beta: 0.9
35
+
36
+ webdataset_path:
37
+ - s3://path/to/your/first/dataset/on/s3
38
+ - s3://path/to/your/second/dataset/on/s3
39
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
40
+ previewer_checkpoint_path: models/previewer.safetensors
41
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/finetune_c_3b_v.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: stage_c_3b_finetuning
3
+ checkpoint_path: /path/to/checkpoint
4
+ output_path: /path/to/output
5
+ model_version: 3.6B
6
+
7
+ # WandB
8
+ wandb_project: StableCascade
9
+ wandb_entity: wandb_username
10
+
11
+ # TRAINING PARAMS
12
+ lr: 1.0e-4
13
+ batch_size: 512
14
+ image_size: 768
15
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
16
+ grad_accum_steps: 1
17
+ updates: 100000
18
+ backup_every: 20000
19
+ save_every: 2000
20
+ warmup_updates: 1
21
+ use_fsdp: True
22
+
23
+ # GDF
24
+ adaptive_loss_weight: True
25
+ edm_objective: True
26
+
27
+ # ema_start_iters: 5000
28
+ # ema_iters: 100
29
+ # ema_beta: 0.9
30
+
31
+ webdataset_path:
32
+ - s3://path/to/your/first/dataset/on/s3
33
+ - s3://path/to/your/second/dataset/on/s3
34
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
35
+ previewer_checkpoint_path: models/previewer.safetensors
36
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
core/__init__.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from torch import nn
5
+ import wandb
6
+ import json
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ from torch.distributed import init_process_group, destroy_process_group, barrier
12
+ from torch.distributed.fsdp import (
13
+ FullyShardedDataParallel as FSDP,
14
+ FullStateDictConfig,
15
+ MixedPrecision,
16
+ ShardingStrategy,
17
+ StateDictType
18
+ )
19
+
20
+ from .utils import Base, EXPECTED, EXPECTED_TRAIN
21
+ from .utils import create_folder_if_necessary, safe_save, load_or_fail
22
+
23
+ # pylint: disable=unused-argument
24
+ class WarpCore(ABC):
25
+ @dataclass(frozen=True)
26
+ class Config(Base):
27
+ experiment_id: str = EXPECTED_TRAIN
28
+ checkpoint_path: str = EXPECTED_TRAIN
29
+ output_path: str = EXPECTED_TRAIN
30
+ checkpoint_extension: str = "safetensors"
31
+ dist_file_subfolder: str = ""
32
+ allow_tf32: bool = True
33
+
34
+ wandb_project: str = None
35
+ wandb_entity: str = None
36
+
37
+ @dataclass() # not frozen, means that fields are mutable
38
+ class Info(): # not inheriting from Base, because we don't want to enforce the default fields
39
+ wandb_run_id: str = None
40
+ total_steps: int = 0
41
+ iter: int = 0
42
+
43
+ @dataclass(frozen=True)
44
+ class Data(Base):
45
+ dataset: Dataset = EXPECTED
46
+ dataloader: DataLoader = EXPECTED
47
+ iterator: any = EXPECTED
48
+
49
+ @dataclass(frozen=True)
50
+ class Models(Base):
51
+ pass
52
+
53
+ @dataclass(frozen=True)
54
+ class Optimizers(Base):
55
+ pass
56
+
57
+ @dataclass(frozen=True)
58
+ class Schedulers(Base):
59
+ pass
60
+
61
+ @dataclass(frozen=True)
62
+ class Extras(Base):
63
+ pass
64
+ # ---------------------------------------
65
+ info: Info
66
+ config: Config
67
+
68
+ # FSDP stuff
69
+ fsdp_defaults = {
70
+ "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
71
+ "cpu_offload": None,
72
+ "mixed_precision": MixedPrecision(
73
+ param_dtype=torch.bfloat16,
74
+ reduce_dtype=torch.bfloat16,
75
+ buffer_dtype=torch.bfloat16,
76
+ ),
77
+ "limit_all_gathers": True,
78
+ }
79
+ fsdp_fullstate_save_policy = FullStateDictConfig(
80
+ offload_to_cpu=True, rank0_only=True
81
+ )
82
+ # ------------
83
+
84
+ # OVERRIDEABLE METHODS
85
+
86
+ # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
87
+ def setup_extras_pre(self) -> Extras:
88
+ return self.Extras()
89
+
90
+ # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
91
+ @abstractmethod
92
+ def setup_data(self, extras: Extras) -> Data:
93
+ raise NotImplementedError("This method needs to be overriden")
94
+
95
+ # return a dict with all models that are going to be used in the training
96
+ @abstractmethod
97
+ def setup_models(self, extras: Extras) -> Models:
98
+ raise NotImplementedError("This method needs to be overriden")
99
+
100
+ # return a dict with all optimizers that are going to be used in the training
101
+ @abstractmethod
102
+ def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
103
+ raise NotImplementedError("This method needs to be overriden")
104
+
105
+ # [optionally] return a dict with all schedulers that are going to be used in the training
106
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
107
+ return self.Schedulers()
108
+
109
+ # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
110
+ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
111
+ return self.Extras.from_dict(extras.to_dict())
112
+
113
+ # perform the training here
114
+ @abstractmethod
115
+ def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
116
+ raise NotImplementedError("This method needs to be overriden")
117
+ # ------------
118
+
119
+ def setup_info(self, full_path=None) -> Info:
120
+ if full_path is None:
121
+ full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
122
+ info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
123
+ info_dto = self.Info(**info_dict)
124
+ if info_dto.total_steps > 0 and self.is_main_node:
125
+ print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
126
+ return info_dto
127
+
128
+ def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
129
+ if config_file_path is not None:
130
+ if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
131
+ with open(config_file_path, "r", encoding="utf-8") as file:
132
+ loaded_config = yaml.safe_load(file)
133
+ elif config_file_path.endswith(".json"):
134
+ with open(config_file_path, "r", encoding="utf-8") as file:
135
+ loaded_config = json.load(file)
136
+ else:
137
+ raise ValueError("Config file must be either a .yml|.yaml or .json file")
138
+ return self.Config.from_dict({**loaded_config, 'training': training})
139
+ if config_dict is not None:
140
+ return self.Config.from_dict({**config_dict, 'training': training})
141
+ return self.Config(training=training)
142
+
143
+ def setup_ddp(self, experiment_id, single_gpu=False):
144
+ if not single_gpu:
145
+ local_rank = int(os.environ.get("SLURM_LOCALID"))
146
+ process_id = int(os.environ.get("SLURM_PROCID"))
147
+ world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
148
+
149
+ self.process_id = process_id
150
+ self.is_main_node = process_id == 0
151
+ self.device = torch.device(local_rank)
152
+ self.world_size = world_size
153
+
154
+ dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
155
+ # if os.path.exists(dist_file_path) and self.is_main_node:
156
+ # os.remove(dist_file_path)
157
+
158
+ torch.cuda.set_device(local_rank)
159
+ init_process_group(
160
+ backend="nccl",
161
+ rank=process_id,
162
+ world_size=world_size,
163
+ init_method=f"file://{dist_file_path}",
164
+ )
165
+ print(f"[GPU {process_id}] READY")
166
+ else:
167
+ print("Running in single thread, DDP not enabled.")
168
+
169
+ def setup_wandb(self):
170
+ if self.is_main_node and self.config.wandb_project is not None:
171
+ self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
172
+ wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
173
+
174
+ if self.info.total_steps > 0:
175
+ wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
176
+ else:
177
+ wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
178
+
179
+ # LOAD UTILITIES ----------
180
+ def load_model(self, model, model_id=None, full_path=None, strict=True):
181
+ if model_id is not None and full_path is None:
182
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
183
+ elif full_path is None and model_id is None:
184
+ raise ValueError(
185
+ "This method expects either 'model_id' or 'full_path' to be defined"
186
+ )
187
+
188
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
189
+ if checkpoint is not None:
190
+ model.load_state_dict(checkpoint, strict=strict)
191
+ del checkpoint
192
+
193
+ return model
194
+
195
+ def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
196
+ if optim_id is not None and full_path is None:
197
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
198
+ elif full_path is None and optim_id is None:
199
+ raise ValueError(
200
+ "This method expects either 'optim_id' or 'full_path' to be defined"
201
+ )
202
+
203
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
204
+ if checkpoint is not None:
205
+ try:
206
+ if fsdp_model is not None:
207
+ sharded_optimizer_state_dict = (
208
+ FSDP.scatter_full_optim_state_dict( # <---- FSDP
209
+ checkpoint
210
+ if (
211
+ self.is_main_node
212
+ or self.fsdp_defaults["sharding_strategy"]
213
+ == ShardingStrategy.NO_SHARD
214
+ )
215
+ else None,
216
+ fsdp_model,
217
+ )
218
+ )
219
+ optim.load_state_dict(sharded_optimizer_state_dict)
220
+ del checkpoint, sharded_optimizer_state_dict
221
+ else:
222
+ optim.load_state_dict(checkpoint)
223
+ # pylint: disable=broad-except
224
+ except Exception as e:
225
+ print("!!! Failed loading optimizer, skipping... Exception:", e)
226
+
227
+ return optim
228
+
229
+ # SAVE UTILITIES ----------
230
+ def save_info(self, info, suffix=""):
231
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
232
+ create_folder_if_necessary(full_path)
233
+ if self.is_main_node:
234
+ safe_save(vars(self.info), full_path)
235
+
236
+ def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
237
+ if model_id is not None and full_path is None:
238
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
239
+ elif full_path is None and model_id is None:
240
+ raise ValueError(
241
+ "This method expects either 'model_id' or 'full_path' to be defined"
242
+ )
243
+ create_folder_if_necessary(full_path)
244
+ if is_fsdp:
245
+ with FSDP.summon_full_params(model):
246
+ pass
247
+ with FSDP.state_dict_type(
248
+ model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
249
+ ):
250
+ checkpoint = model.state_dict()
251
+ if self.is_main_node:
252
+ safe_save(checkpoint, full_path)
253
+ del checkpoint
254
+ else:
255
+ if self.is_main_node:
256
+ checkpoint = model.state_dict()
257
+ safe_save(checkpoint, full_path)
258
+ del checkpoint
259
+
260
+ def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
261
+ if optim_id is not None and full_path is None:
262
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
263
+ elif full_path is None and optim_id is None:
264
+ raise ValueError(
265
+ "This method expects either 'optim_id' or 'full_path' to be defined"
266
+ )
267
+ create_folder_if_necessary(full_path)
268
+ if fsdp_model is not None:
269
+ optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
270
+ if self.is_main_node:
271
+ safe_save(optim_statedict, full_path)
272
+ del optim_statedict
273
+ else:
274
+ if self.is_main_node:
275
+ checkpoint = optim.state_dict()
276
+ safe_save(checkpoint, full_path)
277
+ del checkpoint
278
+ # -----
279
+
280
+ def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
281
+ # Temporary setup, will be overriden by setup_ddp if required
282
+ self.device = device
283
+ self.process_id = 0
284
+ self.is_main_node = True
285
+ self.world_size = 1
286
+ # ----
287
+
288
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
289
+ self.info: self.Info = self.setup_info()
290
+
291
+ def __call__(self, single_gpu=False):
292
+ self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
293
+ self.setup_wandb()
294
+ if self.config.allow_tf32:
295
+ torch.backends.cuda.matmul.allow_tf32 = True
296
+ torch.backends.cudnn.allow_tf32 = True
297
+
298
+ if self.is_main_node:
299
+ print()
300
+ print("**STARTIG JOB WITH CONFIG:**")
301
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
302
+ print("------------------------------------")
303
+ print()
304
+ print("**INFO:**")
305
+ print(yaml.dump(vars(self.info), default_flow_style=False))
306
+ print("------------------------------------")
307
+ print()
308
+
309
+ # SETUP STUFF
310
+ extras = self.setup_extras_pre()
311
+ assert extras is not None, "setup_extras_pre() must return a DTO"
312
+
313
+ data = self.setup_data(extras)
314
+ assert data is not None, "setup_data() must return a DTO"
315
+ if self.is_main_node:
316
+ print("**DATA:**")
317
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
318
+ print("------------------------------------")
319
+ print()
320
+
321
+ models = self.setup_models(extras)
322
+ assert models is not None, "setup_models() must return a DTO"
323
+ if self.is_main_node:
324
+ print("**MODELS:**")
325
+ print(yaml.dump({
326
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
327
+ }, default_flow_style=False))
328
+ print("------------------------------------")
329
+ print()
330
+
331
+ optimizers = self.setup_optimizers(extras, models)
332
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
333
+ if self.is_main_node:
334
+ print("**OPTIMIZERS:**")
335
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
336
+ print("------------------------------------")
337
+ print()
338
+
339
+ schedulers = self.setup_schedulers(extras, models, optimizers)
340
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
341
+ if self.is_main_node:
342
+ print("**SCHEDULERS:**")
343
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
344
+ print("------------------------------------")
345
+ print()
346
+
347
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
348
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
349
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
350
+ if self.is_main_node:
351
+ print("**EXTRAS:**")
352
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
353
+ print("------------------------------------")
354
+ print()
355
+ # -------
356
+
357
+ # TRAIN
358
+ if self.is_main_node:
359
+ print("**TRAINING STARTING...**")
360
+ self.train(data, extras, models, optimizers, schedulers)
361
+
362
+ if single_gpu is False:
363
+ barrier()
364
+ destroy_process_group()
365
+ if self.is_main_node:
366
+ print()
367
+ print("------------------------------------")
368
+ print()
369
+ print("**TRAINING COMPLETE**")
370
+ if self.config.wandb_project is not None:
371
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
core/data/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ import yaml
4
+ import os
5
+ from .bucketeer import Bucketeer
6
+
7
+ class MultiFilter():
8
+ def __init__(self, rules, default=False):
9
+ self.rules = rules
10
+ self.default = default
11
+
12
+ def __call__(self, x):
13
+ try:
14
+ x_json = x['json']
15
+ if isinstance(x_json, bytes):
16
+ x_json = json.loads(x_json)
17
+ validations = []
18
+ for k, r in self.rules.items():
19
+ if isinstance(k, tuple):
20
+ v = r(*[x_json[kv] for kv in k])
21
+ else:
22
+ v = r(x_json[k])
23
+ validations.append(v)
24
+ return all(validations)
25
+ except Exception:
26
+ return False
27
+
28
+ class MultiGetter():
29
+ def __init__(self, rules):
30
+ self.rules = rules
31
+
32
+ def __call__(self, x_json):
33
+ if isinstance(x_json, bytes):
34
+ x_json = json.loads(x_json)
35
+ outputs = []
36
+ for k, r in self.rules.items():
37
+ if isinstance(k, tuple):
38
+ v = r(*[x_json[kv] for kv in k])
39
+ else:
40
+ v = r(x_json[k])
41
+ outputs.append(v)
42
+ if len(outputs) == 1:
43
+ outputs = outputs[0]
44
+ return outputs
45
+
46
+ def setup_webdataset_path(paths, cache_path=None):
47
+ if cache_path is None or not os.path.exists(cache_path):
48
+ tar_paths = []
49
+ if isinstance(paths, str):
50
+ paths = [paths]
51
+ for path in paths:
52
+ if path.strip().endswith(".tar"):
53
+ # Avoid looking up s3 if we already have a tar file
54
+ tar_paths.append(path)
55
+ continue
56
+ bucket = "/".join(path.split("/")[:3])
57
+ result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
58
+ files = result.stdout.decode('utf-8').split()
59
+ files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
60
+ tar_paths += files
61
+
62
+ with open(cache_path, 'w', encoding='utf-8') as outfile:
63
+ yaml.dump(tar_paths, outfile, default_flow_style=False)
64
+ else:
65
+ with open(cache_path, 'r', encoding='utf-8') as file:
66
+ tar_paths = yaml.safe_load(file)
67
+
68
+ tar_paths_str = ",".join([f"{p}" for p in tar_paths])
69
+ return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
core/data/bucketeer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from torchtools.transforms import SmartCrop
5
+ import math
6
+
7
+ class Bucketeer():
8
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
+ assert crop_mode in ['center', 'random', 'smart']
10
+ self.crop_mode = crop_mode
11
+ self.ratios = ratios
12
+ if reverse_list:
13
+ for r in list(ratios):
14
+ if 1/r not in self.ratios:
15
+ self.ratios.append(1/r)
16
+ self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios]
17
+ self.batch_size = dataloader.batch_size
18
+ self.iterator = iter(dataloader)
19
+ self.buckets = {s: [] for s in self.sizes}
20
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
21
+ self.p_random_ratio = p_random_ratio
22
+ self.interpolate_nearest = interpolate_nearest
23
+
24
+ def get_available_batch(self):
25
+ for b in self.buckets:
26
+ if len(self.buckets[b]) >= self.batch_size:
27
+ batch = self.buckets[b][:self.batch_size]
28
+ self.buckets[b] = self.buckets[b][self.batch_size:]
29
+ return batch
30
+ return None
31
+
32
+ def get_closest_size(self, x):
33
+ if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
34
+ best_size_idx = np.random.randint(len(self.ratios))
35
+ else:
36
+ w, h = x.size(-1), x.size(-2)
37
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
38
+ return self.sizes[best_size_idx]
39
+
40
+ def get_resize_size(self, orig_size, tgt_size):
41
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
42
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
43
+ resize_size = max(alt_min, min(tgt_size))
44
+ else:
45
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
46
+ resize_size = max(alt_max, max(tgt_size))
47
+ return resize_size
48
+
49
+ def __next__(self):
50
+ batch = self.get_available_batch()
51
+ while batch is None:
52
+ elements = next(self.iterator)
53
+ for dct in elements:
54
+ img = dct['images']
55
+ size = self.get_closest_size(img)
56
+ resize_size = self.get_resize_size(img.shape[-2:], size)
57
+ if self.interpolate_nearest:
58
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
59
+ else:
60
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
61
+ if self.crop_mode == 'center':
62
+ img = torchvision.transforms.functional.center_crop(img, size)
63
+ elif self.crop_mode == 'random':
64
+ img = torchvision.transforms.RandomCrop(size)(img)
65
+ elif self.crop_mode == 'smart':
66
+ self.smartcrop.output_size = size
67
+ img = self.smartcrop(img)
68
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
69
+ batch = self.get_available_batch()
70
+
71
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
72
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
core/scripts/__init__.py ADDED
File without changes
core/scripts/cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ from .. import WarpCore
4
+ from .. import templates
5
+
6
+
7
+ def template_init(args):
8
+ return ''''
9
+
10
+
11
+ '''.strip()
12
+
13
+
14
+ def init_template(args):
15
+ parser = argparse.ArgumentParser(description='WarpCore template init tool')
16
+ parser.add_argument('-t', '--template', type=str, default='WarpCore')
17
+ args = parser.parse_args(args)
18
+
19
+ if args.template == 'WarpCore':
20
+ template_cls = WarpCore
21
+ else:
22
+ try:
23
+ template_cls = __import__(args.template)
24
+ except ModuleNotFoundError:
25
+ template_cls = getattr(templates, args.template)
26
+ print(template_cls)
27
+
28
+
29
+ def main():
30
+ if len(sys.argv) < 2:
31
+ print('Usage: core <command>')
32
+ sys.exit(1)
33
+ if sys.argv[1] == 'init':
34
+ init_template(sys.argv[2:])
35
+ else:
36
+ print('Unknown command')
37
+ sys.exit(1)
38
+
39
+
40
+ if __name__ == '__main__':
41
+ main()
core/templates/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion import DiffusionCore
core/templates/diffusion.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import WarpCore
2
+ from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import DataLoader
8
+ from gdf import GDF
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import wandb
12
+
13
+ import webdataset as wds
14
+ from webdataset.handlers import warn_and_continue
15
+ from torch.distributed import barrier
16
+ from enum import Enum
17
+
18
+ class TargetReparametrization(Enum):
19
+ EPSILON = 'epsilon'
20
+ X0 = 'x0'
21
+
22
+ class DiffusionCore(WarpCore):
23
+ @dataclass(frozen=True)
24
+ class Config(WarpCore.Config):
25
+ # TRAINING PARAMS
26
+ lr: float = EXPECTED_TRAIN
27
+ grad_accum_steps: int = EXPECTED_TRAIN
28
+ batch_size: int = EXPECTED_TRAIN
29
+ updates: int = EXPECTED_TRAIN
30
+ warmup_updates: int = EXPECTED_TRAIN
31
+ save_every: int = 500
32
+ backup_every: int = 20000
33
+ use_fsdp: bool = True
34
+
35
+ # EMA UPDATE
36
+ ema_start_iters: int = None
37
+ ema_iters: int = None
38
+ ema_beta: float = None
39
+
40
+ # GDF setting
41
+ gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
42
+
43
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
44
+ class Info(WarpCore.Info):
45
+ ema_loss: float = None
46
+
47
+ @dataclass(frozen=True)
48
+ class Models(WarpCore.Models):
49
+ generator : nn.Module = EXPECTED
50
+ generator_ema : nn.Module = None # optional
51
+
52
+ @dataclass(frozen=True)
53
+ class Optimizers(WarpCore.Optimizers):
54
+ generator : any = EXPECTED
55
+
56
+ @dataclass(frozen=True)
57
+ class Schedulers(WarpCore.Schedulers):
58
+ generator: any = None
59
+
60
+ @dataclass(frozen=True)
61
+ class Extras(WarpCore.Extras):
62
+ gdf: GDF = EXPECTED
63
+ sampling_configs: dict = EXPECTED
64
+
65
+ # --------------------------------------------
66
+ info: Info
67
+ config: Config
68
+
69
+ @abstractmethod
70
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
71
+ raise NotImplementedError("This method needs to be overriden")
72
+
73
+ @abstractmethod
74
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
75
+ raise NotImplementedError("This method needs to be overriden")
76
+
77
+ @abstractmethod
78
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
79
+ raise NotImplementedError("This method needs to be overriden")
80
+
81
+ @abstractmethod
82
+ def webdataset_path(self, extras: Extras):
83
+ raise NotImplementedError("This method needs to be overriden")
84
+
85
+ @abstractmethod
86
+ def webdataset_filters(self, extras: Extras):
87
+ raise NotImplementedError("This method needs to be overriden")
88
+
89
+ @abstractmethod
90
+ def webdataset_preprocessors(self, extras: Extras):
91
+ raise NotImplementedError("This method needs to be overriden")
92
+
93
+ @abstractmethod
94
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
95
+ raise NotImplementedError("This method needs to be overriden")
96
+ # -------------
97
+
98
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
99
+ # SETUP DATASET
100
+ dataset_path = self.webdataset_path(extras)
101
+ preprocessors = self.webdataset_preprocessors(extras)
102
+ filters = self.webdataset_filters(extras)
103
+
104
+ handler = warn_and_continue # None
105
+ # handler = None
106
+ dataset = wds.WebDataset(
107
+ dataset_path, resampled=True, handler=handler
108
+ ).select(filters).shuffle(690, handler=handler).decode(
109
+ "pilrgb", handler=handler
110
+ ).to_tuple(
111
+ *[p[0] for p in preprocessors], handler=handler
112
+ ).map_tuple(
113
+ *[p[1] for p in preprocessors], handler=handler
114
+ ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
115
+
116
+ # SETUP DATALOADER
117
+ real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
118
+ dataloader = DataLoader(
119
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
120
+ )
121
+
122
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
123
+
124
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
125
+ batch = next(data.iterator)
126
+
127
+ with torch.no_grad():
128
+ conditions = self.get_conditions(batch, models, extras)
129
+ latents = self.encode_latents(batch, models, extras)
130
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
131
+
132
+ # FORWARD PASS
133
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
134
+ pred = models.generator(noised, noise_cond, **conditions)
135
+ if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
136
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
137
+ target = noise
138
+ elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
139
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
140
+ target = latents
141
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
142
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
143
+
144
+ return loss, loss_adjusted
145
+
146
+ def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
147
+ start_iter = self.info.iter+1
148
+ max_iters = self.config.updates * self.config.grad_accum_steps
149
+ if self.is_main_node:
150
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
151
+
152
+ pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
153
+ models.generator.train()
154
+ for i in pbar:
155
+ # FORWARD PASS
156
+ loss, loss_adjusted = self.forward_pass(data, extras, models)
157
+
158
+ # BACKWARD PASS
159
+ if i % self.config.grad_accum_steps == 0 or i == max_iters:
160
+ loss_adjusted.backward()
161
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
162
+ optimizers_dict = optimizers.to_dict()
163
+ for k in optimizers_dict:
164
+ optimizers_dict[k].step()
165
+ schedulers_dict = schedulers.to_dict()
166
+ for k in schedulers_dict:
167
+ schedulers_dict[k].step()
168
+ models.generator.zero_grad(set_to_none=True)
169
+ self.info.total_steps += 1
170
+ else:
171
+ with models.generator.no_sync():
172
+ loss_adjusted.backward()
173
+ self.info.iter = i
174
+
175
+ # UPDATE EMA
176
+ if models.generator_ema is not None and i % self.config.ema_iters == 0:
177
+ update_weights_ema(
178
+ models.generator_ema, models.generator,
179
+ beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
180
+ )
181
+
182
+ # UPDATE LOSS METRICS
183
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
184
+
185
+ if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
186
+ wandb.alert(
187
+ title=f"NaN value encountered in training run {self.info.wandb_run_id}",
188
+ text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
189
+ wait_duration=60*30
190
+ )
191
+
192
+ if self.is_main_node:
193
+ logs = {
194
+ 'loss': self.info.ema_loss,
195
+ 'raw_loss': loss.mean().item(),
196
+ 'grad_norm': grad_norm.item(),
197
+ 'lr': optimizers.generator.param_groups[0]['lr'],
198
+ 'total_steps': self.info.total_steps,
199
+ }
200
+
201
+ pbar.set_postfix(logs)
202
+ if self.config.wandb_project is not None:
203
+ wandb.log(logs)
204
+
205
+ if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
206
+ # SAVE AND CHECKPOINT STUFF
207
+ if np.isnan(loss.mean().item()):
208
+ if self.is_main_node and self.config.wandb_project is not None:
209
+ tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
210
+ wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
211
+ else:
212
+ self.save_checkpoints(models, optimizers)
213
+ if self.is_main_node:
214
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
215
+ self.sample(models, data, extras)
216
+
217
+ def models_to_save(self):
218
+ return ['generator', 'generator_ema']
219
+
220
+ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
221
+ barrier()
222
+ suffix = '' if suffix is None else suffix
223
+ self.save_info(self.info, suffix=suffix)
224
+ models_dict = models.to_dict()
225
+ optimizers_dict = optimizers.to_dict()
226
+ for key in self.models_to_save():
227
+ model = models_dict[key]
228
+ if model is not None:
229
+ self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
230
+ for key in optimizers_dict:
231
+ optimizer = optimizers_dict[key]
232
+ if optimizer is not None:
233
+ self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
234
+ if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
235
+ self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
236
+ torch.cuda.empty_cache()
core/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
2
+ from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
3
+
4
+ # MOVE IT SOMERWHERE ELSE
5
+ def update_weights_ema(tgt_model, src_model, beta=0.999):
6
+ for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
7
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
8
+ for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
9
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
core/utils/base_dto.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from dataclasses import dataclass, _MISSING_TYPE
3
+ from munch import Munch
4
+
5
+ EXPECTED = "___REQUIRED___"
6
+ EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
7
+
8
+ # pylint: disable=invalid-field-call
9
+ def nested_dto(x, raw=False):
10
+ return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
11
+
12
+ @dataclass(frozen=True)
13
+ class Base:
14
+ training: bool = None
15
+ def __new__(cls, **kwargs):
16
+ training = kwargs.get('training', True)
17
+ setteable_fields = cls.setteable_fields(**kwargs)
18
+ mandatory_fields = cls.mandatory_fields(**kwargs)
19
+ invalid_kwargs = [
20
+ {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
21
+ ]
22
+ print(mandatory_fields)
23
+ assert (
24
+ len(invalid_kwargs) == 0
25
+ ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
26
+ missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
27
+ assert (
28
+ len(missing_kwargs) == 0
29
+ ), f"Required fields missing initializing this DTO: {missing_kwargs}."
30
+ return object.__new__(cls)
31
+
32
+
33
+ @classmethod
34
+ def setteable_fields(cls, **kwargs):
35
+ return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
36
+
37
+ @classmethod
38
+ def mandatory_fields(cls, **kwargs):
39
+ training = kwargs.get('training', True)
40
+ return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
41
+
42
+ @classmethod
43
+ def from_dict(cls, kwargs):
44
+ for k in kwargs:
45
+ if isinstance(kwargs[k], (dict, list, tuple)):
46
+ kwargs[k] = Munch.fromDict(kwargs[k])
47
+ return cls(**kwargs)
48
+
49
+ def to_dict(self):
50
+ # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
51
+ selfdict = {}
52
+ for k in dataclasses.fields(self):
53
+ selfdict[k.name] = getattr(self, k.name)
54
+ if isinstance(selfdict[k.name], Munch):
55
+ selfdict[k.name] = selfdict[k.name].toDict()
56
+ return selfdict
core/utils/save_and_load.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from pathlib import Path
5
+ import safetensors
6
+ import wandb
7
+
8
+
9
+ def create_folder_if_necessary(path):
10
+ path = "/".join(path.split("/")[:-1])
11
+ Path(path).mkdir(parents=True, exist_ok=True)
12
+
13
+
14
+ def safe_save(ckpt, path):
15
+ try:
16
+ os.remove(f"{path}.bak")
17
+ except OSError:
18
+ pass
19
+ try:
20
+ os.rename(path, f"{path}.bak")
21
+ except OSError:
22
+ pass
23
+ if path.endswith(".pt") or path.endswith(".ckpt"):
24
+ torch.save(ckpt, path)
25
+ elif path.endswith(".json"):
26
+ with open(path, "w", encoding="utf-8") as f:
27
+ json.dump(ckpt, f, indent=4)
28
+ elif path.endswith(".safetensors"):
29
+ safetensors.torch.save_file(ckpt, path)
30
+ else:
31
+ raise ValueError(f"File extension not supported: {path}")
32
+
33
+
34
+ def load_or_fail(path, wandb_run_id=None):
35
+ accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
36
+ try:
37
+ assert any(
38
+ [path.endswith(ext) for ext in accepted_extensions]
39
+ ), f"Automatic loading not supported for this extension: {path}"
40
+ if not os.path.exists(path):
41
+ checkpoint = None
42
+ elif path.endswith(".pt") or path.endswith(".ckpt"):
43
+ checkpoint = torch.load(path, map_location="cpu")
44
+ elif path.endswith(".json"):
45
+ with open(path, "r", encoding="utf-8") as f:
46
+ checkpoint = json.load(f)
47
+ elif path.endswith(".safetensors"):
48
+ checkpoint = {}
49
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
50
+ for key in f.keys():
51
+ checkpoint[key] = f.get_tensor(key)
52
+ return checkpoint
53
+ except Exception as e:
54
+ if wandb_run_id is not None:
55
+ wandb.alert(
56
+ title=f"Corrupt checkpoint for run {wandb_run_id}",
57
+ text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
58
+ )
59
+ raise e
figures/collage_1.jpg ADDED

Git LFS Details

  • SHA256: ec5fbc465bd5fa24755689283aca45478ce546a20af8ebcc068962b72a341e0b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
figures/collage_2.jpg ADDED
figures/collage_3.jpg ADDED

Git LFS Details

  • SHA256: 6ad3b1481eb89e4f73dbfdb83589509048e4356d14f900b5351195057736bb32
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
figures/collage_4.jpg ADDED
figures/comparison-inference-speed.jpg ADDED
figures/comparison.png ADDED
figures/controlnet-canny.jpg ADDED
figures/controlnet-face.jpg ADDED
figures/controlnet-paint.jpg ADDED
figures/controlnet-sr.jpg ADDED

Git LFS Details

  • SHA256: f3e8060eebe3a26d7ee49cf553a5892180889868a85257511588de7e94937ee1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
figures/fernando.jpg ADDED
figures/fernando_original.jpg ADDED
figures/image-to-image-example-rodent.jpg ADDED
figures/image-variations-example-headset.jpg ADDED
figures/model-overview.jpg ADDED