kelseye commited on
Commit
9e62268
·
verified ·
1 Parent(s): 3ff3e66

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_rgb_cold.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_rgb_normal.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/cat_rgb_warm.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/room_rgb_cold.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/room_rgb_warm.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates - Color Tone Adjustment (FLUX.2-klein-base-4B)
5
+
6
+ This model is part of the open-source Diffusion Templates series by [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). It is a color tone adjustment model that allows users to globally control the image's color tendency and color temperature atmosphere by directly inputting normalized numerical values for the `R`, `G`, and `B` channels.
7
+
8
+ ## Results
9
+
10
+ > **Prompt:** A cat is sitting on a stone.
11
+
12
+ | cold | normal | warm |
13
+ |:---:|:---:|:---:|
14
+ | ![](./assets/cat_rgb_cold.jpg) | ![](./assets/cat_rgb_normal.jpg) | ![](./assets/cat_rgb_warm.jpg) |
15
+
16
+ ---
17
+
18
+ > **Prompt:** A cinematic portrait of a beautiful woman looking out a rainy window.
19
+
20
+ | cold | normal | warm |
21
+ |:---:|:---:|:---:|
22
+ | ![](./assets/girl_rgb_cold.jpg) | ![](./assets/girl_rgb_normal.jpg) | ![](./assets/girl_rgb_warm.jpg) |
23
+
24
+ ---
25
+
26
+ > **Prompt:** A modern minimalist living room with furniture.
27
+
28
+ | cold | normal | warm |
29
+ |:---:|:---:|:---:|
30
+ | ![](./assets/room_rgb_cold.jpg) | ![](./assets/room_rgb_normal.jpg) | ![](./assets/room_rgb_warm.jpg) |
31
+
32
+ ## Inference Code
33
+
34
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
35
+
36
+ ```
37
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
38
+ cd DiffSynth-Studio
39
+ pip install -e .
40
+ ```
41
+
42
+ * Direct inference, requires 40G GPU memory
43
+
44
+ ```python
45
+ from diffsynth.diffusion.template import TemplatePipeline
46
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
47
+ import torch
48
+ ```
49
+
50
+ pipe = Flux2ImagePipeline.from_pretrained(
51
+ torch_dtype=torch.bfloat16,
52
+ device="cuda",
53
+ model_configs=[
54
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
55
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
56
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
57
+ ],
58
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
59
+ )
60
+ template = TemplatePipeline.from_pretrained(
61
+ torch_dtype=torch.bfloat16,
62
+ device="cuda",
63
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-SoftRGB")],
64
+ )
65
+ image = template(
66
+ pipe,
67
+ prompt="A cat is sitting on a stone.",
68
+ seed=0, cfg_scale=4, num_inference_steps=50,
69
+ template_inputs = [{
70
+ "R": 128/255,
71
+ "G": 128/255,
72
+ "B": 128/255
73
+ }],
74
+ )
75
+ image.save("image_rgb_normal.jpg")
76
+ image = template(
77
+ pipe,
78
+ prompt="A cat is sitting on a stone.",
79
+ seed=0, cfg_scale=4, num_inference_steps=50,
80
+ template_inputs = [{
81
+ "R": 208/255,
82
+ "G": 185/255,
83
+ "B": 138/255
84
+ }],
85
+ )
86
+ image.save("image_rgb_warm.jpg")
87
+ image = template(
88
+ pipe,
89
+ prompt="A cat is sitting on a stone.",
90
+ seed=0, cfg_scale=4, num_inference_steps=50,
91
+ template_inputs = [{
92
+ "R": 94/255,
93
+ "G": 163/255,
94
+ "B": 174/255
95
+ }],
96
+ )
97
+ image.save("image_rgb_cold.jpg")
98
+ ```
99
+
100
+ * Enable lazy loading and memory management, requires 24G GPU memory
101
+
102
+ ```python
103
+ from diffsynth.diffusion.template import TemplatePipeline
104
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
105
+ import torch
106
+
107
+ ```python
108
+ vram_config = {
109
+ "offload_dtype": "disk",
110
+ "offload_device": "disk",
111
+ "onload_dtype": torch.float8_e4m3fn,
112
+ "onload_device": "cpu",
113
+ "preparing_dtype": torch.float8_e4m3fn,
114
+ "preparing_device": "cuda",
115
+ "computation_dtype": torch.bfloat16,
116
+ "computation_device": "cuda",
117
+ }
118
+ pipe = Flux2ImagePipeline.from_pretrained(
119
+ torch_dtype=torch.bfloat16,
120
+ device="cuda",
121
+ model_configs=[
122
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
123
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
124
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
125
+ ],
126
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
127
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
128
+ )
129
+ template = TemplatePipeline.from_pretrained(
130
+ torch_dtype=torch.bfloat16,
131
+ device="cuda",
132
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-SoftRGB")],
133
+ lazy_loading=True,
134
+ )
135
+ image = template(
136
+ pipe,
137
+ prompt="A cat is sitting on a stone.",
138
+ seed=0, cfg_scale=4, num_inference_steps=50,
139
+ template_inputs = [{
140
+ "R": 128/255,
141
+ "G": 128/255,
142
+ "B": 128/255
143
+ }],
144
+ )
145
+ image.save("image_rgb_normal.jpg")
146
+ image = template(
147
+ pipe,
148
+ prompt="A cat is sitting on a stone.",
149
+ seed=0, cfg_scale=4, num_inference_steps=50,
150
+ template_inputs = [{
151
+ "R": 208/255,
152
+ "G": 185/255,
153
+ "B": 138/255
154
+ }],
155
+ )
156
+ image.save("image_rgb_warm.jpg")
157
+ image = template(
158
+ pipe,
159
+ prompt="A cat is sitting on a stone.",
160
+ seed=0, cfg_scale=4, num_inference_steps=50,
161
+ template_inputs = [{
162
+ "R": 94/255,
163
+ "G": 163/255,
164
+ "B": 174/255
165
+ }],
166
+ )
167
+ image.save("image_rgb_cold.jpg")
168
+ ```
169
+
170
+ ## Training Code
171
+
172
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
173
+
174
+ ```shell
175
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-SoftRGB/*" --local_dir ./data/diffsynth_example_dataset
176
+
177
+ accelerate launch examples/flux2/model_training/train.py \
178
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-SoftRGB \
179
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-SoftRGB/metadata.jsonl \
180
+ --extra_inputs "template_inputs" \
181
+ --max_pixels 1048576 \
182
+ --dataset_repeat 50 \
183
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
184
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-SoftRGB:" \
185
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
186
+ --learning_rate 1e-4 \
187
+ --num_epochs 2 \
188
+ --remove_prefix_in_ckpt "pipe.template_model." \
189
+ --output_path "./models/train/Template-KleinBase4B-SoftRGB_full" \
190
+ --trainable_models "template_model" \
191
+ --use_gradient_checkpointing \
192
+ --find_unused_parameters
193
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-色调调节(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型为色调调节模型,允许用户通过直接输入 `R`、`G`、`B` 三个通道的归一化数值,全局调控画面的色彩倾向与色温氛围。
13
+
14
+ ## 效果展示
15
+
16
+ > **Prompt:** A cat is sitting on a stone.
17
+
18
+ | cold | normal | warm |
19
+ |:---:|:---:|:---:|
20
+ | ![](./assets/cat_rgb_cold.jpg) | ![](./assets/cat_rgb_normal.jpg) | ![](./assets/cat_rgb_warm.jpg) |
21
+
22
+ ---
23
+
24
+ > **Prompt:** A cinematic portrait of a beautiful woman looking out a rainy window.
25
+
26
+ | cold | normal | warm |
27
+ |:---:|:---:|:---:|
28
+ | ![](./assets/girl_rgb_cold.jpg) | ![](./assets/girl_rgb_normal.jpg) | ![](./assets/girl_rgb_warm.jpg) |
29
+
30
+ ---
31
+
32
+ > **Prompt:** A modern minimalist living room with furniture.
33
+
34
+ | cold | normal | warm |
35
+ |:---:|:---:|:---:|
36
+ | ![](./assets/room_rgb_cold.jpg) | ![](./assets/room_rgb_normal.jpg) | ![](./assets/room_rgb_warm.jpg) |
37
+
38
+ ## 推理代码
39
+
40
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
41
+
42
+ ```
43
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
44
+ cd DiffSynth-Studio
45
+ pip install -e .
46
+ ```
47
+
48
+ * 直接推理,需 40G 显存
49
+
50
+ ```python
51
+ from diffsynth.diffusion.template import TemplatePipeline
52
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
53
+ import torch
54
+
55
+ pipe = Flux2ImagePipeline.from_pretrained(
56
+ torch_dtype=torch.bfloat16,
57
+ device="cuda",
58
+ model_configs=[
59
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
60
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
61
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
62
+ ],
63
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
64
+ )
65
+ template = TemplatePipeline.from_pretrained(
66
+ torch_dtype=torch.bfloat16,
67
+ device="cuda",
68
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-SoftRGB")],
69
+ )
70
+ image = template(
71
+ pipe,
72
+ prompt="A cat is sitting on a stone.",
73
+ seed=0, cfg_scale=4, num_inference_steps=50,
74
+ template_inputs = [{
75
+ "R": 128/255,
76
+ "G": 128/255,
77
+ "B": 128/255
78
+ }],
79
+ )
80
+ image.save("image_rgb_normal.jpg")
81
+ image = template(
82
+ pipe,
83
+ prompt="A cat is sitting on a stone.",
84
+ seed=0, cfg_scale=4, num_inference_steps=50,
85
+ template_inputs = [{
86
+ "R": 208/255,
87
+ "G": 185/255,
88
+ "B": 138/255
89
+ }],
90
+ )
91
+ image.save("image_rgb_warm.jpg")
92
+ image = template(
93
+ pipe,
94
+ prompt="A cat is sitting on a stone.",
95
+ seed=0, cfg_scale=4, num_inference_steps=50,
96
+ template_inputs = [{
97
+ "R": 94/255,
98
+ "G": 163/255,
99
+ "B": 174/255
100
+ }],
101
+ )
102
+ image.save("image_rgb_cold.jpg")
103
+ ```
104
+
105
+ * 开启惰性加载和显存管理,需 24G 显存
106
+
107
+ ```python
108
+ from diffsynth.diffusion.template import TemplatePipeline
109
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
110
+ import torch
111
+
112
+ vram_config = {
113
+ "offload_dtype": "disk",
114
+ "offload_device": "disk",
115
+ "onload_dtype": torch.float8_e4m3fn,
116
+ "onload_device": "cpu",
117
+ "preparing_dtype": torch.float8_e4m3fn,
118
+ "preparing_device": "cuda",
119
+ "computation_dtype": torch.bfloat16,
120
+ "computation_device": "cuda",
121
+ }
122
+ pipe = Flux2ImagePipeline.from_pretrained(
123
+ torch_dtype=torch.bfloat16,
124
+ device="cuda",
125
+ model_configs=[
126
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
127
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
128
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
129
+ ],
130
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
131
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
132
+ )
133
+ template = TemplatePipeline.from_pretrained(
134
+ torch_dtype=torch.bfloat16,
135
+ device="cuda",
136
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-SoftRGB")],
137
+ lazy_loading=True,
138
+ )
139
+ image = template(
140
+ pipe,
141
+ prompt="A cat is sitting on a stone.",
142
+ seed=0, cfg_scale=4, num_inference_steps=50,
143
+ template_inputs = [{
144
+ "R": 128/255,
145
+ "G": 128/255,
146
+ "B": 128/255
147
+ }],
148
+ )
149
+ image.save("image_rgb_normal.jpg")
150
+ image = template(
151
+ pipe,
152
+ prompt="A cat is sitting on a stone.",
153
+ seed=0, cfg_scale=4, num_inference_steps=50,
154
+ template_inputs = [{
155
+ "R": 208/255,
156
+ "G": 185/255,
157
+ "B": 138/255
158
+ }],
159
+ )
160
+ image.save("image_rgb_warm.jpg")
161
+ image = template(
162
+ pipe,
163
+ prompt="A cat is sitting on a stone.",
164
+ seed=0, cfg_scale=4, num_inference_steps=50,
165
+ template_inputs = [{
166
+ "R": 94/255,
167
+ "G": 163/255,
168
+ "B": 174/255
169
+ }],
170
+ )
171
+ image.save("image_rgb_cold.jpg")
172
+ ```
173
+
174
+ ## 训练代码
175
+
176
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
177
+
178
+ ```shell
179
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-SoftRGB/*" --local_dir ./data/diffsynth_example_dataset
180
+
181
+ accelerate launch examples/flux2/model_training/train.py \
182
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-SoftRGB \
183
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-SoftRGB/metadata.jsonl \
184
+ --extra_inputs "template_inputs" \
185
+ --max_pixels 1048576 \
186
+ --dataset_repeat 50 \
187
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
188
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-SoftRGB:" \
189
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
190
+ --learning_rate 1e-4 \
191
+ --num_epochs 2 \
192
+ --remove_prefix_in_ckpt "pipe.template_model." \
193
+ --output_path "./models/train/Template-KleinBase4B-SoftRGB_full" \
194
+ --trainable_models "template_model" \
195
+ --use_gradient_checkpointing \
196
+ --find_unused_parameters
197
+ ```
assets/cat_rgb_cold.jpg ADDED

Git LFS Details

  • SHA256: 7ab91bc1892c1228e2c220128cb0245ef6e89bdb1edddfd1735b2722c86a6e74
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
assets/cat_rgb_normal.jpg ADDED

Git LFS Details

  • SHA256: 629aceaf20b898f3e4379d9419912b7d1f551740733ab146656723502a2f21a1
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
assets/cat_rgb_warm.jpg ADDED

Git LFS Details

  • SHA256: aa63c1645f0428b6df1ed4ed0a76c46a5ad19f5ebd014b5aa553424af92ac56c
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
assets/girl_rgb_cold.jpg ADDED
assets/girl_rgb_normal.jpg ADDED
assets/girl_rgb_warm.jpg ADDED
assets/room_rgb_cold.jpg ADDED

Git LFS Details

  • SHA256: 7145c51ecb46a224e4e2fdaeede63f70ea3c76cfed7b9a0b41be78b44a091dba
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
assets/room_rgb_normal.jpg ADDED
assets/room_rgb_warm.jpg ADDED

Git LFS Details

  • SHA256: c9a54d37bea82f4bf4b4ccf624201326e82be2d8b92b901bed44ce5d1083ecb3
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ class MultiValueEncoder(torch.nn.Module):
7
+ def __init__(self, dim_in=256, dim_out=4096, length=32, num_values=3):
8
+ super().__init__()
9
+ self.length = length
10
+ self.prefer_value_embedder = torch.nn.Sequential(torch.nn.Linear(dim_in * num_values, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out))
11
+ self.positional_embedding = torch.nn.Parameter(torch.randn(self.length, dim_out))
12
+
13
+ def get_timestep_embedding(self, timesteps, embedding_dim, max_period=10000):
14
+ half_dim = embedding_dim // 2
15
+ exponent = -math.log(max_period) * torch.arange(0, half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
16
+ emb = timesteps[:, None].float() * torch.exp(exponent)[None, :]
17
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
18
+ return emb
19
+
20
+ def forward(self, value, dtype):
21
+ emb = self.get_timestep_embedding(value * 1000, 256).to(dtype)
22
+ emb = emb.view(1, -1)
23
+ emb = self.prefer_value_embedder(emb).squeeze(0)
24
+ base_embeddings = emb.expand(self.length, -1)
25
+ positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
26
+ learned_embeddings = base_embeddings + positional_embedding
27
+ return learned_embeddings
28
+
29
+
30
+ class ValueFormatModel(torch.nn.Module):
31
+ def __init__(self, num_double_blocks=5, num_single_blocks=20, dim=3072, num_heads=24, length=512):
32
+ super().__init__()
33
+ self.block_names = [f"double_{i}" for i in range(num_double_blocks)] + [f"single_{i}" for i in range(num_single_blocks)]
34
+ self.proj_k = torch.nn.ModuleDict({block_name: MultiValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
35
+ self.proj_v = torch.nn.ModuleDict({block_name: MultiValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
36
+ self.num_heads = num_heads
37
+ self.length = length
38
+
39
+ @torch.no_grad()
40
+ def process_inputs(self, pipe, R, G, B, **kwargs):
41
+ return {"value": torch.Tensor([R, G, B]).to(dtype=pipe.torch_dtype, device=pipe.device)}
42
+
43
+ def forward(self, value, **kwargs):
44
+ kv_cache = {}
45
+ for block_name in self.block_names:
46
+ k = self.proj_k[block_name](value, value.dtype)
47
+ k = k.view(1, self.length, self.num_heads, -1)
48
+ v = self.proj_v[block_name](value, value.dtype)
49
+ v = v.view(1, self.length, self.num_heads, -1)
50
+ kv_cache[block_name] = (k, v)
51
+ return {"kv_cache": kv_cache}
52
+
53
+
54
+ class DataAnnotator:
55
+ def __call__(self, image, **kwargs):
56
+ image = Image.open(image).convert("RGB")
57
+ image = np.array(image).astype(np.float32)
58
+ r, g, b = image[:, :, 0].mean() / 255, image[:, :, 1].mean() / 255, image[:, :, 2].mean() / 255
59
+ return {"R": r, "G": g, "B": b}
60
+
61
+
62
+ TEMPLATE_MODEL = ValueFormatModel
63
+ TEMPLATE_MODEL_PATH = "model.safetensors"
64
+ TEMPLATE_DATA_PROCESSOR = DataAnnotator
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49fe9aa7fc27f1ac3ebe6d99013d251ef8312cfa218e813bcf1cc0cbdaeffbca
3
+ size 1337578464