BiliSakura commited on
Commit
f0eba3b
·
verified ·
1 Parent(s): 94572ed

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_images/demo_sde250_class207_seed42.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: diffusers
6
+ tags:
7
+ - diffusers
8
+ - image-generation
9
+ - class-conditional
10
+ - nit
11
+ pipeline_tag: unconditional-image-generation
12
+ widget:
13
+ - src: demo_images/demo_sde250_class207_seed42.png
14
+ example_title: NiT-XL Class 207
15
+ ---
16
+
17
+ # NiT-XL Diffusers (Class-Conditional)
18
+
19
+ Native-resolution Image Transformer (NiT-XL) checkpoint packaged as a Diffusers-style repository with vendored custom code.
20
+
21
+ ## What is included
22
+
23
+ - `transformer/`: `NiTTransformer2DModel` weights + config
24
+ - `scheduler/`: `NiTFlowMatchScheduler` config
25
+ - `vae/`: `AutoencoderDC` weights + config
26
+ - `custom_pipeline/`: local, self-contained implementation for:
27
+ - `NiTPipeline`
28
+ - `NiTTransformer2DModel`
29
+ - `NiTFlowMatchScheduler`
30
+ - `test_inference.py`: standalone sampling script
31
+
32
+ This repository does **not** depend on an external `NiT-diffusers` checkout during inference.
33
+ It includes a root `pipeline.py` custom entrypoint for Diffusers dynamic loading.
34
+
35
+ ## Quickstart
36
+
37
+ ### 1) Environment
38
+
39
+ Install dependencies (example):
40
+
41
+ ```bash
42
+ pip install torch diffusers safetensors
43
+ ```
44
+
45
+ If using this project environment:
46
+
47
+ ```bash
48
+ conda activate rsgen
49
+ ```
50
+
51
+ ### 2) Generate a demo image
52
+
53
+ Run from this repository root:
54
+
55
+ ```bash
56
+ python test_inference.py \
57
+ --class-label 207 \
58
+ --height 512 \
59
+ --width 512 \
60
+ --steps 250 \
61
+ --mode sde \
62
+ --guidance-scale 2.05 \
63
+ --guidance-low 0.0 \
64
+ --guidance-high 0.7 \
65
+ --output demo_images/demo_sde250_class207_seed42.png
66
+ ```
67
+
68
+ ## Python usage
69
+
70
+ ```python
71
+ from pathlib import Path
72
+ import torch
73
+ from diffusers import DiffusionPipeline
74
+
75
+ model_dir = Path(".").resolve()
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+ dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
78
+
79
+ pipe = DiffusionPipeline.from_pretrained(
80
+ model_dir,
81
+ custom_pipeline=str(model_dir / "pipeline.py"),
82
+ local_files_only=True,
83
+ ).to(device)
84
+ if device == "cuda":
85
+ pipe.transformer.to(dtype=dtype)
86
+ pipe.vae.to(dtype=dtype)
87
+
88
+ gen = torch.Generator(device=device).manual_seed(42)
89
+ result = pipe(
90
+ class_labels=[207],
91
+ height=512,
92
+ width=512,
93
+ num_inference_steps=250,
94
+ mode="sde",
95
+ guidance_scale=2.05,
96
+ guidance_interval=(0.0, 0.7),
97
+ generator=gen,
98
+ )
99
+ result.images[0].save("demo_images/sample.png")
100
+ ```
101
+
102
+ For remote Hub loading:
103
+
104
+ ```python
105
+ from diffusers import DiffusionPipeline
106
+
107
+ pipe = DiffusionPipeline.from_pretrained(
108
+ "BiliSakura/NiT-XL-diffusers",
109
+ custom_pipeline="pipeline",
110
+ )
111
+ ```
112
+
113
+ ## Recommended inference settings
114
+
115
+ - Resolution: `512x512`
116
+ - Mode: `sde`
117
+ - Steps: `250`
118
+ - Guidance scale: `2.05`
119
+ - Guidance interval: `(0.0, 0.7)`
120
+
121
+ Using very low steps (for example `2`) is only a smoke test and will produce low-quality images.
122
+
123
+ ## Demo
124
+
125
+ ![NiT-XL demo image](demo_images/demo_sde250_class207_seed42.png)
126
+
127
+ ## Citation
128
+
129
+ If you use this model or the NiT method in your work, please cite:
130
+
131
+ ```bibtex
132
+ @article{wang2025native,
133
+ title={Native-Resolution Image Synthesis},
134
+ author={Wang, Zidong and Bai, Lei and Yue, Xiangyu and Ouyang, Wanli and Zhang, Yiyuan},
135
+ year={2025},
136
+ eprint={2506.03131},
137
+ archivePrefix={arXiv},
138
+ primaryClass={cs.CV}
139
+ }
140
+ ```
141
+
142
+ ## Notes
143
+
144
+ - This is a class-conditional generator (ImageNet label ids), not a text-to-image model.
145
+ - For reproducibility, set `--seed`.
146
+ - The vendored custom pipeline keeps inference behavior consistent without external code dependencies.
__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (728 Bytes). View file
 
custom_pipeline/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipeline_nit import NiTPipeline, NiTPipelineOutput
2
+ from .transformer_nit import NiTTransformer2DModel, NiTTransformer2DModelOutput
3
+ from .scheduling_flow_match_nit import NiTFlowMatchScheduler, NiTFlowMatchSchedulerOutput
4
+
5
+
6
+ def _register_with_diffusers():
7
+ """
8
+ Expose NiT classes on the `diffusers` namespace so pipeline/component loading
9
+ via `from_pretrained()` can resolve entries declared in model_index.json.
10
+ """
11
+ try:
12
+ import diffusers
13
+ except Exception:
14
+ return
15
+
16
+ setattr(diffusers, "NiTPipeline", NiTPipeline)
17
+ setattr(diffusers, "NiTTransformer2DModel", NiTTransformer2DModel)
18
+ setattr(diffusers, "NiTFlowMatchScheduler", NiTFlowMatchScheduler)
19
+
20
+
21
+ _register_with_diffusers()
22
+
23
+ __all__ = [
24
+ "NiTPipeline",
25
+ "NiTPipelineOutput",
26
+ "NiTTransformer2DModel",
27
+ "NiTTransformer2DModelOutput",
28
+ "NiTFlowMatchScheduler",
29
+ "NiTFlowMatchSchedulerOutput",
30
+ ]
custom_pipeline/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
custom_pipeline/__pycache__/pipeline_nit.cpython-312.pyc ADDED
Binary file (12.3 kB). View file
 
custom_pipeline/__pycache__/scheduling_flow_match_nit.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
custom_pipeline/__pycache__/transformer_nit.cpython-312.pyc ADDED
Binary file (31.4 kB). View file
 
custom_pipeline/pipeline_nit.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+
11
+ try:
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.utils import BaseOutput
15
+ except Exception: # pragma: no cover - importable without a full diffusers install.
16
+ class BaseOutput(dict):
17
+ def __post_init__(self):
18
+ self.update(self.__dict__)
19
+
20
+ class DiffusionPipeline:
21
+ def register_modules(self, **kwargs):
22
+ for name, module in kwargs.items():
23
+ setattr(self, name, module)
24
+
25
+ @property
26
+ def _execution_device(self):
27
+ return torch.device("cpu")
28
+
29
+ def maybe_free_model_hooks(self):
30
+ pass
31
+
32
+ class VaeImageProcessor:
33
+ def postprocess(self, image, output_type="pil"):
34
+ return image
35
+
36
+
37
+ @dataclass
38
+ class NiTPipelineOutput(BaseOutput):
39
+ images: Union[torch.FloatTensor, List]
40
+
41
+
42
+ class NiTPipeline(DiffusionPipeline):
43
+ r"""
44
+ Native-resolution Image Synthesis pipeline using a class-conditional NiT transformer.
45
+
46
+ This pipeline follows Diffusers conventions: transformer, scheduler, and VAE are
47
+ saved as separate subfolders and restored with `DiffusionPipeline.from_pretrained`.
48
+ The transformer predicts flow-matching velocity in latent space.
49
+ """
50
+
51
+ model_cpu_offload_seq = "transformer->vae"
52
+ _optional_components = ["vae"]
53
+
54
+ def __init__(self, transformer, scheduler, vae=None):
55
+ super().__init__()
56
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
57
+ self.image_processor = VaeImageProcessor()
58
+
59
+ def _prepare_latents(
60
+ self,
61
+ batch_size: int,
62
+ height: int,
63
+ width: int,
64
+ dtype: torch.dtype,
65
+ device: torch.device,
66
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
67
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
68
+ if self.vae is None:
69
+ spatial_downsample = 1
70
+ elif self.vae.__class__.__name__ == "AutoencoderDC" or "dc-ae" in getattr(self.vae.config, "_name_or_path", ""):
71
+ spatial_downsample = 32
72
+ else:
73
+ spatial_downsample = getattr(self.vae.config, "block_out_channels", [0, 0, 0, 0])
74
+ spatial_downsample = 2 ** (len(spatial_downsample) - 1)
75
+
76
+ if height % spatial_downsample != 0 or width % spatial_downsample != 0:
77
+ raise ValueError(f"height and width must be divisible by the VAE downsample factor {spatial_downsample}.")
78
+
79
+ latent_height = height // spatial_downsample
80
+ latent_width = width // spatial_downsample
81
+ patch_size = int(self.transformer.config.patch_size)
82
+ if latent_height % patch_size != 0 or latent_width % patch_size != 0:
83
+ raise ValueError("Latent height and width must be divisible by transformer's patch_size.")
84
+
85
+ token_height = latent_height // patch_size
86
+ token_width = latent_width // patch_size
87
+ image_sizes = torch.tensor([[token_height, token_width]] * batch_size, device=device, dtype=torch.long)
88
+
89
+ # Match native NiT sampler initialization exactly: sample directly in packed-token space.
90
+ packed_shape = (
91
+ batch_size * token_height * token_width,
92
+ self.transformer.config.in_channels,
93
+ patch_size,
94
+ patch_size,
95
+ )
96
+ packed_latents = torch.randn(packed_shape, generator=generator, device=device, dtype=dtype)
97
+ return packed_latents, image_sizes
98
+
99
+ def _apply_classifier_free_guidance(
100
+ self,
101
+ model_output: torch.Tensor,
102
+ guidance_scale: float,
103
+ guidance_active: bool,
104
+ ) -> torch.Tensor:
105
+ if guidance_scale <= 1.0 or not guidance_active:
106
+ return model_output
107
+ model_output_cond, model_output_uncond = model_output.chunk(2)
108
+ return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
109
+
110
+ def _get_vae_dtype(self, latents: torch.Tensor) -> torch.dtype:
111
+ vae_dtype = getattr(self.vae, "dtype", None)
112
+ if vae_dtype is not None:
113
+ return vae_dtype
114
+ vae_params = next(self.vae.parameters(), None)
115
+ return vae_params.dtype if vae_params is not None else latents.dtype
116
+
117
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
118
+ if self.vae is None:
119
+ return latents
120
+ vae_dtype = self._get_vae_dtype(latents)
121
+ latents = latents.to(dtype=vae_dtype)
122
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
123
+ latents = latents / scaling_factor
124
+ if self.vae.__class__.__name__ == "AutoencoderDC":
125
+ image = self.vae._decode(latents)
126
+ else:
127
+ image = self.vae.decode(latents)
128
+ image = image.sample if hasattr(image, "sample") else image
129
+ return image
130
+
131
+ @torch.no_grad()
132
+ def __call__(
133
+ self,
134
+ class_labels: Union[int, List[int], torch.LongTensor],
135
+ height: int = 256,
136
+ width: int = 256,
137
+ num_inference_steps: int = 50,
138
+ guidance_scale: float = 1.0,
139
+ guidance_interval: Tuple[float, float] = (0.0, 1.0),
140
+ mode: str = "ode",
141
+ heun: bool = False,
142
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
143
+ output_type: str = "pil",
144
+ return_dict: bool = True,
145
+ ) -> Union[NiTPipelineOutput, Tuple]:
146
+ device = self._execution_device
147
+ model_dtype = next(self.transformer.parameters()).dtype
148
+
149
+ if isinstance(class_labels, int):
150
+ class_labels = [class_labels]
151
+ if not torch.is_tensor(class_labels):
152
+ class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
153
+ else:
154
+ class_labels = class_labels.to(device=device, dtype=torch.long)
155
+ batch_size = class_labels.numel()
156
+
157
+ packed_latents, image_sizes = self._prepare_latents(batch_size, height, width, model_dtype, device, generator)
158
+ packed_latents = packed_latents.to(dtype=torch.float64)
159
+ timesteps = self.scheduler.set_timesteps(num_inference_steps, device=device, mode=mode)
160
+
161
+ null_labels = torch.full_like(class_labels, self.transformer.config.num_classes)
162
+ for index, timestep in enumerate(timesteps[:-1]):
163
+ next_timestep = timesteps[index + 1]
164
+ guidance_active = guidance_interval[0] <= float(timestep) <= guidance_interval[1]
165
+ if guidance_scale > 1.0 and guidance_active:
166
+ model_input = torch.cat([packed_latents, packed_latents], dim=0)
167
+ labels = torch.cat([class_labels, null_labels], dim=0)
168
+ model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
169
+ else:
170
+ model_input = packed_latents
171
+ labels = class_labels
172
+ model_image_sizes = image_sizes
173
+
174
+ timestep_batch = torch.full((labels.numel(),), float(timestep), device=device, dtype=model_dtype)
175
+ model_output = self.transformer(
176
+ model_input.to(dtype=model_dtype),
177
+ timestep_batch,
178
+ labels,
179
+ image_sizes=model_image_sizes,
180
+ return_dict=True,
181
+ ).sample
182
+ model_output = self._apply_classifier_free_guidance(model_output, guidance_scale, guidance_active)
183
+
184
+ if heun and mode == "ode" and index < len(timesteps) - 2:
185
+ provisional = self.scheduler.step(
186
+ model_output,
187
+ timestep[None],
188
+ packed_latents,
189
+ next_timestep[None],
190
+ image_sizes=image_sizes,
191
+ ).prev_sample
192
+ if guidance_scale > 1.0 and guidance_active:
193
+ prime_input = torch.cat([provisional, provisional], dim=0)
194
+ labels = torch.cat([class_labels, null_labels], dim=0)
195
+ model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
196
+ else:
197
+ prime_input = provisional
198
+ labels = class_labels
199
+ model_image_sizes = image_sizes
200
+ next_timestep_batch = torch.full((labels.numel(),), float(next_timestep), device=device, dtype=model_dtype)
201
+ next_model_output = self.transformer(
202
+ prime_input.to(dtype=model_dtype),
203
+ next_timestep_batch,
204
+ labels,
205
+ image_sizes=model_image_sizes,
206
+ return_dict=True,
207
+ ).sample
208
+ next_model_output = self._apply_classifier_free_guidance(
209
+ next_model_output, guidance_scale, guidance_active
210
+ )
211
+ packed_latents = self.scheduler.step_heun(
212
+ model_output, next_model_output, timestep[None], packed_latents, next_timestep[None]
213
+ ).prev_sample
214
+ else:
215
+ packed_latents = self.scheduler.step(
216
+ model_output,
217
+ timestep[None],
218
+ packed_latents,
219
+ next_timestep[None],
220
+ image_sizes=image_sizes,
221
+ generator=generator,
222
+ ).prev_sample
223
+
224
+ latents = self.transformer._unpack_latents(packed_latents, image_sizes)
225
+ image = self._decode_latents(latents)
226
+ if self.vae is not None:
227
+ image = (image / 2 + 0.5).clamp(0, 1)
228
+ image = self.image_processor.postprocess(
229
+ image,
230
+ output_type=output_type,
231
+ do_denormalize=[False] * image.shape[0],
232
+ )
233
+
234
+ self.maybe_free_model_hooks()
235
+ if not return_dict:
236
+ return (image,)
237
+ return NiTPipelineOutput(images=image)
custom_pipeline/scheduling_flow_match_nit.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ try:
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
15
+ from diffusers.utils import BaseOutput
16
+ except Exception: # pragma: no cover - importable without an installed diffusers checkout.
17
+ class BaseOutput(dict):
18
+ def __post_init__(self):
19
+ self.update(self.__dict__)
20
+
21
+ class ConfigMixin:
22
+ config_name = "scheduler_config.json"
23
+
24
+ class SchedulerMixin:
25
+ pass
26
+
27
+ def register_to_config(init):
28
+ return init
29
+
30
+
31
+ @dataclass
32
+ class NiTFlowMatchSchedulerOutput(BaseOutput):
33
+ prev_sample: torch.FloatTensor
34
+
35
+
36
+ class NiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
37
+ """
38
+ Flow-matching ODE/SDE scheduler used by Native-resolution Image Synthesis (NiT).
39
+
40
+ The model predicts velocity with a linear path by default. Timesteps run from 1 to 0,
41
+ matching the original sampler while exposing the standard Diffusers `set_timesteps`
42
+ and `step` API.
43
+ """
44
+
45
+ config_name = "scheduler_config.json"
46
+ order = 1
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ mode: str = "ode",
52
+ path_type: str = "linear",
53
+ num_train_timesteps: int = 1000,
54
+ ):
55
+ if mode not in {"ode", "sde"}:
56
+ raise ValueError("mode must be either 'ode' or 'sde'.")
57
+ if path_type not in {"linear", "cosine"}:
58
+ raise ValueError("path_type must be either 'linear' or 'cosine'.")
59
+ self.mode = mode
60
+ self.path_type = path_type
61
+ self.num_train_timesteps = num_train_timesteps
62
+ # Native NiT integrates in float64 for better numerical stability.
63
+ self.timesteps = torch.from_numpy(np.linspace(1.0, 0.0, num_train_timesteps + 1)).to(dtype=torch.float64)
64
+
65
+ def set_timesteps(
66
+ self,
67
+ num_inference_steps: int,
68
+ device: Optional[torch.device] = None,
69
+ mode: Optional[str] = None,
70
+ ):
71
+ mode = mode or self.mode
72
+ dtype = self.timesteps.dtype
73
+ if mode == "sde":
74
+ timesteps = torch.linspace(1.0, 0.04, num_inference_steps, dtype=dtype)
75
+ timesteps = torch.cat([timesteps, torch.zeros(1, dtype=dtype)])
76
+ elif mode == "ode":
77
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, dtype=dtype)
78
+ else:
79
+ raise ValueError("mode must be either 'ode' or 'sde'.")
80
+ self.mode = mode
81
+ self.timesteps = timesteps.to(device=device)
82
+ return self.timesteps
83
+
84
+ @staticmethod
85
+ def _expand_t_like_sample(timestep: torch.Tensor, sample: torch.Tensor, image_sizes: torch.LongTensor):
86
+ dims = [1] * (sample.ndim - 1)
87
+ seqlens = image_sizes[:, 0] * image_sizes[:, 1]
88
+ if timestep.numel() == 1:
89
+ timestep = timestep.repeat(image_sizes.shape[0])
90
+ return torch.cat(
91
+ [timestep[i].reshape(1, *dims).repeat(int(seqlens[i]), *dims) for i in range(image_sizes.shape[0])]
92
+ )
93
+
94
+ def _get_score_from_velocity(
95
+ self,
96
+ model_output: torch.Tensor,
97
+ sample: torch.Tensor,
98
+ timestep: torch.Tensor,
99
+ image_sizes: torch.LongTensor,
100
+ ):
101
+ timestep = self._expand_t_like_sample(timestep, sample, image_sizes)
102
+ if self.path_type == "linear":
103
+ alpha_t, d_alpha_t = 1 - timestep, torch.ones_like(timestep) * -1
104
+ sigma_t, d_sigma_t = timestep, torch.ones_like(timestep)
105
+ elif self.path_type == "cosine":
106
+ alpha_t = torch.cos(timestep * np.pi / 2)
107
+ sigma_t = torch.sin(timestep * np.pi / 2)
108
+ d_alpha_t = -np.pi / 2 * torch.sin(timestep * np.pi / 2)
109
+ d_sigma_t = np.pi / 2 * torch.cos(timestep * np.pi / 2)
110
+ else:
111
+ raise ValueError(f"Unsupported path_type: {self.path_type}")
112
+ reverse_alpha_ratio = alpha_t / d_alpha_t
113
+ variance = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
114
+ return (reverse_alpha_ratio * model_output - sample) / variance
115
+
116
+ @staticmethod
117
+ def _compute_diffusion(timestep: torch.Tensor):
118
+ return 2 * timestep
119
+
120
+ @staticmethod
121
+ def _promote_dtypes(*tensors: torch.Tensor) -> torch.dtype:
122
+ dtype = None
123
+ for tensor in tensors:
124
+ if tensor.is_floating_point() or tensor.is_complex():
125
+ dtype = tensor.dtype if dtype is None else torch.promote_types(dtype, tensor.dtype)
126
+ return dtype if dtype is not None else torch.get_default_dtype()
127
+
128
+ def step(
129
+ self,
130
+ model_output: torch.Tensor,
131
+ timestep: torch.Tensor,
132
+ sample: torch.Tensor,
133
+ next_timestep: torch.Tensor,
134
+ image_sizes: Optional[torch.LongTensor] = None,
135
+ generator: Optional[torch.Generator] = None,
136
+ return_dict: bool = True,
137
+ ) -> NiTFlowMatchSchedulerOutput:
138
+ compute_dtype = torch.float64
139
+ sample = sample.to(dtype=compute_dtype)
140
+ model_output = model_output.to(dtype=compute_dtype)
141
+ timestep = timestep.to(device=sample.device, dtype=compute_dtype).flatten()
142
+ next_timestep = next_timestep.to(device=sample.device, dtype=compute_dtype).flatten()
143
+
144
+ if self.mode == "ode":
145
+ prev_sample = sample + (next_timestep[0] - timestep[0]) * model_output
146
+ else:
147
+ if image_sizes is None:
148
+ raise ValueError("image_sizes are required for SDE sampling.")
149
+ image_sizes = image_sizes.to(device=sample.device, dtype=torch.long)
150
+ diffusion = self._compute_diffusion(timestep[0])
151
+ score = self._get_score_from_velocity(model_output, sample, timestep, image_sizes)
152
+ drift = model_output - 0.5 * diffusion * score
153
+ dt = next_timestep[0] - timestep[0]
154
+ if torch.allclose(next_timestep[0], torch.zeros_like(next_timestep[0])):
155
+ prev_sample = sample + drift * dt
156
+ else:
157
+ if generator is not None:
158
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
159
+ else:
160
+ noise = torch.randn_like(sample)
161
+ prev_sample = sample + drift * dt + torch.sqrt(diffusion) * noise * torch.sqrt(torch.abs(dt))
162
+
163
+ if not return_dict:
164
+ return (prev_sample,)
165
+ return NiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
166
+
167
+ def step_heun(
168
+ self,
169
+ model_output: torch.Tensor,
170
+ next_model_output: torch.Tensor,
171
+ timestep: torch.Tensor,
172
+ sample: torch.Tensor,
173
+ next_timestep: torch.Tensor,
174
+ return_dict: bool = True,
175
+ ) -> NiTFlowMatchSchedulerOutput:
176
+ if self.mode != "ode":
177
+ raise ValueError("Heun correction is only defined for ODE sampling.")
178
+ compute_dtype = torch.float64
179
+ sample = sample.to(dtype=compute_dtype)
180
+ model_output = model_output.to(dtype=compute_dtype)
181
+ next_model_output = next_model_output.to(dtype=compute_dtype)
182
+ timestep = timestep.to(device=sample.device, dtype=compute_dtype).flatten()
183
+ next_timestep = next_timestep.to(device=sample.device, dtype=compute_dtype).flatten()
184
+ prev_sample = sample + (next_timestep[0] - timestep[0]) * (0.5 * model_output + 0.5 * next_model_output)
185
+ if not return_dict:
186
+ return (prev_sample,)
187
+ return NiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
custom_pipeline/transformer_nit.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ from dataclasses import dataclass
7
+ import math
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.utils import BaseOutput
18
+ except Exception: # pragma: no cover - lets this subtree be tested outside diffusers.
19
+ class BaseOutput(dict):
20
+ def __post_init__(self):
21
+ self.update(self.__dict__)
22
+
23
+ class _Config(dict):
24
+ def __getattr__(self, key):
25
+ try:
26
+ return self[key]
27
+ except KeyError as error:
28
+ raise AttributeError(key) from error
29
+
30
+ class ConfigMixin:
31
+ config_name = "config.json"
32
+
33
+ class ModelMixin(nn.Module):
34
+ pass
35
+
36
+ def register_to_config(init):
37
+ def wrapper(self, *args, **kwargs):
38
+ import inspect
39
+
40
+ signature = inspect.signature(init)
41
+ bound = signature.bind(self, *args, **kwargs)
42
+ bound.apply_defaults()
43
+ self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
44
+ init(self, *args, **kwargs)
45
+
46
+ return wrapper
47
+
48
+
49
+ try:
50
+ from flash_attn import flash_attn_varlen_func
51
+ except Exception: # pragma: no cover - optional acceleration.
52
+ flash_attn_varlen_func = None
53
+
54
+
55
+ @dataclass
56
+ class NiTTransformer2DModelOutput(BaseOutput):
57
+ sample: torch.FloatTensor
58
+ projection_states: Optional[Tuple[torch.FloatTensor, ...]] = None
59
+
60
+
61
+ def _modulate(hidden_states: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
62
+ return hidden_states * (1 + scale) + shift
63
+
64
+
65
+ def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
66
+ hidden_states = hidden_states.reshape(*hidden_states.shape[:-1], -1, 2)
67
+ hidden_states_1, hidden_states_2 = hidden_states.unbind(dim=-1)
68
+ return torch.stack((-hidden_states_2, hidden_states_1), dim=-1).flatten(-2)
69
+
70
+
71
+ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.dtype:
72
+ if tensor is not None and tensor.is_floating_point():
73
+ return tensor.dtype
74
+ return torch.get_default_dtype()
75
+
76
+
77
+ class NiTPatchEmbed(nn.Module):
78
+ def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
+ super().__init__()
80
+ self.patch_size = (patch_size, patch_size)
81
+ self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=True)
82
+
83
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
84
+ hidden_states = self.proj(hidden_states)
85
+ return hidden_states.flatten(2).transpose(1, 2)
86
+
87
+
88
+ class NiTTimestepEmbedder(nn.Module):
89
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
90
+ super().__init__()
91
+ self.frequency_embedding_size = frequency_embedding_size
92
+ self.mlp = nn.Sequential(
93
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_size, hidden_size, bias=True),
96
+ )
97
+
98
+ @staticmethod
99
+ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000):
100
+ half = embedding_dim // 2
101
+ # Keep sinusoid construction in fp32 to mirror the native NiT implementation.
102
+ exponent = -math.log(max_period) * torch.arange(half, dtype=torch.float32, device=timesteps.device) / half
103
+ freqs = torch.exp(exponent)
104
+ args = timesteps.float()[:, None] * freqs[None]
105
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
106
+ if embedding_dim % 2:
107
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
108
+ return embedding
109
+
110
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
111
+ timestep_freq = self.get_timestep_embedding(timesteps, self.frequency_embedding_size).to(timesteps.dtype)
112
+ return self.mlp(timestep_freq)
113
+
114
+
115
+ class NiTLabelEmbedder(nn.Module):
116
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
117
+ super().__init__()
118
+ use_cfg_embedding = dropout_prob > 0
119
+ self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size)
120
+ self.num_classes = num_classes
121
+ self.dropout_prob = dropout_prob
122
+
123
+ def forward(self, class_labels: torch.LongTensor) -> torch.Tensor:
124
+ return self.embedding_table(class_labels)
125
+
126
+
127
+ class NiTRotaryEmbedding(nn.Module):
128
+ def __init__(
129
+ self,
130
+ head_dim: int,
131
+ custom_freqs: str = "normal",
132
+ theta: int = 10000,
133
+ max_cached_len: int = 1024,
134
+ max_pe_len_h: Optional[int] = None,
135
+ max_pe_len_w: Optional[int] = None,
136
+ decouple: bool = False,
137
+ ori_max_pe_len: Optional[int] = None,
138
+ ):
139
+ super().__init__()
140
+ del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
+ if custom_freqs not in {"normal", "scale1", "scale2"}:
142
+ raise ValueError(
143
+ "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
+ "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
+ "by changing the model config before loading."
146
+ )
147
+ dim = head_dim // 2
148
+ if dim % 2 != 0:
149
+ raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
150
+ default_dtype = _get_float_dtype_or_default()
151
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
+ self.register_buffer("freqs_h", freqs, persistent=False)
153
+ self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
+ positions = torch.arange(max_cached_len, dtype=default_dtype)
155
+ freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
+ freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
+ self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
+ self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
+
160
+ def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
+ grids = []
162
+ for height, width in image_sizes.tolist():
163
+ # Use the same meshgrid ordering as native NiT to preserve RoPE-token alignment.
164
+ grid_h = torch.arange(height, device=device)
165
+ grid_w = torch.arange(width, device=device)
166
+ grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
+ grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
+ grid = torch.cat(grids, dim=1)
169
+ freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
+ freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
+ return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
173
+
174
+
175
+ class NiTAttention(nn.Module):
176
+ def __init__(self, hidden_size: int, num_heads: int, qk_norm: bool = False):
177
+ super().__init__()
178
+ if hidden_size % num_heads != 0:
179
+ raise ValueError("hidden_size must be divisible by num_heads")
180
+ self.num_heads = num_heads
181
+ self.head_dim = hidden_size // num_heads
182
+ self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
183
+ self.q_norm = nn.LayerNorm(self.head_dim) if qk_norm else nn.Identity()
184
+ self.k_norm = nn.LayerNorm(self.head_dim) if qk_norm else nn.Identity()
185
+ self.proj = nn.Linear(hidden_size, hidden_size)
186
+ self.proj_drop = nn.Dropout(0.0)
187
+
188
+ def forward(
189
+ self,
190
+ hidden_states: torch.Tensor,
191
+ cu_seqlens: torch.IntTensor,
192
+ freqs_cos: torch.Tensor,
193
+ freqs_sin: torch.Tensor,
194
+ ) -> torch.Tensor:
195
+ qkv = self.qkv(hidden_states).reshape(hidden_states.shape[0], 3, self.num_heads, self.head_dim)
196
+ query, key, value = qkv.unbind(dim=1)
197
+ original_dtype = qkv.dtype
198
+ query = self.q_norm(query)
199
+ key = self.k_norm(key)
200
+ query = query * freqs_cos + _rotate_half(query) * freqs_sin
201
+ key = key * freqs_cos + _rotate_half(key) * freqs_sin
202
+ query = query.to(dtype=original_dtype)
203
+ key = key.to(dtype=original_dtype)
204
+
205
+ if flash_attn_varlen_func is not None and query.is_cuda:
206
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
207
+ hidden_states = flash_attn_varlen_func(
208
+ query, key, value, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
209
+ ).reshape(hidden_states.shape[0], -1)
210
+ else:
211
+ segments = []
212
+ for start, end in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()):
213
+ q = query[start:end].transpose(0, 1).unsqueeze(0)
214
+ k = key[start:end].transpose(0, 1).unsqueeze(0)
215
+ v = value[start:end].transpose(0, 1).unsqueeze(0)
216
+ segments.append(F.scaled_dot_product_attention(q, k, v).squeeze(0).transpose(0, 1))
217
+ hidden_states = torch.cat(segments, dim=0).reshape(hidden_states.shape[0], -1)
218
+
219
+ hidden_states = self.proj(hidden_states)
220
+ return self.proj_drop(hidden_states)
221
+
222
+
223
+ class NiTMLP(nn.Module):
224
+ def __init__(self, hidden_size: int, mlp_hidden_dim: int):
225
+ super().__init__()
226
+ self.fc1 = nn.Linear(hidden_size, mlp_hidden_dim)
227
+ self.act = nn.GELU(approximate="tanh")
228
+ self.drop1 = nn.Dropout(0.0)
229
+ self.norm = nn.Identity()
230
+ self.fc2 = nn.Linear(mlp_hidden_dim, hidden_size)
231
+ self.drop2 = nn.Dropout(0.0)
232
+
233
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
234
+ hidden_states = self.fc1(hidden_states)
235
+ hidden_states = self.act(hidden_states)
236
+ hidden_states = self.drop1(hidden_states)
237
+ hidden_states = self.norm(hidden_states)
238
+ hidden_states = self.fc2(hidden_states)
239
+ return self.drop2(hidden_states)
240
+
241
+
242
+ class NiTBlock(nn.Module):
243
+ def __init__(
244
+ self,
245
+ hidden_size: int,
246
+ num_heads: int,
247
+ mlp_ratio: float = 4.0,
248
+ qk_norm: bool = False,
249
+ use_adaln_lora: bool = False,
250
+ adaln_lora_dim: int = 512,
251
+ ):
252
+ super().__init__()
253
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
254
+ self.attn = NiTAttention(hidden_size, num_heads=num_heads, qk_norm=qk_norm)
255
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
256
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
257
+ self.mlp = NiTMLP(hidden_size, mlp_hidden_dim)
258
+ if use_adaln_lora:
259
+ self.adaLN_modulation = nn.Sequential(
260
+ nn.SiLU(),
261
+ nn.Linear(hidden_size, adaln_lora_dim, bias=True),
262
+ nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True),
263
+ )
264
+ else:
265
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
266
+
267
+ def forward(self, hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin):
268
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk(
269
+ 6, dim=-1
270
+ )
271
+ hidden_states = hidden_states + gate_msa * self.attn(
272
+ _modulate(self.norm1(hidden_states), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin
273
+ )
274
+ hidden_states = hidden_states + gate_mlp * self.mlp(
275
+ _modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)
276
+ )
277
+ return hidden_states
278
+
279
+
280
+ class NiTFinalLayer(nn.Module):
281
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
282
+ super().__init__()
283
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
284
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
285
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
286
+
287
+ def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
288
+ shift, scale = self.adaLN_modulation(conditioning).chunk(2, dim=-1)
289
+ hidden_states = _modulate(self.norm_final(hidden_states), shift, scale)
290
+ return self.linear(hidden_states)
291
+
292
+
293
+ def _build_mlp(hidden_size: int, projector_dim: int, z_dim: int) -> nn.Sequential:
294
+ return nn.Sequential(
295
+ nn.Linear(hidden_size, projector_dim),
296
+ nn.SiLU(),
297
+ nn.Linear(projector_dim, projector_dim),
298
+ nn.SiLU(),
299
+ nn.Linear(projector_dim, z_dim),
300
+ )
301
+
302
+
303
+ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
304
+ config_name = "config.json"
305
+
306
+ @register_to_config
307
+ def __init__(
308
+ self,
309
+ input_size: int = 32,
310
+ patch_size: int = 1,
311
+ in_channels: int = 32,
312
+ hidden_size: int = 1152,
313
+ depth: int = 28,
314
+ num_heads: int = 16,
315
+ mlp_ratio: float = 4.0,
316
+ class_dropout_prob: float = 0.1,
317
+ num_classes: int = 1000,
318
+ encoder_depth: int = 8,
319
+ projector_dim: int = 2048,
320
+ z_dim: int = 1280,
321
+ use_checkpoint: bool = False,
322
+ custom_freqs: str = "normal",
323
+ theta: int = 10000,
324
+ max_pe_len_h: Optional[int] = None,
325
+ max_pe_len_w: Optional[int] = None,
326
+ decouple: bool = False,
327
+ ori_max_pe_len: Optional[int] = None,
328
+ qk_norm: bool = True,
329
+ use_adaln_lora: bool = False,
330
+ adaln_lora_dim: int = 512,
331
+ ):
332
+ super().__init__()
333
+ del input_size
334
+ self.in_channels = in_channels
335
+ self.out_channels = in_channels
336
+ self.patch_size = patch_size
337
+ self.num_heads = num_heads
338
+ self.num_classes = num_classes
339
+ self.encoder_depth = encoder_depth
340
+ self.use_checkpoint = use_checkpoint
341
+
342
+ self.x_embedder = NiTPatchEmbed(patch_size, in_channels, hidden_size)
343
+ self.t_embedder = NiTTimestepEmbedder(hidden_size)
344
+ self.y_embedder = NiTLabelEmbedder(num_classes, hidden_size, class_dropout_prob)
345
+ self.rope = NiTRotaryEmbedding(
346
+ hidden_size // num_heads,
347
+ custom_freqs=custom_freqs,
348
+ theta=theta,
349
+ max_pe_len_h=max_pe_len_h,
350
+ max_pe_len_w=max_pe_len_w,
351
+ decouple=decouple,
352
+ ori_max_pe_len=ori_max_pe_len,
353
+ )
354
+ self.projector = _build_mlp(hidden_size, projector_dim, z_dim)
355
+ self.blocks = nn.ModuleList(
356
+ [
357
+ NiTBlock(
358
+ hidden_size,
359
+ num_heads,
360
+ mlp_ratio=mlp_ratio,
361
+ qk_norm=qk_norm,
362
+ use_adaln_lora=use_adaln_lora,
363
+ adaln_lora_dim=adaln_lora_dim,
364
+ )
365
+ for _ in range(depth)
366
+ ]
367
+ )
368
+ self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
+
370
+ def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
+ batch_size, channels, height, width = hidden_states.shape
372
+ if channels != self.in_channels:
373
+ raise ValueError(f"Expected {self.in_channels} latent channels, got {channels}.")
374
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
375
+ raise ValueError("Latent height and width must be divisible by patch_size.")
376
+ latent_h = height // self.patch_size
377
+ latent_w = width // self.patch_size
378
+ hidden_states = hidden_states.reshape(batch_size, channels, latent_h, self.patch_size, latent_w, self.patch_size)
379
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).reshape(
380
+ batch_size * latent_h * latent_w, channels, self.patch_size, self.patch_size
381
+ )
382
+ image_sizes = torch.tensor([[latent_h, latent_w]] * batch_size, device=hidden_states.device, dtype=torch.long)
383
+ return hidden_states, image_sizes, (height, width)
384
+
385
+ def _unpack_latents(self, hidden_states: torch.Tensor, image_sizes: torch.LongTensor) -> torch.Tensor:
386
+ if image_sizes.shape[0] == 1:
387
+ height, width = image_sizes[0].tolist()
388
+ hidden_states = hidden_states.reshape(height, width, self.out_channels, self.patch_size, self.patch_size)
389
+ return hidden_states.permute(2, 0, 3, 1, 4).reshape(
390
+ 1, self.out_channels, height * self.patch_size, width * self.patch_size
391
+ )
392
+
393
+ samples = []
394
+ cursor = 0
395
+ for height, width in image_sizes.tolist():
396
+ length = height * width
397
+ sample = hidden_states[cursor : cursor + length].reshape(
398
+ height, width, self.out_channels, self.patch_size, self.patch_size
399
+ )
400
+ samples.append(
401
+ sample.permute(2, 0, 3, 1, 4).reshape(
402
+ self.out_channels, height * self.patch_size, width * self.patch_size
403
+ )
404
+ )
405
+ cursor += length
406
+ if len({tuple(sample.shape) for sample in samples}) != 1:
407
+ return hidden_states
408
+ return torch.stack(samples, dim=0)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ timestep: Union[torch.Tensor, float],
414
+ class_labels: torch.LongTensor,
415
+ image_sizes: Optional[Union[torch.LongTensor, List[Tuple[int, int]]]] = None,
416
+ return_dict: bool = True,
417
+ output_projection_states: bool = False,
418
+ ) -> Union[NiTTransformer2DModelOutput, Tuple[torch.Tensor, ...]]:
419
+ input_was_image = hidden_states.dim() == 4 and image_sizes is None
420
+ if input_was_image:
421
+ hidden_states, image_sizes, _ = self._pack_latents(hidden_states)
422
+ elif image_sizes is None:
423
+ raise ValueError("image_sizes must be provided when hidden_states are already packed.")
424
+ elif not torch.is_tensor(image_sizes):
425
+ image_sizes = torch.tensor(image_sizes, device=hidden_states.device, dtype=torch.long)
426
+ else:
427
+ image_sizes = image_sizes.to(device=hidden_states.device, dtype=torch.long)
428
+
429
+ if not torch.is_tensor(timestep):
430
+ timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype)
431
+ timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten()
432
+ if timestep.numel() == 1:
433
+ timestep = timestep.repeat(image_sizes.shape[0])
434
+ class_labels = class_labels.to(device=hidden_states.device, dtype=torch.long).flatten()
435
+
436
+ hidden_states = self.x_embedder(hidden_states).squeeze(1)
437
+ freqs_cos, freqs_sin = self.rope(image_sizes, hidden_states.device)
438
+
439
+ seqlens = image_sizes[:, 0] * image_sizes[:, 1]
440
+ cu_seqlens = torch.cat(
441
+ [torch.zeros(1, device=hidden_states.device, dtype=torch.int32), torch.cumsum(seqlens, dim=0).int()]
442
+ )
443
+
444
+ conditioning = self.t_embedder(timestep) + self.y_embedder(class_labels)
445
+ conditioning = torch.cat([conditioning[i].repeat(int(seqlens[i]), 1) for i in range(image_sizes.shape[0])], dim=0)
446
+
447
+ projection_states = []
448
+ for index, block in enumerate(self.blocks):
449
+ if self.use_checkpoint and self.training:
450
+ hidden_states = torch.utils.checkpoint.checkpoint(
451
+ block, hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin, use_reentrant=False
452
+ )
453
+ else:
454
+ hidden_states = block(hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin)
455
+ if output_projection_states and (index + 1) == self.encoder_depth:
456
+ projection_states.append(self.projector(hidden_states))
457
+
458
+ hidden_states = self.final_layer(hidden_states, conditioning)
459
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], self.out_channels, self.patch_size, self.patch_size)
460
+ if input_was_image:
461
+ hidden_states = self._unpack_latents(hidden_states, image_sizes)
462
+
463
+ if not return_dict:
464
+ output = (hidden_states,)
465
+ if output_projection_states:
466
+ output = output + (tuple(projection_states),)
467
+ return output
468
+ return NiTTransformer2DModelOutput(
469
+ sample=hidden_states,
470
+ projection_states=tuple(projection_states) if output_projection_states else None,
471
+ )
demo_images/demo_sde250_class207_seed42.png ADDED

Git LFS Details

  • SHA256: eb6fd6d24d517744a597a8d5f3277f1b7a4a91834dbefba0607c69d04ceecd3f
  • Pointer size: 131 Bytes
  • Size of remote file: 453 kB
model_index.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "NiTPipeline",
3
+ "_diffusers_version": "0.30.1",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "NiTFlowMatchScheduler"
7
+ ],
8
+ "transformer": [
9
+ "diffusers",
10
+ "NiTTransformer2DModel"
11
+ ],
12
+ "vae": [
13
+ "diffusers",
14
+ "AutoencoderDC"
15
+ ]
16
+ }
pipeline.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom pipeline entrypoint for Diffusers dynamic loading."""
2
+
3
+ from .custom_pipeline.pipeline_nit import NiTPipeline
4
+ from .custom_pipeline.scheduling_flow_match_nit import NiTFlowMatchScheduler
5
+ from .custom_pipeline.transformer_nit import NiTTransformer2DModel
6
+
7
+ try:
8
+ import diffusers
9
+
10
+ setattr(diffusers, "NiTPipeline", NiTPipeline)
11
+ setattr(diffusers, "NiTTransformer2DModel", NiTTransformer2DModel)
12
+ setattr(diffusers, "NiTFlowMatchScheduler", NiTFlowMatchScheduler)
13
+ except Exception:
14
+ pass
15
+
16
+
17
+ __all__ = [
18
+ "NiTPipeline",
19
+ "NiTTransformer2DModel",
20
+ "NiTFlowMatchScheduler",
21
+ ]
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "NiTFlowMatchScheduler",
3
+ "mode": "ode",
4
+ "num_train_timesteps": 1000,
5
+ "path_type": "linear"
6
+ }
test_inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone inference script for the NiT-XL Diffusers checkpoint.
4
+
5
+ This script only uses code vendored in this model repository:
6
+ `custom_pipeline/` for NiT pipeline, transformer, and scheduler classes.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ from diffusers import DiffusionPipeline
16
+
17
+
18
+ def parse_args() -> argparse.Namespace:
19
+ parser = argparse.ArgumentParser(description="Run class-conditional NiT-XL inference.")
20
+ parser.add_argument(
21
+ "--model-dir",
22
+ type=Path,
23
+ default=Path(__file__).resolve().parent,
24
+ help="Path to model repository root.",
25
+ )
26
+ parser.add_argument("--class-label", type=int, default=207, help="ImageNet class label to sample.")
27
+ parser.add_argument("--height", type=int, default=512, help="Output image height.")
28
+ parser.add_argument("--width", type=int, default=512, help="Output image width.")
29
+ parser.add_argument("--steps", type=int, default=250, help="Number of inference steps.")
30
+ parser.add_argument("--mode", choices=["ode", "sde"], default="sde", help="Sampling mode.")
31
+ parser.add_argument("--guidance-scale", type=float, default=2.05, help="Classifier-free guidance scale.")
32
+ parser.add_argument("--guidance-low", type=float, default=0.0, help="Guidance start timestep fraction.")
33
+ parser.add_argument("--guidance-high", type=float, default=0.7, help="Guidance end timestep fraction.")
34
+ parser.add_argument("--heun", action="store_true", help="Enable Heun correction for ODE mode.")
35
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
36
+ parser.add_argument(
37
+ "--output",
38
+ type=Path,
39
+ default=Path("demo_images/demo_sde250_class207_seed42.png"),
40
+ help="Output image path relative to model dir, or absolute path.",
41
+ )
42
+ return parser.parse_args()
43
+
44
+
45
+ def resolve_output_path(model_dir: Path, output: Path) -> Path:
46
+ if output.is_absolute():
47
+ return output
48
+ return model_dir / output
49
+
50
+
51
+ def main() -> None:
52
+ args = parse_args()
53
+ model_dir = args.model_dir.resolve()
54
+ custom_dir = model_dir / "custom_pipeline"
55
+ if not custom_dir.exists():
56
+ raise FileNotFoundError(f"Missing custom pipeline dir: {custom_dir}")
57
+ if not (model_dir / "pipeline.py").exists():
58
+ raise FileNotFoundError(f"Missing custom entrypoint: {model_dir / 'pipeline.py'}")
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ torch_dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
62
+ generator_device = device.type if device.type != "cpu" else "cpu"
63
+ generator = torch.Generator(device=generator_device).manual_seed(args.seed)
64
+
65
+ pipe = DiffusionPipeline.from_pretrained(
66
+ model_dir,
67
+ custom_pipeline=str(model_dir / "pipeline.py"),
68
+ local_files_only=True,
69
+ ).to(device=device)
70
+ if device.type == "cuda":
71
+ pipe.transformer.to(dtype=torch_dtype)
72
+ pipe.vae.to(dtype=torch_dtype)
73
+
74
+ output = pipe(
75
+ class_labels=[args.class_label],
76
+ height=args.height,
77
+ width=args.width,
78
+ num_inference_steps=args.steps,
79
+ mode=args.mode,
80
+ guidance_scale=args.guidance_scale,
81
+ guidance_interval=(args.guidance_low, args.guidance_high),
82
+ heun=args.heun,
83
+ generator=generator,
84
+ output_type="pil",
85
+ )
86
+
87
+ output_path = resolve_output_path(model_dir, args.output)
88
+ output_path.parent.mkdir(parents=True, exist_ok=True)
89
+ output.images[0].save(output_path)
90
+
91
+ print(f"Saved image to: {output_path}")
92
+ print(f"Device: {device} | dtype: {torch_dtype}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "NiTTransformer2DModel",
3
+ "class_dropout_prob": 0.1,
4
+ "depth": 28,
5
+ "encoder_depth": 8,
6
+ "hidden_size": 1152,
7
+ "in_channels": 32,
8
+ "input_size": 32,
9
+ "num_classes": 1000,
10
+ "num_heads": 16,
11
+ "patch_size": 1,
12
+ "qk_norm": true,
13
+ "z_dim": 1280
14
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68cf19eb16e2231d1493dbb2c1bc7922fdfb23cc1e4b209aca6b6282238aa83b
3
+ size 2736207096
vae/config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderDC",
3
+ "_diffusers_version": "0.32.2",
4
+ "attention_head_dim": 32,
5
+ "decoder_act_fns": "silu",
6
+ "decoder_block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512,
11
+ 1024,
12
+ 1024
13
+ ],
14
+ "decoder_block_types": [
15
+ "ResBlock",
16
+ "ResBlock",
17
+ "ResBlock",
18
+ "EfficientViTBlock",
19
+ "EfficientViTBlock",
20
+ "EfficientViTBlock"
21
+ ],
22
+ "decoder_layers_per_block": [
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 3,
28
+ 3
29
+ ],
30
+ "decoder_norm_types": "rms_norm",
31
+ "decoder_qkv_multiscales": [
32
+ [],
33
+ [],
34
+ [],
35
+ [
36
+ 5
37
+ ],
38
+ [
39
+ 5
40
+ ],
41
+ [
42
+ 5
43
+ ]
44
+ ],
45
+ "downsample_block_type": "Conv",
46
+ "encoder_block_out_channels": [
47
+ 128,
48
+ 256,
49
+ 512,
50
+ 512,
51
+ 1024,
52
+ 1024
53
+ ],
54
+ "encoder_block_types": [
55
+ "ResBlock",
56
+ "ResBlock",
57
+ "ResBlock",
58
+ "EfficientViTBlock",
59
+ "EfficientViTBlock",
60
+ "EfficientViTBlock"
61
+ ],
62
+ "encoder_layers_per_block": [
63
+ 2,
64
+ 2,
65
+ 2,
66
+ 3,
67
+ 3,
68
+ 3
69
+ ],
70
+ "encoder_qkv_multiscales": [
71
+ [],
72
+ [],
73
+ [],
74
+ [
75
+ 5
76
+ ],
77
+ [
78
+ 5
79
+ ],
80
+ [
81
+ 5
82
+ ]
83
+ ],
84
+ "in_channels": 3,
85
+ "latent_channels": 32,
86
+ "scaling_factor": 0.41407,
87
+ "upsample_block_type": "interpolate"
88
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfd991d1b54ffabf22745c5885589d8f2a7bc59930d95d92bd741c4fc64454bb
3
+ size 1249044836