mkshing commited on
Commit
18e8e3f
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +99 -0
  3. evosdxl_jp_v1.py +204 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: apache-2.0
4
+ language:
5
+ - ja
6
+ pipeline_tag: text-to-image
7
+ tags:
8
+ - stable-diffusion
9
+ ---
10
+ # 🐟 EvoSDXL-JP-v1
11
+
12
+ 🤗 [Models](https://huggingface.co/SakanaAI) | 📚 [Paper](https://arxiv.org/abs/2403.13187) | 📝 [Blog](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
13
+
14
+
15
+ **EvoSDXL-JP-v1** is an experimental education-purpose Japanese SDXL Lightning.
16
+ This model was created using the Evolutionary Model Merge method.
17
+ Please refer to our [report](https://arxiv.org/abs/2403.13187) and [blog](https://sakana.ai/evosdxl-jp/) for more details.
18
+ This model was produced by merging the following models.
19
+ We are grateful to the developers of the source models.
20
+
21
+ - [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
22
+ - [Juggernaut-XL-v9](https://huggingface.co/RunDiffusion/Juggernaut-XL-v9)
23
+ - [SDXL-DPO](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
24
+ - [JSDXL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl)
25
+ - [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
26
+
27
+
28
+ ## Usage
29
+
30
+ Use the code below to get started with the model.
31
+
32
+
33
+ <details>
34
+ <summary> Click to expand </summary>
35
+
36
+ 1. Git clone this model card
37
+ ```
38
+ git clone https://huggingface.co/SakanaAI/EvoSDXL-JP-v1
39
+ ```
40
+ 2. Install packages
41
+ ```
42
+ cd EvoSDXL-JP-v1
43
+ pip install -r requirements.txt
44
+ ```
45
+ 3. Run
46
+ ```python
47
+ from evosdxl_jp_v1 import load_evosdxl_jp
48
+
49
+ prompt = "柴犬"
50
+ pipe = load_evosdxl_jp(device="cuda")
51
+ images = pipe(prompt, num_inference_steps=4, guidance_scale=0).images
52
+ images[0].save("image.png")
53
+ ```
54
+
55
+ </details>
56
+
57
+
58
+
59
+ ## Model Details
60
+
61
+ <!-- Provide a longer summary of what this model is. -->
62
+
63
+ - **Developed by:** [Sakana AI](https://sakana.ai/)
64
+ - **Model type:** Diffusion-based text-to-image generative model
65
+ - **Language(s):** Japanese
66
+ - **Repository:** [SakanaAI/evolutionary-model-merge](https://github.com/SakanaAI/evolutionary-model-merge)
67
+ - **Paper:** https://arxiv.org/abs/2403.13187
68
+ - **Blog:** https://sakana.ai/evosdxl-jp/
69
+
70
+
71
+ ## License
72
+ The Python script included in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
73
+ Please note that the license for the model/pipeline generated by this script is inherited from the source models.
74
+
75
+ ## Uses
76
+ This model is provided for research and development purposes only and should be considered as an experimental prototype.
77
+ It is not intended for commercial use or deployment in mission-critical environments.
78
+ Use of this model is at the user's own risk, and its performance and outcomes are not guaranteed.
79
+ Sakana AI shall not be liable for any direct, indirect, special, incidental, or consequential damages, or any loss arising from the use of this model, regardless of the results obtained.
80
+ Users must fully understand the risks associated with the use of this model and use it at their own discretion.
81
+
82
+
83
+ ## Acknowledgement
84
+
85
+ We would like to thank the developers of the source models for their contributions and for making their work available.
86
+
87
+
88
+ ## Citation
89
+
90
+ ```bibtex
91
+ @misc{akiba2024evomodelmerge,
92
+ title = {Evolutionary Optimization of Model Merging Recipes},
93
+ author. = {Takuya Akiba and Makoto Shing and Yujin Tang and Qi Sun and David Ha},
94
+ year = {2024},
95
+ eprint = {2403.13187},
96
+ archivePrefix = {arXiv},
97
+ primaryClass = {cs.NE}
98
+ }
99
+ ```
evosdxl_jp_v1.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Union
3
+ from tqdm import tqdm
4
+ import torch
5
+ import safetensors
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
8
+ from diffusers import (
9
+ StableDiffusionXLPipeline,
10
+ UNet2DConditionModel,
11
+ EulerDiscreteScheduler,
12
+ )
13
+ from diffusers.loaders import LoraLoaderMixin
14
+
15
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
17
+ L_REPO = "ByteDance/SDXL-Lightning"
18
+
19
+
20
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
21
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
22
+ if file_extension == "safetensors":
23
+ return safetensors.torch.load_file(checkpoint_file, device=device)
24
+ else:
25
+ return torch.load(checkpoint_file, map_location=device)
26
+
27
+
28
+ def load_from_pretrained(
29
+ repo_id,
30
+ filename="diffusion_pytorch_model.fp16.safetensors",
31
+ subfolder="unet",
32
+ device="cuda",
33
+ ) -> Dict[str, torch.Tensor]:
34
+ return load_state_dict(
35
+ hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=filename,
38
+ subfolder=subfolder,
39
+ ),
40
+ device=device,
41
+ )
42
+
43
+
44
+ def reshape_weight_task_tensors(task_tensors, weights):
45
+ """
46
+ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
47
+
48
+ Args:
49
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
50
+ weights (`torch.Tensor`): The tensor to be reshaped.
51
+
52
+ Returns:
53
+ `torch.Tensor`: The reshaped tensor.
54
+ """
55
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
56
+ weights = weights.view(new_shape)
57
+ return weights
58
+
59
+
60
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Merge the task tensors using `linear`.
63
+
64
+ Args:
65
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
66
+ weights (`torch.Tensor`):The weights of the task tensors.
67
+
68
+ Returns:
69
+ `torch.Tensor`: The merged tensor.
70
+ """
71
+ task_tensors = torch.stack(task_tensors, dim=0)
72
+ # weighted task tensors
73
+ weights = reshape_weight_task_tensors(task_tensors, weights)
74
+ weighted_task_tensors = task_tensors * weights
75
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
76
+ return mixed_task_tensors
77
+
78
+
79
+ def merge_models(
80
+ task_tensors,
81
+ weights,
82
+ ):
83
+ keys = list(task_tensors[0].keys())
84
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
85
+ state_dict = {}
86
+ for key in tqdm(keys, desc="Merging"):
87
+ w_list = []
88
+ for i, sd in enumerate(task_tensors):
89
+ w = sd.pop(key)
90
+ w_list.append(w)
91
+ new_w = linear(task_tensors=w_list, weights=weights)
92
+ state_dict[key] = new_w
93
+ return state_dict
94
+
95
+
96
+ def split_conv_attn(weights):
97
+ attn_tensors = {}
98
+ conv_tensors = {}
99
+ for key in list(weights.keys()):
100
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
101
+ attn_tensors[key] = weights.pop(key)
102
+ else:
103
+ conv_tensors[key] = weights.pop(key)
104
+ return {"conv": conv_tensors, "attn": attn_tensors}
105
+
106
+
107
+ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
108
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
109
+ dpo_weights = split_conv_attn(
110
+ load_from_pretrained(
111
+ "mhdang/dpo-sdxl-text2image-v1",
112
+ "diffusion_pytorch_model.safetensors",
113
+ device=device,
114
+ )
115
+ )
116
+ jn_weights = split_conv_attn(
117
+ load_from_pretrained("RunDiffusion/Juggernaut-XL-v9", device=device)
118
+ )
119
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
120
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
121
+ new_conv = merge_models(
122
+ [sd["conv"] for sd in tensors],
123
+ [
124
+ 0.15928833971605916,
125
+ 0.1032449268871776,
126
+ 0.6503217149752791,
127
+ 0.08714501842148402,
128
+ ],
129
+ )
130
+ new_attn = merge_models(
131
+ [sd["attn"] for sd in tensors],
132
+ [
133
+ 0.1877279276437178,
134
+ 0.20014114603909822,
135
+ 0.3922685507065275,
136
+ 0.2198623756106564,
137
+ ],
138
+ )
139
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
140
+ torch.cuda.empty_cache()
141
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
142
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
143
+ unet.load_state_dict({**new_conv, **new_attn})
144
+ state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
145
+ L_REPO, weight_name="sdxl_lightning_4step_lora.safetensors"
146
+ )
147
+ LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
148
+ unet.fuse_lora(lora_scale=3.224682864579401)
149
+ new_weights = split_conv_attn(unet.state_dict())
150
+ l_weights = split_conv_attn(
151
+ load_from_pretrained(
152
+ L_REPO,
153
+ "sdxl_lightning_4step_unet.safetensors",
154
+ subfolder=None,
155
+ device=device,
156
+ )
157
+ )
158
+ jnl_weights = split_conv_attn(
159
+ load_from_pretrained(
160
+ "RunDiffusion/Juggernaut-XL-Lightning",
161
+ "diffusion_pytorch_model.bin",
162
+ device=device,
163
+ )
164
+ )
165
+ tensors = [l_weights, jnl_weights, new_weights]
166
+ new_conv = merge_models(
167
+ [sd["conv"] for sd in tensors],
168
+ [0.47222002022088533, 0.48419531030361584, 0.04358466947549889],
169
+ )
170
+ new_attn = merge_models(
171
+ [sd["attn"] for sd in tensors],
172
+ [0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
173
+ )
174
+ new_weights = {**new_conv, **new_attn}
175
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
176
+ unet.load_state_dict({**new_conv, **new_attn})
177
+
178
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
179
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
180
+ )
181
+ tokenizer = AutoTokenizer.from_pretrained(
182
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False
183
+ )
184
+
185
+ pipe = StableDiffusionXLPipeline.from_pretrained(
186
+ SDXL_REPO,
187
+ unet=unet,
188
+ text_encoder=text_encoder,
189
+ tokenizer=tokenizer,
190
+ torch_dtype=torch.float16,
191
+ variant="fp16",
192
+ )
193
+ # Ensure sampler uses "trailing" timesteps.
194
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
195
+ pipe.scheduler.config, timestep_spacing="trailing"
196
+ )
197
+ pipe = pipe.to(device, dtype=torch.float16)
198
+ return pipe
199
+
200
+
201
+ if __name__ == "__main__":
202
+ pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
203
+ images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
204
+ images[0].save("out.png")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ diffusers==0.26.0
2
+ sentencepiece
3
+ transformers
4
+ accelerate