kelseye commited on
Commit
82d3188
·
verified ·
1 Parent(s): cc4cf34

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/ball.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_ControlNet_magic.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/cat_ControlNet_sunshine.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/fox.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates - Structural Control (FLUX.2-klein-base-4B)
5
+
6
+ This model is one of the open-source Diffusion Templates series models from [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). It is a ControlNet control model capable of precisely guiding the spatial structure, object outlines, and perspective of generated images through an input reference image.
7
+
8
+ ## Result Examples
9
+
10
+ |Condition|Prompt: A cat is sitting on a stone, bathed in bright sunshine.|Prompt: A cat is sitting on a stone, surrounded by colorful magical particles.|
11
+ |-|-|-|
12
+ |![](./assets/cat_image_depth.jpg)|![](./assets/cat_ControlNet_sunshine.jpg)|![](./assets/cat_ControlNet_magic.jpg)|
13
+
14
+ |Condition|Prompt: A lovely fox wearing a casual green shirt, sitting in a cafe bar, smiling gently, peaceful anime aesthetic.|Prompt: A cute 3D rendered anthropomorphic fox character wearing a bright green shirt, sitting in a cozy magical tavern, smiling happily.|
15
+ |-|-|-|
16
+ |![](./assets/fox.png)|![](./assets/fox_ControlNet_sunshine.jpg)|![](./assets/fox_ControlNet_magic.jpg)|
17
+
18
+ |Condition|Prompt: A photorealistic glass crystal ball containing a tiny, dreamy scene of a castle, a large tree, and a girl, soft warm lighting, detailed texture.|Prompt: A cute 3D Pixar style scene inside a crystal ball, featuring a girl standing by a large tree with a castle in the background.|
19
+ |-|-|-|
20
+ |![](./assets/ball.png)|![](./assets/ball_ControlNet_sunshine.jpg)|![](./assets/ball_ControlNet_magic.jpg)|
21
+
22
+ ## Inference Code
23
+
24
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
25
+
26
+ ```
27
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
28
+ cd DiffSynth-Studio
29
+ pip install -e .
30
+ ```
31
+
32
+ * Direct inference (requires 40GB GPU memory)
33
+
34
+ ```python
35
+ from diffsynth.diffusion.template import TemplatePipeline
36
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
37
+ import torch
38
+ from modelscope import dataset_snapshot_download
39
+ from PIL import Image
40
+ ```
41
+
42
+ ```python
43
+ pipe = Flux2ImagePipeline.from_pretrained(
44
+ torch_dtype=torch.bfloat16,
45
+ device="cuda",
46
+ model_configs=[
47
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
48
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
49
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
50
+ ],
51
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
52
+ )
53
+ template = TemplatePipeline.from_pretrained(
54
+ torch_dtype=torch.bfloat16,
55
+ device="cuda",
56
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ControlNet")],
57
+ )
58
+ dataset_snapshot_download(
59
+ "DiffSynth-Studio/examples_in_diffsynth",
60
+ allow_file_pattern=["templates/*"],
61
+ local_dir="data/examples",
62
+ )
63
+ image = template(
64
+ pipe,
65
+ prompt="A cat is sitting on a stone, bathed in bright sunshine.",
66
+ seed=0, cfg_scale=4, num_inference_steps=50,
67
+ template_inputs=[{
68
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
69
+ "prompt": "A cat is sitting on a stone, bathed in bright sunshine.",
70
+ }],
71
+ negative_template_inputs=[{
72
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
73
+ "prompt": "",
74
+ }],
75
+ )
76
+ image.save("image_ControlNet_sunshine.jpg")
77
+ image = template(
78
+ pipe,
79
+ prompt="A cat is sitting on a stone, surrounded by colorful magical particles.",
80
+ seed=0, cfg_scale=4, num_inference_steps=50,
81
+ template_inputs=[{
82
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
83
+ "prompt": "A cat is sitting on a stone, surrounded by colorful magical particles.",
84
+ }],
85
+ negative_template_inputs=[{
86
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
87
+ "prompt": "",
88
+ }],
89
+ )
90
+ image.save("image_ControlNet_magic.jpg")
91
+ ```
92
+
93
+ * Enable lazy loading and memory management, requires 24G GPU memory
94
+
95
+ ```python
96
+ from diffsynth.diffusion.template import TemplatePipeline
97
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
98
+ import torch
99
+ from modelscope import dataset_snapshot_download
100
+ from PIL import Image
101
+ ```
102
+
103
+ ```python
104
+ vram_config = {
105
+ "offload_dtype": "disk",
106
+ "offload_device": "disk",
107
+ "onload_dtype": torch.float8_e4m3fn,
108
+ "onload_device": "cpu",
109
+ "preparing_dtype": torch.float8_e4m3fn,
110
+ "preparing_device": "cuda",
111
+ "computation_dtype": torch.bfloat16,
112
+ "computation_device": "cuda",
113
+ }
114
+ pipe = Flux2ImagePipeline.from_pretrained(
115
+ torch_dtype=torch.bfloat16,
116
+ device="cuda",
117
+ model_configs=[
118
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
119
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
120
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
121
+ ],
122
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
123
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
124
+ )
125
+ template = TemplatePipeline.from_pretrained(
126
+ torch_dtype=torch.bfloat16,
127
+ device="cuda",
128
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ControlNet")],
129
+ lazy_loading=True,
130
+ )
131
+ dataset_snapshot_download(
132
+ "DiffSynth-Studio/examples_in_diffsynth",
133
+ allow_file_pattern=["templates/*"],
134
+ local_dir="data/examples",
135
+ )
136
+ image = template(
137
+ pipe,
138
+ prompt="A cat is sitting on a stone, bathed in bright sunshine.",
139
+ seed=0, cfg_scale=4, num_inference_steps=50,
140
+ template_inputs = [{
141
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
142
+ "prompt": "A cat is sitting on a stone, bathed in bright sunshine.",
143
+ }],
144
+ negative_template_inputs = [{
145
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
146
+ "prompt": "",
147
+ }],
148
+ )
149
+ image.save("image_ControlNet_sunshine.jpg")
150
+ image = template(
151
+ pipe,
152
+ prompt="A cat is sitting on a stone, surrounded by colorful magical particles.",
153
+ seed=0, cfg_scale=4, num_inference_steps=50,
154
+ template_inputs = [{
155
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
156
+ "prompt": "A cat is sitting on a stone, surrounded by colorful magical particles.",
157
+ }],
158
+ negative_template_inputs = [{
159
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
160
+ "prompt": "",
161
+ }],
162
+ )
163
+ image.save("image_ControlNet_magic.jpg")
164
+ ```
165
+
166
+ ## Training Code
167
+
168
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
169
+
170
+ ```shell
171
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-ControlNet/*" --local_dir ./data/diffsynth_example_dataset
172
+
173
+ accelerate launch examples/flux2/model_training/train.py \
174
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ControlNet \
175
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ControlNet/metadata.jsonl \
176
+ --extra_inputs "template_inputs" \
177
+ --max_pixels 1048576 \
178
+ --dataset_repeat 50 \
179
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
180
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-ControlNet:" \
181
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
182
+ --learning_rate 1e-4 \
183
+ --num_epochs 2 \
184
+ --remove_prefix_in_ckpt "pipe.template_model." \
185
+ --output_path "./models/train/Template-KleinBase4B-ControlNet_full" \
186
+ --trainable_models "template_model" \
187
+ --use_gradient_checkpointing \
188
+ --find_unused_parameters
189
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-结构控制(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型为 ControlNet 控制模型,能够通过输入的参考图对生成图像的空间结构、物体轮廓与透视进行精准的条件引导。
13
+
14
+ ## 效果展示
15
+
16
+ |Condition|Prompt: A cat is sitting on a stone, bathed in bright sunshine.|Prompt: A cat is sitting on a stone, surrounded by colorful magical particles.|
17
+ |-|-|-|
18
+ |![](./assets/cat_image_depth.jpg)|![](./assets/cat_ControlNet_sunshine.jpg)|![](./assets/cat_ControlNet_magic.jpg)|
19
+
20
+ |Condition|Prompt: A lovely fox wearing a casual green shirt, sitting in a cafe bar, smiling gently, peaceful anime aesthetic.|Prompt: A cute 3D rendered anthropomorphic fox character wearing a bright green shirt, sitting in a cozy magical tavern, smiling happily.|
21
+ |-|-|-|
22
+ |![](./assets/fox.png)|![](./assets/fox_ControlNet_sunshine.jpg)|![](./assets/fox_ControlNet_magic.jpg)|
23
+
24
+ |Condition|Prompt: A photorealistic glass crystal ball containing a tiny, dreamy scene of a castle, a large tree, and a girl, soft warm lighting, detailed texture.|Prompt: A cute 3D Pixar style scene inside a crystal ball, featuring a girl standing by a large tree with a castle in the background.|
25
+ |-|-|-|
26
+ |![](./assets/ball.png)|![](./assets/ball_ControlNet_sunshine.jpg)|![](./assets/ball_ControlNet_magic.jpg)|
27
+
28
+ ## 推理代码
29
+
30
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
31
+
32
+ ```
33
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
34
+ cd DiffSynth-Studio
35
+ pip install -e .
36
+ ```
37
+
38
+ * 直接推理,需 40G 显存
39
+
40
+ ```python
41
+ from diffsynth.diffusion.template import TemplatePipeline
42
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
43
+ import torch
44
+ from modelscope import dataset_snapshot_download
45
+ from PIL import Image
46
+
47
+ pipe = Flux2ImagePipeline.from_pretrained(
48
+ torch_dtype=torch.bfloat16,
49
+ device="cuda",
50
+ model_configs=[
51
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
52
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
53
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
54
+ ],
55
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
56
+ )
57
+ template = TemplatePipeline.from_pretrained(
58
+ torch_dtype=torch.bfloat16,
59
+ device="cuda",
60
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ControlNet")],
61
+ )
62
+ dataset_snapshot_download(
63
+ "DiffSynth-Studio/examples_in_diffsynth",
64
+ allow_file_pattern=["templates/*"],
65
+ local_dir="data/examples",
66
+ )
67
+ image = template(
68
+ pipe,
69
+ prompt="A cat is sitting on a stone, bathed in bright sunshine.",
70
+ seed=0, cfg_scale=4, num_inference_steps=50,
71
+ template_inputs = [{
72
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
73
+ "prompt": "A cat is sitting on a stone, bathed in bright sunshine.",
74
+ }],
75
+ negative_template_inputs = [{
76
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
77
+ "prompt": "",
78
+ }],
79
+ )
80
+ image.save("image_ControlNet_sunshine.jpg")
81
+ image = template(
82
+ pipe,
83
+ prompt="A cat is sitting on a stone, surrounded by colorful magical particles.",
84
+ seed=0, cfg_scale=4, num_inference_steps=50,
85
+ template_inputs = [{
86
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
87
+ "prompt": "A cat is sitting on a stone, surrounded by colorful magical particles.",
88
+ }],
89
+ negative_template_inputs = [{
90
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
91
+ "prompt": "",
92
+ }],
93
+ )
94
+ image.save("image_ControlNet_magic.jpg")
95
+ ```
96
+
97
+ * 开启惰性加载和显存管理,需 24G 显存
98
+
99
+ ```python
100
+ from diffsynth.diffusion.template import TemplatePipeline
101
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
102
+ import torch
103
+ from modelscope import dataset_snapshot_download
104
+ from PIL import Image
105
+
106
+ vram_config = {
107
+ "offload_dtype": "disk",
108
+ "offload_device": "disk",
109
+ "onload_dtype": torch.float8_e4m3fn,
110
+ "onload_device": "cpu",
111
+ "preparing_dtype": torch.float8_e4m3fn,
112
+ "preparing_device": "cuda",
113
+ "computation_dtype": torch.bfloat16,
114
+ "computation_device": "cuda",
115
+ }
116
+ pipe = Flux2ImagePipeline.from_pretrained(
117
+ torch_dtype=torch.bfloat16,
118
+ device="cuda",
119
+ model_configs=[
120
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
121
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
122
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
123
+ ],
124
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
125
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
126
+ )
127
+ template = TemplatePipeline.from_pretrained(
128
+ torch_dtype=torch.bfloat16,
129
+ device="cuda",
130
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ControlNet")],
131
+ lazy_loading=True,
132
+ )
133
+ dataset_snapshot_download(
134
+ "DiffSynth-Studio/examples_in_diffsynth",
135
+ allow_file_pattern=["templates/*"],
136
+ local_dir="data/examples",
137
+ )
138
+ image = template(
139
+ pipe,
140
+ prompt="A cat is sitting on a stone, bathed in bright sunshine.",
141
+ seed=0, cfg_scale=4, num_inference_steps=50,
142
+ template_inputs = [{
143
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
144
+ "prompt": "A cat is sitting on a stone, bathed in bright sunshine.",
145
+ }],
146
+ negative_template_inputs = [{
147
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
148
+ "prompt": "",
149
+ }],
150
+ )
151
+ image.save("image_ControlNet_sunshine.jpg")
152
+ image = template(
153
+ pipe,
154
+ prompt="A cat is sitting on a stone, surrounded by colorful magical particles.",
155
+ seed=0, cfg_scale=4, num_inference_steps=50,
156
+ template_inputs = [{
157
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
158
+ "prompt": "A cat is sitting on a stone, surrounded by colorful magical particles.",
159
+ }],
160
+ negative_template_inputs = [{
161
+ "image": Image.open("data/examples/templates/image_depth.jpg"),
162
+ "prompt": "",
163
+ }],
164
+ )
165
+ image.save("image_ControlNet_magic.jpg")
166
+ ```
167
+
168
+ ## 训练代码
169
+
170
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
171
+
172
+ ```shell
173
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-ControlNet/*" --local_dir ./data/diffsynth_example_dataset
174
+
175
+ accelerate launch examples/flux2/model_training/train.py \
176
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ControlNet \
177
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-ControlNet/metadata.jsonl \
178
+ --extra_inputs "template_inputs" \
179
+ --max_pixels 1048576 \
180
+ --dataset_repeat 50 \
181
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
182
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-ControlNet:" \
183
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
184
+ --learning_rate 1e-4 \
185
+ --num_epochs 2 \
186
+ --remove_prefix_in_ckpt "pipe.template_model." \
187
+ --output_path "./models/train/Template-KleinBase4B-ControlNet_full" \
188
+ --trainable_models "template_model" \
189
+ --use_gradient_checkpointing \
190
+ --find_unused_parameters
191
+ ```
assets/ball.png ADDED

Git LFS Details

  • SHA256: 495f1887423481b8bfc506fd0a150312da38ec27b2d1a07242005240b168213c
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
assets/ball_ControlNet_magic.jpg ADDED
assets/ball_ControlNet_sunshine.jpg ADDED
assets/cat_ControlNet_magic.jpg ADDED

Git LFS Details

  • SHA256: 132bd0bb5d9c11c532d4c0bc32315e00a29d1c3b179e6f5864e8ecd29a00167a
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
assets/cat_ControlNet_sunshine.jpg ADDED

Git LFS Details

  • SHA256: 084ea6b6a8da0a33a8bc14aa4a777e237013e4b0292d0c72d5e483fdf498d2d2
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
assets/cat_image_depth.jpg ADDED
assets/fox.png ADDED

Git LFS Details

  • SHA256: 83df64c08840e0da177c91194cbcb408fa761ce1023368113b1f79894180eb77
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
assets/fox_ControlNet_magic.jpg ADDED
assets/fox_ControlNet_sunshine.jpg ADDED
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ import torch, math
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from diffsynth.core.attention import attention_forward
6
+ from diffsynth.core.gradient import gradient_checkpoint_forward
7
+ from diffsynth.models.flux2_dit import apply_rotary_emb, Flux2PosEmbed
8
+ from diffsynth.models.general_modules import get_timestep_embedding
9
+
10
+
11
+ class AdaLayerNormContinuous(nn.Module):
12
+ def __init__(self, dim_in, dim_out, eps=1e-6):
13
+ super().__init__()
14
+ self.linear = nn.Linear(dim_in, dim_out * 2, bias=False)
15
+ self.norm = nn.LayerNorm(dim_in, eps=eps, elementwise_affine=False, bias=False)
16
+
17
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
18
+ scale, shift = self.linear(torch.nn.functional.silu(conditioning_embedding)).chunk(2, dim=1)
19
+ x = self.norm(x) * (1 + scale) + shift
20
+ return x
21
+
22
+
23
+ class Flux2FeedForward(nn.Module):
24
+ def __init__(self, dim):
25
+ super().__init__()
26
+ self.linear_in = nn.Linear(dim, dim*3*2, bias=False)
27
+ self.linear_out = nn.Linear(dim*3, dim, bias=False)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x1, x2 = self.linear_in(x).chunk(2, dim=-1)
31
+ x = torch.nn.functional.silu(x1) * x2
32
+ x = self.linear_out(x)
33
+ return x
34
+
35
+
36
+ class Flux2TransformerBlock(nn.Module):
37
+ def __init__(self, dim, num_heads, eps=1e-6):
38
+ super().__init__()
39
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
40
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
41
+
42
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
43
+ self.img_ff = Flux2FeedForward(dim)
44
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
45
+ self.txt_ff = Flux2FeedForward(dim)
46
+
47
+ self.num_heads = num_heads
48
+ self.img_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
49
+ self.img_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
50
+ self.img_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
51
+ self.img_to_out = torch.nn.Linear(dim, dim, bias=False)
52
+ self.txt_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
53
+ self.txt_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
54
+ self.txt_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
55
+ self.txt_to_out = torch.nn.Linear(dim, dim, bias=False)
56
+
57
+ def attention(self, img: torch.Tensor, txt: torch.Tensor, image_rotary_emb: torch.Tensor, **kwargs) -> torch.Tensor:
58
+ img_q, img_k, img_v = self.img_to_qkv(img).chunk(3, dim=-1)
59
+ txt_q, txt_k, txt_v = self.txt_to_qkv(txt).chunk(3, dim=-1)
60
+ img_q, img_k, img_v, txt_q, txt_k, txt_v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), (img_q, img_k, img_v, txt_q, txt_k, txt_v)))
61
+ img_q = self.img_norm_q(img_q)
62
+ img_k = self.img_norm_k(img_k)
63
+ txt_q = self.txt_norm_q(txt_q)
64
+ txt_k = self.txt_norm_k(txt_k)
65
+
66
+ q = torch.cat([txt_q, img_q], dim=1)
67
+ k = torch.cat([txt_k, img_k], dim=1)
68
+ v = torch.cat([txt_v, img_v], dim=1)
69
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
70
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
71
+
72
+ img = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
73
+ txt, img = img.split_with_sizes([txt.shape[1], img.shape[1] - txt.shape[1]], dim=1)
74
+ txt = self.txt_to_out(txt)
75
+ img = self.img_to_out(img)
76
+ return img, txt, (k, v)
77
+
78
+ def forward(self, img, txt, temb_mod_params_img, temb_mod_params_txt, image_rotary_emb):
79
+ (img_shift_msa, img_scale_msa, img_gate_msa), (img_shift_mlp, img_scale_mlp, img_gate_mlp) = temb_mod_params_img
80
+ (txt_shift_msa, txt_scale_msa, txt_gate_msa), (txt_shift_mlp, txt_scale_mlp, txt_gate_mlp) = temb_mod_params_txt
81
+
82
+ norm_img = (1 + img_scale_msa) * self.img_norm1(img) + img_shift_msa
83
+ norm_txt = (1 + txt_scale_msa) * self.txt_norm1(txt) + txt_shift_msa
84
+ img_attn_out, txt_attn_out, kv_cache = self.attention(norm_img, norm_txt, image_rotary_emb)
85
+
86
+ img = img + img_gate_msa * img_attn_out
87
+ norm_img = self.img_norm2(img) * (1 + img_scale_mlp) + img_shift_mlp
88
+ img = img + img_gate_mlp * self.img_ff(norm_img)
89
+
90
+ txt = txt + txt_gate_msa * txt_attn_out
91
+ norm_txt = self.txt_norm2(txt) * (1 + txt_scale_mlp) + txt_shift_mlp
92
+ txt = txt + txt_gate_mlp * self.txt_ff(norm_txt)
93
+ return txt, img, kv_cache
94
+
95
+
96
+ class Flux2SingleTransformerBlock(nn.Module):
97
+ def __init__(self, dim, num_heads, eps: float = 1e-6):
98
+ super().__init__()
99
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
100
+ self.dim = dim
101
+ self.num_heads = num_heads
102
+ self.norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
103
+ self.norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
104
+ self.to_qkv_mlp_proj = torch.nn.Linear(dim, dim * 3 + dim * 3 * 2, bias=False)
105
+ self.to_out = torch.nn.Linear(dim + dim * 3, dim, bias=False)
106
+
107
+ def attention(self, x: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
108
+ x = self.to_qkv_mlp_proj(x)
109
+ qkv, mlp_x = torch.split(x, [3 * self.dim, self.dim * 3 * 2], dim=-1)
110
+ q, k, v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), qkv.chunk(3, dim=-1)))
111
+
112
+ q = self.norm_q(q)
113
+ k = self.norm_k(k)
114
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
115
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
116
+ x = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
117
+
118
+ x1, x2 = mlp_x.chunk(2, dim=-1)
119
+ x = torch.cat([x, torch.nn.functional.silu(x1) * x2], dim=-1)
120
+ x = self.to_out(x)
121
+ return x, (k, v)
122
+
123
+ def forward(self, x, temb_mod_params, image_rotary_emb):
124
+ mod_shift, mod_scale, mod_gate = temb_mod_params
125
+ norm_x = (1 + mod_scale) * self.norm(x) + mod_shift
126
+ attn_output, kv_cache = self.attention(x=norm_x, image_rotary_emb=image_rotary_emb,)
127
+ x = x + mod_gate * attn_output
128
+ return x, kv_cache
129
+
130
+
131
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
132
+ def __init__(self, dim_in, dim_out):
133
+ super().__init__()
134
+ self.dim_in = dim_in
135
+ self.timestep_embedder = torch.nn.Sequential(nn.Linear(dim_in, dim_out, bias=False), nn.SiLU(), nn.Linear(dim_out, dim_out, bias=False))
136
+
137
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
138
+ timesteps_proj = get_timestep_embedding(timestep, self.dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
139
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype))
140
+ return timesteps_emb
141
+
142
+
143
+ class Flux2Modulation(nn.Module):
144
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
145
+ super().__init__()
146
+ self.mod_param_sets = mod_param_sets
147
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
148
+
149
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
150
+ mod = torch.nn.functional.silu(temb)
151
+ mod = self.linear(mod)
152
+ mod = mod.unsqueeze(1)
153
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
154
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
155
+
156
+
157
+ class Flux2DiTVariantModel(torch.nn.Module):
158
+ def __init__(
159
+ self,
160
+ patch_size: int = 1,
161
+ in_channels: int = 128,
162
+ out_channels: Optional[int] = None,
163
+ num_layers: int = 5,
164
+ num_single_layers: int = 20,
165
+ attention_head_dim: int = 128,
166
+ num_attention_heads: int = 24,
167
+ joint_attention_dim: int = 7680,
168
+ timestep_guidance_channels: int = 256,
169
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
170
+ rope_theta: int = 2000,
171
+ ):
172
+ super().__init__()
173
+ self.out_channels = out_channels or in_channels
174
+ self.inner_dim = num_attention_heads * attention_head_dim
175
+
176
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
177
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
178
+
179
+ # 2. Combined timestep + guidance embedding
180
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
181
+ dim_in=timestep_guidance_channels,
182
+ dim_out=self.inner_dim,
183
+ )
184
+
185
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
186
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
187
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
188
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
189
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
190
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
191
+
192
+ # 4. Input projections
193
+ self.img_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
194
+ self.txt_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
195
+
196
+ # 5. Double Stream Transformer Blocks
197
+ self.transformer_blocks = nn.ModuleList([Flux2TransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_layers)])
198
+
199
+ # 6. Single Stream Transformer Blocks
200
+ self.single_transformer_blocks = nn.ModuleList([Flux2SingleTransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_single_layers)])
201
+
202
+ # 7. Output layers
203
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim)
204
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
205
+
206
+ def prepare_static_parameters(self, img, txt):
207
+ timestep = torch.zeros((1,), dtype=txt.dtype, device=txt.device)
208
+ img_ids = []
209
+ for latent_id, latent in enumerate(img):
210
+ _, _, height, width = latent.shape
211
+ x_ids = torch.cartesian_prod(torch.tensor([(latent_id + 1) * 10]), torch.arange(height), torch.arange(width), torch.arange(1))
212
+ img_ids.append(x_ids)
213
+ img_ids = torch.cat(img_ids, dim=0).to(txt.device)
214
+ txt_ids = torch.cartesian_prod(torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(txt.shape[1])).to(txt.device)
215
+ return timestep, img_ids, txt_ids
216
+
217
+ def patchify(self, img):
218
+ img_ = []
219
+ for latent in img:
220
+ latent = rearrange(latent, "B C H W -> B (H W) C")
221
+ img_.append(latent)
222
+ img_ = torch.concat(img_, dim=1)
223
+ return img_
224
+
225
+ @torch.no_grad()
226
+ def process_inputs(
227
+ self,
228
+ pipe, image, prompt,
229
+ **kwargs
230
+ ):
231
+ images = image
232
+ if not isinstance(images, list):
233
+ images = [images]
234
+ pipe.load_models_to_device(["vae"])
235
+ kv_cache_input_latents = [pipe.vae.encode(pipe.preprocess_image(image)) for image in images]
236
+ prompt_emb_unit = [unit for unit in pipe.units if unit.__class__.__name__ == "Flux2Unit_Qwen3PromptEmbedder"][0]
237
+ kv_cache_prompt_emb = prompt_emb_unit.process(pipe, prompt)["prompt_embeds"]
238
+ pipe.load_models_to_device([])
239
+ return {
240
+ "kv_cache_input_latents": kv_cache_input_latents,
241
+ "kv_cache_prompt_emb": kv_cache_prompt_emb,
242
+ }
243
+
244
+ def forward(
245
+ self,
246
+ kv_cache_input_latents,
247
+ kv_cache_prompt_emb,
248
+ use_gradient_checkpointing=False,
249
+ use_gradient_checkpointing_offload=False,
250
+ **kwargs,
251
+ ):
252
+ img = kv_cache_input_latents
253
+ txt = kv_cache_prompt_emb
254
+ num_txt_tokens = txt.shape[1]
255
+
256
+ # 1. Calculate timestep embedding and modulation parameters
257
+ timestep, img_ids, txt_ids = self.prepare_static_parameters(img, txt)
258
+ img = self.patchify(img)
259
+
260
+ temb = self.time_guidance_embed(timestep)
261
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
262
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
263
+ single_stream_mod = self.single_stream_modulation(temb)[0]
264
+
265
+ # 2. Input projection for image (img) and conditioning text (txt)
266
+ img = self.img_embedder(img)
267
+ txt = self.txt_embedder(txt)
268
+
269
+ # 3. Calculate RoPE embeddings from image and text tokens
270
+ image_rotary_emb = self.pos_embed(img_ids)
271
+ text_rotary_emb = self.pos_embed(txt_ids)
272
+ concat_rotary_emb = (
273
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
274
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
275
+ )
276
+
277
+ # 4. Double Stream Transformer Blocks
278
+ kv_cache = {}
279
+ for block_id, block in enumerate(self.transformer_blocks):
280
+ txt, img, kv_cache_ = gradient_checkpoint_forward(
281
+ block,
282
+ use_gradient_checkpointing=use_gradient_checkpointing,
283
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
284
+ img=img,
285
+ txt=txt,
286
+ temb_mod_params_img=double_stream_mod_img,
287
+ temb_mod_params_txt=double_stream_mod_txt,
288
+ image_rotary_emb=concat_rotary_emb,
289
+ )
290
+ kv_cache[f"double_{block_id}"] = kv_cache_
291
+ # Concatenate text and image streams for single-block inference
292
+ img = torch.cat([txt, img], dim=1)
293
+
294
+ # 5. Single Stream Transformer Blocks
295
+ for block_id, block in enumerate(self.single_transformer_blocks):
296
+ img, kv_cache_ = gradient_checkpoint_forward(
297
+ block,
298
+ use_gradient_checkpointing=use_gradient_checkpointing,
299
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
300
+ x=img,
301
+ temb_mod_params=single_stream_mod,
302
+ image_rotary_emb=concat_rotary_emb,
303
+ )
304
+ kv_cache[f"single_{block_id}"] = kv_cache_
305
+ # # Remove text tokens from concatenated stream
306
+ # img = img[:, num_txt_tokens:, ...]
307
+
308
+ # # 6. Output layers
309
+ # img = self.norm_out(img, temb)
310
+ # output = self.proj_out(img)
311
+
312
+ return {"kv_cache": kv_cache}
313
+
314
+
315
+ class TrainDataProcessor:
316
+ def __init__(self):
317
+ from diffsynth.core import UnifiedDataset
318
+ self.image_oparator = UnifiedDataset.default_image_operator(
319
+ base_path="", # If your dataset contains relative paths, please specify the root path here.
320
+ max_pixels=1024*1024,
321
+ height_division_factor=16,
322
+ width_division_factor=16,
323
+ )
324
+
325
+ def __call__(self, image, prompt, **kwargs):
326
+ return {
327
+ "image": self.image_oparator(image),
328
+ "prompt": prompt,
329
+ }
330
+
331
+ TEMPLATE_MODEL = Flux2DiTVariantModel
332
+ TEMPLATE_MODEL_PATH = "model.safetensors"
333
+ TEMPLATE_DATA_PROCESSOR = TrainDataProcessor
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f25e7220d421bee0bed9cae5a572fdeeff253bd3be617fffeaae39aeab4902c4
3
+ size 7751106808