kelseye commited on
Commit
6322336
·
verified ·
1 Parent(s): 07de83f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_ContentRef_2.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_style_2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/girl_ContentRef_1.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/girl_ContentRef_2.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/girl_style_1.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/girl_style_2.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/house_ContentRef_1.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/house_style_1.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates - Content Reference (FLUX.2-klein-base-4B)
5
+
6
+ This model is one of the Diffusion Templates series models open-sourced by [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). It can extract visual features from an input reference image and fuse them into the base generation guided by natural language descriptions.
7
+
8
+ ## Results
9
+
10
+ > **Prompt:** A cat is sitting on a stone.
11
+
12
+ | Template | Generated | Template | Generated |
13
+ |:---:|:---:|:---:|:---:|
14
+ | ![](./assets/cat_style_1.jpg) | ![](./assets/cat_ContentRef_1.jpg) | ![](./assets/cat_style_2.jpg) | ![](./assets/cat_ContentRef_2.jpg) |
15
+
16
+ ---
17
+
18
+ > **Prompt:** A cozy wooden cottage in a lush green valley, white fluffy clouds in the sky, peaceful atmosphere.
19
+
20
+ | Template | Generated | Template | Generated |
21
+ |:---:|:---:|:---:|:---:|
22
+ | ![](./assets/house_style_1.jpg) | ![](./assets/house_ContentRef_1.jpg) | ![](./assets/house_style_2.jpg) | ![](./assets/house_ContentRef_2.jpg) |
23
+
24
+ ---
25
+
26
+ > **Prompt:** A beautiful girl on an outdoor adventure.
27
+
28
+ | Template | Generated | Template | Generated |
29
+ |:---:|:---:|:---:|:---:|
30
+ | ![](./assets/girl_style_1.jpg) | ![](./assets/girl_ContentRef_1.jpg) | ![](./assets/girl_style_2.jpg) | ![](./assets/girl_ContentRef_2.jpg) |
31
+
32
+ ## Inference Code
33
+
34
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
35
+
36
+ ```
37
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
38
+ cd DiffSynth-Studio
39
+ pip install -e .
40
+ ```
41
+
42
+ * Direct inference (requires 40G GPU memory)
43
+
44
+ ```python
45
+ from diffsynth.diffusion.template import TemplatePipeline
46
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
47
+ import torch
48
+ from modelscope import dataset_snapshot_download
49
+ from PIL import Image
50
+ import numpy as np
51
+ ```
52
+
53
+ ```python
54
+ pipe = Flux2ImagePipeline.from_pretrained(
55
+ torch_dtype=torch.bfloat16,
56
+ device="cuda",
57
+ model_configs=[
58
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
59
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
60
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
61
+ ],
62
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
63
+ )
64
+ pipe.dit = pipe.enable_lora_hot_loading(pipe.dit) # Important!
65
+ template = TemplatePipeline.from_pretrained(
66
+ torch_dtype=torch.bfloat16,
67
+ device="cuda",
68
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ContentRef")],
69
+ )
70
+ dataset_snapshot_download(
71
+ "DiffSynth-Studio/examples_in_diffsynth",
72
+ allow_file_pattern=["templates/*"],
73
+ local_dir="data/examples",
74
+ )
75
+ image = template(
76
+ pipe,
77
+ prompt="A cat is sitting on a stone.",
78
+ seed=0, cfg_scale=4, num_inference_steps=50,
79
+ template_inputs=[{
80
+ "image": Image.open("data/examples/templates/image_style_1.jpg"),
81
+ }],
82
+ negative_template_inputs=[{
83
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
84
+ }],
85
+ )
86
+ image.save("image_ContentRef_1.jpg")
87
+ image = template(
88
+ pipe,
89
+ prompt="A cat is sitting on a stone.",
90
+ seed=0, cfg_scale=4, num_inference_steps=50,
91
+ template_inputs=[{
92
+ "image": Image.open("data/examples/templates/image_style_2.jpg"),
93
+ }],
94
+ negative_template_inputs=[{
95
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
96
+ }],
97
+ )
98
+ image.save("image_ContentRef_2.jpg")
99
+ ```
100
+
101
+ * Enable lazy loading and memory management, requires 24G GPU memory
102
+
103
+ ```python
104
+ from diffsynth.diffusion.template import TemplatePipeline
105
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
106
+ import torch
107
+ from modelscope import dataset_snapshot_download
108
+ from PIL import Image
109
+ import numpy as np
110
+ ```
111
+
112
+ ```python
113
+ vram_config = {
114
+ "offload_dtype": "disk",
115
+ "offload_device": "disk",
116
+ "onload_dtype": torch.float8_e4m3fn,
117
+ "onload_device": "cpu",
118
+ "preparing_dtype": torch.float8_e4m3fn,
119
+ "preparing_device": "cuda",
120
+ "computation_dtype": torch.bfloat16,
121
+ "computation_device": "cuda",
122
+ }
123
+ pipe = Flux2ImagePipeline.from_pretrained(
124
+ torch_dtype=torch.bfloat16,
125
+ device="cuda",
126
+ model_configs=[
127
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
128
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
129
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
130
+ ],
131
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
132
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
133
+ )
134
+ template = TemplatePipeline.from_pretrained(
135
+ torch_dtype=torch.bfloat16,
136
+ device="cuda",
137
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ContentRef")],
138
+ lazy_loading=True,
139
+ )
140
+ dataset_snapshot_download(
141
+ "DiffSynth-Studio/examples_in_diffsynth",
142
+ allow_file_pattern=["templates/*"],
143
+ local_dir="data/examples",
144
+ )
145
+ image = template(
146
+ pipe,
147
+ prompt="A cat is sitting on a stone.",
148
+ seed=0, cfg_scale=4, num_inference_steps=50,
149
+ template_inputs=[{
150
+ "image": Image.open("data/examples/templates/image_style_1.jpg"),
151
+ }],
152
+ negative_template_inputs=[{
153
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
154
+ }],
155
+ )
156
+ image.save("image_ContentRef_1.jpg")
157
+ image = template(
158
+ pipe,
159
+ prompt="A cat is sitting on a stone.",
160
+ seed=0, cfg_scale=4, num_inference_steps=50,
161
+ template_inputs=[{
162
+ "image": Image.open("data/examples/templates/image_style_2.jpg"),
163
+ }],
164
+ negative_template_inputs=[{
165
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
166
+ }],
167
+ )
168
+ image.save("image_ContentRef_2.jpg")
169
+ ```
170
+
171
+ ## Training Code
172
+
173
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
174
+
175
+ ```shell
176
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-ContentRef/*" --local_dir ./data/diffsynth_example_dataset
177
+
178
+ accelerate launch examples/flux2/model_training/train.py \
179
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ContentRef \
180
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ContentRef/metadata.jsonl \
181
+ --extra_inputs "template_inputs" \
182
+ --max_pixels 1048576 \
183
+ --dataset_repeat 50 \
184
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
185
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-ContentRef:" \
186
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
187
+ --learning_rate 1e-4 \
188
+ --num_epochs 2 \
189
+ --remove_prefix_in_ckpt "pipe.template_model." \
190
+ --output_path "./models/train/Template-KleinBase4B-ContentRef_full" \
191
+ --trainable_models "template_model" \
192
+ --use_gradient_checkpointing \
193
+ --find_unused_parameters \
194
+ --enable_lora_hot_loading
195
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-内容参考(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型能够从输入的参考图像中提取视觉特征,并将其融合到基于自然语言描述的基础生成目标中。
13
+
14
+ ## 效果展示
15
+
16
+ > **Prompt:** A cat is sitting on a stone.
17
+
18
+ | Template | Generated | Template | Generated |
19
+ |:---:|:---:|:---:|:---:|
20
+ | ![](./assets/cat_style_1.jpg) | ![](./assets/cat_ContentRef_1.jpg) | ![](./assets/cat_style_2.jpg) | ![](./assets/cat_ContentRef_2.jpg) |
21
+
22
+ ---
23
+
24
+ > **Prompt:** A cozy wooden cottage in a lush green valley, white fluffy clouds in the sky, peaceful atmosphere.
25
+
26
+ | Template | Generated | Template | Generated |
27
+ |:---:|:---:|:---:|:---:|
28
+ | ![](./assets/house_style_1.jpg) | ![](./assets/house_ContentRef_1.jpg) | ![](./assets/house_style_2.jpg) | ![](./assets/house_ContentRef_2.jpg) |
29
+
30
+ ---
31
+
32
+ > **Prompt:** A beautiful girl on an outdoor adventure.
33
+
34
+ | Template | Generated | Template | Generated |
35
+ |:---:|:---:|:---:|:---:|
36
+ | ![](./assets/girl_style_1.jpg) | ![](./assets/girl_ContentRef_1.jpg) | ![](./assets/girl_style_2.jpg) | ![](./assets/girl_ContentRef_2.jpg) |
37
+
38
+ ## 推理代码
39
+
40
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
41
+
42
+ ```
43
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
44
+ cd DiffSynth-Studio
45
+ pip install -e .
46
+ ```
47
+
48
+ * 直接推理,需 40G 显存
49
+
50
+ ```python
51
+ from diffsynth.diffusion.template import TemplatePipeline
52
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
53
+ import torch
54
+ from modelscope import dataset_snapshot_download
55
+ from PIL import Image
56
+ import numpy as np
57
+
58
+ pipe = Flux2ImagePipeline.from_pretrained(
59
+ torch_dtype=torch.bfloat16,
60
+ device="cuda",
61
+ model_configs=[
62
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
63
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
64
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
65
+ ],
66
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
67
+ )
68
+ pipe.dit = pipe.enable_lora_hot_loading(pipe.dit) # Important!
69
+ template = TemplatePipeline.from_pretrained(
70
+ torch_dtype=torch.bfloat16,
71
+ device="cuda",
72
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ContentRef")],
73
+ )
74
+ dataset_snapshot_download(
75
+ "DiffSynth-Studio/examples_in_diffsynth",
76
+ allow_file_pattern=["templates/*"],
77
+ local_dir="data/examples",
78
+ )
79
+ image = template(
80
+ pipe,
81
+ prompt="A cat is sitting on a stone.",
82
+ seed=0, cfg_scale=4, num_inference_steps=50,
83
+ template_inputs = [{
84
+ "image": Image.open("data/examples/templates/image_style_1.jpg"),
85
+ }],
86
+ negative_template_inputs = [{
87
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
88
+ }],
89
+ )
90
+ image.save("image_ContentRef_1.jpg")
91
+ image = template(
92
+ pipe,
93
+ prompt="A cat is sitting on a stone.",
94
+ seed=0, cfg_scale=4, num_inference_steps=50,
95
+ template_inputs = [{
96
+ "image": Image.open("data/examples/templates/image_style_2.jpg"),
97
+ }],
98
+ negative_template_inputs = [{
99
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
100
+ }],
101
+ )
102
+ image.save("image_ContentRef_2.jpg")
103
+ ```
104
+
105
+ * 开启惰性加载和显存管理,需 24G 显存
106
+
107
+ ```python
108
+ from diffsynth.diffusion.template import TemplatePipeline
109
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
110
+ import torch
111
+ from modelscope import dataset_snapshot_download
112
+ from PIL import Image
113
+ import numpy as np
114
+
115
+ vram_config = {
116
+ "offload_dtype": "disk",
117
+ "offload_device": "disk",
118
+ "onload_dtype": torch.float8_e4m3fn,
119
+ "onload_device": "cpu",
120
+ "preparing_dtype": torch.float8_e4m3fn,
121
+ "preparing_device": "cuda",
122
+ "computation_dtype": torch.bfloat16,
123
+ "computation_device": "cuda",
124
+ }
125
+ pipe = Flux2ImagePipeline.from_pretrained(
126
+ torch_dtype=torch.bfloat16,
127
+ device="cuda",
128
+ model_configs=[
129
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
130
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
131
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
132
+ ],
133
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
134
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
135
+ )
136
+ template = TemplatePipeline.from_pretrained(
137
+ torch_dtype=torch.bfloat16,
138
+ device="cuda",
139
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ContentRef")],
140
+ lazy_loading=True,
141
+ )
142
+ dataset_snapshot_download(
143
+ "DiffSynth-Studio/examples_in_diffsynth",
144
+ allow_file_pattern=["templates/*"],
145
+ local_dir="data/examples",
146
+ )
147
+ image = template(
148
+ pipe,
149
+ prompt="A cat is sitting on a stone.",
150
+ seed=0, cfg_scale=4, num_inference_steps=50,
151
+ template_inputs = [{
152
+ "image": Image.open("data/examples/templates/image_style_1.jpg"),
153
+ }],
154
+ negative_template_inputs = [{
155
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
156
+ }],
157
+ )
158
+ image.save("image_ContentRef_1.jpg")
159
+ image = template(
160
+ pipe,
161
+ prompt="A cat is sitting on a stone.",
162
+ seed=0, cfg_scale=4, num_inference_steps=50,
163
+ template_inputs = [{
164
+ "image": Image.open("data/examples/templates/image_style_2.jpg"),
165
+ }],
166
+ negative_template_inputs = [{
167
+ "image": Image.fromarray(np.zeros((1024, 1024, 3), dtype=np.uint8) + 128),
168
+ }],
169
+ )
170
+ image.save("image_ContentRef_2.jpg")
171
+ ```
172
+
173
+ ## 训练代码
174
+
175
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
176
+
177
+ ```shell
178
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-ContentRef/*" --local_dir ./data/diffsynth_example_dataset
179
+
180
+ accelerate launch examples/flux2/model_training/train.py \
181
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ContentRef \
182
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ContentRef/metadata.jsonl \
183
+ --extra_inputs "template_inputs" \
184
+ --max_pixels 1048576 \
185
+ --dataset_repeat 50 \
186
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
187
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-ContentRef:" \
188
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
189
+ --learning_rate 1e-4 \
190
+ --num_epochs 2 \
191
+ --remove_prefix_in_ckpt "pipe.template_model." \
192
+ --output_path "./models/train/Template-KleinBase4B-ContentRef_full" \
193
+ --trainable_models "template_model" \
194
+ --use_gradient_checkpointing \
195
+ --find_unused_parameters \
196
+ --enable_lora_hot_loading
197
+ ```
assets/cat_ContentRef_1.jpg ADDED
assets/cat_ContentRef_2.jpg ADDED

Git LFS Details

  • SHA256: a5ffbfa561bdf44e344fc2d24d5219549b75c949448133f623ab6190e56a3615
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/cat_style_1.jpg ADDED
assets/cat_style_2.jpg ADDED

Git LFS Details

  • SHA256: ab1f138570b5df2ced584c373e570ca9b2b08a55ba6e5f42e77a0e6d09e6e6ac
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
assets/girl_ContentRef_1.jpg ADDED

Git LFS Details

  • SHA256: 36ec31d3d3689290199fa201cb761994d1462b6a0a935d4472d27be873a45756
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB
assets/girl_ContentRef_2.jpg ADDED

Git LFS Details

  • SHA256: bff566792cde66cf43b5ad70cf7eb35ee7589b4fdb7978953ab7b1d218441f96
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
assets/girl_style_1.jpg ADDED

Git LFS Details

  • SHA256: 57c60a71fbd4f771dfb03fa811089cfa4f96a3b9825ffd1ad4cba603d94e77b7
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
assets/girl_style_2.jpg ADDED

Git LFS Details

  • SHA256: 4b5e068a8b3709ec1e804e23a15ab5a7937c9ecd4acc9c0e2ca1b716883e8c3c
  • Pointer size: 131 Bytes
  • Size of remote file: 228 kB
assets/house_ContentRef_1.jpg ADDED

Git LFS Details

  • SHA256: df6b7b01f0b3ab961d7124cb2be04399b65decf5afe83534f8b1fc52c77c6c6e
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
assets/house_ContentRef_2.jpg ADDED
assets/house_style_1.jpg ADDED

Git LFS Details

  • SHA256: 615cdb1fa5f2b200e9b8e60af85f040a09c180932df13118b07236776f0cfc95
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
assets/house_style_2.jpg ADDED
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
2
+ from transformers import SiglipImageProcessor
3
+ from PIL import Image
4
+ import torch
5
+
6
+
7
+ def merge_lora_weight(tensors_A, tensors_B):
8
+ lora_A = torch.concat(tensors_A, dim=0)
9
+ lora_B = torch.concat(tensors_B, dim=1)
10
+ return lora_A, lora_B
11
+
12
+
13
+ def merge_lora(loras, alpha=1):
14
+ lora_merged = {}
15
+ keys = [i for i in loras[0].keys() if ".lora_A." in i]
16
+ for key in keys:
17
+ tensors_A = [lora[key] for lora in loras]
18
+ tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
19
+ lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
20
+ lora_merged[key] = lora_A * alpha
21
+ lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
22
+ return lora_merged
23
+
24
+
25
+ class Siglip2ImageEncoder(SiglipVisionTransformer):
26
+ def __init__(self):
27
+ config = SiglipVisionConfig(
28
+ attention_dropout = 0.0,
29
+ dtype = "float32",
30
+ hidden_act = "gelu_pytorch_tanh",
31
+ hidden_size = 1536,
32
+ image_size = 384,
33
+ intermediate_size = 6144,
34
+ layer_norm_eps = 1e-06,
35
+ model_type = "siglip_vision_model",
36
+ num_attention_heads = 16,
37
+ num_channels = 3,
38
+ num_hidden_layers = 40,
39
+ patch_size = 16,
40
+ transformers_version = "4.56.1",
41
+ _attn_implementation = "sdpa"
42
+ )
43
+ # For compatibility with transformers
44
+ import sys
45
+ sys.modules["template_model"] = None
46
+
47
+ super().__init__(config)
48
+ self.processor = SiglipImageProcessor(
49
+ do_convert_rgb = None,
50
+ do_normalize = True,
51
+ do_rescale = True,
52
+ do_resize = True,
53
+ image_mean = [0.5, 0.5, 0.5],
54
+ image_processor_type = "SiglipImageProcessor",
55
+ image_std = [0.5, 0.5, 0.5],
56
+ processor_class = "SiglipProcessor",
57
+ resample = 2,
58
+ rescale_factor = 0.00392156862745098,
59
+ size = {
60
+ "height": 384,
61
+ "width": 384
62
+ }
63
+ )
64
+
65
+ def forward(self, image, torch_dtype=torch.bfloat16, device="cuda", query_embs=None):
66
+ pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
67
+ pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
68
+ output_attentions = False
69
+ output_hidden_states = False
70
+ interpolate_pos_encoding = False
71
+
72
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
73
+
74
+ encoder_outputs = self.encoder(
75
+ inputs_embeds=hidden_states,
76
+ output_attentions=output_attentions,
77
+ output_hidden_states=output_hidden_states,
78
+ )
79
+
80
+ last_hidden_state = encoder_outputs.last_hidden_state
81
+ last_hidden_state = self.post_layernorm(last_hidden_state)
82
+
83
+ if query_embs is None:
84
+ pooler_output = self.head(last_hidden_state)
85
+ else:
86
+ hidden_state = self.head.attention(query_embs, last_hidden_state, last_hidden_state)[0]
87
+ residual = hidden_state
88
+ hidden_state = self.head.layernorm(hidden_state)
89
+ pooler_output = residual + self.head.mlp(hidden_state)
90
+ return pooler_output
91
+
92
+
93
+ class CompressedMLP(torch.nn.Module):
94
+ def __init__(self, in_dim, mid_dim, out_dim, bias=False):
95
+ super().__init__()
96
+ self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
97
+ self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
98
+
99
+ def forward(self, x):
100
+ x = self.proj_in(x)
101
+ x = self.proj_out(x)
102
+ return x
103
+
104
+
105
+ class ImageEmbeddingToLoraMatrix(torch.nn.Module):
106
+ def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
107
+ super().__init__()
108
+ self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
109
+ self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
110
+ self.lora_a_dim = lora_a_dim
111
+ self.lora_b_dim = lora_b_dim
112
+ self.rank = rank
113
+
114
+ def forward(self, x):
115
+ lora_a = self.proj_a(x).view(self.rank, self.lora_a_dim)
116
+ lora_b = self.proj_b(x).view(self.lora_b_dim, self.rank)
117
+ return lora_a, lora_b
118
+
119
+
120
+ class FLUX2Image2LoRAQuerys(torch.nn.Module):
121
+ def __init__(self, length, dim):
122
+ super().__init__()
123
+ self.weights = torch.nn.Parameter(torch.randn((1, length, dim)))
124
+
125
+ def forward(self):
126
+ return self.weights
127
+
128
+
129
+ class FLUX2Image2LoRAModel(torch.nn.Module):
130
+ def __init__(self):
131
+ super().__init__()
132
+ self.lora_patterns = [
133
+ {
134
+ "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj",
135
+ "num_blocks": 20,
136
+ "dim_in": 3072,
137
+ "dim_out": 27648,
138
+ },
139
+ {
140
+ "name": "single_transformer_blocks.{block_id}.attn.to_out",
141
+ "num_blocks": 20,
142
+ "dim_in": 12288,
143
+ "dim_out": 3072,
144
+ },
145
+ ]
146
+ self.image_encoder = Siglip2ImageEncoder()
147
+ self.parse_lora_layers(
148
+ self.lora_patterns,
149
+ dim_image=1536,
150
+ compress_dim=256,
151
+ rank=4,
152
+ )
153
+ self.query_embs = FLUX2Image2LoRAQuerys(len(self.layers), 1536)
154
+
155
+ def parse_lora_layers(self, lora_patterns, dim_image, compress_dim, rank):
156
+ names = []
157
+ layers = []
158
+ for lora_pattern in lora_patterns:
159
+ for block_id in range(lora_pattern["num_blocks"]):
160
+ name = lora_pattern["name"].format(block_id=block_id)
161
+ layer = ImageEmbeddingToLoraMatrix(dim_image, compress_dim, lora_pattern["dim_in"], lora_pattern["dim_out"], rank)
162
+ names.append(name)
163
+ layers.append(layer)
164
+ self.names = names
165
+ self.layers = torch.nn.ModuleList(layers)
166
+
167
+ @torch.no_grad()
168
+ def process_inputs(self, image, scale=1, **kwargs):
169
+ return {"image": image, "scale": scale}
170
+
171
+ def forward_single_image(self, image):
172
+ embs = self.image_encoder(image, query_embs=self.query_embs.weights, device=self.query_embs.weights.device)
173
+ embs = embs.chunk(len(self.layers), dim=1)
174
+ lora = {}
175
+ for emb, name, layer in zip(embs, self.names, self.layers):
176
+ lora_a, lora_b = layer(emb)
177
+ lora[f"{name}.lora_A.default.weight"] = lora_a
178
+ lora[f"{name}.lora_B.default.weight"] = lora_b
179
+ return {"lora": lora}
180
+
181
+ def forward(self, image, scale=1, **kwargs):
182
+ if not isinstance(image, list):
183
+ image = [image]
184
+ loras = [self.forward_single_image(i)["lora"] for i in image]
185
+ lora = merge_lora(loras, alpha=1 / len(loras) * scale)
186
+ return {"lora": lora}
187
+
188
+
189
+ class DataAnnotator:
190
+ def __init__(self):
191
+ from diffsynth.core import UnifiedDataset
192
+ self.image_oparator = UnifiedDataset.default_image_operator(
193
+ base_path="", # If your dataset contains relative paths, please specify the root path here.
194
+ max_pixels=1024*1024,
195
+ height_division_factor=16,
196
+ width_division_factor=16,
197
+ )
198
+
199
+ def __call__(self, image, **kwargs):
200
+ image = self.image_oparator(image)
201
+ return {"image": image}
202
+
203
+
204
+ def initialize_model_weights():
205
+ from diffsynth import ModelConfig, load_state_dict
206
+ from safetensors.torch import save_file
207
+ import os
208
+ config = ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
209
+ config.download_if_necessary()
210
+ state_dict = load_state_dict(config.path, torch_dtype=torch.bfloat16, device="cuda")
211
+ model = FLUX2Image2LoRAModel().to(dtype=torch.bfloat16, device="cuda")
212
+ model.image_encoder.load_state_dict(state_dict)
213
+ query_embs = {"weights": torch.concat([state_dict["head.probe"]] * len(model.layers), dim=1)}
214
+ model.query_embs.load_state_dict(query_embs, strict=False)
215
+ lora_weights = {}
216
+ for name, param in model.named_parameters():
217
+ if ".proj_b.proj_out." in name:
218
+ lora_weights[name] = param * 0
219
+ elif ".proj_b." in name or ".proj_a." in name:
220
+ lora_weights[name] = param * 0.3
221
+ model.load_state_dict(lora_weights, strict=False)
222
+ print(sum(p.numel() for p in model.parameters()))
223
+ save_file(model.state_dict(), os.path.join(os.path.dirname(__file__), "model.safetensors"))
224
+
225
+
226
+ TEMPLATE_MODEL = FLUX2Image2LoRAModel
227
+ TEMPLATE_MODEL_PATH = "model.safetensors"
228
+ TEMPLATE_DATA_PROCESSOR = DataAnnotator
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7f6cc09ae8693e2083ae7f7f328abf473801a1625d7cb3f10ec1a941d26deef
3
+ size 4277893528