gvecchio commited on
Commit
5b8131f
·
1 Parent(s): 73c6e44

Add model

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,98 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ language:
6
+ - en
7
+ tags:
8
+ - diffusers
9
+ - matfuse
10
+ - pbr
11
+ - material-generation
12
+ - svbrdf
13
+ - text-to-image
14
  ---
15
+
16
+
17
+ # MatFuse — Controllable Material Generation with Diffusion Models
18
+
19
+ MatFuse generates tileable PBR material maps (diffuse, normal, roughness,
20
+ specular) from text, reference images, sketches, and/or color palettes.
21
+
22
+ > **Paper:** [MatFuse: Controllable Material Generation with Diffusion Models](https://arxiv.org/abs/2308.11408) — CVPR 2024
23
+ > **Project page:** <https://gvecchio.com/matfuse/>
24
+
25
+ ## Quick Start
26
+
27
+ ```python
28
+ import torch
29
+ from diffusers import DiffusionPipeline
30
+
31
+ pipe = DiffusionPipeline.from_pretrained(
32
+ "gvecchio/MatFuse",
33
+ trust_remote_code=True,
34
+ torch_dtype=torch.float16,
35
+ )
36
+ pipe = pipe.to("cuda")
37
+
38
+ result = pipe(
39
+ text="red brick wall",
40
+ num_inference_steps=50,
41
+ guidance_scale=4.0,
42
+ generator=torch.Generator("cuda").manual_seed(42),
43
+ )
44
+
45
+ result["diffuse"][0].save("diffuse.png")
46
+ result["normal"][0].save("normal.png")
47
+ result["roughness"][0].save("roughness.png")
48
+ result["specular"][0].save("specular.png")
49
+ ```
50
+
51
+ ## Conditioning Inputs
52
+
53
+ All conditions are **optional** and freely composable:
54
+
55
+ | Input | Type | Description |
56
+ |-------|------|-------------|
57
+ | `text` | `str` | Text description of the material |
58
+ | `image` | `PIL.Image` | Reference image for style/appearance |
59
+ | `sketch` | `PIL.Image` (grayscale) | Binary edge map for structure |
60
+ | `palette` | `list[tuple]` | Up to 5 RGB colour tuples (0–255) |
61
+
62
+ ```python
63
+ from PIL import Image
64
+
65
+ result = pipe(
66
+ image=Image.open("reference.png"),
67
+ text="rough stone texture",
68
+ palette=[(120, 80, 60), (90, 60, 40), (150, 110, 80), (70, 50, 30), (180, 140, 100)],
69
+ num_inference_steps=50,
70
+ guidance_scale=4.0,
71
+ )
72
+ ```
73
+
74
+ ## Architecture
75
+
76
+ | Component | Class | Key parameters |
77
+ |-----------|-------|----------------|
78
+ | **UNet** | `UNet2DConditionModel` | in=16, out=12, blocks=[256,512,1024], cross_attn=512 |
79
+ | **VAE** | `MatFuseVQModel` (custom) | 4 encoders + 4 VQ codebooks (4096×3), shared decoder, f=8 |
80
+ | **Scheduler** | `DDIMScheduler` | β 0.0015–0.0195, scaled_linear, ε-prediction |
81
+ | **Conditioning** | `MultiConditionEncoder` (custom) | CLIP ViT-B/16 · sentence-transformers · palette MLP · sketch CNN |
82
+
83
+ ## 📜 Citation
84
+
85
+ ```bibtex
86
+ @inproceedings{vecchio2024matfuse,
87
+ author = {Vecchio, Giuseppe and Sortino, Renato and Palazzo, Simone and Spampinato, Concetto},
88
+ title = {MatFuse: Controllable Material Generation with Diffusion Models},
89
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
90
+ month = {June},
91
+ year = {2024},
92
+ pages = {4429-4438}
93
+ }
94
+ ```
95
+
96
+ ## License
97
+
98
+ This project is licensed under the MIT License.
checkpoints/matfuse-full.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0f9ce7058f0a122dcbca8d4870d98df656c8114f20914e7d126406222b5df98
3
+ size 4641175367
checkpoints/matfuse-pruned.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27597775fd7f25a98523e7171bf31305bc3f274f7300fbd406654952e608a036
3
+ size 3060674264
condition_encoder/condition_encoders.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatFuse Condition Encoders for diffusers.
3
+
4
+ These encoders handle the multi-modal conditioning:
5
+ - Image embedding (CLIP image encoder)
6
+ - Text embedding (CLIP text encoder)
7
+ - Sketch encoder (CNN)
8
+ - Palette encoder (MLP)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from typing import Optional, Dict, Union, List
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+
18
+
19
+ class SketchEncoder(ModelMixin, ConfigMixin):
20
+ """
21
+ CNN encoder for binary sketch/edge maps.
22
+
23
+ Takes a single-channel binary image and encodes it to a spatial feature map
24
+ that will be concatenated with the latent for hybrid conditioning.
25
+ """
26
+
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ in_channels: int = 1,
31
+ out_channels: int = 4,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.net = nn.Sequential(
36
+ nn.Conv2d(in_channels, 32, 7, 1, 1),
37
+ nn.BatchNorm2d(32),
38
+ nn.GELU(),
39
+ nn.Conv2d(32, 64, 3, 2, 1),
40
+ nn.BatchNorm2d(64),
41
+ nn.GELU(),
42
+ nn.Conv2d(64, 128, 3, 2, 1),
43
+ nn.BatchNorm2d(128),
44
+ nn.GELU(),
45
+ nn.Conv2d(128, 256, 3, 2, 1),
46
+ nn.BatchNorm2d(256),
47
+ nn.GELU(),
48
+ nn.Conv2d(256, out_channels, 1, 1, 0),
49
+ nn.BatchNorm2d(out_channels),
50
+ nn.GELU(),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ """
55
+ Encode sketch input.
56
+
57
+ Args:
58
+ x: Input tensor of shape (B, 1, H, W) with values in [0, 1].
59
+
60
+ Returns:
61
+ Encoded features of shape (B, out_channels, H/8, W/8).
62
+ """
63
+ return self.net(x)
64
+
65
+
66
+ class PaletteEncoder(ModelMixin, ConfigMixin):
67
+ """
68
+ MLP encoder for color palettes.
69
+
70
+ Takes a color palette (N colors, RGB) and encodes it to a single embedding
71
+ for cross-attention conditioning.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ in_channels: int = 3,
78
+ hidden_channels: int = 64,
79
+ out_channels: int = 512,
80
+ n_colors: int = 5,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.net = nn.Sequential(
85
+ nn.Linear(in_channels, hidden_channels),
86
+ nn.GELU(),
87
+ nn.Flatten(),
88
+ nn.Linear(hidden_channels * n_colors, out_channels),
89
+ nn.GELU(),
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Encode color palette.
95
+
96
+ Args:
97
+ x: Input tensor of shape (B, n_colors, 3) with RGB values in [0, 1].
98
+
99
+ Returns:
100
+ Encoded embedding of shape (B, out_channels).
101
+ """
102
+ return self.net(x)
103
+
104
+
105
+ class CLIPImageEncoder(ModelMixin, ConfigMixin):
106
+ """
107
+ Wrapper for CLIP image encoder using the OpenAI CLIP library.
108
+
109
+ Generates image embeddings for cross-attention conditioning.
110
+ """
111
+
112
+ @register_to_config
113
+ def __init__(
114
+ self,
115
+ model_name: str = "ViT-B/16",
116
+ normalize: bool = True,
117
+ ):
118
+ super().__init__()
119
+
120
+ self.model_name = model_name
121
+ self.normalize = normalize
122
+ self.model = None # Lazy loading
123
+
124
+ # Register normalization buffers
125
+ self.register_buffer(
126
+ "mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
127
+ )
128
+ self.register_buffer(
129
+ "std", torch.tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
130
+ )
131
+
132
+ def _load_model(self):
133
+ """Lazy load the CLIP model."""
134
+ if self.model is None:
135
+ import clip
136
+
137
+ self.model, _ = clip.load(self.model_name, device="cpu", jit=False)
138
+ self.model = self.model.visual
139
+
140
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
141
+ """Preprocess images for CLIP."""
142
+ # Resize to 224x224
143
+ x = F.interpolate(
144
+ x, size=(224, 224), mode="bicubic", align_corners=True, antialias=True
145
+ )
146
+ # Normalize from [-1, 1] to [0, 1]
147
+ x = (x + 1.0) / 2.0
148
+ # Normalize according to CLIP - move mean/std to device if needed
149
+ mean = self.mean.to(x.device).view(1, 3, 1, 1)
150
+ std = self.std.to(x.device).view(1, 3, 1, 1)
151
+ x = (x - mean) / std
152
+ return x
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Encode image using CLIP.
157
+
158
+ Args:
159
+ x: Input tensor of shape (B, 3, H, W) with values in [-1, 1].
160
+
161
+ Returns:
162
+ Image embedding of shape (B, 1, 512).
163
+ """
164
+ self._load_model()
165
+
166
+ # Move model to same device as input
167
+ device = x.device
168
+ self.model = self.model.to(device)
169
+
170
+ x = self.preprocess(x)
171
+ z = self.model(x).float().unsqueeze(1) # (B, 1, 512)
172
+
173
+ if self.normalize:
174
+ z = z / torch.linalg.norm(z, dim=2, keepdim=True)
175
+
176
+ return z
177
+
178
+
179
+ class CLIPTextEncoder(ModelMixin, ConfigMixin):
180
+ """
181
+ Wrapper for CLIP sentence encoder using sentence-transformers.
182
+
183
+ Generates text embeddings for cross-attention conditioning.
184
+ """
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ model_name: str = "sentence-transformers/clip-ViT-B-16",
190
+ ):
191
+ super().__init__()
192
+
193
+ self.model_name = model_name
194
+ self.model = None # Lazy loading
195
+
196
+ def _load_model(self):
197
+ """Lazy load the sentence transformer model."""
198
+ if self.model is None:
199
+ from sentence_transformers import SentenceTransformer
200
+
201
+ self.model = SentenceTransformer(self.model_name)
202
+ self.model.eval()
203
+
204
+ def forward(self, text: Union[str, List[str]]) -> torch.Tensor:
205
+ """
206
+ Encode text using CLIP sentence transformer.
207
+
208
+ Args:
209
+ text: Input text or list of texts.
210
+
211
+ Returns:
212
+ Text embedding of shape (B, 512).
213
+ """
214
+ self._load_model()
215
+
216
+ if isinstance(text, str):
217
+ text = [text]
218
+
219
+ embeddings = self.model.encode(text, convert_to_tensor=True)
220
+ return embeddings
221
+
222
+
223
+ class MultiConditionEncoder(ModelMixin, ConfigMixin):
224
+ """
225
+ Multi-condition encoder that combines all conditioning modalities.
226
+
227
+ This encoder takes multiple condition inputs and produces:
228
+ - c_crossattn: Features for cross-attention (image, text, palette embeddings)
229
+ - c_concat: Features for concatenation (sketch encoding)
230
+ """
231
+
232
+ @register_to_config
233
+ def __init__(
234
+ self,
235
+ sketch_in_channels: int = 1,
236
+ sketch_out_channels: int = 4,
237
+ palette_in_channels: int = 3,
238
+ palette_hidden_channels: int = 64,
239
+ palette_out_channels: int = 512,
240
+ n_colors: int = 5,
241
+ clip_image_model: str = "ViT-B/16",
242
+ clip_text_model: str = "sentence-transformers/clip-ViT-B-16",
243
+ ):
244
+ super().__init__()
245
+
246
+ self.sketch_encoder = SketchEncoder(
247
+ in_channels=sketch_in_channels,
248
+ out_channels=sketch_out_channels,
249
+ )
250
+
251
+ self.palette_encoder = PaletteEncoder(
252
+ in_channels=palette_in_channels,
253
+ hidden_channels=palette_hidden_channels,
254
+ out_channels=palette_out_channels,
255
+ n_colors=n_colors,
256
+ )
257
+
258
+ # CLIP encoders are lazy-loaded
259
+ self.clip_image_encoder = None
260
+ self.clip_text_encoder = None
261
+ self._clip_image_model = clip_image_model
262
+ self._clip_text_model = clip_text_model
263
+
264
+ def _load_clip_encoders(self):
265
+ """Lazy load CLIP encoders."""
266
+ if self.clip_image_encoder is None:
267
+ self.clip_image_encoder = CLIPImageEncoder(
268
+ model_name=self._clip_image_model
269
+ )
270
+ if self.clip_text_encoder is None:
271
+ self.clip_text_encoder = CLIPTextEncoder(model_name=self._clip_text_model)
272
+
273
+ def encode_image(self, image: torch.Tensor) -> torch.Tensor:
274
+ """Encode image using CLIP."""
275
+ self._load_clip_encoders()
276
+ return self.clip_image_encoder(image)
277
+
278
+ def encode_text(self, text: Union[str, List[str]]) -> torch.Tensor:
279
+ """Encode text using CLIP."""
280
+ self._load_clip_encoders()
281
+ return self.clip_text_encoder(text)
282
+
283
+ def encode_sketch(self, sketch: torch.Tensor) -> torch.Tensor:
284
+ """Encode sketch/edge map."""
285
+ return self.sketch_encoder(sketch)
286
+
287
+ def encode_palette(self, palette: torch.Tensor) -> torch.Tensor:
288
+ """Encode color palette."""
289
+ return self.palette_encoder(palette)
290
+
291
+ def get_unconditional_conditioning(
292
+ self,
293
+ batch_size: int = 1,
294
+ image_size: int = 256,
295
+ device: Optional[torch.device] = None,
296
+ ) -> Dict[str, torch.Tensor]:
297
+ """
298
+ Get unconditional conditioning for classifier-free guidance.
299
+
300
+ IMPORTANT: The original model was trained to drop conditions by replacing them
301
+ with encoded placeholders (zero/gray image through CLIP, empty string through
302
+ sentence-transformers, zero palette through PaletteEncoder, zero sketch through
303
+ SketchEncoder) — NOT with zero tensors. This method produces the correct
304
+ unconditional embeddings.
305
+
306
+ Args:
307
+ batch_size: Batch size.
308
+ image_size: Image resolution (for sketch spatial dims).
309
+ device: Device to place tensors on.
310
+
311
+ Returns:
312
+ Dictionary with c_crossattn and c_concat for unconditional guidance.
313
+ """
314
+ return self.forward(
315
+ image_embed=None,
316
+ text=None,
317
+ sketch=None,
318
+ palette=None,
319
+ batch_size=batch_size,
320
+ image_size=image_size,
321
+ device=device,
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ image_embed: Optional[torch.Tensor] = None,
327
+ text: Optional[Union[str, List[str]]] = None,
328
+ sketch: Optional[torch.Tensor] = None,
329
+ palette: Optional[torch.Tensor] = None,
330
+ batch_size: int = 1,
331
+ image_size: int = 256,
332
+ device: Optional[torch.device] = None,
333
+ ) -> Dict[str, torch.Tensor]:
334
+ """
335
+ Encode all conditions.
336
+
337
+ When a condition is not provided, the model encodes a placeholder input
338
+ through the actual encoder (matching training behavior) rather than using
339
+ zero tensors. This is critical because the model was trained with:
340
+ - Image drop → CLIP encoding of a gray/zero image (0.0 in [-1,1])
341
+ - Text drop → sentence-transformer encoding of ""
342
+ - Palette drop → PaletteEncoder(zeros)
343
+ - Sketch drop → SketchEncoder(zeros)
344
+
345
+ Args:
346
+ image_embed: Reference image of shape (B, 3, H, W) in [-1, 1].
347
+ text: Text description(s).
348
+ sketch: Binary sketch of shape (B, 1, H, W) in [0, 1].
349
+ palette: Color palette of shape (B, n_colors, 3) in [0, 1].
350
+ batch_size: Batch size (used when no inputs are provided).
351
+ image_size: Image resolution (used to create placeholder sketch).
352
+ device: Device to place tensors on.
353
+
354
+ Returns:
355
+ Dictionary with:
356
+ - c_crossattn: Cross-attention context of shape (B, 3, 512) - always 3 tokens.
357
+ - c_concat: Concatenation features of shape (B, 4, H/8, W/8).
358
+ """
359
+ self._load_clip_encoders()
360
+
361
+ # Determine batch size and device from any available input
362
+ if image_embed is not None:
363
+ batch_size = image_embed.shape[0]
364
+ device = device or image_embed.device
365
+ image_size = image_embed.shape[-1]
366
+ elif sketch is not None:
367
+ batch_size = sketch.shape[0]
368
+ device = device or sketch.device
369
+ image_size = sketch.shape[-1]
370
+ elif palette is not None:
371
+ batch_size = palette.shape[0]
372
+ device = device or palette.device
373
+
374
+ device = device or torch.device("cpu")
375
+ # Infer dtype from model weights for placeholder tensors (e.g. float16)
376
+ dtype = next(self.palette_encoder.parameters()).dtype
377
+
378
+ # --- Image embedding (token 0) ---
379
+ # When not provided, encode a zero (gray) image through CLIP, matching training ucg_training val=0.0
380
+ if image_embed is not None:
381
+ img_emb = self.clip_image_encoder(image_embed) # (B, 1, 512)
382
+ else:
383
+ placeholder_img = torch.zeros(
384
+ batch_size, 3, image_size, image_size, device=device, dtype=dtype
385
+ )
386
+ img_emb = self.clip_image_encoder(placeholder_img) # (B, 1, 512)
387
+
388
+ # --- Text embedding (token 1) ---
389
+ # When not provided, encode empty string through sentence-transformers, matching training ucg_training val=""
390
+ if text is not None:
391
+ text_emb = self.clip_text_encoder(text) # (B, 512)
392
+ if device is not None:
393
+ text_emb = text_emb.to(device)
394
+ text_emb = text_emb.unsqueeze(1) # (B, 1, 512)
395
+ else:
396
+ text_emb = self.clip_text_encoder([""] * batch_size) # (B, 512)
397
+ text_emb = text_emb.to(device).unsqueeze(1) # (B, 1, 512)
398
+
399
+ # --- Palette embedding (token 2) ---
400
+ # When not provided, encode zero palette through PaletteEncoder, matching training ucg_training val=0.0
401
+ if palette is not None:
402
+ palette_emb = self.palette_encoder(palette) # (B, 512)
403
+ palette_emb = palette_emb.unsqueeze(1) # (B, 1, 512)
404
+ else:
405
+ n_colors = self.config.get("n_colors", 5)
406
+ placeholder_palette = torch.zeros(batch_size, n_colors, 3, device=device, dtype=dtype)
407
+ palette_emb = self.palette_encoder(placeholder_palette) # (B, 512)
408
+ palette_emb = palette_emb.unsqueeze(1) # (B, 1, 512)
409
+
410
+ # Combine cross-attention embeddings - always (B, 3, 512)
411
+ c_crossattn = torch.cat([img_emb, text_emb, palette_emb], dim=1)
412
+
413
+ # --- Sketch encoding for concatenation ---
414
+ # When not provided, encode zero sketch through SketchEncoder, matching training ucg_training val=0.0
415
+ if sketch is not None:
416
+ c_concat = self.sketch_encoder(sketch) # (B, 4, H/8, W/8)
417
+ else:
418
+ placeholder_sketch = torch.zeros(
419
+ batch_size, 1, image_size, image_size, device=device, dtype=dtype
420
+ )
421
+ c_concat = self.sketch_encoder(placeholder_sketch) # (B, 4, H/8, W/8)
422
+
423
+ return {
424
+ "c_crossattn": c_crossattn,
425
+ "c_concat": c_concat,
426
+ }
condition_encoder/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiConditionEncoder",
3
+ "_diffusers_version": "0.35.2",
4
+ "clip_image_model": "ViT-B/16",
5
+ "clip_text_model": "sentence-transformers/clip-ViT-B-16",
6
+ "n_colors": 5,
7
+ "palette_hidden_channels": 64,
8
+ "palette_in_channels": 3,
9
+ "palette_out_channels": 512,
10
+ "sketch_in_channels": 1,
11
+ "sketch_out_channels": 4
12
+ }
condition_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bf932008656551c46a064549570965c19c48539c9e6d5a54c2a40a6414518cb
3
+ size 2230464
condition_encoders.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatFuse Condition Encoders for diffusers.
3
+
4
+ These encoders handle the multi-modal conditioning:
5
+ - Image embedding (CLIP image encoder)
6
+ - Text embedding (CLIP text encoder)
7
+ - Sketch encoder (CNN)
8
+ - Palette encoder (MLP)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from typing import Optional, Dict, Union, List
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+
18
+
19
+ class SketchEncoder(ModelMixin, ConfigMixin):
20
+ """
21
+ CNN encoder for binary sketch/edge maps.
22
+
23
+ Takes a single-channel binary image and encodes it to a spatial feature map
24
+ that will be concatenated with the latent for hybrid conditioning.
25
+ """
26
+
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ in_channels: int = 1,
31
+ out_channels: int = 4,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.net = nn.Sequential(
36
+ nn.Conv2d(in_channels, 32, 7, 1, 1),
37
+ nn.BatchNorm2d(32),
38
+ nn.GELU(),
39
+ nn.Conv2d(32, 64, 3, 2, 1),
40
+ nn.BatchNorm2d(64),
41
+ nn.GELU(),
42
+ nn.Conv2d(64, 128, 3, 2, 1),
43
+ nn.BatchNorm2d(128),
44
+ nn.GELU(),
45
+ nn.Conv2d(128, 256, 3, 2, 1),
46
+ nn.BatchNorm2d(256),
47
+ nn.GELU(),
48
+ nn.Conv2d(256, out_channels, 1, 1, 0),
49
+ nn.BatchNorm2d(out_channels),
50
+ nn.GELU(),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ """
55
+ Encode sketch input.
56
+
57
+ Args:
58
+ x: Input tensor of shape (B, 1, H, W) with values in [0, 1].
59
+
60
+ Returns:
61
+ Encoded features of shape (B, out_channels, H/8, W/8).
62
+ """
63
+ return self.net(x)
64
+
65
+
66
+ class PaletteEncoder(ModelMixin, ConfigMixin):
67
+ """
68
+ MLP encoder for color palettes.
69
+
70
+ Takes a color palette (N colors, RGB) and encodes it to a single embedding
71
+ for cross-attention conditioning.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ in_channels: int = 3,
78
+ hidden_channels: int = 64,
79
+ out_channels: int = 512,
80
+ n_colors: int = 5,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.net = nn.Sequential(
85
+ nn.Linear(in_channels, hidden_channels),
86
+ nn.GELU(),
87
+ nn.Flatten(),
88
+ nn.Linear(hidden_channels * n_colors, out_channels),
89
+ nn.GELU(),
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Encode color palette.
95
+
96
+ Args:
97
+ x: Input tensor of shape (B, n_colors, 3) with RGB values in [0, 1].
98
+
99
+ Returns:
100
+ Encoded embedding of shape (B, out_channels).
101
+ """
102
+ return self.net(x)
103
+
104
+
105
+ class CLIPImageEncoder(ModelMixin, ConfigMixin):
106
+ """
107
+ Wrapper for CLIP image encoder using the OpenAI CLIP library.
108
+
109
+ Generates image embeddings for cross-attention conditioning.
110
+ """
111
+
112
+ @register_to_config
113
+ def __init__(
114
+ self,
115
+ model_name: str = "ViT-B/16",
116
+ normalize: bool = True,
117
+ ):
118
+ super().__init__()
119
+
120
+ self.model_name = model_name
121
+ self.normalize = normalize
122
+ self.model = None # Lazy loading
123
+
124
+ # Register normalization buffers
125
+ self.register_buffer(
126
+ "mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
127
+ )
128
+ self.register_buffer(
129
+ "std", torch.tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
130
+ )
131
+
132
+ def _load_model(self):
133
+ """Lazy load the CLIP model."""
134
+ if self.model is None:
135
+ import clip
136
+
137
+ self.model, _ = clip.load(self.model_name, device="cpu", jit=False)
138
+ self.model = self.model.visual
139
+
140
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
141
+ """Preprocess images for CLIP."""
142
+ # Resize to 224x224
143
+ x = F.interpolate(
144
+ x, size=(224, 224), mode="bicubic", align_corners=True, antialias=True
145
+ )
146
+ # Normalize from [-1, 1] to [0, 1]
147
+ x = (x + 1.0) / 2.0
148
+ # Normalize according to CLIP - move mean/std to device if needed
149
+ mean = self.mean.to(x.device).view(1, 3, 1, 1)
150
+ std = self.std.to(x.device).view(1, 3, 1, 1)
151
+ x = (x - mean) / std
152
+ return x
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Encode image using CLIP.
157
+
158
+ Args:
159
+ x: Input tensor of shape (B, 3, H, W) with values in [-1, 1].
160
+
161
+ Returns:
162
+ Image embedding of shape (B, 1, 512).
163
+ """
164
+ self._load_model()
165
+
166
+ # Move model to same device as input
167
+ device = x.device
168
+ self.model = self.model.to(device)
169
+
170
+ x = self.preprocess(x)
171
+ z = self.model(x).float().unsqueeze(1) # (B, 1, 512)
172
+
173
+ if self.normalize:
174
+ z = z / torch.linalg.norm(z, dim=2, keepdim=True)
175
+
176
+ return z
177
+
178
+
179
+ class CLIPTextEncoder(ModelMixin, ConfigMixin):
180
+ """
181
+ Wrapper for CLIP sentence encoder using sentence-transformers.
182
+
183
+ Generates text embeddings for cross-attention conditioning.
184
+ """
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ model_name: str = "sentence-transformers/clip-ViT-B-16",
190
+ ):
191
+ super().__init__()
192
+
193
+ self.model_name = model_name
194
+ self.model = None # Lazy loading
195
+
196
+ def _load_model(self):
197
+ """Lazy load the sentence transformer model."""
198
+ if self.model is None:
199
+ from sentence_transformers import SentenceTransformer
200
+
201
+ self.model = SentenceTransformer(self.model_name)
202
+ self.model.eval()
203
+
204
+ def forward(self, text: Union[str, List[str]]) -> torch.Tensor:
205
+ """
206
+ Encode text using CLIP sentence transformer.
207
+
208
+ Args:
209
+ text: Input text or list of texts.
210
+
211
+ Returns:
212
+ Text embedding of shape (B, 512).
213
+ """
214
+ self._load_model()
215
+
216
+ if isinstance(text, str):
217
+ text = [text]
218
+
219
+ embeddings = self.model.encode(text, convert_to_tensor=True)
220
+ return embeddings
221
+
222
+
223
+ class MultiConditionEncoder(ModelMixin, ConfigMixin):
224
+ """
225
+ Multi-condition encoder that combines all conditioning modalities.
226
+
227
+ This encoder takes multiple condition inputs and produces:
228
+ - c_crossattn: Features for cross-attention (image, text, palette embeddings)
229
+ - c_concat: Features for concatenation (sketch encoding)
230
+ """
231
+
232
+ @register_to_config
233
+ def __init__(
234
+ self,
235
+ sketch_in_channels: int = 1,
236
+ sketch_out_channels: int = 4,
237
+ palette_in_channels: int = 3,
238
+ palette_hidden_channels: int = 64,
239
+ palette_out_channels: int = 512,
240
+ n_colors: int = 5,
241
+ clip_image_model: str = "ViT-B/16",
242
+ clip_text_model: str = "sentence-transformers/clip-ViT-B-16",
243
+ ):
244
+ super().__init__()
245
+
246
+ self.sketch_encoder = SketchEncoder(
247
+ in_channels=sketch_in_channels,
248
+ out_channels=sketch_out_channels,
249
+ )
250
+
251
+ self.palette_encoder = PaletteEncoder(
252
+ in_channels=palette_in_channels,
253
+ hidden_channels=palette_hidden_channels,
254
+ out_channels=palette_out_channels,
255
+ n_colors=n_colors,
256
+ )
257
+
258
+ # CLIP encoders are lazy-loaded
259
+ self.clip_image_encoder = None
260
+ self.clip_text_encoder = None
261
+ self._clip_image_model = clip_image_model
262
+ self._clip_text_model = clip_text_model
263
+
264
+ def _load_clip_encoders(self):
265
+ """Lazy load CLIP encoders."""
266
+ if self.clip_image_encoder is None:
267
+ self.clip_image_encoder = CLIPImageEncoder(
268
+ model_name=self._clip_image_model
269
+ )
270
+ if self.clip_text_encoder is None:
271
+ self.clip_text_encoder = CLIPTextEncoder(model_name=self._clip_text_model)
272
+
273
+ def encode_image(self, image: torch.Tensor) -> torch.Tensor:
274
+ """Encode image using CLIP."""
275
+ self._load_clip_encoders()
276
+ return self.clip_image_encoder(image)
277
+
278
+ def encode_text(self, text: Union[str, List[str]]) -> torch.Tensor:
279
+ """Encode text using CLIP."""
280
+ self._load_clip_encoders()
281
+ return self.clip_text_encoder(text)
282
+
283
+ def encode_sketch(self, sketch: torch.Tensor) -> torch.Tensor:
284
+ """Encode sketch/edge map."""
285
+ return self.sketch_encoder(sketch)
286
+
287
+ def encode_palette(self, palette: torch.Tensor) -> torch.Tensor:
288
+ """Encode color palette."""
289
+ return self.palette_encoder(palette)
290
+
291
+ def get_unconditional_conditioning(
292
+ self,
293
+ batch_size: int = 1,
294
+ image_size: int = 256,
295
+ device: Optional[torch.device] = None,
296
+ ) -> Dict[str, torch.Tensor]:
297
+ """
298
+ Get unconditional conditioning for classifier-free guidance.
299
+
300
+ IMPORTANT: The original model was trained to drop conditions by replacing them
301
+ with encoded placeholders (zero/gray image through CLIP, empty string through
302
+ sentence-transformers, zero palette through PaletteEncoder, zero sketch through
303
+ SketchEncoder) — NOT with zero tensors. This method produces the correct
304
+ unconditional embeddings.
305
+
306
+ Args:
307
+ batch_size: Batch size.
308
+ image_size: Image resolution (for sketch spatial dims).
309
+ device: Device to place tensors on.
310
+
311
+ Returns:
312
+ Dictionary with c_crossattn and c_concat for unconditional guidance.
313
+ """
314
+ return self.forward(
315
+ image_embed=None,
316
+ text=None,
317
+ sketch=None,
318
+ palette=None,
319
+ batch_size=batch_size,
320
+ image_size=image_size,
321
+ device=device,
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ image_embed: Optional[torch.Tensor] = None,
327
+ text: Optional[Union[str, List[str]]] = None,
328
+ sketch: Optional[torch.Tensor] = None,
329
+ palette: Optional[torch.Tensor] = None,
330
+ batch_size: int = 1,
331
+ image_size: int = 256,
332
+ device: Optional[torch.device] = None,
333
+ ) -> Dict[str, torch.Tensor]:
334
+ """
335
+ Encode all conditions.
336
+
337
+ When a condition is not provided, the model encodes a placeholder input
338
+ through the actual encoder (matching training behavior) rather than using
339
+ zero tensors. This is critical because the model was trained with:
340
+ - Image drop → CLIP encoding of a gray/zero image (0.0 in [-1,1])
341
+ - Text drop → sentence-transformer encoding of ""
342
+ - Palette drop → PaletteEncoder(zeros)
343
+ - Sketch drop → SketchEncoder(zeros)
344
+
345
+ Args:
346
+ image_embed: Reference image of shape (B, 3, H, W) in [-1, 1].
347
+ text: Text description(s).
348
+ sketch: Binary sketch of shape (B, 1, H, W) in [0, 1].
349
+ palette: Color palette of shape (B, n_colors, 3) in [0, 1].
350
+ batch_size: Batch size (used when no inputs are provided).
351
+ image_size: Image resolution (used to create placeholder sketch).
352
+ device: Device to place tensors on.
353
+
354
+ Returns:
355
+ Dictionary with:
356
+ - c_crossattn: Cross-attention context of shape (B, 3, 512) - always 3 tokens.
357
+ - c_concat: Concatenation features of shape (B, 4, H/8, W/8).
358
+ """
359
+ self._load_clip_encoders()
360
+
361
+ # Determine batch size and device from any available input
362
+ if image_embed is not None:
363
+ batch_size = image_embed.shape[0]
364
+ device = device or image_embed.device
365
+ image_size = image_embed.shape[-1]
366
+ elif sketch is not None:
367
+ batch_size = sketch.shape[0]
368
+ device = device or sketch.device
369
+ image_size = sketch.shape[-1]
370
+ elif palette is not None:
371
+ batch_size = palette.shape[0]
372
+ device = device or palette.device
373
+
374
+ device = device or torch.device("cpu")
375
+ # Infer dtype from model weights for placeholder tensors (e.g. float16)
376
+ dtype = next(self.palette_encoder.parameters()).dtype
377
+
378
+ # --- Image embedding (token 0) ---
379
+ # When not provided, encode a zero (gray) image through CLIP, matching training ucg_training val=0.0
380
+ if image_embed is not None:
381
+ img_emb = self.clip_image_encoder(image_embed) # (B, 1, 512)
382
+ else:
383
+ placeholder_img = torch.zeros(
384
+ batch_size, 3, image_size, image_size, device=device, dtype=dtype
385
+ )
386
+ img_emb = self.clip_image_encoder(placeholder_img) # (B, 1, 512)
387
+
388
+ # --- Text embedding (token 1) ---
389
+ # When not provided, encode empty string through sentence-transformers, matching training ucg_training val=""
390
+ if text is not None:
391
+ text_emb = self.clip_text_encoder(text) # (B, 512)
392
+ if device is not None:
393
+ text_emb = text_emb.to(device)
394
+ text_emb = text_emb.unsqueeze(1) # (B, 1, 512)
395
+ else:
396
+ text_emb = self.clip_text_encoder([""] * batch_size) # (B, 512)
397
+ text_emb = text_emb.to(device).unsqueeze(1) # (B, 1, 512)
398
+
399
+ # --- Palette embedding (token 2) ---
400
+ # When not provided, encode zero palette through PaletteEncoder, matching training ucg_training val=0.0
401
+ if palette is not None:
402
+ palette_emb = self.palette_encoder(palette) # (B, 512)
403
+ palette_emb = palette_emb.unsqueeze(1) # (B, 1, 512)
404
+ else:
405
+ n_colors = self.config.get("n_colors", 5)
406
+ placeholder_palette = torch.zeros(batch_size, n_colors, 3, device=device, dtype=dtype)
407
+ palette_emb = self.palette_encoder(placeholder_palette) # (B, 512)
408
+ palette_emb = palette_emb.unsqueeze(1) # (B, 1, 512)
409
+
410
+ # Combine cross-attention embeddings - always (B, 3, 512)
411
+ c_crossattn = torch.cat([img_emb, text_emb, palette_emb], dim=1)
412
+
413
+ # --- Sketch encoding for concatenation ---
414
+ # When not provided, encode zero sketch through SketchEncoder, matching training ucg_training val=0.0
415
+ if sketch is not None:
416
+ c_concat = self.sketch_encoder(sketch) # (B, 4, H/8, W/8)
417
+ else:
418
+ placeholder_sketch = torch.zeros(
419
+ batch_size, 1, image_size, image_size, device=device, dtype=dtype
420
+ )
421
+ c_concat = self.sketch_encoder(placeholder_sketch) # (B, 4, H/8, W/8)
422
+
423
+ return {
424
+ "c_crossattn": c_crossattn,
425
+ "c_concat": c_concat,
426
+ }
model_index.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline_matfuse",
4
+ "MatFusePipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "unet": [
8
+ "diffusers",
9
+ "UNet2DConditionModel"
10
+ ],
11
+ "vae": [
12
+ "vae_matfuse",
13
+ "MatFuseVQModel"
14
+ ],
15
+ "condition_encoder": [
16
+ "condition_encoders",
17
+ "MultiConditionEncoder"
18
+ ],
19
+ "scheduler": [
20
+ "diffusers",
21
+ "DDIMScheduler"
22
+ ]
23
+ }
pipeline_matfuse.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatFuse Pipeline for diffusers.
3
+
4
+ A custom diffusers pipeline for generating PBR material maps using the MatFuse model.
5
+
6
+ Note: This pipeline uses:
7
+ - Standard UNet2DConditionModel from diffusers (with custom in/out channels config)
8
+ - Custom MatFuseVQModel (required because MatFuse uses 4 separate encoders/quantizers)
9
+ """
10
+
11
+ import os
12
+ import inspect
13
+ from typing import Optional, Union, List, Callable, Dict, Any, Tuple
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ import numpy as np
19
+
20
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
21
+ from diffusers.models import UNet2DConditionModel
22
+ from diffusers.schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler
23
+
24
+ try:
25
+ from vae_matfuse import MatFuseVQModel
26
+ except ImportError:
27
+ from diffusers.models.modeling_utils import ModelMixin as MatFuseVQModel
28
+ try:
29
+ from condition_encoders import MultiConditionEncoder
30
+ except ImportError:
31
+ from diffusers.models.modeling_utils import ModelMixin as MultiConditionEncoder
32
+
33
+
34
+ class MatFusePipeline(DiffusionPipeline):
35
+ """
36
+ Pipeline for generating PBR material maps using MatFuse.
37
+
38
+ This pipeline generates 4 material maps (diffuse, normal, roughness, specular)
39
+ from various conditioning inputs like reference images, text, sketches, and color palettes.
40
+
41
+ Args:
42
+ vae: MatFuseVQModel for encoding/decoding material maps (custom, required).
43
+ unet: UNet2DConditionModel for denoising (standard diffusers model).
44
+ scheduler: Diffusion scheduler.
45
+ condition_encoder: Multi-condition encoder for processing inputs.
46
+
47
+ Note:
48
+ The VQ-VAE must be the custom MatFuseVQModel because MatFuse uses 4 separate
49
+ encoders and quantizers (one per material map type). The UNet can be the
50
+ standard diffusers UNet2DConditionModel configured with:
51
+ - in_channels=16 (12 latent + 4 sketch concat)
52
+ - out_channels=12 (4 maps × 3 channels)
53
+ - cross_attention_dim=512
54
+ """
55
+
56
+ model_cpu_offload_seq = "condition_encoder->unet->vae"
57
+ _optional_components = ["condition_encoder"]
58
+
59
+ def __init__(
60
+ self,
61
+ vae: MatFuseVQModel,
62
+ unet: UNet2DConditionModel,
63
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler],
64
+ condition_encoder: Optional[MultiConditionEncoder] = None,
65
+ ):
66
+ super().__init__()
67
+
68
+ self.register_modules(
69
+ vae=vae,
70
+ unet=unet,
71
+ scheduler=scheduler,
72
+ condition_encoder=condition_encoder,
73
+ )
74
+
75
+ self.vae_scale_factor = 8 # Downsampling factor of VQ-VAE
76
+
77
+ @classmethod
78
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
79
+ """
80
+ Load the MatFuse pipeline from a local directory.
81
+
82
+ Loads each component (UNet, VAE, scheduler, condition_encoder) individually
83
+ from their respective subdirectories.
84
+
85
+ Args:
86
+ pretrained_model_name_or_path: Path to the directory containing the model components.
87
+ **kwargs: Additional keyword arguments (e.g., torch_dtype).
88
+ """
89
+ model_dir = pretrained_model_name_or_path
90
+ torch_dtype = kwargs.get("torch_dtype", None)
91
+
92
+ # Load UNet (standard diffusers)
93
+ unet = UNet2DConditionModel.from_pretrained(
94
+ os.path.join(model_dir, "unet"),
95
+ torch_dtype=torch_dtype,
96
+ )
97
+
98
+ # Load VAE (custom)
99
+ vae = MatFuseVQModel.from_pretrained(
100
+ os.path.join(model_dir, "vae"),
101
+ torch_dtype=torch_dtype,
102
+ )
103
+
104
+ # Load scheduler
105
+ scheduler = DDIMScheduler.from_pretrained(
106
+ os.path.join(model_dir, "scheduler"),
107
+ )
108
+
109
+ # Load condition encoder (custom) if it exists
110
+ cond_dir = os.path.join(model_dir, "condition_encoder")
111
+ condition_encoder = None
112
+ if os.path.isdir(cond_dir):
113
+ condition_encoder = MultiConditionEncoder.from_pretrained(
114
+ cond_dir,
115
+ torch_dtype=torch_dtype,
116
+ )
117
+
118
+ return cls(
119
+ vae=vae,
120
+ unet=unet,
121
+ scheduler=scheduler,
122
+ condition_encoder=condition_encoder,
123
+ )
124
+
125
+ @property
126
+ def _execution_device(self):
127
+ if self.device != torch.device("meta"):
128
+ return self.device
129
+ for name, model in self.components.items():
130
+ if isinstance(model, torch.nn.Module):
131
+ return next(model.parameters()).device
132
+ # Also check condition_encoder (may not be in components dict)
133
+ if self.condition_encoder is not None:
134
+ return next(self.condition_encoder.parameters()).device
135
+ return torch.device("cpu")
136
+
137
+ def to(self, *args, **kwargs):
138
+ """Override to() to also move condition_encoder (not auto-tracked by diffusers)."""
139
+ result = super().to(*args, **kwargs)
140
+ if self.condition_encoder is not None:
141
+ self.condition_encoder = self.condition_encoder.to(*args, **kwargs)
142
+ return result
143
+
144
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
145
+ """Decode latents to material maps."""
146
+ # Add circular padding for seamless textures
147
+ latents = F.pad(latents, (7, 7, 7, 7), mode="circular")
148
+
149
+ # Upcast to float32 for VAE decoding to avoid NaN from float16 precision
150
+ needs_upcast = latents.dtype == torch.float16
151
+ if needs_upcast:
152
+ self.vae.to(dtype=torch.float32)
153
+ latents = latents.float()
154
+
155
+ # Decode
156
+ materials = self.vae.decode(latents)
157
+
158
+ if needs_upcast:
159
+ self.vae.to(dtype=torch.float16)
160
+ materials = materials.half()
161
+
162
+ # Center crop to remove padding
163
+ _, _, h, w = materials.shape
164
+ target_h = (h - 14 * self.vae_scale_factor)
165
+ target_w = (w - 14 * self.vae_scale_factor)
166
+ start_h = (h - target_h) // 2
167
+ start_w = (w - target_w) // 2
168
+ materials = materials[:, :, start_h:start_h + target_h, start_w:start_w + target_w]
169
+
170
+ return materials
171
+
172
+ def prepare_latents(
173
+ self,
174
+ batch_size: int,
175
+ num_channels_latents: int,
176
+ height: int,
177
+ width: int,
178
+ dtype: torch.dtype,
179
+ device: torch.device,
180
+ generator: Optional[torch.Generator] = None,
181
+ latents: Optional[torch.Tensor] = None,
182
+ ) -> torch.Tensor:
183
+ """Prepare initial noise latents."""
184
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
185
+
186
+ if latents is None:
187
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
188
+ else:
189
+ if latents.shape != shape:
190
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
191
+ latents = latents.to(device)
192
+
193
+ # Scale by scheduler
194
+ latents = latents * self.scheduler.init_noise_sigma
195
+
196
+ return latents
197
+
198
+ def prepare_extra_step_kwargs(self, generator: Optional[torch.Generator], eta: float) -> Dict[str, Any]:
199
+ """Prepare extra kwargs for the scheduler step."""
200
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
201
+ extra_step_kwargs = {}
202
+ if accepts_eta:
203
+ extra_step_kwargs["eta"] = eta
204
+
205
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
206
+ if accepts_generator:
207
+ extra_step_kwargs["generator"] = generator
208
+
209
+ return extra_step_kwargs
210
+
211
+ def _encode_conditions(
212
+ self,
213
+ image: Optional[torch.Tensor] = None,
214
+ text: Optional[Union[str, List[str]]] = None,
215
+ sketch: Optional[torch.Tensor] = None,
216
+ palette: Optional[torch.Tensor] = None,
217
+ batch_size: int = 1,
218
+ image_size: int = 256,
219
+ device: torch.device = None,
220
+ dtype: torch.dtype = None,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ """Encode all condition inputs through their respective encoders.
223
+
224
+ When a condition is not provided, the encoder creates a placeholder
225
+ and encodes it (matching training behavior), rather than using zero tensors.
226
+ """
227
+ device = device or self._execution_device
228
+
229
+ if self.condition_encoder is not None:
230
+ cond = self.condition_encoder(
231
+ image_embed=image,
232
+ text=text,
233
+ sketch=sketch,
234
+ palette=palette,
235
+ batch_size=batch_size,
236
+ image_size=image_size,
237
+ device=device,
238
+ )
239
+ c_crossattn = cond["c_crossattn"]
240
+ c_concat = cond["c_concat"]
241
+ else:
242
+ c_crossattn = None
243
+ c_concat = None
244
+
245
+ # Ensure proper dtype
246
+ if c_crossattn is not None:
247
+ c_crossattn = c_crossattn.to(dtype=dtype, device=device)
248
+ if c_concat is not None:
249
+ c_concat = c_concat.to(dtype=dtype, device=device)
250
+
251
+ return c_crossattn, c_concat
252
+
253
+ def _get_uncond_embeddings(
254
+ self,
255
+ batch_size: int,
256
+ image_size: int,
257
+ device: torch.device,
258
+ dtype: torch.dtype,
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ """Get unconditional embeddings for classifier-free guidance.
261
+
262
+ Creates proper unconditional embeddings by encoding placeholder inputs
263
+ through the actual encoders (gray image → CLIP, empty string → SentenceTransformer,
264
+ zero palette → PaletteEncoder, zero sketch → SketchEncoder).
265
+
266
+ This matches the original training behavior where ucg_training drops conditions
267
+ by setting them to val=0.0 (images/palette/sketch) or val="" (text), and then
268
+ encoding those placeholder values through the encoders.
269
+ """
270
+ if self.condition_encoder is not None:
271
+ uc = self.condition_encoder.get_unconditional_conditioning(
272
+ batch_size=batch_size,
273
+ image_size=image_size,
274
+ device=device,
275
+ )
276
+ uc_crossattn = uc["c_crossattn"].to(dtype=dtype, device=device)
277
+ uc_concat = uc["c_concat"].to(dtype=dtype, device=device)
278
+ else:
279
+ uc_crossattn = None
280
+ uc_concat = None
281
+
282
+ return uc_crossattn, uc_concat
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ image: Optional[Union[torch.Tensor, Image.Image]] = None,
288
+ text: Optional[Union[str, List[str]]] = None,
289
+ sketch: Optional[Union[torch.Tensor, Image.Image]] = None,
290
+ palette: Optional[Union[torch.Tensor, np.ndarray, List[Tuple[int, int, int]]]] = None,
291
+ height: int = 256,
292
+ width: int = 256,
293
+ num_inference_steps: int = 50,
294
+ guidance_scale: float = 7.5,
295
+ num_images_per_prompt: int = 1,
296
+ eta: float = 0.0,
297
+ generator: Optional[torch.Generator] = None,
298
+ latents: Optional[torch.Tensor] = None,
299
+ output_type: str = "pil",
300
+ return_dict: bool = True,
301
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
302
+ callback_steps: int = 1,
303
+ ) -> Dict[str, Any]:
304
+ """
305
+ Generate PBR material maps.
306
+
307
+ Args:
308
+ image: Reference image for style/appearance guidance.
309
+ text: Text description of the material.
310
+ sketch: Binary edge/sketch map for structure guidance.
311
+ palette: Color palette (5 colors) for color guidance.
312
+ height: Output image height.
313
+ width: Output image width.
314
+ num_inference_steps: Number of denoising steps.
315
+ guidance_scale: Classifier-free guidance scale.
316
+ num_images_per_prompt: Number of images to generate per prompt.
317
+ eta: DDIM eta parameter.
318
+ generator: Random number generator for reproducibility.
319
+ latents: Pre-generated noise latents.
320
+ output_type: Output format ("pil", "tensor", "np").
321
+ return_dict: Whether to return a dict.
322
+ callback: Callback function called every `callback_steps` steps.
323
+ callback_steps: Frequency of callback calls.
324
+
325
+ Returns:
326
+ Dictionary containing:
327
+ - images: List of generated images (4 maps per generation).
328
+ - diffuse: Diffuse/albedo maps.
329
+ - normal: Normal maps.
330
+ - roughness: Roughness maps.
331
+ - specular: Specular maps.
332
+ """
333
+ device = self._execution_device
334
+ dtype = self.unet.dtype if hasattr(self.unet, 'dtype') else torch.float32
335
+
336
+ # Determine batch size
337
+ if text is not None and isinstance(text, str):
338
+ batch_size = 1
339
+ elif text is not None:
340
+ batch_size = len(text)
341
+ else:
342
+ batch_size = 1
343
+
344
+ batch_size = batch_size * num_images_per_prompt
345
+
346
+ # Preprocess inputs
347
+ if image is not None and isinstance(image, Image.Image):
348
+ image = self._preprocess_image(image, device, dtype)
349
+
350
+ if sketch is not None and isinstance(sketch, Image.Image):
351
+ sketch = self._preprocess_sketch(sketch, height, width, device, dtype)
352
+
353
+ if palette is not None and not isinstance(palette, torch.Tensor):
354
+ palette = self._preprocess_palette(palette, device, dtype)
355
+
356
+ # Encode conditions
357
+ # The encoder handles None conditions by encoding placeholder inputs
358
+ # (matching the original model's UCG training behavior)
359
+ c_crossattn, c_concat = self._encode_conditions(
360
+ image=image,
361
+ text=text,
362
+ sketch=sketch,
363
+ palette=palette,
364
+ batch_size=batch_size,
365
+ image_size=height,
366
+ device=device,
367
+ dtype=dtype,
368
+ )
369
+
370
+ # Get unconditional embeddings for CFG
371
+ # These are encoded placeholders, NOT zero tensors
372
+ do_classifier_free_guidance = guidance_scale > 1.0
373
+ if do_classifier_free_guidance:
374
+ uc_crossattn, uc_concat = self._get_uncond_embeddings(
375
+ batch_size, height, device, dtype
376
+ )
377
+
378
+ # Prepare timesteps
379
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
380
+ timesteps = self.scheduler.timesteps
381
+
382
+ # Prepare latent variables
383
+ num_channels_latents = 12 # 4 maps * 3 channels per quantizer
384
+ latents = self.prepare_latents(
385
+ batch_size,
386
+ num_channels_latents,
387
+ height,
388
+ width,
389
+ dtype,
390
+ device,
391
+ generator,
392
+ latents,
393
+ )
394
+
395
+ # Prepare extra step kwargs
396
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
397
+
398
+ # Denoising loop
399
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
400
+
401
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
402
+ for i, t in enumerate(timesteps):
403
+ # Prepare latent input with sketch conditioning
404
+ if do_classifier_free_guidance:
405
+ # For CFG: unconditional uses uc_concat, conditional uses c_concat
406
+ latent_uncond = torch.cat([latents, uc_concat], dim=1)
407
+ latent_cond = torch.cat([latents, c_concat], dim=1)
408
+ latent_model_input = torch.cat([latent_uncond, latent_cond])
409
+ if c_crossattn is not None:
410
+ encoder_hidden_states = torch.cat([uc_crossattn, c_crossattn])
411
+ else:
412
+ encoder_hidden_states = None
413
+ else:
414
+ latent_model_input = torch.cat([latents, c_concat], dim=1)
415
+ encoder_hidden_states = c_crossattn
416
+
417
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
418
+
419
+ # Predict noise
420
+ noise_pred = self.unet(
421
+ latent_model_input,
422
+ t,
423
+ encoder_hidden_states=encoder_hidden_states,
424
+ return_dict=False,
425
+ )
426
+ # return_dict=False returns tuple, first element is sample
427
+ if isinstance(noise_pred, tuple):
428
+ noise_pred = noise_pred[0]
429
+ elif isinstance(noise_pred, dict):
430
+ noise_pred = noise_pred["sample"]
431
+
432
+ # Classifier-free guidance
433
+ if do_classifier_free_guidance:
434
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
435
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
436
+
437
+ # Compute previous noisy sample
438
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
439
+
440
+ # Callback
441
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
442
+ progress_bar.update()
443
+ if callback is not None and i % callback_steps == 0:
444
+ callback(i, t, latents)
445
+
446
+ # Decode latents
447
+ materials = self.decode_latents(latents)
448
+
449
+ # Split into individual maps
450
+ diffuse = materials[:, 0:3]
451
+ normal = materials[:, 3:6]
452
+ roughness = materials[:, 6:9]
453
+ specular = materials[:, 9:12]
454
+
455
+ # Post-process outputs
456
+ if output_type == "pil":
457
+ diffuse = self._tensor_to_pil(diffuse)
458
+ normal = self._tensor_to_pil(normal)
459
+ roughness = self._tensor_to_pil(roughness)
460
+ specular = self._tensor_to_pil(specular)
461
+ elif output_type == "np":
462
+ diffuse = self._tensor_to_numpy(diffuse)
463
+ normal = self._tensor_to_numpy(normal)
464
+ roughness = self._tensor_to_numpy(roughness)
465
+ specular = self._tensor_to_numpy(specular)
466
+
467
+ if return_dict:
468
+ return {
469
+ "diffuse": diffuse,
470
+ "normal": normal,
471
+ "roughness": roughness,
472
+ "specular": specular,
473
+ }
474
+
475
+ return (diffuse, normal, roughness, specular)
476
+
477
+ def _preprocess_image(self, image: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
478
+ """Preprocess PIL image to tensor."""
479
+ image = image.convert("RGB")
480
+ image = np.array(image).astype(np.float32) / 255.0
481
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
482
+ image = image * 2.0 - 1.0 # Scale to [-1, 1]
483
+ return image.to(device=device, dtype=dtype)
484
+
485
+ def _preprocess_sketch(
486
+ self,
487
+ sketch: Image.Image,
488
+ height: int,
489
+ width: int,
490
+ device: torch.device,
491
+ dtype: torch.dtype,
492
+ ) -> torch.Tensor:
493
+ """Preprocess sketch image to tensor."""
494
+ sketch = sketch.convert("L")
495
+ sketch = sketch.resize((width, height), Image.BILINEAR)
496
+ sketch = np.array(sketch).astype(np.float32) / 255.0
497
+ sketch = torch.from_numpy(sketch).unsqueeze(0).unsqueeze(0)
498
+ return sketch.to(device=device, dtype=dtype)
499
+
500
+ def _preprocess_palette(
501
+ self,
502
+ palette: Union[np.ndarray, List[Tuple[int, int, int]]],
503
+ device: torch.device,
504
+ dtype: torch.dtype,
505
+ ) -> torch.Tensor:
506
+ """Preprocess color palette to tensor."""
507
+ if isinstance(palette, list):
508
+ palette = np.array(palette, dtype=np.float32) / 255.0
509
+ elif isinstance(palette, np.ndarray):
510
+ if palette.max() > 1.0:
511
+ palette = palette.astype(np.float32) / 255.0
512
+ else:
513
+ palette = palette.astype(np.float32)
514
+
515
+ # Ensure 5 colors
516
+ while len(palette) < 5:
517
+ palette = np.concatenate([palette, palette[-1:]], axis=0)
518
+ palette = palette[:5]
519
+
520
+ palette = torch.from_numpy(palette).unsqueeze(0)
521
+ return palette.to(device=device, dtype=dtype)
522
+
523
+ def _tensor_to_pil(self, tensor: torch.Tensor) -> List[Image.Image]:
524
+ """Convert tensor to list of PIL images."""
525
+ tensor = (tensor + 1.0) / 2.0
526
+ tensor = tensor.clamp(0, 1)
527
+ tensor = tensor.cpu().permute(0, 2, 3, 1).numpy()
528
+ tensor = (tensor * 255).astype(np.uint8)
529
+ return [Image.fromarray(img) for img in tensor]
530
+
531
+ def _tensor_to_numpy(self, tensor: torch.Tensor) -> np.ndarray:
532
+ """Convert tensor to numpy array."""
533
+ tensor = (tensor + 1.0) / 2.0
534
+ tensor = tensor.clamp(0, 1)
535
+ return tensor.cpu().permute(0, 2, 3, 1).numpy()
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.35.2",
4
+ "beta_end": 0.0195,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0015,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 0,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
unet/config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.35.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [8, 16, 32],
9
+ "attention_type": "default",
10
+ "block_out_channels": [
11
+ 256,
12
+ 512,
13
+ 1024
14
+ ],
15
+ "center_input_sample": false,
16
+ "class_embed_type": null,
17
+ "class_embeddings_concat": false,
18
+ "conv_in_kernel": 3,
19
+ "conv_out_kernel": 3,
20
+ "cross_attention_dim": 512,
21
+ "cross_attention_norm": null,
22
+ "down_block_types": [
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D"
26
+ ],
27
+ "downsample_padding": 1,
28
+ "dropout": 0.0,
29
+ "dual_cross_attention": false,
30
+ "encoder_hid_dim": null,
31
+ "encoder_hid_dim_type": null,
32
+ "flip_sin_to_cos": true,
33
+ "freq_shift": 0,
34
+ "in_channels": 16,
35
+ "layers_per_block": 2,
36
+ "mid_block_only_cross_attention": null,
37
+ "mid_block_scale_factor": 1,
38
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
39
+ "norm_eps": 1e-05,
40
+ "norm_num_groups": 32,
41
+ "num_attention_heads": null,
42
+ "num_class_embeds": null,
43
+ "only_cross_attention": false,
44
+ "out_channels": 12,
45
+ "projection_class_embeddings_input_dim": null,
46
+ "resnet_out_scale_factor": 1.0,
47
+ "resnet_skip_time_act": false,
48
+ "resnet_time_scale_shift": "default",
49
+ "reverse_transformer_layers_per_block": null,
50
+ "sample_size": 32,
51
+ "time_cond_proj_dim": null,
52
+ "time_embedding_act_fn": null,
53
+ "time_embedding_dim": null,
54
+ "time_embedding_type": "positional",
55
+ "timestep_post_act": null,
56
+ "transformer_layers_per_block": 1,
57
+ "up_block_types": [
58
+ "CrossAttnUpBlock2D",
59
+ "CrossAttnUpBlock2D",
60
+ "CrossAttnUpBlock2D"
61
+ ],
62
+ "upcast_attention": false,
63
+ "use_linear_projection": false
64
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de0fd0026e9621a105c46aea51a671dae90b5e1233ddf0c4e26a54eb0097e3d1
3
+ size 1580197816
vae/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MatFuseVQModel",
3
+ "_diffusers_version": "0.35.2",
4
+ "attn_resolutions": [],
5
+ "ch": 128,
6
+ "ch_mult": [
7
+ 1,
8
+ 1,
9
+ 2,
10
+ 4
11
+ ],
12
+ "dropout": 0.0,
13
+ "embed_dim": 3,
14
+ "in_channels": 3,
15
+ "n_embed": 4096,
16
+ "num_res_blocks": 2,
17
+ "out_channels": 12,
18
+ "resolution": 256,
19
+ "scaling_factor": 1.0,
20
+ "z_channels": 256
21
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad0962b5d9ad812c2ee8fe146e28ee73bba4382be388cc6703e3b4bb6ce6d06b
3
+ size 528848272
vae/vae_matfuse.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatFuse VQ-VAE Model for diffusers.
3
+
4
+ This is a custom VQ-VAE that has 4 separate encoders (one for each material map)
5
+ and 4 separate quantizers, with a single shared decoder.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+
16
+ def Normalize(in_channels: int, num_groups: int = 32) -> nn.GroupNorm:
17
+ """Group normalization."""
18
+ return nn.GroupNorm(
19
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
20
+ )
21
+
22
+
23
+ def nonlinearity(x: torch.Tensor) -> torch.Tensor:
24
+ """Swish activation."""
25
+ return x * torch.sigmoid(x)
26
+
27
+
28
+ class Upsample(nn.Module):
29
+ """Upsampling layer with optional convolution."""
30
+
31
+ def __init__(self, in_channels: int, with_conv: bool = True):
32
+ super().__init__()
33
+ self.with_conv = with_conv
34
+ if self.with_conv:
35
+ self.conv = nn.Conv2d(
36
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
37
+ )
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
41
+ if self.with_conv:
42
+ x = self.conv(x)
43
+ return x
44
+
45
+
46
+ class Downsample(nn.Module):
47
+ """Downsampling layer with optional convolution."""
48
+
49
+ def __init__(self, in_channels: int, with_conv: bool = True):
50
+ super().__init__()
51
+ self.with_conv = with_conv
52
+ if self.with_conv:
53
+ self.conv = nn.Conv2d(
54
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
55
+ )
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ if self.with_conv:
59
+ pad = (0, 1, 0, 1)
60
+ x = F.pad(x, pad, mode="constant", value=0)
61
+ x = self.conv(x)
62
+ else:
63
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
64
+ return x
65
+
66
+
67
+ class ResnetBlock(nn.Module):
68
+ """Residual block with optional time embedding."""
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int,
73
+ out_channels: Optional[int] = None,
74
+ conv_shortcut: bool = False,
75
+ dropout: float = 0.0,
76
+ temb_channels: int = 0,
77
+ ):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
87
+ )
88
+
89
+ if temb_channels > 0:
90
+ self.temb_proj = nn.Linear(temb_channels, out_channels)
91
+
92
+ self.norm2 = Normalize(out_channels)
93
+ self.dropout = nn.Dropout(dropout)
94
+ self.conv2 = nn.Conv2d(
95
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
96
+ )
97
+
98
+ if self.in_channels != self.out_channels:
99
+ if self.use_conv_shortcut:
100
+ self.conv_shortcut = nn.Conv2d(
101
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
102
+ )
103
+ else:
104
+ self.nin_shortcut = nn.Conv2d(
105
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
106
+ )
107
+
108
+ def forward(
109
+ self, x: torch.Tensor, temb: Optional[torch.Tensor] = None
110
+ ) -> torch.Tensor:
111
+ h = x
112
+ h = self.norm1(h)
113
+ h = nonlinearity(h)
114
+ h = self.conv1(h)
115
+
116
+ if temb is not None and hasattr(self, "temb_proj"):
117
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
118
+
119
+ h = self.norm2(h)
120
+ h = nonlinearity(h)
121
+ h = self.dropout(h)
122
+ h = self.conv2(h)
123
+
124
+ if self.in_channels != self.out_channels:
125
+ if self.use_conv_shortcut:
126
+ x = self.conv_shortcut(x)
127
+ else:
128
+ x = self.nin_shortcut(x)
129
+
130
+ return x + h
131
+
132
+
133
+ class AttnBlock(nn.Module):
134
+ """Self-attention block."""
135
+
136
+ def __init__(self, in_channels: int):
137
+ super().__init__()
138
+ self.in_channels = in_channels
139
+
140
+ self.norm = Normalize(in_channels)
141
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
142
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
143
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
144
+ self.proj_out = nn.Conv2d(
145
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
146
+ )
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # Compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h * w)
158
+ q = q.permute(0, 2, 1) # b, hw, c
159
+ k = k.reshape(b, c, h * w) # b, c, hw
160
+ w_ = torch.bmm(q, k) # b, hw, hw
161
+ w_ = w_ * (int(c) ** (-0.5))
162
+ w_ = F.softmax(w_, dim=2)
163
+
164
+ # Attend to values
165
+ v = v.reshape(b, c, h * w)
166
+ w_ = w_.permute(0, 2, 1) # b, hw, hw
167
+ h_ = torch.bmm(v, w_) # b, c, hw
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+
172
+ return x + h_
173
+
174
+
175
+ class Encoder(nn.Module):
176
+ """Encoder module for VQ-VAE."""
177
+
178
+ def __init__(
179
+ self,
180
+ ch: int = 128,
181
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
182
+ num_res_blocks: int = 2,
183
+ attn_resolutions: Tuple[int, ...] = (),
184
+ dropout: float = 0.0,
185
+ in_channels: int = 3,
186
+ resolution: int = 256,
187
+ z_channels: int = 256,
188
+ double_z: bool = False,
189
+ **ignore_kwargs,
190
+ ):
191
+ super().__init__()
192
+ self.ch = ch
193
+ self.temb_ch = 0
194
+ self.num_resolutions = len(ch_mult)
195
+ self.num_res_blocks = num_res_blocks
196
+ self.resolution = resolution
197
+ self.in_channels = in_channels
198
+
199
+ # Downsampling
200
+ self.conv_in = nn.Conv2d(
201
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
202
+ )
203
+
204
+ curr_res = resolution
205
+ in_ch_mult = (1,) + tuple(ch_mult)
206
+ self.down = nn.ModuleList()
207
+
208
+ for i_level in range(self.num_resolutions):
209
+ block = nn.ModuleList()
210
+ attn = nn.ModuleList()
211
+ block_in = ch * in_ch_mult[i_level]
212
+ block_out = ch * ch_mult[i_level]
213
+
214
+ for i_block in range(self.num_res_blocks):
215
+ block.append(
216
+ ResnetBlock(
217
+ in_channels=block_in,
218
+ out_channels=block_out,
219
+ temb_channels=self.temb_ch,
220
+ dropout=dropout,
221
+ )
222
+ )
223
+ block_in = block_out
224
+ if curr_res in attn_resolutions:
225
+ attn.append(AttnBlock(block_in))
226
+
227
+ down = nn.Module()
228
+ down.block = block
229
+ down.attn = attn
230
+
231
+ if i_level != self.num_resolutions - 1:
232
+ down.downsample = Downsample(block_in, with_conv=True)
233
+ curr_res = curr_res // 2
234
+
235
+ self.down.append(down)
236
+
237
+ # Middle
238
+ self.mid = nn.Module()
239
+ self.mid.block_1 = ResnetBlock(
240
+ in_channels=block_in,
241
+ out_channels=block_in,
242
+ temb_channels=self.temb_ch,
243
+ dropout=dropout,
244
+ )
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(
247
+ in_channels=block_in,
248
+ out_channels=block_in,
249
+ temb_channels=self.temb_ch,
250
+ dropout=dropout,
251
+ )
252
+
253
+ # End
254
+ self.norm_out = Normalize(block_in)
255
+ out_channels = 2 * z_channels if double_z else z_channels
256
+ self.conv_out = nn.Conv2d(
257
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
258
+ )
259
+
260
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
261
+ # Downsampling
262
+ h = self.conv_in(x)
263
+
264
+ for i_level in range(self.num_resolutions):
265
+ for i_block in range(self.num_res_blocks):
266
+ h = self.down[i_level].block[i_block](h, None)
267
+ if len(self.down[i_level].attn) > 0:
268
+ h = self.down[i_level].attn[i_block](h)
269
+ if hasattr(self.down[i_level], "downsample"):
270
+ h = self.down[i_level].downsample(h)
271
+
272
+ # Middle
273
+ h = self.mid.block_1(h, None)
274
+ h = self.mid.attn_1(h)
275
+ h = self.mid.block_2(h, None)
276
+
277
+ # End
278
+ h = self.norm_out(h)
279
+ h = nonlinearity(h)
280
+ h = self.conv_out(h)
281
+
282
+ return h
283
+
284
+
285
+ class Decoder(nn.Module):
286
+ """Decoder module for VQ-VAE."""
287
+
288
+ def __init__(
289
+ self,
290
+ ch: int = 128,
291
+ out_ch: int = 12,
292
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
293
+ num_res_blocks: int = 2,
294
+ attn_resolutions: Tuple[int, ...] = (),
295
+ dropout: float = 0.0,
296
+ in_channels: int = 3,
297
+ resolution: int = 256,
298
+ z_channels: int = 256,
299
+ give_pre_end: bool = False,
300
+ **ignore_kwargs,
301
+ ):
302
+ super().__init__()
303
+ self.ch = ch
304
+ self.temb_ch = 0
305
+ self.num_resolutions = len(ch_mult)
306
+ self.num_res_blocks = num_res_blocks
307
+ self.resolution = resolution
308
+ self.in_channels = in_channels
309
+ self.give_pre_end = give_pre_end
310
+
311
+ # Compute in_ch_mult and block_in
312
+ block_in = ch * ch_mult[self.num_resolutions - 1]
313
+ curr_res = resolution // (2 ** (self.num_resolutions - 1))
314
+
315
+ # z to block_in
316
+ self.conv_in = nn.Conv2d(
317
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
318
+ )
319
+
320
+ # Middle
321
+ self.mid = nn.Module()
322
+ self.mid.block_1 = ResnetBlock(
323
+ in_channels=block_in,
324
+ out_channels=block_in,
325
+ temb_channels=self.temb_ch,
326
+ dropout=dropout,
327
+ )
328
+ self.mid.attn_1 = AttnBlock(block_in)
329
+ self.mid.block_2 = ResnetBlock(
330
+ in_channels=block_in,
331
+ out_channels=block_in,
332
+ temb_channels=self.temb_ch,
333
+ dropout=dropout,
334
+ )
335
+
336
+ # Upsampling
337
+ self.up = nn.ModuleList()
338
+ for i_level in reversed(range(self.num_resolutions)):
339
+ block = nn.ModuleList()
340
+ attn = nn.ModuleList()
341
+ block_out = ch * ch_mult[i_level]
342
+
343
+ for i_block in range(self.num_res_blocks + 1):
344
+ block.append(
345
+ ResnetBlock(
346
+ in_channels=block_in,
347
+ out_channels=block_out,
348
+ temb_channels=self.temb_ch,
349
+ dropout=dropout,
350
+ )
351
+ )
352
+ block_in = block_out
353
+ if curr_res in attn_resolutions:
354
+ attn.append(AttnBlock(block_in))
355
+
356
+ up = nn.Module()
357
+ up.block = block
358
+ up.attn = attn
359
+
360
+ if i_level != 0:
361
+ up.upsample = Upsample(block_in, with_conv=True)
362
+ curr_res = curr_res * 2
363
+
364
+ self.up.insert(0, up)
365
+
366
+ # End
367
+ self.norm_out = Normalize(block_in)
368
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
369
+
370
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
371
+ # z to block_in
372
+ h = self.conv_in(z)
373
+
374
+ # Middle
375
+ h = self.mid.block_1(h, None)
376
+ h = self.mid.attn_1(h)
377
+ h = self.mid.block_2(h, None)
378
+
379
+ # Upsampling
380
+ for i_level in reversed(range(self.num_resolutions)):
381
+ for i_block in range(self.num_res_blocks + 1):
382
+ h = self.up[i_level].block[i_block](h, None)
383
+ if len(self.up[i_level].attn) > 0:
384
+ h = self.up[i_level].attn[i_block](h)
385
+ if hasattr(self.up[i_level], "upsample"):
386
+ h = self.up[i_level].upsample(h)
387
+
388
+ # End
389
+ if self.give_pre_end:
390
+ return h
391
+
392
+ h = self.norm_out(h)
393
+ h = nonlinearity(h)
394
+ h = self.conv_out(h)
395
+
396
+ return h
397
+
398
+
399
+ class VectorQuantizer(nn.Module):
400
+ """
401
+ Vector Quantizer module.
402
+
403
+ Discretizes the input vectors using a learned codebook.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ n_embed: int,
409
+ embed_dim: int,
410
+ beta: float = 0.25,
411
+ ):
412
+ super().__init__()
413
+ self.n_embed = n_embed
414
+ self.embed_dim = embed_dim
415
+ self.beta = beta
416
+
417
+ self.embedding = nn.Embedding(self.n_embed, self.embed_dim)
418
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
419
+
420
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
421
+ # Reshape z -> (batch, height, width, channel) and flatten
422
+ z = z.permute(0, 2, 3, 1).contiguous()
423
+ z_flattened = z.view(-1, self.embed_dim)
424
+
425
+ # Distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
426
+ d = (
427
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
428
+ + torch.sum(self.embedding.weight**2, dim=1)
429
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
430
+ )
431
+
432
+ min_encoding_indices = torch.argmin(d, dim=1)
433
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
434
+
435
+ # Compute loss for embedding
436
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
437
+ (z_q - z.detach()) ** 2
438
+ )
439
+
440
+ # Preserve gradients
441
+ z_q = z + (z_q - z).detach()
442
+
443
+ # Reshape back to match original input shape
444
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
445
+
446
+ return z_q, loss, (None, None, min_encoding_indices)
447
+
448
+ def get_codebook_entry(
449
+ self, indices: torch.Tensor, shape: Optional[Tuple] = None
450
+ ) -> torch.Tensor:
451
+ # Get quantized latent vectors
452
+ z_q = self.embedding(indices)
453
+
454
+ if shape is not None:
455
+ z_q = z_q.view(shape)
456
+ # Reshape back to match original input shape
457
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
458
+
459
+ return z_q
460
+
461
+
462
+ class MatFuseVQModel(ModelMixin, ConfigMixin):
463
+ """
464
+ MatFuse VQ-VAE Model.
465
+
466
+ This model has 4 separate encoders for each material map (diffuse, normal, roughness, specular)
467
+ and 4 separate VQ quantizers, with a single shared decoder that outputs 12 channels.
468
+ """
469
+
470
+ @register_to_config
471
+ def __init__(
472
+ self,
473
+ ch: int = 128,
474
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
475
+ num_res_blocks: int = 2,
476
+ attn_resolutions: Tuple[int, ...] = (),
477
+ dropout: float = 0.0,
478
+ in_channels: int = 3,
479
+ out_channels: int = 12,
480
+ resolution: int = 256,
481
+ z_channels: int = 256,
482
+ n_embed: int = 4096,
483
+ embed_dim: int = 3,
484
+ scaling_factor: float = 1.0,
485
+ ):
486
+ super().__init__()
487
+
488
+ self.scaling_factor = scaling_factor
489
+ self.embed_dim = embed_dim
490
+
491
+ ddconfig = dict(
492
+ ch=ch,
493
+ ch_mult=ch_mult,
494
+ num_res_blocks=num_res_blocks,
495
+ attn_resolutions=attn_resolutions,
496
+ dropout=dropout,
497
+ in_channels=in_channels,
498
+ resolution=resolution,
499
+ z_channels=z_channels,
500
+ double_z=False,
501
+ )
502
+
503
+ # 4 separate encoders for each material map
504
+ self.encoder_0 = Encoder(**ddconfig)
505
+ self.encoder_1 = Encoder(**ddconfig)
506
+ self.encoder_2 = Encoder(**ddconfig)
507
+ self.encoder_3 = Encoder(**ddconfig)
508
+
509
+ # Single decoder
510
+ decoder_config = dict(
511
+ ch=ch,
512
+ out_ch=out_channels,
513
+ ch_mult=ch_mult,
514
+ num_res_blocks=num_res_blocks,
515
+ attn_resolutions=attn_resolutions,
516
+ dropout=dropout,
517
+ in_channels=in_channels,
518
+ resolution=resolution,
519
+ z_channels=z_channels,
520
+ )
521
+ self.decoder = Decoder(**decoder_config)
522
+
523
+ # 4 separate quantizers
524
+ self.quantize_0 = VectorQuantizer(n_embed, embed_dim)
525
+ self.quantize_1 = VectorQuantizer(n_embed, embed_dim)
526
+ self.quantize_2 = VectorQuantizer(n_embed, embed_dim)
527
+ self.quantize_3 = VectorQuantizer(n_embed, embed_dim)
528
+
529
+ # Quant convolutions
530
+ self.quant_conv_0 = nn.Conv2d(z_channels, embed_dim, 1)
531
+ self.quant_conv_1 = nn.Conv2d(z_channels, embed_dim, 1)
532
+ self.quant_conv_2 = nn.Conv2d(z_channels, embed_dim, 1)
533
+ self.quant_conv_3 = nn.Conv2d(z_channels, embed_dim, 1)
534
+
535
+ # Post quant convolution (takes 4 * embed_dim channels)
536
+ self.post_quant_conv = nn.Conv2d(embed_dim * 4, z_channels, 1)
537
+
538
+ def encode_to_prequant(self, x: torch.Tensor) -> torch.Tensor:
539
+ """Encode input to pre-quantized latent space."""
540
+ h_0 = self.encoder_0(x[:, :3])
541
+ h_1 = self.encoder_1(x[:, 3:6])
542
+ h_2 = self.encoder_2(x[:, 6:9])
543
+ h_3 = self.encoder_3(x[:, 9:12])
544
+
545
+ h_0 = self.quant_conv_0(h_0)
546
+ h_1 = self.quant_conv_1(h_1)
547
+ h_2 = self.quant_conv_2(h_2)
548
+ h_3 = self.quant_conv_3(h_3)
549
+
550
+ h = torch.cat((h_0, h_1, h_2, h_3), dim=1)
551
+ return h
552
+
553
+ def quantize_latent(
554
+ self, h: torch.Tensor
555
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
556
+ """Quantize the latent space."""
557
+ quant_0, emb_loss_0, info_0 = self.quantize_0(h[:, : self.embed_dim])
558
+ quant_1, emb_loss_1, info_1 = self.quantize_1(
559
+ h[:, self.embed_dim : 2 * self.embed_dim]
560
+ )
561
+ quant_2, emb_loss_2, info_2 = self.quantize_2(
562
+ h[:, 2 * self.embed_dim : 3 * self.embed_dim]
563
+ )
564
+ quant_3, emb_loss_3, info_3 = self.quantize_3(h[:, 3 * self.embed_dim :])
565
+
566
+ quant = torch.cat((quant_0, quant_1, quant_2, quant_3), dim=1)
567
+ emb_loss = emb_loss_0 + emb_loss_1 + emb_loss_2 + emb_loss_3
568
+ info = torch.stack([info_0[-1], info_1[-1], info_2[-1], info_3[-1]], dim=0)
569
+
570
+ return quant, emb_loss, info
571
+
572
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
573
+ """Encode input to quantized latent space."""
574
+ h = self.encode_to_prequant(x)
575
+ quant, _, _ = self.quantize_latent(h)
576
+ return quant * self.scaling_factor
577
+
578
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
579
+ """Decode from latent space to image."""
580
+ z = z / self.scaling_factor
581
+ z = self.post_quant_conv(z)
582
+ dec = self.decoder(z)
583
+ return dec
584
+
585
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
586
+ """Forward pass through the VQ-VAE."""
587
+ h = self.encode_to_prequant(x)
588
+ quant, diff, _ = self.quantize_latent(h)
589
+ dec = self.decode(quant * self.scaling_factor)
590
+ return dec, diff
vae_matfuse.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatFuse VQ-VAE Model for diffusers.
3
+
4
+ This is a custom VQ-VAE that has 4 separate encoders (one for each material map)
5
+ and 4 separate quantizers, with a single shared decoder.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+
16
+ def Normalize(in_channels: int, num_groups: int = 32) -> nn.GroupNorm:
17
+ """Group normalization."""
18
+ return nn.GroupNorm(
19
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
20
+ )
21
+
22
+
23
+ def nonlinearity(x: torch.Tensor) -> torch.Tensor:
24
+ """Swish activation."""
25
+ return x * torch.sigmoid(x)
26
+
27
+
28
+ class Upsample(nn.Module):
29
+ """Upsampling layer with optional convolution."""
30
+
31
+ def __init__(self, in_channels: int, with_conv: bool = True):
32
+ super().__init__()
33
+ self.with_conv = with_conv
34
+ if self.with_conv:
35
+ self.conv = nn.Conv2d(
36
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
37
+ )
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
41
+ if self.with_conv:
42
+ x = self.conv(x)
43
+ return x
44
+
45
+
46
+ class Downsample(nn.Module):
47
+ """Downsampling layer with optional convolution."""
48
+
49
+ def __init__(self, in_channels: int, with_conv: bool = True):
50
+ super().__init__()
51
+ self.with_conv = with_conv
52
+ if self.with_conv:
53
+ self.conv = nn.Conv2d(
54
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
55
+ )
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ if self.with_conv:
59
+ pad = (0, 1, 0, 1)
60
+ x = F.pad(x, pad, mode="constant", value=0)
61
+ x = self.conv(x)
62
+ else:
63
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
64
+ return x
65
+
66
+
67
+ class ResnetBlock(nn.Module):
68
+ """Residual block with optional time embedding."""
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int,
73
+ out_channels: Optional[int] = None,
74
+ conv_shortcut: bool = False,
75
+ dropout: float = 0.0,
76
+ temb_channels: int = 0,
77
+ ):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
87
+ )
88
+
89
+ if temb_channels > 0:
90
+ self.temb_proj = nn.Linear(temb_channels, out_channels)
91
+
92
+ self.norm2 = Normalize(out_channels)
93
+ self.dropout = nn.Dropout(dropout)
94
+ self.conv2 = nn.Conv2d(
95
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
96
+ )
97
+
98
+ if self.in_channels != self.out_channels:
99
+ if self.use_conv_shortcut:
100
+ self.conv_shortcut = nn.Conv2d(
101
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
102
+ )
103
+ else:
104
+ self.nin_shortcut = nn.Conv2d(
105
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
106
+ )
107
+
108
+ def forward(
109
+ self, x: torch.Tensor, temb: Optional[torch.Tensor] = None
110
+ ) -> torch.Tensor:
111
+ h = x
112
+ h = self.norm1(h)
113
+ h = nonlinearity(h)
114
+ h = self.conv1(h)
115
+
116
+ if temb is not None and hasattr(self, "temb_proj"):
117
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
118
+
119
+ h = self.norm2(h)
120
+ h = nonlinearity(h)
121
+ h = self.dropout(h)
122
+ h = self.conv2(h)
123
+
124
+ if self.in_channels != self.out_channels:
125
+ if self.use_conv_shortcut:
126
+ x = self.conv_shortcut(x)
127
+ else:
128
+ x = self.nin_shortcut(x)
129
+
130
+ return x + h
131
+
132
+
133
+ class AttnBlock(nn.Module):
134
+ """Self-attention block."""
135
+
136
+ def __init__(self, in_channels: int):
137
+ super().__init__()
138
+ self.in_channels = in_channels
139
+
140
+ self.norm = Normalize(in_channels)
141
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
142
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
143
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
144
+ self.proj_out = nn.Conv2d(
145
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
146
+ )
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # Compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h * w)
158
+ q = q.permute(0, 2, 1) # b, hw, c
159
+ k = k.reshape(b, c, h * w) # b, c, hw
160
+ w_ = torch.bmm(q, k) # b, hw, hw
161
+ w_ = w_ * (int(c) ** (-0.5))
162
+ w_ = F.softmax(w_, dim=2)
163
+
164
+ # Attend to values
165
+ v = v.reshape(b, c, h * w)
166
+ w_ = w_.permute(0, 2, 1) # b, hw, hw
167
+ h_ = torch.bmm(v, w_) # b, c, hw
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+
172
+ return x + h_
173
+
174
+
175
+ class Encoder(nn.Module):
176
+ """Encoder module for VQ-VAE."""
177
+
178
+ def __init__(
179
+ self,
180
+ ch: int = 128,
181
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
182
+ num_res_blocks: int = 2,
183
+ attn_resolutions: Tuple[int, ...] = (),
184
+ dropout: float = 0.0,
185
+ in_channels: int = 3,
186
+ resolution: int = 256,
187
+ z_channels: int = 256,
188
+ double_z: bool = False,
189
+ **ignore_kwargs,
190
+ ):
191
+ super().__init__()
192
+ self.ch = ch
193
+ self.temb_ch = 0
194
+ self.num_resolutions = len(ch_mult)
195
+ self.num_res_blocks = num_res_blocks
196
+ self.resolution = resolution
197
+ self.in_channels = in_channels
198
+
199
+ # Downsampling
200
+ self.conv_in = nn.Conv2d(
201
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
202
+ )
203
+
204
+ curr_res = resolution
205
+ in_ch_mult = (1,) + tuple(ch_mult)
206
+ self.down = nn.ModuleList()
207
+
208
+ for i_level in range(self.num_resolutions):
209
+ block = nn.ModuleList()
210
+ attn = nn.ModuleList()
211
+ block_in = ch * in_ch_mult[i_level]
212
+ block_out = ch * ch_mult[i_level]
213
+
214
+ for i_block in range(self.num_res_blocks):
215
+ block.append(
216
+ ResnetBlock(
217
+ in_channels=block_in,
218
+ out_channels=block_out,
219
+ temb_channels=self.temb_ch,
220
+ dropout=dropout,
221
+ )
222
+ )
223
+ block_in = block_out
224
+ if curr_res in attn_resolutions:
225
+ attn.append(AttnBlock(block_in))
226
+
227
+ down = nn.Module()
228
+ down.block = block
229
+ down.attn = attn
230
+
231
+ if i_level != self.num_resolutions - 1:
232
+ down.downsample = Downsample(block_in, with_conv=True)
233
+ curr_res = curr_res // 2
234
+
235
+ self.down.append(down)
236
+
237
+ # Middle
238
+ self.mid = nn.Module()
239
+ self.mid.block_1 = ResnetBlock(
240
+ in_channels=block_in,
241
+ out_channels=block_in,
242
+ temb_channels=self.temb_ch,
243
+ dropout=dropout,
244
+ )
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(
247
+ in_channels=block_in,
248
+ out_channels=block_in,
249
+ temb_channels=self.temb_ch,
250
+ dropout=dropout,
251
+ )
252
+
253
+ # End
254
+ self.norm_out = Normalize(block_in)
255
+ out_channels = 2 * z_channels if double_z else z_channels
256
+ self.conv_out = nn.Conv2d(
257
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
258
+ )
259
+
260
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
261
+ # Downsampling
262
+ h = self.conv_in(x)
263
+
264
+ for i_level in range(self.num_resolutions):
265
+ for i_block in range(self.num_res_blocks):
266
+ h = self.down[i_level].block[i_block](h, None)
267
+ if len(self.down[i_level].attn) > 0:
268
+ h = self.down[i_level].attn[i_block](h)
269
+ if hasattr(self.down[i_level], "downsample"):
270
+ h = self.down[i_level].downsample(h)
271
+
272
+ # Middle
273
+ h = self.mid.block_1(h, None)
274
+ h = self.mid.attn_1(h)
275
+ h = self.mid.block_2(h, None)
276
+
277
+ # End
278
+ h = self.norm_out(h)
279
+ h = nonlinearity(h)
280
+ h = self.conv_out(h)
281
+
282
+ return h
283
+
284
+
285
+ class Decoder(nn.Module):
286
+ """Decoder module for VQ-VAE."""
287
+
288
+ def __init__(
289
+ self,
290
+ ch: int = 128,
291
+ out_ch: int = 12,
292
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
293
+ num_res_blocks: int = 2,
294
+ attn_resolutions: Tuple[int, ...] = (),
295
+ dropout: float = 0.0,
296
+ in_channels: int = 3,
297
+ resolution: int = 256,
298
+ z_channels: int = 256,
299
+ give_pre_end: bool = False,
300
+ **ignore_kwargs,
301
+ ):
302
+ super().__init__()
303
+ self.ch = ch
304
+ self.temb_ch = 0
305
+ self.num_resolutions = len(ch_mult)
306
+ self.num_res_blocks = num_res_blocks
307
+ self.resolution = resolution
308
+ self.in_channels = in_channels
309
+ self.give_pre_end = give_pre_end
310
+
311
+ # Compute in_ch_mult and block_in
312
+ block_in = ch * ch_mult[self.num_resolutions - 1]
313
+ curr_res = resolution // (2 ** (self.num_resolutions - 1))
314
+
315
+ # z to block_in
316
+ self.conv_in = nn.Conv2d(
317
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
318
+ )
319
+
320
+ # Middle
321
+ self.mid = nn.Module()
322
+ self.mid.block_1 = ResnetBlock(
323
+ in_channels=block_in,
324
+ out_channels=block_in,
325
+ temb_channels=self.temb_ch,
326
+ dropout=dropout,
327
+ )
328
+ self.mid.attn_1 = AttnBlock(block_in)
329
+ self.mid.block_2 = ResnetBlock(
330
+ in_channels=block_in,
331
+ out_channels=block_in,
332
+ temb_channels=self.temb_ch,
333
+ dropout=dropout,
334
+ )
335
+
336
+ # Upsampling
337
+ self.up = nn.ModuleList()
338
+ for i_level in reversed(range(self.num_resolutions)):
339
+ block = nn.ModuleList()
340
+ attn = nn.ModuleList()
341
+ block_out = ch * ch_mult[i_level]
342
+
343
+ for i_block in range(self.num_res_blocks + 1):
344
+ block.append(
345
+ ResnetBlock(
346
+ in_channels=block_in,
347
+ out_channels=block_out,
348
+ temb_channels=self.temb_ch,
349
+ dropout=dropout,
350
+ )
351
+ )
352
+ block_in = block_out
353
+ if curr_res in attn_resolutions:
354
+ attn.append(AttnBlock(block_in))
355
+
356
+ up = nn.Module()
357
+ up.block = block
358
+ up.attn = attn
359
+
360
+ if i_level != 0:
361
+ up.upsample = Upsample(block_in, with_conv=True)
362
+ curr_res = curr_res * 2
363
+
364
+ self.up.insert(0, up)
365
+
366
+ # End
367
+ self.norm_out = Normalize(block_in)
368
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
369
+
370
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
371
+ # z to block_in
372
+ h = self.conv_in(z)
373
+
374
+ # Middle
375
+ h = self.mid.block_1(h, None)
376
+ h = self.mid.attn_1(h)
377
+ h = self.mid.block_2(h, None)
378
+
379
+ # Upsampling
380
+ for i_level in reversed(range(self.num_resolutions)):
381
+ for i_block in range(self.num_res_blocks + 1):
382
+ h = self.up[i_level].block[i_block](h, None)
383
+ if len(self.up[i_level].attn) > 0:
384
+ h = self.up[i_level].attn[i_block](h)
385
+ if hasattr(self.up[i_level], "upsample"):
386
+ h = self.up[i_level].upsample(h)
387
+
388
+ # End
389
+ if self.give_pre_end:
390
+ return h
391
+
392
+ h = self.norm_out(h)
393
+ h = nonlinearity(h)
394
+ h = self.conv_out(h)
395
+
396
+ return h
397
+
398
+
399
+ class VectorQuantizer(nn.Module):
400
+ """
401
+ Vector Quantizer module.
402
+
403
+ Discretizes the input vectors using a learned codebook.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ n_embed: int,
409
+ embed_dim: int,
410
+ beta: float = 0.25,
411
+ ):
412
+ super().__init__()
413
+ self.n_embed = n_embed
414
+ self.embed_dim = embed_dim
415
+ self.beta = beta
416
+
417
+ self.embedding = nn.Embedding(self.n_embed, self.embed_dim)
418
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
419
+
420
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
421
+ # Reshape z -> (batch, height, width, channel) and flatten
422
+ z = z.permute(0, 2, 3, 1).contiguous()
423
+ z_flattened = z.view(-1, self.embed_dim)
424
+
425
+ # Distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
426
+ d = (
427
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
428
+ + torch.sum(self.embedding.weight**2, dim=1)
429
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
430
+ )
431
+
432
+ min_encoding_indices = torch.argmin(d, dim=1)
433
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
434
+
435
+ # Compute loss for embedding
436
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
437
+ (z_q - z.detach()) ** 2
438
+ )
439
+
440
+ # Preserve gradients
441
+ z_q = z + (z_q - z).detach()
442
+
443
+ # Reshape back to match original input shape
444
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
445
+
446
+ return z_q, loss, (None, None, min_encoding_indices)
447
+
448
+ def get_codebook_entry(
449
+ self, indices: torch.Tensor, shape: Optional[Tuple] = None
450
+ ) -> torch.Tensor:
451
+ # Get quantized latent vectors
452
+ z_q = self.embedding(indices)
453
+
454
+ if shape is not None:
455
+ z_q = z_q.view(shape)
456
+ # Reshape back to match original input shape
457
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
458
+
459
+ return z_q
460
+
461
+
462
+ class MatFuseVQModel(ModelMixin, ConfigMixin):
463
+ """
464
+ MatFuse VQ-VAE Model.
465
+
466
+ This model has 4 separate encoders for each material map (diffuse, normal, roughness, specular)
467
+ and 4 separate VQ quantizers, with a single shared decoder that outputs 12 channels.
468
+ """
469
+
470
+ @register_to_config
471
+ def __init__(
472
+ self,
473
+ ch: int = 128,
474
+ ch_mult: Tuple[int, ...] = (1, 1, 2, 4),
475
+ num_res_blocks: int = 2,
476
+ attn_resolutions: Tuple[int, ...] = (),
477
+ dropout: float = 0.0,
478
+ in_channels: int = 3,
479
+ out_channels: int = 12,
480
+ resolution: int = 256,
481
+ z_channels: int = 256,
482
+ n_embed: int = 4096,
483
+ embed_dim: int = 3,
484
+ scaling_factor: float = 1.0,
485
+ ):
486
+ super().__init__()
487
+
488
+ self.scaling_factor = scaling_factor
489
+ self.embed_dim = embed_dim
490
+
491
+ ddconfig = dict(
492
+ ch=ch,
493
+ ch_mult=ch_mult,
494
+ num_res_blocks=num_res_blocks,
495
+ attn_resolutions=attn_resolutions,
496
+ dropout=dropout,
497
+ in_channels=in_channels,
498
+ resolution=resolution,
499
+ z_channels=z_channels,
500
+ double_z=False,
501
+ )
502
+
503
+ # 4 separate encoders for each material map
504
+ self.encoder_0 = Encoder(**ddconfig)
505
+ self.encoder_1 = Encoder(**ddconfig)
506
+ self.encoder_2 = Encoder(**ddconfig)
507
+ self.encoder_3 = Encoder(**ddconfig)
508
+
509
+ # Single decoder
510
+ decoder_config = dict(
511
+ ch=ch,
512
+ out_ch=out_channels,
513
+ ch_mult=ch_mult,
514
+ num_res_blocks=num_res_blocks,
515
+ attn_resolutions=attn_resolutions,
516
+ dropout=dropout,
517
+ in_channels=in_channels,
518
+ resolution=resolution,
519
+ z_channels=z_channels,
520
+ )
521
+ self.decoder = Decoder(**decoder_config)
522
+
523
+ # 4 separate quantizers
524
+ self.quantize_0 = VectorQuantizer(n_embed, embed_dim)
525
+ self.quantize_1 = VectorQuantizer(n_embed, embed_dim)
526
+ self.quantize_2 = VectorQuantizer(n_embed, embed_dim)
527
+ self.quantize_3 = VectorQuantizer(n_embed, embed_dim)
528
+
529
+ # Quant convolutions
530
+ self.quant_conv_0 = nn.Conv2d(z_channels, embed_dim, 1)
531
+ self.quant_conv_1 = nn.Conv2d(z_channels, embed_dim, 1)
532
+ self.quant_conv_2 = nn.Conv2d(z_channels, embed_dim, 1)
533
+ self.quant_conv_3 = nn.Conv2d(z_channels, embed_dim, 1)
534
+
535
+ # Post quant convolution (takes 4 * embed_dim channels)
536
+ self.post_quant_conv = nn.Conv2d(embed_dim * 4, z_channels, 1)
537
+
538
+ def encode_to_prequant(self, x: torch.Tensor) -> torch.Tensor:
539
+ """Encode input to pre-quantized latent space."""
540
+ h_0 = self.encoder_0(x[:, :3])
541
+ h_1 = self.encoder_1(x[:, 3:6])
542
+ h_2 = self.encoder_2(x[:, 6:9])
543
+ h_3 = self.encoder_3(x[:, 9:12])
544
+
545
+ h_0 = self.quant_conv_0(h_0)
546
+ h_1 = self.quant_conv_1(h_1)
547
+ h_2 = self.quant_conv_2(h_2)
548
+ h_3 = self.quant_conv_3(h_3)
549
+
550
+ h = torch.cat((h_0, h_1, h_2, h_3), dim=1)
551
+ return h
552
+
553
+ def quantize_latent(
554
+ self, h: torch.Tensor
555
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
556
+ """Quantize the latent space."""
557
+ quant_0, emb_loss_0, info_0 = self.quantize_0(h[:, : self.embed_dim])
558
+ quant_1, emb_loss_1, info_1 = self.quantize_1(
559
+ h[:, self.embed_dim : 2 * self.embed_dim]
560
+ )
561
+ quant_2, emb_loss_2, info_2 = self.quantize_2(
562
+ h[:, 2 * self.embed_dim : 3 * self.embed_dim]
563
+ )
564
+ quant_3, emb_loss_3, info_3 = self.quantize_3(h[:, 3 * self.embed_dim :])
565
+
566
+ quant = torch.cat((quant_0, quant_1, quant_2, quant_3), dim=1)
567
+ emb_loss = emb_loss_0 + emb_loss_1 + emb_loss_2 + emb_loss_3
568
+ info = torch.stack([info_0[-1], info_1[-1], info_2[-1], info_3[-1]], dim=0)
569
+
570
+ return quant, emb_loss, info
571
+
572
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
573
+ """Encode input to quantized latent space."""
574
+ h = self.encode_to_prequant(x)
575
+ quant, _, _ = self.quantize_latent(h)
576
+ return quant * self.scaling_factor
577
+
578
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
579
+ """Decode from latent space to image."""
580
+ z = z / self.scaling_factor
581
+ z = self.post_quant_conv(z)
582
+ dec = self.decoder(z)
583
+ return dec
584
+
585
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
586
+ """Forward pass through the VQ-VAE."""
587
+ h = self.encode_to_prequant(x)
588
+ quant, diff, _ = self.quantize_latent(h)
589
+ dec = self.decode(quant * self.scaling_factor)
590
+ return dec, diff