Thuchk commited on
Commit
98a48b9
1 Parent(s): f87ebcf

deploy to huggingface spaces

Browse files
Files changed (33) hide show
  1. README.md +125 -13
  2. env.yaml +4 -0
  3. gradio_app.py +39 -0
  4. infer.py +9 -0
  5. metrics.py +29 -0
  6. requirements.txt +10 -0
  7. sd-pokemon-model/feature_extractor/preprocessor_config.json +28 -0
  8. sd-pokemon-model/logs/text2image-fine-tune/1683946307.9457252/events.out.tfevents.1683946307.haca1003.18301.1 +3 -0
  9. sd-pokemon-model/logs/text2image-fine-tune/1683946307.9529555/hparams.yml +49 -0
  10. sd-pokemon-model/logs/text2image-fine-tune/1683946765.030428/events.out.tfevents.1683946765.haca1003.20427.1 +3 -0
  11. sd-pokemon-model/logs/text2image-fine-tune/1683946765.0349936/hparams.yml +49 -0
  12. sd-pokemon-model/logs/text2image-fine-tune/1683947054.8602302/events.out.tfevents.1683947054.haca1003.32635.1 +3 -0
  13. sd-pokemon-model/logs/text2image-fine-tune/1683947054.8646092/hparams.yml +49 -0
  14. sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683946307.haca1003.18301.0 +3 -0
  15. sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683946765.haca1003.20427.0 +3 -0
  16. sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683947054.haca1003.32635.0 +3 -0
  17. sd-pokemon-model/model_index.json +33 -0
  18. sd-pokemon-model/safety_checker/config.json +168 -0
  19. sd-pokemon-model/safety_checker/pytorch_model.bin +3 -0
  20. sd-pokemon-model/scheduler/scheduler_config.json +14 -0
  21. sd-pokemon-model/text_encoder/config.json +25 -0
  22. sd-pokemon-model/text_encoder/pytorch_model.bin +3 -0
  23. sd-pokemon-model/tokenizer/merges.txt +0 -0
  24. sd-pokemon-model/tokenizer/special_tokens_map.json +24 -0
  25. sd-pokemon-model/tokenizer/tokenizer_config.json +33 -0
  26. sd-pokemon-model/tokenizer/vocab.json +0 -0
  27. sd-pokemon-model/unet/config.json +61 -0
  28. sd-pokemon-model/unet/diffusion_pytorch_model.bin +3 -0
  29. sd-pokemon-model/vae/config.json +31 -0
  30. sd-pokemon-model/vae/diffusion_pytorch_model.bin +3 -0
  31. train.sh +16 -0
  32. train_text_to_image.py +961 -0
  33. yoda-pokemon.png +0 -0
README.md CHANGED
@@ -1,13 +1,125 @@
1
- ---
2
- title: Stable Diffusion V1.4
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 1: Choosing model
2
+
3
+ # Chosen model: Stable Diffusion text-to-image fine-tuning
4
+
5
+ The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
6
+
7
+ ### How to install the code requirements.
8
+
9
+ First, clone the repo and then create a conda env from the env.yaml file and activate the env
10
+ ```bash
11
+ git clone https://github.com/hoangkimthuc/diffusers.git
12
+ cd diffusers/examples/text_to_image
13
+ conda env create -f env.yaml
14
+ conda activate stable_diffusion
15
+ ```
16
+
17
+ Before running the scripts, make sure to install the library's training dependencies:
18
+
19
+ **Important**
20
+
21
+ To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
22
+ ```bash
23
+ cd diffusers
24
+ pip install .
25
+ ```
26
+
27
+ Then cd in the diffusers/examples/text_to_image folder and run
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
33
+
34
+ ```bash
35
+ accelerate config
36
+ ```
37
+
38
+ ### Steps to run the training.
39
+
40
+ You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
41
+
42
+ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
43
+
44
+ Run the following command to authenticate your token
45
+
46
+ ```bash
47
+ huggingface-cli login
48
+ ```
49
+
50
+ If you have already cloned the repo, then you won't need to go through these steps.
51
+
52
+ <br>
53
+
54
+ #### Hardware
55
+ With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
56
+
57
+ **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
58
+
59
+ ```bash
60
+ bash train.sh
61
+ ```
62
+
63
+ ### Sample input/output after training
64
+
65
+ Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
66
+
67
+
68
+ ```python
69
+ from diffusers import StableDiffusionPipeline
70
+
71
+ model_path = "sd-pokemon-model"
72
+ pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
73
+ pipe.to("cuda")
74
+
75
+ image = pipe(prompt="yoda").images[0]
76
+ image.save("yoda-pokemon.png")
77
+ ```
78
+ The output with the prompt "yoda" is saved in the `yoda-pokemon.png` image file.
79
+
80
+ ### Name and link to the training dataset.
81
+
82
+ Dataset name: pokemon-blip-captions
83
+
84
+ Dataset link: https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions
85
+
86
+ ### The number of model parameters to determine the model’s complexity.
87
+
88
+ Note: CLIPTextModel (text conditioning model) and AutoencoderKL (image generating model) are frozen, only the Unet (the diffusion model) is trained.
89
+ The number of trainable parameters in the script: 859_520_964
90
+ To get this number, you can put a breakpoint by calling `breakpoint()` at line 813 of the `train_text_to_image.py` file and then run `train.sh`. Once the pbd session stops at that line, you can check the model's parameters by `p unet.num_parameters()`.
91
+
92
+ ### The model evaluation metric (CLIP score)
93
+ CLIP score is a measure of how well the generated images match the prompts.
94
+
95
+ Validation prompts to calculate the CLIP scores:
96
+ ```python
97
+ prompts = [
98
+ "a photo of an astronaut riding a horse on mars",
99
+ "A high tech solarpunk utopia in the Amazon rainforest",
100
+ "A pikachu fine dining with a view to the Eiffel Tower",
101
+ "A mecha robot in a favela in expressionist style",
102
+ "an insect robot preparing a delicious meal",
103
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
104
+ ]
105
+ ```
106
+ To calculate the CLIP score for the above prompts, run:
107
+ ```bash
108
+ python metrics.py
109
+ ```
110
+
111
+ ### Link to the trained model
112
+
113
+ https://drive.google.com/file/d/1xzVUO0nZn-0oaJgHOWjrYKHmGUlsoJ1g/view?usp=sharing
114
+
115
+ ### Modifications made to the original code
116
+ - Add metrics and gradio_app scripts
117
+ - Remove redundunt code
118
+ - Add training bash script
119
+ - Improve readme
120
+ - Add conda env.yaml file and add more dependencies for the web app
121
+
122
+ # Task 2: Using the model in a web application
123
+
124
+ To create
125
+
env.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: stable_diffusion
2
+ dependencies:
3
+ - pip
4
+ - python=3.9
gradio_app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+ from uuid import uuid4
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ model_path = "sd-pokemon-model"
8
+ pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
9
+ pipe.to("cuda")
10
+
11
+ def predict(prompt):
12
+ image = pipe(prompt=prompt).images[0]
13
+ tmp_filename = f"/tmp/{uuid4()}.png"
14
+ image.save(tmp_filename)
15
+ img = Image.open(tmp_filename)
16
+ return img
17
+
18
+ title = "Stable Diffusion Pokemon Generator"
19
+ description = "Generate Pokemon from text prompts using Stable Diffusion v1.4"
20
+ article="<p style='text-align: center'><a href='https://github.com/hoangkimthuc/diffusers' target='_blank'>Click here to see the original repo of this app</a></p>"
21
+ examples = ["yoda", "pikachu", "charmander"]
22
+ interpretation='default'
23
+ enable_queue=True
24
+
25
+
26
+ text_to_image_app = gr.Interface(fn=predict,
27
+ inputs="text",
28
+ outputs="image",
29
+ title=title,
30
+ description=description,
31
+ article=article,
32
+ examples=examples,
33
+ interpretation=interpretation,
34
+ enable_queue=enable_queue
35
+ )
36
+ text_to_image_app.launch(share=True)
37
+
38
+
39
+
infer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+
4
+ model_path = "sd-pokemon-model"
5
+ pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
6
+ pipe.to("cuda")
7
+
8
+ image = pipe(prompt="yoda").images[0]
9
+ image.save("yoda-pokemon.png")
metrics.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+ from torchmetrics.functional.multimodal import clip_score
4
+ from functools import partial
5
+
6
+ model_ckpt = "sd-pokemon-model"
7
+ sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda")
8
+
9
+ prompts = [
10
+ "a photo of an astronaut riding a horse on mars",
11
+ "A high tech solarpunk utopia in the Amazon rainforest",
12
+ "A pikachu fine dining with a view to the Eiffel Tower",
13
+ "A mecha robot in a favela in expressionist style",
14
+ "an insect robot preparing a delicious meal",
15
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
16
+ ]
17
+
18
+ images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="numpy").images
19
+ clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
20
+
21
+
22
+ def calculate_clip_score(images, prompts):
23
+ images_int = (images * 255).astype("uint8")
24
+ clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
25
+ return round(float(clip_score), 4)
26
+
27
+
28
+ sd_clip_score = calculate_clip_score(images, prompts)
29
+ print(f"CLIP score: {sd_clip_score}")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.16.0
2
+ torchvision
3
+ transformers>=4.25.1
4
+ datasets
5
+ ftfy
6
+ tensorboard
7
+ Jinja2
8
+ torchmetrics==0.11.4
9
+ gradio
10
+ Pillow==9.5.0
sd-pokemon-model/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPFeatureExtractor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
sd-pokemon-model/logs/text2image-fine-tune/1683946307.9457252/events.out.tfevents.1683946307.haca1003.18301.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee0bced002f58847f9190296c72d3d3f14e8630cfb2e5be289dddb909b13ca5
3
+ size 2212
sd-pokemon-model/logs/text2image-fine-tune/1683946307.9529555/hparams.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ center_crop: true
9
+ checkpointing_steps: 500
10
+ checkpoints_total_limit: null
11
+ dataloader_num_workers: 0
12
+ dataset_config_name: null
13
+ dataset_name: lambdalabs/pokemon-blip-captions
14
+ enable_xformers_memory_efficient_attention: false
15
+ gradient_accumulation_steps: 4
16
+ gradient_checkpointing: true
17
+ hub_model_id: null
18
+ hub_token: null
19
+ image_column: image
20
+ input_pertubation: 0
21
+ learning_rate: 1.0e-05
22
+ local_rank: -1
23
+ logging_dir: logs
24
+ lr_scheduler: constant
25
+ lr_warmup_steps: 0
26
+ max_grad_norm: 1.0
27
+ max_train_samples: null
28
+ max_train_steps: 209
29
+ mixed_precision: null
30
+ noise_offset: 0
31
+ non_ema_revision: null
32
+ num_train_epochs: 1
33
+ output_dir: sd-pokemon-model
34
+ pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4
35
+ push_to_hub: false
36
+ random_flip: true
37
+ report_to: tensorboard
38
+ resolution: 512
39
+ resume_from_checkpoint: null
40
+ revision: null
41
+ scale_lr: false
42
+ seed: null
43
+ snr_gamma: null
44
+ tracker_project_name: text2image-fine-tune
45
+ train_batch_size: 1
46
+ train_data_dir: null
47
+ use_8bit_adam: false
48
+ use_ema: true
49
+ validation_epochs: 5
sd-pokemon-model/logs/text2image-fine-tune/1683946765.030428/events.out.tfevents.1683946765.haca1003.20427.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2473ac0919c737a30b79f602307790765f17a31124065004e63712f736aa40ce
3
+ size 2212
sd-pokemon-model/logs/text2image-fine-tune/1683946765.0349936/hparams.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ center_crop: true
9
+ checkpointing_steps: 500
10
+ checkpoints_total_limit: null
11
+ dataloader_num_workers: 0
12
+ dataset_config_name: null
13
+ dataset_name: lambdalabs/pokemon-blip-captions
14
+ enable_xformers_memory_efficient_attention: false
15
+ gradient_accumulation_steps: 4
16
+ gradient_checkpointing: true
17
+ hub_model_id: null
18
+ hub_token: null
19
+ image_column: image
20
+ input_pertubation: 0
21
+ learning_rate: 1.0e-05
22
+ local_rank: -1
23
+ logging_dir: logs
24
+ lr_scheduler: constant
25
+ lr_warmup_steps: 0
26
+ max_grad_norm: 1.0
27
+ max_train_samples: null
28
+ max_train_steps: 10
29
+ mixed_precision: null
30
+ noise_offset: 0
31
+ non_ema_revision: null
32
+ num_train_epochs: 1
33
+ output_dir: sd-pokemon-model
34
+ pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4
35
+ push_to_hub: false
36
+ random_flip: true
37
+ report_to: tensorboard
38
+ resolution: 512
39
+ resume_from_checkpoint: null
40
+ revision: null
41
+ scale_lr: false
42
+ seed: null
43
+ snr_gamma: null
44
+ tracker_project_name: text2image-fine-tune
45
+ train_batch_size: 1
46
+ train_data_dir: null
47
+ use_8bit_adam: false
48
+ use_ema: true
49
+ validation_epochs: 5
sd-pokemon-model/logs/text2image-fine-tune/1683947054.8602302/events.out.tfevents.1683947054.haca1003.32635.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f48e7f64b65414ef16bdafdd167236365b92294f8f878461936b2f263a129cf0
3
+ size 2212
sd-pokemon-model/logs/text2image-fine-tune/1683947054.8646092/hparams.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.01
5
+ allow_tf32: false
6
+ cache_dir: null
7
+ caption_column: text
8
+ center_crop: true
9
+ checkpointing_steps: 500
10
+ checkpoints_total_limit: null
11
+ dataloader_num_workers: 0
12
+ dataset_config_name: null
13
+ dataset_name: lambdalabs/pokemon-blip-captions
14
+ enable_xformers_memory_efficient_attention: false
15
+ gradient_accumulation_steps: 4
16
+ gradient_checkpointing: true
17
+ hub_model_id: null
18
+ hub_token: null
19
+ image_column: image
20
+ input_pertubation: 0
21
+ learning_rate: 1.0e-05
22
+ local_rank: -1
23
+ logging_dir: logs
24
+ lr_scheduler: constant
25
+ lr_warmup_steps: 0
26
+ max_grad_norm: 1.0
27
+ max_train_samples: null
28
+ max_train_steps: 10
29
+ mixed_precision: null
30
+ noise_offset: 0
31
+ non_ema_revision: null
32
+ num_train_epochs: 1
33
+ output_dir: sd-pokemon-model
34
+ pretrained_model_name_or_path: CompVis/stable-diffusion-v1-4
35
+ push_to_hub: false
36
+ random_flip: true
37
+ report_to: tensorboard
38
+ resolution: 512
39
+ resume_from_checkpoint: null
40
+ revision: null
41
+ scale_lr: false
42
+ seed: null
43
+ snr_gamma: null
44
+ tracker_project_name: text2image-fine-tune
45
+ train_batch_size: 1
46
+ train_data_dir: null
47
+ use_8bit_adam: false
48
+ use_ema: true
49
+ validation_epochs: 5
sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683946307.haca1003.18301.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb3c73b2d6c2c8de1ba99393dff8194e3ed6df24ce1a9c864edbbe5395c5232e
3
+ size 10202
sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683946765.haca1003.20427.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1e63c0925bcee35408823e264b183f399289af129f53f9db30f36e3566b955
3
+ size 467088
sd-pokemon-model/logs/text2image-fine-tune/events.out.tfevents.1683947054.haca1003.32635.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a40baf737775076a113a52d54b2fd3a0807f6e52c0329b9f09c1a34675a34aeb
3
+ size 568
sd-pokemon-model/model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.17.0.dev0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPFeatureExtractor"
7
+ ],
8
+ "requires_safety_checker": true,
9
+ "safety_checker": [
10
+ "stable_diffusion",
11
+ "StableDiffusionSafetyChecker"
12
+ ],
13
+ "scheduler": [
14
+ "diffusers",
15
+ "PNDMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "diffusers",
27
+ "UNet2DConditionModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
sd-pokemon-model/safety_checker/config.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "249dd2d739844dea6a0bc7fc27b3c1d014720b28",
3
+ "_name_or_path": "/home/ubuntu/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/249dd2d739844dea6a0bc7fc27b3c1d014720b28/safety_checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 2,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.29.1",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "torch_dtype": "float32",
89
+ "transformers_version": null,
90
+ "vision_config": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_dropout": 0.0,
95
+ "bad_words_ids": null,
96
+ "begin_suppress_tokens": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "exponential_decay_length_penalty": null,
108
+ "finetuning_task": null,
109
+ "forced_bos_token_id": null,
110
+ "forced_eos_token_id": null,
111
+ "hidden_act": "quick_gelu",
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_factor": 1.0,
119
+ "initializer_range": 0.02,
120
+ "intermediate_size": 4096,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layer_norm_eps": 1e-05,
128
+ "length_penalty": 1.0,
129
+ "max_length": 20,
130
+ "min_length": 0,
131
+ "model_type": "clip_vision_model",
132
+ "no_repeat_ngram_size": 0,
133
+ "num_attention_heads": 16,
134
+ "num_beam_groups": 1,
135
+ "num_beams": 1,
136
+ "num_channels": 3,
137
+ "num_hidden_layers": 24,
138
+ "num_return_sequences": 1,
139
+ "output_attentions": false,
140
+ "output_hidden_states": false,
141
+ "output_scores": false,
142
+ "pad_token_id": null,
143
+ "patch_size": 14,
144
+ "prefix": null,
145
+ "problem_type": null,
146
+ "projection_dim": 512,
147
+ "pruned_heads": {},
148
+ "remove_invalid_values": false,
149
+ "repetition_penalty": 1.0,
150
+ "return_dict": true,
151
+ "return_dict_in_generate": false,
152
+ "sep_token_id": null,
153
+ "suppress_tokens": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.29.1",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ }
168
+ }
sd-pokemon-model/safety_checker/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394
3
+ size 1216064769
sd-pokemon-model/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.17.0.dev0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
sd-pokemon-model/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.29.1",
24
+ "vocab_size": 49408
25
+ }
sd-pokemon-model/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a34f30098988d85dc0fb0fc272a842ebcf552e2ebc6ce4adbcf3695d08e8a90
3
+ size 246188833
sd-pokemon-model/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
sd-pokemon-model/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
sd-pokemon-model/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "do_lower_case": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 77,
23
+ "pad_token": "<|endoftext|>",
24
+ "tokenizer_class": "CLIPTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
sd-pokemon-model/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
sd-pokemon-model/unet/config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.17.0.dev0",
4
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "attention_head_dim": 8,
9
+ "block_out_channels": [
10
+ 320,
11
+ 640,
12
+ 1280,
13
+ 1280
14
+ ],
15
+ "center_input_sample": false,
16
+ "class_embed_type": null,
17
+ "class_embeddings_concat": false,
18
+ "conv_in_kernel": 3,
19
+ "conv_out_kernel": 3,
20
+ "cross_attention_dim": 768,
21
+ "cross_attention_norm": null,
22
+ "down_block_types": [
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "dual_cross_attention": false,
30
+ "encoder_hid_dim": null,
31
+ "flip_sin_to_cos": true,
32
+ "freq_shift": 0,
33
+ "in_channels": 4,
34
+ "layers_per_block": 2,
35
+ "mid_block_only_cross_attention": null,
36
+ "mid_block_scale_factor": 1,
37
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
38
+ "norm_eps": 1e-05,
39
+ "norm_num_groups": 32,
40
+ "num_class_embeds": null,
41
+ "only_cross_attention": false,
42
+ "out_channels": 4,
43
+ "projection_class_embeddings_input_dim": null,
44
+ "resnet_out_scale_factor": 1.0,
45
+ "resnet_skip_time_act": false,
46
+ "resnet_time_scale_shift": "default",
47
+ "sample_size": 64,
48
+ "time_cond_proj_dim": null,
49
+ "time_embedding_act_fn": null,
50
+ "time_embedding_dim": null,
51
+ "time_embedding_type": "positional",
52
+ "timestep_post_act": null,
53
+ "up_block_types": [
54
+ "UpBlock2D",
55
+ "CrossAttnUpBlock2D",
56
+ "CrossAttnUpBlock2D",
57
+ "CrossAttnUpBlock2D"
58
+ ],
59
+ "upcast_attention": false,
60
+ "use_linear_projection": false
61
+ }
sd-pokemon-model/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bcfb4bd5f949d230f06eb91c99d0231cc1cfeb162ff31724c96455fafb19b4e
3
+ size 3438375973
sd-pokemon-model/vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.17.0.dev0",
4
+ "_name_or_path": "CompVis/stable-diffusion-v1-4",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 512,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
sd-pokemon-model/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa8c1b74b3e2781e4347b9b350203597674d8860a4338b46431de760c3a5dd22
3
+ size 167407857
train.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export dataset_name="lambdalabs/pokemon-blip-captions"
3
+
4
+ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
5
+ --pretrained_model_name_or_path=$MODEL_NAME \
6
+ --dataset_name=$dataset_name \
7
+ --use_ema \
8
+ --resolution=512 --center_crop --random_flip \
9
+ --train_batch_size=1 \
10
+ --gradient_accumulation_steps=4 \
11
+ --gradient_checkpointing \
12
+ --learning_rate=1e-05 \
13
+ --max_grad_norm=1 \
14
+ --num_train_epochs=1 \
15
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
16
+ --output_dir="sd-pokemon-model"
train_text_to_image.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+
23
+ import accelerate
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from accelerate import Accelerator
31
+ from accelerate.logging import get_logger
32
+ from accelerate.state import AcceleratorState
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from torchvision import transforms
38
+ from tqdm.auto import tqdm
39
+ from transformers import CLIPTextModel, CLIPTokenizer
40
+ from transformers.utils import ContextManagers
41
+
42
+ import diffusers
43
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.training_utils import EMAModel
46
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available
47
+ from diffusers.utils.import_utils import is_xformers_available
48
+
49
+
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.17.0.dev0")
56
+
57
+ logger = get_logger(__name__, log_level="INFO")
58
+
59
+ DATASET_NAME_MAPPING = {
60
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
61
+ }
62
+
63
+
64
+ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
65
+ logger.info("Running validation... ")
66
+
67
+ pipeline = StableDiffusionPipeline.from_pretrained(
68
+ args.pretrained_model_name_or_path,
69
+ vae=accelerator.unwrap_model(vae),
70
+ text_encoder=accelerator.unwrap_model(text_encoder),
71
+ tokenizer=tokenizer,
72
+ unet=accelerator.unwrap_model(unet),
73
+ safety_checker=None,
74
+ revision=args.revision,
75
+ torch_dtype=weight_dtype,
76
+ )
77
+ pipeline = pipeline.to(accelerator.device)
78
+ pipeline.set_progress_bar_config(disable=True)
79
+
80
+ if args.enable_xformers_memory_efficient_attention:
81
+ pipeline.enable_xformers_memory_efficient_attention()
82
+
83
+ if args.seed is None:
84
+ generator = None
85
+ else:
86
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
87
+
88
+ images = []
89
+ for i in range(len(args.validation_prompts)):
90
+ with torch.autocast("cuda"):
91
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
92
+
93
+ images.append(image)
94
+
95
+ for tracker in accelerator.trackers:
96
+ breakpoint()
97
+ if tracker.name == "tensorboard":
98
+ np_images = np.stack([np.asarray(img) for img in images])
99
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
100
+ elif tracker.name == "wandb":
101
+ tracker.log(
102
+ {
103
+ "validation": [
104
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
105
+ for i, image in enumerate(images)
106
+ ]
107
+ }
108
+ )
109
+ else:
110
+ logger.warn(f"image logging not implemented for {tracker.name}")
111
+
112
+ del pipeline
113
+ torch.cuda.empty_cache()
114
+
115
+
116
+ def parse_args():
117
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
118
+ parser.add_argument(
119
+ "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
120
+ )
121
+ parser.add_argument(
122
+ "--pretrained_model_name_or_path",
123
+ type=str,
124
+ default=None,
125
+ required=True,
126
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
127
+ )
128
+ parser.add_argument(
129
+ "--revision",
130
+ type=str,
131
+ default=None,
132
+ required=False,
133
+ help="Revision of pretrained model identifier from huggingface.co/models.",
134
+ )
135
+ parser.add_argument(
136
+ "--dataset_name",
137
+ type=str,
138
+ default=None,
139
+ help=(
140
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
141
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
142
+ " or to a folder containing files that 🤗 Datasets can understand."
143
+ ),
144
+ )
145
+ parser.add_argument(
146
+ "--dataset_config_name",
147
+ type=str,
148
+ default=None,
149
+ help="The config of the Dataset, leave as None if there's only one config.",
150
+ )
151
+ parser.add_argument(
152
+ "--train_data_dir",
153
+ type=str,
154
+ default=None,
155
+ help=(
156
+ "A folder containing the training data. Folder contents must follow the structure described in"
157
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
158
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
163
+ )
164
+ parser.add_argument(
165
+ "--caption_column",
166
+ type=str,
167
+ default="text",
168
+ help="The column of the dataset containing a caption or a list of captions.",
169
+ )
170
+ parser.add_argument(
171
+ "--max_train_samples",
172
+ type=int,
173
+ default=None,
174
+ help=(
175
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
176
+ "value if set."
177
+ ),
178
+ )
179
+ parser.add_argument(
180
+ "--validation_prompts",
181
+ type=str,
182
+ default=None,
183
+ nargs="+",
184
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
185
+ )
186
+ parser.add_argument(
187
+ "--output_dir",
188
+ type=str,
189
+ default="sd-model-finetuned",
190
+ help="The output directory where the model predictions and checkpoints will be written.",
191
+ )
192
+ parser.add_argument(
193
+ "--cache_dir",
194
+ type=str,
195
+ default=None,
196
+ help="The directory where the downloaded models and datasets will be stored.",
197
+ )
198
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
199
+ parser.add_argument(
200
+ "--resolution",
201
+ type=int,
202
+ default=512,
203
+ help=(
204
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
205
+ " resolution"
206
+ ),
207
+ )
208
+ parser.add_argument(
209
+ "--center_crop",
210
+ default=False,
211
+ action="store_true",
212
+ help=(
213
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
214
+ " cropped. The images will be resized to the resolution first before cropping."
215
+ ),
216
+ )
217
+ parser.add_argument(
218
+ "--random_flip",
219
+ action="store_true",
220
+ help="whether to randomly flip images horizontally",
221
+ )
222
+ parser.add_argument(
223
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
224
+ )
225
+ parser.add_argument("--num_train_epochs", type=int, default=100)
226
+ parser.add_argument(
227
+ "--max_train_steps",
228
+ type=int,
229
+ default=None,
230
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
231
+ )
232
+ parser.add_argument(
233
+ "--gradient_accumulation_steps",
234
+ type=int,
235
+ default=1,
236
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
237
+ )
238
+ parser.add_argument(
239
+ "--gradient_checkpointing",
240
+ action="store_true",
241
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
242
+ )
243
+ parser.add_argument(
244
+ "--learning_rate",
245
+ type=float,
246
+ default=1e-4,
247
+ help="Initial learning rate (after the potential warmup period) to use.",
248
+ )
249
+ parser.add_argument(
250
+ "--scale_lr",
251
+ action="store_true",
252
+ default=False,
253
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
254
+ )
255
+ parser.add_argument(
256
+ "--lr_scheduler",
257
+ type=str,
258
+ default="constant",
259
+ help=(
260
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
261
+ ' "constant", "constant_with_warmup"]'
262
+ ),
263
+ )
264
+ parser.add_argument(
265
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
266
+ )
267
+ parser.add_argument(
268
+ "--snr_gamma",
269
+ type=float,
270
+ default=None,
271
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
272
+ "More details here: https://arxiv.org/abs/2303.09556.",
273
+ )
274
+ parser.add_argument(
275
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
276
+ )
277
+ parser.add_argument(
278
+ "--allow_tf32",
279
+ action="store_true",
280
+ help=(
281
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
282
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
283
+ ),
284
+ )
285
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
286
+ parser.add_argument(
287
+ "--non_ema_revision",
288
+ type=str,
289
+ default=None,
290
+ required=False,
291
+ help=(
292
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
293
+ " remote repository specified with --pretrained_model_name_or_path."
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--dataloader_num_workers",
298
+ type=int,
299
+ default=0,
300
+ help=(
301
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
302
+ ),
303
+ )
304
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
305
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
306
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
307
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
308
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
309
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
310
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
311
+ parser.add_argument(
312
+ "--hub_model_id",
313
+ type=str,
314
+ default=None,
315
+ help="The name of the repository to keep in sync with the local `output_dir`.",
316
+ )
317
+ parser.add_argument(
318
+ "--logging_dir",
319
+ type=str,
320
+ default="logs",
321
+ help=(
322
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
323
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
324
+ ),
325
+ )
326
+ parser.add_argument(
327
+ "--mixed_precision",
328
+ type=str,
329
+ default=None,
330
+ choices=["no", "fp16", "bf16"],
331
+ help=(
332
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
333
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
334
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
335
+ ),
336
+ )
337
+ parser.add_argument(
338
+ "--report_to",
339
+ type=str,
340
+ default="tensorboard",
341
+ help=(
342
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
343
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
344
+ ),
345
+ )
346
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
347
+ parser.add_argument(
348
+ "--checkpointing_steps",
349
+ type=int,
350
+ default=500,
351
+ help=(
352
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
353
+ " training using `--resume_from_checkpoint`."
354
+ ),
355
+ )
356
+ parser.add_argument(
357
+ "--checkpoints_total_limit",
358
+ type=int,
359
+ default=None,
360
+ help=(
361
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
362
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
363
+ " for more docs"
364
+ ),
365
+ )
366
+ parser.add_argument(
367
+ "--resume_from_checkpoint",
368
+ type=str,
369
+ default=None,
370
+ help=(
371
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
372
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
373
+ ),
374
+ )
375
+ parser.add_argument(
376
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
377
+ )
378
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
379
+ parser.add_argument(
380
+ "--validation_epochs",
381
+ type=int,
382
+ default=5,
383
+ help="Run validation every X epochs.",
384
+ )
385
+ parser.add_argument(
386
+ "--tracker_project_name",
387
+ type=str,
388
+ default="text2image-fine-tune",
389
+ help=(
390
+ "The `project_name` argument passed to Accelerator.init_trackers for"
391
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
392
+ ),
393
+ )
394
+
395
+ args = parser.parse_args()
396
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
397
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
398
+ args.local_rank = env_local_rank
399
+
400
+ # Sanity checks
401
+ if args.dataset_name is None and args.train_data_dir is None:
402
+ raise ValueError("Need either a dataset name or a training folder.")
403
+
404
+ # default to using the same revision for the non-ema model if not specified
405
+ if args.non_ema_revision is None:
406
+ args.non_ema_revision = args.revision
407
+
408
+ return args
409
+
410
+
411
+ def main():
412
+ args = parse_args()
413
+
414
+ if args.non_ema_revision is not None:
415
+ deprecate(
416
+ "non_ema_revision!=None",
417
+ "0.15.0",
418
+ message=(
419
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
420
+ " use `--variant=non_ema` instead."
421
+ ),
422
+ )
423
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
424
+
425
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
426
+
427
+ accelerator = Accelerator(
428
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
429
+ mixed_precision=args.mixed_precision,
430
+ log_with=args.report_to,
431
+ logging_dir=logging_dir,
432
+ project_config=accelerator_project_config,
433
+ )
434
+
435
+ # Make one log on every process with the configuration for debugging.
436
+ logging.basicConfig(
437
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
438
+ datefmt="%m/%d/%Y %H:%M:%S",
439
+ level=logging.INFO,
440
+ )
441
+ logger.info(accelerator.state, main_process_only=False)
442
+ if accelerator.is_local_main_process:
443
+ datasets.utils.logging.set_verbosity_warning()
444
+ transformers.utils.logging.set_verbosity_warning()
445
+ diffusers.utils.logging.set_verbosity_info()
446
+ else:
447
+ datasets.utils.logging.set_verbosity_error()
448
+ transformers.utils.logging.set_verbosity_error()
449
+ diffusers.utils.logging.set_verbosity_error()
450
+
451
+ # If passed along, set the training seed now.
452
+ if args.seed is not None:
453
+ set_seed(args.seed)
454
+
455
+ # Handle the repository creation
456
+ if accelerator.is_main_process:
457
+ if args.output_dir is not None:
458
+ os.makedirs(args.output_dir, exist_ok=True)
459
+
460
+ if args.push_to_hub:
461
+ repo_id = create_repo(
462
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
463
+ ).repo_id
464
+
465
+ # Load scheduler, tokenizer and models.
466
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
467
+ tokenizer = CLIPTokenizer.from_pretrained(
468
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
469
+ )
470
+
471
+ def deepspeed_zero_init_disabled_context_manager():
472
+ """
473
+ returns either a context list that includes one that will disable zero.Init or an empty context list
474
+ """
475
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
476
+ if deepspeed_plugin is None:
477
+ return []
478
+
479
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
480
+
481
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
482
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
483
+ # will try to assign the same optimizer with the same weights to all models during
484
+ # `deepspeed.initialize`, which of course doesn't work.
485
+ #
486
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
487
+ # frozen models from being partitioned during `zero.Init` which gets called during
488
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
489
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
490
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
491
+ text_encoder = CLIPTextModel.from_pretrained(
492
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
493
+ )
494
+ vae = AutoencoderKL.from_pretrained(
495
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
496
+ )
497
+
498
+ unet = UNet2DConditionModel.from_pretrained(
499
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
500
+ )
501
+
502
+ # Freeze vae and text_encoder
503
+ vae.requires_grad_(False)
504
+ text_encoder.requires_grad_(False)
505
+
506
+ # Create EMA for the unet.
507
+ if args.use_ema:
508
+ ema_unet = UNet2DConditionModel.from_pretrained(
509
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
510
+ )
511
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
512
+
513
+ if args.enable_xformers_memory_efficient_attention:
514
+ if is_xformers_available():
515
+ import xformers
516
+
517
+ xformers_version = version.parse(xformers.__version__)
518
+ if xformers_version == version.parse("0.0.16"):
519
+ logger.warn(
520
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
521
+ )
522
+ unet.enable_xformers_memory_efficient_attention()
523
+ else:
524
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
525
+
526
+ def compute_snr(timesteps):
527
+ """
528
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
529
+ """
530
+ alphas_cumprod = noise_scheduler.alphas_cumprod
531
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
532
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
533
+
534
+ # Expand the tensors.
535
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
536
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
537
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
538
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
539
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
540
+
541
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
542
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
543
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
544
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
545
+
546
+ # Compute SNR.
547
+ snr = (alpha / sigma) ** 2
548
+ return snr
549
+
550
+ # `accelerate` 0.16.0 will have better support for customized saving
551
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
552
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
553
+ def save_model_hook(models, weights, output_dir):
554
+ if args.use_ema:
555
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
556
+
557
+ for i, model in enumerate(models):
558
+ model.save_pretrained(os.path.join(output_dir, "unet"))
559
+
560
+ # make sure to pop weight so that corresponding model is not saved again
561
+ weights.pop()
562
+
563
+ def load_model_hook(models, input_dir):
564
+ if args.use_ema:
565
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
566
+ ema_unet.load_state_dict(load_model.state_dict())
567
+ ema_unet.to(accelerator.device)
568
+ del load_model
569
+
570
+ for i in range(len(models)):
571
+ # pop models so that they are not loaded again
572
+ model = models.pop()
573
+
574
+ # load diffusers style into model
575
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
576
+ model.register_to_config(**load_model.config)
577
+
578
+ model.load_state_dict(load_model.state_dict())
579
+ del load_model
580
+
581
+ accelerator.register_save_state_pre_hook(save_model_hook)
582
+ accelerator.register_load_state_pre_hook(load_model_hook)
583
+
584
+ if args.gradient_checkpointing:
585
+ unet.enable_gradient_checkpointing()
586
+
587
+ # Enable TF32 for faster training on Ampere GPUs,
588
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
589
+ if args.allow_tf32:
590
+ torch.backends.cuda.matmul.allow_tf32 = True
591
+
592
+ if args.scale_lr:
593
+ args.learning_rate = (
594
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
595
+ )
596
+
597
+ # Initialize the optimizer
598
+ if args.use_8bit_adam:
599
+ try:
600
+ import bitsandbytes as bnb
601
+ except ImportError:
602
+ raise ImportError(
603
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
604
+ )
605
+
606
+ optimizer_cls = bnb.optim.AdamW8bit
607
+ else:
608
+ optimizer_cls = torch.optim.AdamW
609
+
610
+ optimizer = optimizer_cls(
611
+ unet.parameters(),
612
+ lr=args.learning_rate,
613
+ betas=(args.adam_beta1, args.adam_beta2),
614
+ weight_decay=args.adam_weight_decay,
615
+ eps=args.adam_epsilon,
616
+ )
617
+
618
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
619
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
620
+
621
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
622
+ # download the dataset.
623
+ if args.dataset_name is not None:
624
+ # Downloading and loading a dataset from the hub.
625
+ dataset = load_dataset(
626
+ args.dataset_name,
627
+ args.dataset_config_name,
628
+ cache_dir=args.cache_dir,
629
+ )
630
+ else:
631
+ data_files = {}
632
+ if args.train_data_dir is not None:
633
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
634
+ dataset = load_dataset(
635
+ "imagefolder",
636
+ data_files=data_files,
637
+ cache_dir=args.cache_dir,
638
+ )
639
+ # See more about loading custom images at
640
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
641
+
642
+ # Preprocessing the datasets.
643
+ # We need to tokenize inputs and targets.
644
+ column_names = dataset["train"].column_names
645
+
646
+ # 6. Get the column names for input/target.
647
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
648
+ if args.image_column is None:
649
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
650
+ else:
651
+ image_column = args.image_column
652
+ if image_column not in column_names:
653
+ raise ValueError(
654
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
655
+ )
656
+ if args.caption_column is None:
657
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
658
+ else:
659
+ caption_column = args.caption_column
660
+ if caption_column not in column_names:
661
+ raise ValueError(
662
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
663
+ )
664
+
665
+ # Preprocessing the datasets.
666
+ # We need to tokenize input captions and transform the images.
667
+ def tokenize_captions(examples, is_train=True):
668
+ captions = []
669
+ for caption in examples[caption_column]:
670
+ if isinstance(caption, str):
671
+ captions.append(caption)
672
+ elif isinstance(caption, (list, np.ndarray)):
673
+ # take a random caption if there are multiple
674
+ captions.append(random.choice(caption) if is_train else caption[0])
675
+ else:
676
+ raise ValueError(
677
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
678
+ )
679
+ inputs = tokenizer(
680
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
681
+ )
682
+ return inputs.input_ids
683
+
684
+ # Preprocessing the datasets.
685
+ train_transforms = transforms.Compose(
686
+ [
687
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
688
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
689
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
690
+ transforms.ToTensor(),
691
+ transforms.Normalize([0.5], [0.5]),
692
+ ]
693
+ )
694
+
695
+ def preprocess_train(examples):
696
+ images = [image.convert("RGB") for image in examples[image_column]]
697
+ examples["pixel_values"] = [train_transforms(image) for image in images]
698
+ examples["input_ids"] = tokenize_captions(examples)
699
+ return examples
700
+
701
+ with accelerator.main_process_first():
702
+ if args.max_train_samples is not None:
703
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
704
+ # Set the training transforms
705
+ train_dataset = dataset["train"].with_transform(preprocess_train)
706
+
707
+ def collate_fn(examples):
708
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
709
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
710
+ input_ids = torch.stack([example["input_ids"] for example in examples])
711
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
712
+
713
+ # DataLoaders creation:
714
+ train_dataloader = torch.utils.data.DataLoader(
715
+ train_dataset,
716
+ shuffle=True,
717
+ collate_fn=collate_fn,
718
+ batch_size=args.train_batch_size,
719
+ num_workers=args.dataloader_num_workers,
720
+ )
721
+
722
+ # Scheduler and math around the number of training steps.
723
+ overrode_max_train_steps = False
724
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
725
+ if args.max_train_steps is None:
726
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
727
+ overrode_max_train_steps = True
728
+
729
+ lr_scheduler = get_scheduler(
730
+ args.lr_scheduler,
731
+ optimizer=optimizer,
732
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
733
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
734
+ )
735
+
736
+ # Prepare everything with our `accelerator`.
737
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
738
+ unet, optimizer, train_dataloader, lr_scheduler
739
+ )
740
+
741
+ if args.use_ema:
742
+ ema_unet.to(accelerator.device)
743
+
744
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
745
+ # as these models are only used for inference, keeping weights in full precision is not required.
746
+ weight_dtype = torch.float32
747
+ if accelerator.mixed_precision == "fp16":
748
+ weight_dtype = torch.float16
749
+ elif accelerator.mixed_precision == "bf16":
750
+ weight_dtype = torch.bfloat16
751
+
752
+ # Move text_encode and vae to gpu and cast to weight_dtype
753
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
754
+ vae.to(accelerator.device, dtype=weight_dtype)
755
+
756
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
757
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
758
+ if overrode_max_train_steps:
759
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
760
+ # Afterwards we recalculate our number of training epochs
761
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
762
+
763
+ # We need to initialize the trackers we use, and also store our configuration.
764
+ # The trackers initializes automatically on the main process.
765
+ if accelerator.is_main_process:
766
+ tracker_config = dict(vars(args))
767
+ tracker_config.pop("validation_prompts")
768
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
769
+
770
+ # Train!
771
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
772
+
773
+ logger.info("***** Running training *****")
774
+ logger.info(f" Num examples = {len(train_dataset)}")
775
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
776
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
777
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
778
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
779
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
780
+ global_step = 0
781
+ first_epoch = 0
782
+
783
+ # Potentially load in the weights and states from a previous save
784
+ if args.resume_from_checkpoint:
785
+ if args.resume_from_checkpoint != "latest":
786
+ path = os.path.basename(args.resume_from_checkpoint)
787
+ else:
788
+ # Get the most recent checkpoint
789
+ dirs = os.listdir(args.output_dir)
790
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
791
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
792
+ path = dirs[-1] if len(dirs) > 0 else None
793
+
794
+ if path is None:
795
+ accelerator.print(
796
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
797
+ )
798
+ args.resume_from_checkpoint = None
799
+ else:
800
+ accelerator.print(f"Resuming from checkpoint {path}")
801
+ accelerator.load_state(os.path.join(args.output_dir, path))
802
+ global_step = int(path.split("-")[1])
803
+
804
+ resume_global_step = global_step * args.gradient_accumulation_steps
805
+ first_epoch = global_step // num_update_steps_per_epoch
806
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
807
+
808
+ # Only show the progress bar once on each machine.
809
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
810
+ progress_bar.set_description("Steps")
811
+
812
+ for epoch in range(first_epoch, args.num_train_epochs):
813
+ breakpoint()
814
+ unet.train()
815
+ train_loss = 0.0
816
+ for step, batch in enumerate(train_dataloader):
817
+ # Skip steps until we reach the resumed step
818
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
819
+ if step % args.gradient_accumulation_steps == 0:
820
+ progress_bar.update(1)
821
+ continue
822
+
823
+ with accelerator.accumulate(unet):
824
+ # Convert images to latent space
825
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
826
+ latents = latents * vae.config.scaling_factor
827
+
828
+ # Sample noise that we'll add to the latents
829
+ noise = torch.randn_like(latents)
830
+ if args.noise_offset:
831
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
832
+ noise += args.noise_offset * torch.randn(
833
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
834
+ )
835
+ if args.input_pertubation:
836
+ new_noise = noise + args.input_pertubation * torch.randn_like(noise)
837
+ bsz = latents.shape[0]
838
+ # Sample a random timestep for each image
839
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
840
+ timesteps = timesteps.long()
841
+
842
+ # Add noise to the latents according to the noise magnitude at each timestep
843
+ # (this is the forward diffusion process)
844
+ if args.input_pertubation:
845
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
846
+ else:
847
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
848
+
849
+ # Get the text embedding for conditioning
850
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
851
+
852
+ # Get the target for loss depending on the prediction type
853
+ if noise_scheduler.config.prediction_type == "epsilon":
854
+ target = noise
855
+ elif noise_scheduler.config.prediction_type == "v_prediction":
856
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
857
+ else:
858
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
859
+
860
+ # Predict the noise residual and compute loss
861
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
862
+
863
+ if args.snr_gamma is None:
864
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
865
+ else:
866
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
867
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
868
+ # This is discussed in Section 4.2 of the same paper.
869
+ snr = compute_snr(timesteps)
870
+ mse_loss_weights = (
871
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
872
+ )
873
+ # We first calculate the original loss. Then we mean over the non-batch dimensions and
874
+ # rebalance the sample-wise losses with their respective loss weights.
875
+ # Finally, we take the mean of the rebalanced loss.
876
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
877
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
878
+ loss = loss.mean()
879
+
880
+ # Gather the losses across all processes for logging (if we use distributed training).
881
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
882
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
883
+
884
+ # Backpropagate
885
+ accelerator.backward(loss)
886
+ if accelerator.sync_gradients:
887
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
888
+ optimizer.step()
889
+ lr_scheduler.step()
890
+ optimizer.zero_grad()
891
+
892
+ # Checks if the accelerator has performed an optimization step behind the scenes
893
+ if accelerator.sync_gradients:
894
+ if args.use_ema:
895
+ ema_unet.step(unet.parameters())
896
+ progress_bar.update(1)
897
+ global_step += 1
898
+ accelerator.log({"train_loss": train_loss}, step=global_step)
899
+ train_loss = 0.0
900
+
901
+ if global_step % args.checkpointing_steps == 0:
902
+ if accelerator.is_main_process:
903
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
904
+ accelerator.save_state(save_path)
905
+ logger.info(f"Saved state to {save_path}")
906
+
907
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
908
+ progress_bar.set_postfix(**logs)
909
+
910
+ if global_step >= args.max_train_steps:
911
+ break
912
+
913
+ if accelerator.is_main_process:
914
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
915
+ if args.use_ema:
916
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
917
+ ema_unet.store(unet.parameters())
918
+ ema_unet.copy_to(unet.parameters())
919
+ log_validation(
920
+ vae,
921
+ text_encoder,
922
+ tokenizer,
923
+ unet,
924
+ args,
925
+ accelerator,
926
+ weight_dtype,
927
+ global_step,
928
+ )
929
+ if args.use_ema:
930
+ # Switch back to the original UNet parameters.
931
+ ema_unet.restore(unet.parameters())
932
+
933
+ # Create the pipeline using the trained modules and save it.
934
+ accelerator.wait_for_everyone()
935
+ if accelerator.is_main_process:
936
+ unet = accelerator.unwrap_model(unet)
937
+ if args.use_ema:
938
+ ema_unet.copy_to(unet.parameters())
939
+
940
+ pipeline = StableDiffusionPipeline.from_pretrained(
941
+ args.pretrained_model_name_or_path,
942
+ text_encoder=text_encoder,
943
+ vae=vae,
944
+ unet=unet,
945
+ revision=args.revision,
946
+ )
947
+ pipeline.save_pretrained(args.output_dir)
948
+
949
+ if args.push_to_hub:
950
+ upload_folder(
951
+ repo_id=repo_id,
952
+ folder_path=args.output_dir,
953
+ commit_message="End of training",
954
+ ignore_patterns=["step_*", "epoch_*"],
955
+ )
956
+
957
+ accelerator.end_training()
958
+
959
+
960
+ if __name__ == "__main__":
961
+ main()
yoda-pokemon.png ADDED