BiliSakura commited on
Commit
4c42d10
·
verified ·
1 Parent(s): e42eaac

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +132 -0
  2. SiT-B-2-256-diffusers/README.md +42 -0
  3. SiT-B-2-256-diffusers/model_index.json +19 -0
  4. SiT-B-2-256-diffusers/pipeline.py +82 -0
  5. SiT-B-2-256-diffusers/scheduler/scheduler_config.json +9 -0
  6. SiT-B-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
  7. SiT-B-2-256-diffusers/transformer/config.json +14 -0
  8. SiT-B-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
  9. SiT-B-2-256-diffusers/transformer/transformer_sit.py +224 -0
  10. SiT-B-2-256-diffusers/vae/config.json +38 -0
  11. SiT-B-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
  12. SiT-L-2-256-diffusers/README.md +42 -0
  13. SiT-L-2-256-diffusers/model_index.json +19 -0
  14. SiT-L-2-256-diffusers/pipeline.py +82 -0
  15. SiT-L-2-256-diffusers/scheduler/scheduler_config.json +9 -0
  16. SiT-L-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
  17. SiT-L-2-256-diffusers/transformer/config.json +14 -0
  18. SiT-L-2-256-diffusers/transformer/transformer_sit.py +224 -0
  19. SiT-L-2-256-diffusers/vae/config.json +38 -0
  20. SiT-L-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
  21. SiT-S-2-256-diffusers/README.md +42 -0
  22. SiT-S-2-256-diffusers/model_index.json +19 -0
  23. SiT-S-2-256-diffusers/pipeline.py +82 -0
  24. SiT-S-2-256-diffusers/scheduler/scheduler_config.json +9 -0
  25. SiT-S-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
  26. SiT-S-2-256-diffusers/transformer/config.json +14 -0
  27. SiT-S-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
  28. SiT-S-2-256-diffusers/transformer/transformer_sit.py +224 -0
  29. SiT-S-2-256-diffusers/vae/config.json +38 -0
  30. SiT-S-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
  31. SiT-XL-2-256-diffusers/README.md +42 -0
  32. SiT-XL-2-256-diffusers/demo_50steps.png +0 -0
  33. SiT-XL-2-256-diffusers/model_index.json +19 -0
  34. SiT-XL-2-256-diffusers/pipeline.py +82 -0
  35. SiT-XL-2-256-diffusers/scheduler/scheduler_config.json +9 -0
  36. SiT-XL-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
  37. SiT-XL-2-256-diffusers/transformer/config.json +14 -0
  38. SiT-XL-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
  39. SiT-XL-2-256-diffusers/transformer/transformer_sit.py +224 -0
  40. SiT-XL-2-256-diffusers/vae/config.json +38 -0
  41. SiT-XL-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
  42. SiT-XL-2-512-diffusers/README.md +42 -0
  43. SiT-XL-2-512-diffusers/model_index.json +19 -0
  44. SiT-XL-2-512-diffusers/pipeline.py +82 -0
  45. SiT-XL-2-512-diffusers/scheduler/scheduler_config.json +9 -0
  46. SiT-XL-2-512-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
  47. SiT-XL-2-512-diffusers/transformer/config.json +14 -0
  48. SiT-XL-2-512-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
  49. SiT-XL-2-512-diffusers/transformer/transformer_sit.py +224 -0
  50. SiT-XL-2-512-diffusers/vae/config.json +38 -0
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ - imagenet
10
+ license: mit
11
+ inference: true
12
+ ---
13
+
14
+ # SiT-diffusers
15
+
16
+ Diffusers-ready checkpoints for **Scalable Interpolant Transformers (SiT)**, converted for local/offline use.
17
+
18
+ This root folder is a model collection that contains:
19
+
20
+ - `SiT-S-2-256-diffusers`
21
+ - `SiT-B-2-256-diffusers`
22
+ - `SiT-L-2-256-diffusers`
23
+ - `SiT-XL-2-256-diffusers`
24
+ - `SiT-XL-2-512-diffusers`
25
+
26
+ Each subfolder is a self-contained Diffusers model repo with:
27
+
28
+ - `pipeline.py`
29
+ - `transformer/transformer_sit.py`
30
+ - `scheduler/scheduling_flow_match_sit.py`
31
+ - `transformer/diffusion_pytorch_model.safetensors`
32
+ - `vae/diffusion_pytorch_model.safetensors`
33
+
34
+ ## Model Paths
35
+
36
+ Use paths relative to this root README:
37
+
38
+ | Model | Resolution | Local path |
39
+ |---|---:|---|
40
+ | SiT-S/2 | 256x256 | `./SiT-S-2-256-diffusers` |
41
+ | SiT-B/2 | 256x256 | `./SiT-B-2-256-diffusers` |
42
+ | SiT-L/2 | 256x256 | `./SiT-L-2-256-diffusers` |
43
+ | SiT-XL/2 | 256x256 | `./SiT-XL-2-256-diffusers` |
44
+ | SiT-XL/2 | 512x512 | `./SiT-XL-2-512-diffusers` |
45
+
46
+ ## Inference Demo (Diffusers)
47
+
48
+ ### 1) Load a local subfolder checkpoint
49
+
50
+ ```python
51
+ import torch
52
+ from diffusers import DiffusionPipeline
53
+
54
+ model_path = "./SiT-XL-2-512-diffusers" # change to any path in the table above
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+
57
+ pipe = DiffusionPipeline.from_pretrained(
58
+ model_path,
59
+ trust_remote_code=True,
60
+ ).to(device)
61
+
62
+ generator = torch.Generator(device=device).manual_seed(0)
63
+
64
+ # ImageNet class example: 207 = golden retriever
65
+ result = pipe(
66
+ class_labels=207,
67
+ height=512,
68
+ width=512,
69
+ num_inference_steps=250, # official SiT comparisons commonly use 250 steps
70
+ guidance_scale=4.0,
71
+ generator=generator,
72
+ )
73
+
74
+ image = result.images[0]
75
+ image.save("sit_xl_512_demo.png")
76
+ ```
77
+
78
+ ### 2) Quick variant switch (256 models)
79
+
80
+ ```python
81
+ model_path = "./SiT-S-2-256-diffusers"
82
+ # model_path = "./SiT-B-2-256-diffusers"
83
+ # model_path = "./SiT-L-2-256-diffusers"
84
+ # model_path = "./SiT-XL-2-256-diffusers"
85
+
86
+ pipe = DiffusionPipeline.from_pretrained(model_path, trust_remote_code=True).to(device)
87
+ image = pipe(
88
+ class_labels=207,
89
+ height=256,
90
+ width=256,
91
+ num_inference_steps=250,
92
+ guidance_scale=4.0,
93
+ generator=generator,
94
+ ).images[0]
95
+ image.save("sit_256_demo.png")
96
+ ```
97
+
98
+ ## FID Reference (from Official SiT Results)
99
+
100
+ The table below summarizes widely cited SiT numbers from the official project materials for class-conditional ImageNet generation.
101
+
102
+ | Model / setting | Resolution | FID-50K (lower is better) |
103
+ |---|---:|---:|
104
+ | SiT-S (400K steps) | 256x256 | 57.6 |
105
+ | SiT-B (400K steps) | 256x256 | 33.5 |
106
+ | SiT-L (400K steps) | 256x256 | 17.2 |
107
+ | SiT-XL (400K steps) | 256x256 | 8.6 |
108
+ | SiT-XL (cfg=1.5, ODE) | 256x256 | 2.15 |
109
+ | SiT-XL (cfg=1.5, SDE, `w(t)=sigma_t`) | 256x256 | 2.06 |
110
+ | SiT-XL (sample showcase) | 512x512 | Not reported in the same benchmark table |
111
+
112
+ > Note: FID depends on training recipe, sampler choice (ODE/SDE), guidance scale, and evaluation protocol. Treat this table as a reference to official SiT reports, not as guaranteed reproducibility for every conversion/export.
113
+
114
+ ## Source and Paper
115
+
116
+ - Official SiT code: [willisma/SiT](https://github.com/willisma/SiT)
117
+ - Project page: [scalable-interpolant.github.io](https://scalable-interpolant.github.io/)
118
+ - Paper (arXiv): [SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers](https://arxiv.org/abs/2401.08740)
119
+
120
+ ## Citation
121
+
122
+ If you use SiT in your work, please cite:
123
+
124
+ ```bibtex
125
+ @inproceedings{ma2024sit,
126
+ title={SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers},
127
+ author={Ma, Nanye and Goldstein, Mark and Albergo, Michael S. and Boffi, Nicholas M. and Vanden-Eijnden, Eric and Xie, Saining},
128
+ booktitle={European Conference on Computer Vision (ECCV)},
129
+ year={2024},
130
+ note={Accepted to ECCV 2024}
131
+ }
132
+ ```
SiT-B-2-256-diffusers/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ inference: true
10
+ ---
11
+
12
+ # SiT-B-2-256-diffusers
13
+
14
+ Self-contained Diffusers checkpoint repo for SiT.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import DiffusionPipeline
21
+
22
+ pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
23
+ generator = torch.Generator(device=pipe.device).manual_seed(0)
24
+
25
+ image = pipe(
26
+ class_labels=207,
27
+ height=256,
28
+ width=256,
29
+ num_inference_steps=250,
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save("demo.png")
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - `pipeline.py`
39
+ - `transformer/transformer_sit.py`
40
+ - `scheduler/scheduling_flow_match_sit.py`
41
+ - `transformer/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
SiT-B-2-256-diffusers/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_sit",
9
+ "SiTFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
SiT-B-2-256-diffusers/pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+
9
+
10
+ class SiTPipeline(DiffusionPipeline):
11
+ model_cpu_offload_seq = "transformer->vae"
12
+
13
+ def __init__(self, transformer, scheduler, vae):
14
+ super().__init__()
15
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
+ self.vae_scale_factor = 8
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ class_labels: Union[int, List[int]] = 207,
23
+ height: int = 256,
24
+ width: int = 256,
25
+ num_inference_steps: int = 250,
26
+ guidance_scale: float = 4.0,
27
+ generator: Optional[torch.Generator] = None,
28
+ output_type: str = "pil",
29
+ return_dict: bool = True,
30
+ ):
31
+ device = self._execution_device
32
+ if isinstance(class_labels, int):
33
+ class_labels = [class_labels]
34
+ batch_size = len(class_labels)
35
+
36
+ latent_h = height // self.vae_scale_factor
37
+ latent_w = width // self.vae_scale_factor
38
+ latents = randn_tensor(
39
+ (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
+ generator=generator,
41
+ device=device,
42
+ dtype=self.transformer.dtype,
43
+ )
44
+
45
+ labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
+ do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
+ if do_cfg:
48
+ null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
+ labels = torch.cat([labels, null_label], dim=0)
50
+
51
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
52
+ timesteps = self.scheduler.timesteps
53
+
54
+ for t in self.progress_bar(timesteps):
55
+ t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
+ model_input = latents
57
+ if do_cfg:
58
+ model_input = torch.cat([latents, latents], dim=0)
59
+ t_batch = torch.cat([t_batch, t_batch], dim=0)
60
+
61
+ model_pred = self.transformer(
62
+ hidden_states=model_input,
63
+ timestep=t_batch,
64
+ class_labels=labels,
65
+ ).sample
66
+
67
+ if do_cfg:
68
+ cond, uncond = model_pred.chunk(2, dim=0)
69
+ model_pred = uncond + guidance_scale * (cond - uncond)
70
+
71
+ latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
+
73
+ image = self.vae.decode(latents / 0.18215).sample
74
+ # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
+ if output_type == "pt":
76
+ image = image
77
+ else:
78
+ image = self.image_processor.postprocess(image, output_type=output_type)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+ return ImagePipelineOutput(images=image)
SiT-B-2-256-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "diffusion_form": "sigma",
5
+ "diffusion_norm": 1.0,
6
+ "mode": "ode",
7
+ "num_train_timesteps": 1000,
8
+ "shift": 1.0
9
+ }
SiT-B-2-256-diffusers/scheduler/scheduling_flow_match_sit.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SiTFlowMatchSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
+ order = 1
19
+
20
+ @register_to_config
21
+ def __init__(
22
+ self,
23
+ mode: str = "ode",
24
+ num_train_timesteps: int = 1000,
25
+ shift: float = 1.0,
26
+ diffusion_form: str = "sigma",
27
+ diffusion_norm: float = 1.0,
28
+ ):
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self._step_index = None
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
+ # Flow matching integrates from noise (t=0) to data (t=1).
35
+ ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
+ self.timesteps = ts[:-1]
37
+ self.sigmas = 1.0 - self.timesteps
38
+ self._step_index = 0
39
+ return self.timesteps
40
+
41
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
+ return sample
43
+
44
+ def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
+ form = self.config.diffusion_form
46
+ norm = self.config.diffusion_norm
47
+ if form == "constant":
48
+ return torch.full_like(t, norm)
49
+ if form == "sigma":
50
+ return norm * (1.0 - t)
51
+ if form == "linear":
52
+ return norm * (1.0 - t)
53
+ if form == "decreasing":
54
+ return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
+ if form == "increasing-decreasing":
56
+ return norm * torch.sin(torch.pi * t) ** 2
57
+ # "SBDM" approximated with sigma-based schedule for compatibility.
58
+ return norm * (1.0 - t)
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.Tensor,
63
+ timestep: Union[float, torch.Tensor],
64
+ sample: torch.Tensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
+ if self.timesteps is None:
69
+ raise ValueError("Call `set_timesteps` before `step`.")
70
+ if self._step_index is None:
71
+ self._step_index = 0
72
+
73
+ step_index = min(self._step_index, len(self.timesteps) - 1)
74
+ t = self.timesteps[step_index].to(sample.device)
75
+ next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
+ dt = next_t - t
77
+
78
+ prev_sample = sample + model_output * dt
79
+ if self.config.mode.lower() == "sde":
80
+ diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
+ while diffusion.dim() < sample.dim():
82
+ diffusion = diffusion.unsqueeze(-1)
83
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
+ prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
+
86
+ self._step_index += 1
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+ return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
+
91
+ def add_noise(
92
+ self,
93
+ original_samples: torch.Tensor,
94
+ noise: torch.Tensor,
95
+ timesteps: torch.Tensor,
96
+ ) -> torch.Tensor:
97
+ sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
+ return (1 - sigma) * original_samples + sigma * noise
SiT-B-2-256-diffusers/transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "class_dropout_prob": 0.1,
5
+ "depth": 12,
6
+ "hidden_size": 768,
7
+ "in_channels": 4,
8
+ "input_size": 32,
9
+ "learn_sigma": true,
10
+ "mlp_ratio": 4.0,
11
+ "num_classes": 1000,
12
+ "num_heads": 12,
13
+ "patch_size": 2
14
+ }
SiT-B-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be5318b2de818389f53e49ce39495394a94e20b0449041e0f9a6ced1ccc64f6c
3
+ size 522062536
SiT-B-2-256-diffusers/transformer/transformer_sit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ @dataclass
20
+ class SiTTransformer2DModelOutput(BaseOutput):
21
+ sample: torch.Tensor
22
+
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
+ half = dim // 2
37
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
+ device=t.device
39
+ )
40
+ args = t[:, None].float() * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+ if dim % 2:
43
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
+ return embedding
45
+
46
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
47
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
+
49
+
50
+ class LabelEmbedder(nn.Module):
51
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
+ super().__init__()
53
+ use_cfg_embedding = dropout_prob > 0
54
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
+ self.num_classes = num_classes
56
+ self.dropout_prob = dropout_prob
57
+
58
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
+ if force_drop_ids is None:
60
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
+ else:
62
+ drop_ids = force_drop_ids == 1
63
+ labels = torch.where(drop_ids, self.num_classes, labels)
64
+ return labels
65
+
66
+ def forward(
67
+ self,
68
+ labels: torch.Tensor,
69
+ train: bool,
70
+ force_drop_ids: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ return self.embedding_table(labels)
76
+
77
+
78
+ class SiTBlock(nn.Module):
79
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
+ super().__init__()
81
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
86
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
+
89
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
+ return x
94
+
95
+
96
+ class FinalLayer(nn.Module):
97
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
+ super().__init__()
99
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
+
103
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
+ x = modulate(self.norm_final(x), shift, scale)
106
+ return self.linear(x)
107
+
108
+
109
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ input_size: int = 32,
114
+ patch_size: int = 2,
115
+ in_channels: int = 4,
116
+ hidden_size: int = 1152,
117
+ depth: int = 28,
118
+ num_heads: int = 16,
119
+ mlp_ratio: float = 4.0,
120
+ class_dropout_prob: float = 0.1,
121
+ num_classes: int = 1000,
122
+ learn_sigma: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.learn_sigma = learn_sigma
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
+ self.patch_size = patch_size
129
+ self.num_classes = num_classes
130
+
131
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
+ self.t_embedder = TimestepEmbedder(hidden_size)
133
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
+ num_patches = self.x_embedder.num_patches
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
+
137
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self) -> None:
142
+ def _basic_init(module: nn.Module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
+
152
+ w = self.x_embedder.proj.weight.data
153
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
155
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
+ for block in self.blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
+ nn.init.constant_(self.final_layer.linear.weight, 0)
164
+ nn.init.constant_(self.final_layer.linear.bias, 0)
165
+
166
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
+ c = self.out_channels
168
+ p = self.x_embedder.patch_size[0]
169
+ h = w = int(x.shape[1] ** 0.5)
170
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ timestep: torch.Tensor,
178
+ class_labels: torch.Tensor,
179
+ force_drop_ids: Optional[torch.Tensor] = None,
180
+ return_dict: bool = True,
181
+ ) -> SiTTransformer2DModelOutput:
182
+ x = self.x_embedder(hidden_states) + self.pos_embed
183
+ t = self.t_embedder(timestep)
184
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
+ c = t + y
186
+ for block in self.blocks:
187
+ x = block(x, c)
188
+ x = self.final_layer(x, c)
189
+ x = self.unpatchify(x)
190
+ if self.learn_sigma:
191
+ x, _ = x.chunk(2, dim=1)
192
+ if not return_dict:
193
+ return (x,)
194
+ return SiTTransformer2DModelOutput(sample=x)
195
+
196
+
197
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
+ grid_h = np.arange(grid_size, dtype=np.float32)
199
+ grid_w = np.arange(grid_size, dtype=np.float32)
200
+ grid = np.meshgrid(grid_w, grid_h)
201
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
+ if cls_token and extra_tokens > 0:
204
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
+ return pos_embed
206
+
207
+
208
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
+ assert embed_dim % 2 == 0
210
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
+ return np.concatenate([emb_h, emb_w], axis=1)
213
+
214
+
215
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
+ assert embed_dim % 2 == 0
217
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
218
+ omega /= embed_dim / 2.0
219
+ omega = 1.0 / 10000**omega
220
+ pos = pos.reshape(-1)
221
+ out = np.einsum("m,d->md", pos, omega)
222
+ emb_sin = np.sin(out)
223
+ emb_cos = np.cos(out)
224
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-B-2-256-diffusers/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
SiT-B-2-256-diffusers/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
SiT-L-2-256-diffusers/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ inference: true
10
+ ---
11
+
12
+ # SiT-L-2-256-diffusers
13
+
14
+ Self-contained Diffusers checkpoint repo for SiT.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import DiffusionPipeline
21
+
22
+ pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
23
+ generator = torch.Generator(device=pipe.device).manual_seed(0)
24
+
25
+ image = pipe(
26
+ class_labels=207,
27
+ height=256,
28
+ width=256,
29
+ num_inference_steps=250,
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save("demo.png")
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - `pipeline.py`
39
+ - `transformer/transformer_sit.py`
40
+ - `scheduler/scheduling_flow_match_sit.py`
41
+ - `transformer/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
SiT-L-2-256-diffusers/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_sit",
9
+ "SiTFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
SiT-L-2-256-diffusers/pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+
9
+
10
+ class SiTPipeline(DiffusionPipeline):
11
+ model_cpu_offload_seq = "transformer->vae"
12
+
13
+ def __init__(self, transformer, scheduler, vae):
14
+ super().__init__()
15
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
+ self.vae_scale_factor = 8
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ class_labels: Union[int, List[int]] = 207,
23
+ height: int = 256,
24
+ width: int = 256,
25
+ num_inference_steps: int = 250,
26
+ guidance_scale: float = 4.0,
27
+ generator: Optional[torch.Generator] = None,
28
+ output_type: str = "pil",
29
+ return_dict: bool = True,
30
+ ):
31
+ device = self._execution_device
32
+ if isinstance(class_labels, int):
33
+ class_labels = [class_labels]
34
+ batch_size = len(class_labels)
35
+
36
+ latent_h = height // self.vae_scale_factor
37
+ latent_w = width // self.vae_scale_factor
38
+ latents = randn_tensor(
39
+ (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
+ generator=generator,
41
+ device=device,
42
+ dtype=self.transformer.dtype,
43
+ )
44
+
45
+ labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
+ do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
+ if do_cfg:
48
+ null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
+ labels = torch.cat([labels, null_label], dim=0)
50
+
51
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
52
+ timesteps = self.scheduler.timesteps
53
+
54
+ for t in self.progress_bar(timesteps):
55
+ t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
+ model_input = latents
57
+ if do_cfg:
58
+ model_input = torch.cat([latents, latents], dim=0)
59
+ t_batch = torch.cat([t_batch, t_batch], dim=0)
60
+
61
+ model_pred = self.transformer(
62
+ hidden_states=model_input,
63
+ timestep=t_batch,
64
+ class_labels=labels,
65
+ ).sample
66
+
67
+ if do_cfg:
68
+ cond, uncond = model_pred.chunk(2, dim=0)
69
+ model_pred = uncond + guidance_scale * (cond - uncond)
70
+
71
+ latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
+
73
+ image = self.vae.decode(latents / 0.18215).sample
74
+ # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
+ if output_type == "pt":
76
+ image = image
77
+ else:
78
+ image = self.image_processor.postprocess(image, output_type=output_type)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+ return ImagePipelineOutput(images=image)
SiT-L-2-256-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "diffusion_form": "sigma",
5
+ "diffusion_norm": 1.0,
6
+ "mode": "ode",
7
+ "num_train_timesteps": 1000,
8
+ "shift": 1.0
9
+ }
SiT-L-2-256-diffusers/scheduler/scheduling_flow_match_sit.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SiTFlowMatchSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
+ order = 1
19
+
20
+ @register_to_config
21
+ def __init__(
22
+ self,
23
+ mode: str = "ode",
24
+ num_train_timesteps: int = 1000,
25
+ shift: float = 1.0,
26
+ diffusion_form: str = "sigma",
27
+ diffusion_norm: float = 1.0,
28
+ ):
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self._step_index = None
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
+ # Flow matching integrates from noise (t=0) to data (t=1).
35
+ ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
+ self.timesteps = ts[:-1]
37
+ self.sigmas = 1.0 - self.timesteps
38
+ self._step_index = 0
39
+ return self.timesteps
40
+
41
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
+ return sample
43
+
44
+ def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
+ form = self.config.diffusion_form
46
+ norm = self.config.diffusion_norm
47
+ if form == "constant":
48
+ return torch.full_like(t, norm)
49
+ if form == "sigma":
50
+ return norm * (1.0 - t)
51
+ if form == "linear":
52
+ return norm * (1.0 - t)
53
+ if form == "decreasing":
54
+ return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
+ if form == "increasing-decreasing":
56
+ return norm * torch.sin(torch.pi * t) ** 2
57
+ # "SBDM" approximated with sigma-based schedule for compatibility.
58
+ return norm * (1.0 - t)
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.Tensor,
63
+ timestep: Union[float, torch.Tensor],
64
+ sample: torch.Tensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
+ if self.timesteps is None:
69
+ raise ValueError("Call `set_timesteps` before `step`.")
70
+ if self._step_index is None:
71
+ self._step_index = 0
72
+
73
+ step_index = min(self._step_index, len(self.timesteps) - 1)
74
+ t = self.timesteps[step_index].to(sample.device)
75
+ next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
+ dt = next_t - t
77
+
78
+ prev_sample = sample + model_output * dt
79
+ if self.config.mode.lower() == "sde":
80
+ diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
+ while diffusion.dim() < sample.dim():
82
+ diffusion = diffusion.unsqueeze(-1)
83
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
+ prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
+
86
+ self._step_index += 1
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+ return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
+
91
+ def add_noise(
92
+ self,
93
+ original_samples: torch.Tensor,
94
+ noise: torch.Tensor,
95
+ timesteps: torch.Tensor,
96
+ ) -> torch.Tensor:
97
+ sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
+ return (1 - sigma) * original_samples + sigma * noise
SiT-L-2-256-diffusers/transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "class_dropout_prob": 0.1,
5
+ "depth": 24,
6
+ "hidden_size": 1024,
7
+ "in_channels": 4,
8
+ "input_size": 32,
9
+ "learn_sigma": true,
10
+ "mlp_ratio": 4.0,
11
+ "num_classes": 1000,
12
+ "num_heads": 16,
13
+ "patch_size": 2
14
+ }
SiT-L-2-256-diffusers/transformer/transformer_sit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ @dataclass
20
+ class SiTTransformer2DModelOutput(BaseOutput):
21
+ sample: torch.Tensor
22
+
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
+ half = dim // 2
37
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
+ device=t.device
39
+ )
40
+ args = t[:, None].float() * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+ if dim % 2:
43
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
+ return embedding
45
+
46
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
47
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
+
49
+
50
+ class LabelEmbedder(nn.Module):
51
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
+ super().__init__()
53
+ use_cfg_embedding = dropout_prob > 0
54
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
+ self.num_classes = num_classes
56
+ self.dropout_prob = dropout_prob
57
+
58
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
+ if force_drop_ids is None:
60
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
+ else:
62
+ drop_ids = force_drop_ids == 1
63
+ labels = torch.where(drop_ids, self.num_classes, labels)
64
+ return labels
65
+
66
+ def forward(
67
+ self,
68
+ labels: torch.Tensor,
69
+ train: bool,
70
+ force_drop_ids: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ return self.embedding_table(labels)
76
+
77
+
78
+ class SiTBlock(nn.Module):
79
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
+ super().__init__()
81
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
86
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
+
89
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
+ return x
94
+
95
+
96
+ class FinalLayer(nn.Module):
97
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
+ super().__init__()
99
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
+
103
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
+ x = modulate(self.norm_final(x), shift, scale)
106
+ return self.linear(x)
107
+
108
+
109
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ input_size: int = 32,
114
+ patch_size: int = 2,
115
+ in_channels: int = 4,
116
+ hidden_size: int = 1152,
117
+ depth: int = 28,
118
+ num_heads: int = 16,
119
+ mlp_ratio: float = 4.0,
120
+ class_dropout_prob: float = 0.1,
121
+ num_classes: int = 1000,
122
+ learn_sigma: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.learn_sigma = learn_sigma
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
+ self.patch_size = patch_size
129
+ self.num_classes = num_classes
130
+
131
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
+ self.t_embedder = TimestepEmbedder(hidden_size)
133
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
+ num_patches = self.x_embedder.num_patches
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
+
137
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self) -> None:
142
+ def _basic_init(module: nn.Module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
+
152
+ w = self.x_embedder.proj.weight.data
153
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
155
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
+ for block in self.blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
+ nn.init.constant_(self.final_layer.linear.weight, 0)
164
+ nn.init.constant_(self.final_layer.linear.bias, 0)
165
+
166
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
+ c = self.out_channels
168
+ p = self.x_embedder.patch_size[0]
169
+ h = w = int(x.shape[1] ** 0.5)
170
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ timestep: torch.Tensor,
178
+ class_labels: torch.Tensor,
179
+ force_drop_ids: Optional[torch.Tensor] = None,
180
+ return_dict: bool = True,
181
+ ) -> SiTTransformer2DModelOutput:
182
+ x = self.x_embedder(hidden_states) + self.pos_embed
183
+ t = self.t_embedder(timestep)
184
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
+ c = t + y
186
+ for block in self.blocks:
187
+ x = block(x, c)
188
+ x = self.final_layer(x, c)
189
+ x = self.unpatchify(x)
190
+ if self.learn_sigma:
191
+ x, _ = x.chunk(2, dim=1)
192
+ if not return_dict:
193
+ return (x,)
194
+ return SiTTransformer2DModelOutput(sample=x)
195
+
196
+
197
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
+ grid_h = np.arange(grid_size, dtype=np.float32)
199
+ grid_w = np.arange(grid_size, dtype=np.float32)
200
+ grid = np.meshgrid(grid_w, grid_h)
201
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
+ if cls_token and extra_tokens > 0:
204
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
+ return pos_embed
206
+
207
+
208
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
+ assert embed_dim % 2 == 0
210
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
+ return np.concatenate([emb_h, emb_w], axis=1)
213
+
214
+
215
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
+ assert embed_dim % 2 == 0
217
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
218
+ omega /= embed_dim / 2.0
219
+ omega = 1.0 / 10000**omega
220
+ pos = pos.reshape(-1)
221
+ out = np.einsum("m,d->md", pos, omega)
222
+ emb_sin = np.sin(out)
223
+ emb_cos = np.cos(out)
224
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-L-2-256-diffusers/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
SiT-L-2-256-diffusers/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
SiT-S-2-256-diffusers/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ inference: true
10
+ ---
11
+
12
+ # SiT-S-2-256-diffusers
13
+
14
+ Self-contained Diffusers checkpoint repo for SiT.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import DiffusionPipeline
21
+
22
+ pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
23
+ generator = torch.Generator(device=pipe.device).manual_seed(0)
24
+
25
+ image = pipe(
26
+ class_labels=207,
27
+ height=256,
28
+ width=256,
29
+ num_inference_steps=250,
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save("demo.png")
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - `pipeline.py`
39
+ - `transformer/transformer_sit.py`
40
+ - `scheduler/scheduling_flow_match_sit.py`
41
+ - `transformer/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
SiT-S-2-256-diffusers/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_sit",
9
+ "SiTFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
SiT-S-2-256-diffusers/pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+
9
+
10
+ class SiTPipeline(DiffusionPipeline):
11
+ model_cpu_offload_seq = "transformer->vae"
12
+
13
+ def __init__(self, transformer, scheduler, vae):
14
+ super().__init__()
15
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
+ self.vae_scale_factor = 8
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ class_labels: Union[int, List[int]] = 207,
23
+ height: int = 256,
24
+ width: int = 256,
25
+ num_inference_steps: int = 250,
26
+ guidance_scale: float = 4.0,
27
+ generator: Optional[torch.Generator] = None,
28
+ output_type: str = "pil",
29
+ return_dict: bool = True,
30
+ ):
31
+ device = self._execution_device
32
+ if isinstance(class_labels, int):
33
+ class_labels = [class_labels]
34
+ batch_size = len(class_labels)
35
+
36
+ latent_h = height // self.vae_scale_factor
37
+ latent_w = width // self.vae_scale_factor
38
+ latents = randn_tensor(
39
+ (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
+ generator=generator,
41
+ device=device,
42
+ dtype=self.transformer.dtype,
43
+ )
44
+
45
+ labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
+ do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
+ if do_cfg:
48
+ null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
+ labels = torch.cat([labels, null_label], dim=0)
50
+
51
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
52
+ timesteps = self.scheduler.timesteps
53
+
54
+ for t in self.progress_bar(timesteps):
55
+ t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
+ model_input = latents
57
+ if do_cfg:
58
+ model_input = torch.cat([latents, latents], dim=0)
59
+ t_batch = torch.cat([t_batch, t_batch], dim=0)
60
+
61
+ model_pred = self.transformer(
62
+ hidden_states=model_input,
63
+ timestep=t_batch,
64
+ class_labels=labels,
65
+ ).sample
66
+
67
+ if do_cfg:
68
+ cond, uncond = model_pred.chunk(2, dim=0)
69
+ model_pred = uncond + guidance_scale * (cond - uncond)
70
+
71
+ latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
+
73
+ image = self.vae.decode(latents / 0.18215).sample
74
+ # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
+ if output_type == "pt":
76
+ image = image
77
+ else:
78
+ image = self.image_processor.postprocess(image, output_type=output_type)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+ return ImagePipelineOutput(images=image)
SiT-S-2-256-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "diffusion_form": "sigma",
5
+ "diffusion_norm": 1.0,
6
+ "mode": "ode",
7
+ "num_train_timesteps": 1000,
8
+ "shift": 1.0
9
+ }
SiT-S-2-256-diffusers/scheduler/scheduling_flow_match_sit.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SiTFlowMatchSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
+ order = 1
19
+
20
+ @register_to_config
21
+ def __init__(
22
+ self,
23
+ mode: str = "ode",
24
+ num_train_timesteps: int = 1000,
25
+ shift: float = 1.0,
26
+ diffusion_form: str = "sigma",
27
+ diffusion_norm: float = 1.0,
28
+ ):
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self._step_index = None
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
+ # Flow matching integrates from noise (t=0) to data (t=1).
35
+ ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
+ self.timesteps = ts[:-1]
37
+ self.sigmas = 1.0 - self.timesteps
38
+ self._step_index = 0
39
+ return self.timesteps
40
+
41
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
+ return sample
43
+
44
+ def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
+ form = self.config.diffusion_form
46
+ norm = self.config.diffusion_norm
47
+ if form == "constant":
48
+ return torch.full_like(t, norm)
49
+ if form == "sigma":
50
+ return norm * (1.0 - t)
51
+ if form == "linear":
52
+ return norm * (1.0 - t)
53
+ if form == "decreasing":
54
+ return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
+ if form == "increasing-decreasing":
56
+ return norm * torch.sin(torch.pi * t) ** 2
57
+ # "SBDM" approximated with sigma-based schedule for compatibility.
58
+ return norm * (1.0 - t)
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.Tensor,
63
+ timestep: Union[float, torch.Tensor],
64
+ sample: torch.Tensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
+ if self.timesteps is None:
69
+ raise ValueError("Call `set_timesteps` before `step`.")
70
+ if self._step_index is None:
71
+ self._step_index = 0
72
+
73
+ step_index = min(self._step_index, len(self.timesteps) - 1)
74
+ t = self.timesteps[step_index].to(sample.device)
75
+ next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
+ dt = next_t - t
77
+
78
+ prev_sample = sample + model_output * dt
79
+ if self.config.mode.lower() == "sde":
80
+ diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
+ while diffusion.dim() < sample.dim():
82
+ diffusion = diffusion.unsqueeze(-1)
83
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
+ prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
+
86
+ self._step_index += 1
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+ return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
+
91
+ def add_noise(
92
+ self,
93
+ original_samples: torch.Tensor,
94
+ noise: torch.Tensor,
95
+ timesteps: torch.Tensor,
96
+ ) -> torch.Tensor:
97
+ sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
+ return (1 - sigma) * original_samples + sigma * noise
SiT-S-2-256-diffusers/transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "class_dropout_prob": 0.1,
5
+ "depth": 12,
6
+ "hidden_size": 384,
7
+ "in_channels": 4,
8
+ "input_size": 32,
9
+ "learn_sigma": true,
10
+ "mlp_ratio": 4.0,
11
+ "num_classes": 1000,
12
+ "num_heads": 6,
13
+ "patch_size": 2
14
+ }
SiT-S-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b0754c57c2b6e2e4e74b181d1730a8bd824a30eafeaae9eaf8bc4015e8e4f39
3
+ size 131866144
SiT-S-2-256-diffusers/transformer/transformer_sit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ @dataclass
20
+ class SiTTransformer2DModelOutput(BaseOutput):
21
+ sample: torch.Tensor
22
+
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
+ half = dim // 2
37
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
+ device=t.device
39
+ )
40
+ args = t[:, None].float() * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+ if dim % 2:
43
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
+ return embedding
45
+
46
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
47
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
+
49
+
50
+ class LabelEmbedder(nn.Module):
51
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
+ super().__init__()
53
+ use_cfg_embedding = dropout_prob > 0
54
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
+ self.num_classes = num_classes
56
+ self.dropout_prob = dropout_prob
57
+
58
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
+ if force_drop_ids is None:
60
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
+ else:
62
+ drop_ids = force_drop_ids == 1
63
+ labels = torch.where(drop_ids, self.num_classes, labels)
64
+ return labels
65
+
66
+ def forward(
67
+ self,
68
+ labels: torch.Tensor,
69
+ train: bool,
70
+ force_drop_ids: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ return self.embedding_table(labels)
76
+
77
+
78
+ class SiTBlock(nn.Module):
79
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
+ super().__init__()
81
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
86
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
+
89
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
+ return x
94
+
95
+
96
+ class FinalLayer(nn.Module):
97
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
+ super().__init__()
99
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
+
103
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
+ x = modulate(self.norm_final(x), shift, scale)
106
+ return self.linear(x)
107
+
108
+
109
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ input_size: int = 32,
114
+ patch_size: int = 2,
115
+ in_channels: int = 4,
116
+ hidden_size: int = 1152,
117
+ depth: int = 28,
118
+ num_heads: int = 16,
119
+ mlp_ratio: float = 4.0,
120
+ class_dropout_prob: float = 0.1,
121
+ num_classes: int = 1000,
122
+ learn_sigma: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.learn_sigma = learn_sigma
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
+ self.patch_size = patch_size
129
+ self.num_classes = num_classes
130
+
131
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
+ self.t_embedder = TimestepEmbedder(hidden_size)
133
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
+ num_patches = self.x_embedder.num_patches
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
+
137
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self) -> None:
142
+ def _basic_init(module: nn.Module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
+
152
+ w = self.x_embedder.proj.weight.data
153
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
155
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
+ for block in self.blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
+ nn.init.constant_(self.final_layer.linear.weight, 0)
164
+ nn.init.constant_(self.final_layer.linear.bias, 0)
165
+
166
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
+ c = self.out_channels
168
+ p = self.x_embedder.patch_size[0]
169
+ h = w = int(x.shape[1] ** 0.5)
170
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ timestep: torch.Tensor,
178
+ class_labels: torch.Tensor,
179
+ force_drop_ids: Optional[torch.Tensor] = None,
180
+ return_dict: bool = True,
181
+ ) -> SiTTransformer2DModelOutput:
182
+ x = self.x_embedder(hidden_states) + self.pos_embed
183
+ t = self.t_embedder(timestep)
184
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
+ c = t + y
186
+ for block in self.blocks:
187
+ x = block(x, c)
188
+ x = self.final_layer(x, c)
189
+ x = self.unpatchify(x)
190
+ if self.learn_sigma:
191
+ x, _ = x.chunk(2, dim=1)
192
+ if not return_dict:
193
+ return (x,)
194
+ return SiTTransformer2DModelOutput(sample=x)
195
+
196
+
197
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
+ grid_h = np.arange(grid_size, dtype=np.float32)
199
+ grid_w = np.arange(grid_size, dtype=np.float32)
200
+ grid = np.meshgrid(grid_w, grid_h)
201
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
+ if cls_token and extra_tokens > 0:
204
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
+ return pos_embed
206
+
207
+
208
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
+ assert embed_dim % 2 == 0
210
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
+ return np.concatenate([emb_h, emb_w], axis=1)
213
+
214
+
215
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
+ assert embed_dim % 2 == 0
217
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
218
+ omega /= embed_dim / 2.0
219
+ omega = 1.0 / 10000**omega
220
+ pos = pos.reshape(-1)
221
+ out = np.einsum("m,d->md", pos, omega)
222
+ emb_sin = np.sin(out)
223
+ emb_cos = np.cos(out)
224
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-S-2-256-diffusers/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
SiT-S-2-256-diffusers/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
SiT-XL-2-256-diffusers/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ inference: true
10
+ ---
11
+
12
+ # SiT-XL-2-256-diffusers
13
+
14
+ Self-contained Diffusers checkpoint repo for SiT.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import DiffusionPipeline
21
+
22
+ pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
23
+ generator = torch.Generator(device=pipe.device).manual_seed(0)
24
+
25
+ image = pipe(
26
+ class_labels=207,
27
+ height=256,
28
+ width=256,
29
+ num_inference_steps=250,
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save("demo.png")
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - `pipeline.py`
39
+ - `transformer/transformer_sit.py`
40
+ - `scheduler/scheduling_flow_match_sit.py`
41
+ - `transformer/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
SiT-XL-2-256-diffusers/demo_50steps.png ADDED
SiT-XL-2-256-diffusers/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_sit",
9
+ "SiTFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
SiT-XL-2-256-diffusers/pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+
9
+
10
+ class SiTPipeline(DiffusionPipeline):
11
+ model_cpu_offload_seq = "transformer->vae"
12
+
13
+ def __init__(self, transformer, scheduler, vae):
14
+ super().__init__()
15
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
+ self.vae_scale_factor = 8
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ class_labels: Union[int, List[int]] = 207,
23
+ height: int = 256,
24
+ width: int = 256,
25
+ num_inference_steps: int = 250,
26
+ guidance_scale: float = 4.0,
27
+ generator: Optional[torch.Generator] = None,
28
+ output_type: str = "pil",
29
+ return_dict: bool = True,
30
+ ):
31
+ device = self._execution_device
32
+ if isinstance(class_labels, int):
33
+ class_labels = [class_labels]
34
+ batch_size = len(class_labels)
35
+
36
+ latent_h = height // self.vae_scale_factor
37
+ latent_w = width // self.vae_scale_factor
38
+ latents = randn_tensor(
39
+ (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
+ generator=generator,
41
+ device=device,
42
+ dtype=self.transformer.dtype,
43
+ )
44
+
45
+ labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
+ do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
+ if do_cfg:
48
+ null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
+ labels = torch.cat([labels, null_label], dim=0)
50
+
51
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
52
+ timesteps = self.scheduler.timesteps
53
+
54
+ for t in self.progress_bar(timesteps):
55
+ t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
+ model_input = latents
57
+ if do_cfg:
58
+ model_input = torch.cat([latents, latents], dim=0)
59
+ t_batch = torch.cat([t_batch, t_batch], dim=0)
60
+
61
+ model_pred = self.transformer(
62
+ hidden_states=model_input,
63
+ timestep=t_batch,
64
+ class_labels=labels,
65
+ ).sample
66
+
67
+ if do_cfg:
68
+ cond, uncond = model_pred.chunk(2, dim=0)
69
+ model_pred = uncond + guidance_scale * (cond - uncond)
70
+
71
+ latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
+
73
+ image = self.vae.decode(latents / 0.18215).sample
74
+ # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
+ if output_type == "pt":
76
+ image = image
77
+ else:
78
+ image = self.image_processor.postprocess(image, output_type=output_type)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+ return ImagePipelineOutput(images=image)
SiT-XL-2-256-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "diffusion_form": "sigma",
5
+ "diffusion_norm": 1.0,
6
+ "mode": "ode",
7
+ "num_train_timesteps": 1000,
8
+ "shift": 1.0
9
+ }
SiT-XL-2-256-diffusers/scheduler/scheduling_flow_match_sit.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SiTFlowMatchSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
+ order = 1
19
+
20
+ @register_to_config
21
+ def __init__(
22
+ self,
23
+ mode: str = "ode",
24
+ num_train_timesteps: int = 1000,
25
+ shift: float = 1.0,
26
+ diffusion_form: str = "sigma",
27
+ diffusion_norm: float = 1.0,
28
+ ):
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self._step_index = None
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
+ # Flow matching integrates from noise (t=0) to data (t=1).
35
+ ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
+ self.timesteps = ts[:-1]
37
+ self.sigmas = 1.0 - self.timesteps
38
+ self._step_index = 0
39
+ return self.timesteps
40
+
41
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
+ return sample
43
+
44
+ def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
+ form = self.config.diffusion_form
46
+ norm = self.config.diffusion_norm
47
+ if form == "constant":
48
+ return torch.full_like(t, norm)
49
+ if form == "sigma":
50
+ return norm * (1.0 - t)
51
+ if form == "linear":
52
+ return norm * (1.0 - t)
53
+ if form == "decreasing":
54
+ return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
+ if form == "increasing-decreasing":
56
+ return norm * torch.sin(torch.pi * t) ** 2
57
+ # "SBDM" approximated with sigma-based schedule for compatibility.
58
+ return norm * (1.0 - t)
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.Tensor,
63
+ timestep: Union[float, torch.Tensor],
64
+ sample: torch.Tensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
+ if self.timesteps is None:
69
+ raise ValueError("Call `set_timesteps` before `step`.")
70
+ if self._step_index is None:
71
+ self._step_index = 0
72
+
73
+ step_index = min(self._step_index, len(self.timesteps) - 1)
74
+ t = self.timesteps[step_index].to(sample.device)
75
+ next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
+ dt = next_t - t
77
+
78
+ prev_sample = sample + model_output * dt
79
+ if self.config.mode.lower() == "sde":
80
+ diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
+ while diffusion.dim() < sample.dim():
82
+ diffusion = diffusion.unsqueeze(-1)
83
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
+ prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
+
86
+ self._step_index += 1
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+ return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
+
91
+ def add_noise(
92
+ self,
93
+ original_samples: torch.Tensor,
94
+ noise: torch.Tensor,
95
+ timesteps: torch.Tensor,
96
+ ) -> torch.Tensor:
97
+ sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
+ return (1 - sigma) * original_samples + sigma * noise
SiT-XL-2-256-diffusers/transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "class_dropout_prob": 0.1,
5
+ "depth": 28,
6
+ "hidden_size": 1152,
7
+ "in_channels": 4,
8
+ "input_size": 32,
9
+ "learn_sigma": true,
10
+ "mlp_ratio": 4.0,
11
+ "num_classes": 1000,
12
+ "num_heads": 16,
13
+ "patch_size": 2
14
+ }
SiT-XL-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fc19454ff52c2f741194974b567a8679709dac950a2f74b766eaaeae22bdc09
3
+ size 2700547792
SiT-XL-2-256-diffusers/transformer/transformer_sit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ @dataclass
20
+ class SiTTransformer2DModelOutput(BaseOutput):
21
+ sample: torch.Tensor
22
+
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
+ half = dim // 2
37
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
+ device=t.device
39
+ )
40
+ args = t[:, None].float() * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+ if dim % 2:
43
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
+ return embedding
45
+
46
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
47
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
+
49
+
50
+ class LabelEmbedder(nn.Module):
51
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
+ super().__init__()
53
+ use_cfg_embedding = dropout_prob > 0
54
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
+ self.num_classes = num_classes
56
+ self.dropout_prob = dropout_prob
57
+
58
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
+ if force_drop_ids is None:
60
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
+ else:
62
+ drop_ids = force_drop_ids == 1
63
+ labels = torch.where(drop_ids, self.num_classes, labels)
64
+ return labels
65
+
66
+ def forward(
67
+ self,
68
+ labels: torch.Tensor,
69
+ train: bool,
70
+ force_drop_ids: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ return self.embedding_table(labels)
76
+
77
+
78
+ class SiTBlock(nn.Module):
79
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
+ super().__init__()
81
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
86
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
+
89
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
+ return x
94
+
95
+
96
+ class FinalLayer(nn.Module):
97
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
+ super().__init__()
99
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
+
103
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
+ x = modulate(self.norm_final(x), shift, scale)
106
+ return self.linear(x)
107
+
108
+
109
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ input_size: int = 32,
114
+ patch_size: int = 2,
115
+ in_channels: int = 4,
116
+ hidden_size: int = 1152,
117
+ depth: int = 28,
118
+ num_heads: int = 16,
119
+ mlp_ratio: float = 4.0,
120
+ class_dropout_prob: float = 0.1,
121
+ num_classes: int = 1000,
122
+ learn_sigma: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.learn_sigma = learn_sigma
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
+ self.patch_size = patch_size
129
+ self.num_classes = num_classes
130
+
131
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
+ self.t_embedder = TimestepEmbedder(hidden_size)
133
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
+ num_patches = self.x_embedder.num_patches
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
+
137
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self) -> None:
142
+ def _basic_init(module: nn.Module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
+
152
+ w = self.x_embedder.proj.weight.data
153
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
155
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
+ for block in self.blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
+ nn.init.constant_(self.final_layer.linear.weight, 0)
164
+ nn.init.constant_(self.final_layer.linear.bias, 0)
165
+
166
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
+ c = self.out_channels
168
+ p = self.x_embedder.patch_size[0]
169
+ h = w = int(x.shape[1] ** 0.5)
170
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ timestep: torch.Tensor,
178
+ class_labels: torch.Tensor,
179
+ force_drop_ids: Optional[torch.Tensor] = None,
180
+ return_dict: bool = True,
181
+ ) -> SiTTransformer2DModelOutput:
182
+ x = self.x_embedder(hidden_states) + self.pos_embed
183
+ t = self.t_embedder(timestep)
184
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
+ c = t + y
186
+ for block in self.blocks:
187
+ x = block(x, c)
188
+ x = self.final_layer(x, c)
189
+ x = self.unpatchify(x)
190
+ if self.learn_sigma:
191
+ x, _ = x.chunk(2, dim=1)
192
+ if not return_dict:
193
+ return (x,)
194
+ return SiTTransformer2DModelOutput(sample=x)
195
+
196
+
197
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
+ grid_h = np.arange(grid_size, dtype=np.float32)
199
+ grid_w = np.arange(grid_size, dtype=np.float32)
200
+ grid = np.meshgrid(grid_w, grid_h)
201
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
+ if cls_token and extra_tokens > 0:
204
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
+ return pos_embed
206
+
207
+
208
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
+ assert embed_dim % 2 == 0
210
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
+ return np.concatenate([emb_h, emb_w], axis=1)
213
+
214
+
215
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
+ assert embed_dim % 2 == 0
217
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
218
+ omega /= embed_dim / 2.0
219
+ omega = 1.0 / 10000**omega
220
+ pos = pos.reshape(-1)
221
+ out = np.einsum("m,d->md", pos, omega)
222
+ emb_sin = np.sin(out)
223
+ emb_cos = np.cos(out)
224
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-XL-2-256-diffusers/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
SiT-XL-2-256-diffusers/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268
SiT-XL-2-512-diffusers/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ pipeline_tag: unconditional-image-generation
4
+ tags:
5
+ - diffusers
6
+ - sit
7
+ - image-generation
8
+ - class-conditional
9
+ inference: true
10
+ ---
11
+
12
+ # SiT-XL-2-512-diffusers
13
+
14
+ Self-contained Diffusers checkpoint repo for SiT.
15
+
16
+ ## Usage
17
+
18
+ ```python
19
+ import torch
20
+ from diffusers import DiffusionPipeline
21
+
22
+ pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
23
+ generator = torch.Generator(device=pipe.device).manual_seed(0)
24
+
25
+ image = pipe(
26
+ class_labels=207,
27
+ height=512,
28
+ width=512,
29
+ num_inference_steps=250,
30
+ guidance_scale=4.0,
31
+ generator=generator,
32
+ ).images[0]
33
+ image.save("demo.png")
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - `pipeline.py`
39
+ - `transformer/transformer_sit.py`
40
+ - `scheduler/scheduling_flow_match_sit.py`
41
+ - `transformer/diffusion_pytorch_model.safetensors`
42
+ - `vae/diffusion_pytorch_model.safetensors`
SiT-XL-2-512-diffusers/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "SiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "scheduling_flow_match_sit",
9
+ "SiTFlowMatchScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_sit",
13
+ "SiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKL"
18
+ ]
19
+ }
SiT-XL-2-512-diffusers/pipeline.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+
9
+
10
+ class SiTPipeline(DiffusionPipeline):
11
+ model_cpu_offload_seq = "transformer->vae"
12
+
13
+ def __init__(self, transformer, scheduler, vae):
14
+ super().__init__()
15
+ self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
16
+ self.vae_scale_factor = 8
17
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
+
19
+ @torch.no_grad()
20
+ def __call__(
21
+ self,
22
+ class_labels: Union[int, List[int]] = 207,
23
+ height: int = 256,
24
+ width: int = 256,
25
+ num_inference_steps: int = 250,
26
+ guidance_scale: float = 4.0,
27
+ generator: Optional[torch.Generator] = None,
28
+ output_type: str = "pil",
29
+ return_dict: bool = True,
30
+ ):
31
+ device = self._execution_device
32
+ if isinstance(class_labels, int):
33
+ class_labels = [class_labels]
34
+ batch_size = len(class_labels)
35
+
36
+ latent_h = height // self.vae_scale_factor
37
+ latent_w = width // self.vae_scale_factor
38
+ latents = randn_tensor(
39
+ (batch_size, self.transformer.config.in_channels, latent_h, latent_w),
40
+ generator=generator,
41
+ device=device,
42
+ dtype=self.transformer.dtype,
43
+ )
44
+
45
+ labels = torch.tensor(class_labels, device=device, dtype=torch.long)
46
+ do_cfg = guidance_scale is not None and guidance_scale > 1.0
47
+ if do_cfg:
48
+ null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
49
+ labels = torch.cat([labels, null_label], dim=0)
50
+
51
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
52
+ timesteps = self.scheduler.timesteps
53
+
54
+ for t in self.progress_bar(timesteps):
55
+ t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
56
+ model_input = latents
57
+ if do_cfg:
58
+ model_input = torch.cat([latents, latents], dim=0)
59
+ t_batch = torch.cat([t_batch, t_batch], dim=0)
60
+
61
+ model_pred = self.transformer(
62
+ hidden_states=model_input,
63
+ timestep=t_batch,
64
+ class_labels=labels,
65
+ ).sample
66
+
67
+ if do_cfg:
68
+ cond, uncond = model_pred.chunk(2, dim=0)
69
+ model_pred = uncond + guidance_scale * (cond - uncond)
70
+
71
+ latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
72
+
73
+ image = self.vae.decode(latents / 0.18215).sample
74
+ # Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
75
+ if output_type == "pt":
76
+ image = image
77
+ else:
78
+ image = self.image_processor.postprocess(image, output_type=output_type)
79
+
80
+ if not return_dict:
81
+ return (image,)
82
+ return ImagePipelineOutput(images=image)
SiT-XL-2-512-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTFlowMatchScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "diffusion_form": "sigma",
5
+ "diffusion_norm": 1.0,
6
+ "mode": "ode",
7
+ "num_train_timesteps": 1000,
8
+ "shift": 1.0
9
+ }
SiT-XL-2-512-diffusers/scheduler/scheduling_flow_match_sit.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
8
+ from diffusers.utils import BaseOutput
9
+
10
+
11
+ @dataclass
12
+ class SiTFlowMatchSchedulerOutput(BaseOutput):
13
+ prev_sample: torch.Tensor
14
+
15
+
16
+ class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
17
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
18
+ order = 1
19
+
20
+ @register_to_config
21
+ def __init__(
22
+ self,
23
+ mode: str = "ode",
24
+ num_train_timesteps: int = 1000,
25
+ shift: float = 1.0,
26
+ diffusion_form: str = "sigma",
27
+ diffusion_norm: float = 1.0,
28
+ ):
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self._step_index = None
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
34
+ # Flow matching integrates from noise (t=0) to data (t=1).
35
+ ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
36
+ self.timesteps = ts[:-1]
37
+ self.sigmas = 1.0 - self.timesteps
38
+ self._step_index = 0
39
+ return self.timesteps
40
+
41
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
42
+ return sample
43
+
44
+ def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
45
+ form = self.config.diffusion_form
46
+ norm = self.config.diffusion_norm
47
+ if form == "constant":
48
+ return torch.full_like(t, norm)
49
+ if form == "sigma":
50
+ return norm * (1.0 - t)
51
+ if form == "linear":
52
+ return norm * (1.0 - t)
53
+ if form == "decreasing":
54
+ return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
55
+ if form == "increasing-decreasing":
56
+ return norm * torch.sin(torch.pi * t) ** 2
57
+ # "SBDM" approximated with sigma-based schedule for compatibility.
58
+ return norm * (1.0 - t)
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.Tensor,
63
+ timestep: Union[float, torch.Tensor],
64
+ sample: torch.Tensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
68
+ if self.timesteps is None:
69
+ raise ValueError("Call `set_timesteps` before `step`.")
70
+ if self._step_index is None:
71
+ self._step_index = 0
72
+
73
+ step_index = min(self._step_index, len(self.timesteps) - 1)
74
+ t = self.timesteps[step_index].to(sample.device)
75
+ next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
76
+ dt = next_t - t
77
+
78
+ prev_sample = sample + model_output * dt
79
+ if self.config.mode.lower() == "sde":
80
+ diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
81
+ while diffusion.dim() < sample.dim():
82
+ diffusion = diffusion.unsqueeze(-1)
83
+ noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
84
+ prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
85
+
86
+ self._step_index += 1
87
+ if not return_dict:
88
+ return (prev_sample,)
89
+ return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
90
+
91
+ def add_noise(
92
+ self,
93
+ original_samples: torch.Tensor,
94
+ noise: torch.Tensor,
95
+ timesteps: torch.Tensor,
96
+ ) -> torch.Tensor:
97
+ sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
98
+ return (1 - sigma) * original_samples + sigma * noise
SiT-XL-2-512-diffusers/transformer/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SiTTransformer2DModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "class_dropout_prob": 0.1,
5
+ "depth": 28,
6
+ "hidden_size": 1152,
7
+ "in_channels": 4,
8
+ "input_size": 64,
9
+ "learn_sigma": false,
10
+ "mlp_ratio": 4.0,
11
+ "num_classes": 1000,
12
+ "num_heads": 16,
13
+ "patch_size": 2
14
+ }
SiT-XL-2-512-diffusers/transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d830587730010dae36882e38da966a784396e4e1b8f7f0997685f23fd63063f
3
+ size 2704012944
SiT-XL-2-512-diffusers/transformer/transformer_sit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+
14
+
15
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
16
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
17
+
18
+
19
+ @dataclass
20
+ class SiTTransformer2DModelOutput(BaseOutput):
21
+ sample: torch.Tensor
22
+
23
+
24
+ class TimestepEmbedder(nn.Module):
25
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
26
+ super().__init__()
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(hidden_size, hidden_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
36
+ half = dim // 2
37
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
38
+ device=t.device
39
+ )
40
+ args = t[:, None].float() * freqs[None]
41
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
42
+ if dim % 2:
43
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
44
+ return embedding
45
+
46
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
47
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
48
+
49
+
50
+ class LabelEmbedder(nn.Module):
51
+ def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
52
+ super().__init__()
53
+ use_cfg_embedding = dropout_prob > 0
54
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
55
+ self.num_classes = num_classes
56
+ self.dropout_prob = dropout_prob
57
+
58
+ def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
59
+ if force_drop_ids is None:
60
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
61
+ else:
62
+ drop_ids = force_drop_ids == 1
63
+ labels = torch.where(drop_ids, self.num_classes, labels)
64
+ return labels
65
+
66
+ def forward(
67
+ self,
68
+ labels: torch.Tensor,
69
+ train: bool,
70
+ force_drop_ids: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ return self.embedding_table(labels)
76
+
77
+
78
+ class SiTBlock(nn.Module):
79
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
80
+ super().__init__()
81
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
82
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
83
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
84
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
85
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
86
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
87
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
88
+
89
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
90
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
91
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
92
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
93
+ return x
94
+
95
+
96
+ class FinalLayer(nn.Module):
97
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
98
+ super().__init__()
99
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
100
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
101
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
102
+
103
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
104
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
105
+ x = modulate(self.norm_final(x), shift, scale)
106
+ return self.linear(x)
107
+
108
+
109
+ class SiTTransformer2DModel(ModelMixin, ConfigMixin):
110
+ @register_to_config
111
+ def __init__(
112
+ self,
113
+ input_size: int = 32,
114
+ patch_size: int = 2,
115
+ in_channels: int = 4,
116
+ hidden_size: int = 1152,
117
+ depth: int = 28,
118
+ num_heads: int = 16,
119
+ mlp_ratio: float = 4.0,
120
+ class_dropout_prob: float = 0.1,
121
+ num_classes: int = 1000,
122
+ learn_sigma: bool = True,
123
+ ):
124
+ super().__init__()
125
+ self.learn_sigma = learn_sigma
126
+ self.in_channels = in_channels
127
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
128
+ self.patch_size = patch_size
129
+ self.num_classes = num_classes
130
+
131
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
132
+ self.t_embedder = TimestepEmbedder(hidden_size)
133
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
134
+ num_patches = self.x_embedder.num_patches
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
136
+
137
+ self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
138
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
139
+ self.initialize_weights()
140
+
141
+ def initialize_weights(self) -> None:
142
+ def _basic_init(module: nn.Module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
150
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
151
+
152
+ w = self.x_embedder.proj.weight.data
153
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
154
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
155
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
156
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
157
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
158
+ for block in self.blocks:
159
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
160
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
161
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
162
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
163
+ nn.init.constant_(self.final_layer.linear.weight, 0)
164
+ nn.init.constant_(self.final_layer.linear.bias, 0)
165
+
166
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
167
+ c = self.out_channels
168
+ p = self.x_embedder.patch_size[0]
169
+ h = w = int(x.shape[1] ** 0.5)
170
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
171
+ x = torch.einsum("nhwpqc->nchpwq", x)
172
+ return x.reshape(shape=(x.shape[0], c, h * p, h * p))
173
+
174
+ def forward(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ timestep: torch.Tensor,
178
+ class_labels: torch.Tensor,
179
+ force_drop_ids: Optional[torch.Tensor] = None,
180
+ return_dict: bool = True,
181
+ ) -> SiTTransformer2DModelOutput:
182
+ x = self.x_embedder(hidden_states) + self.pos_embed
183
+ t = self.t_embedder(timestep)
184
+ y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
185
+ c = t + y
186
+ for block in self.blocks:
187
+ x = block(x, c)
188
+ x = self.final_layer(x, c)
189
+ x = self.unpatchify(x)
190
+ if self.learn_sigma:
191
+ x, _ = x.chunk(2, dim=1)
192
+ if not return_dict:
193
+ return (x,)
194
+ return SiTTransformer2DModelOutput(sample=x)
195
+
196
+
197
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
198
+ grid_h = np.arange(grid_size, dtype=np.float32)
199
+ grid_w = np.arange(grid_size, dtype=np.float32)
200
+ grid = np.meshgrid(grid_w, grid_h)
201
+ grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
202
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
203
+ if cls_token and extra_tokens > 0:
204
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
205
+ return pos_embed
206
+
207
+
208
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
209
+ assert embed_dim % 2 == 0
210
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
211
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
212
+ return np.concatenate([emb_h, emb_w], axis=1)
213
+
214
+
215
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
216
+ assert embed_dim % 2 == 0
217
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
218
+ omega /= embed_dim / 2.0
219
+ omega = 1.0 / 10000**omega
220
+ pos = pos.reshape(-1)
221
+ out = np.einsum("m,d->md", pos, omega)
222
+ emb_sin = np.sin(out)
223
+ emb_cos = np.cos(out)
224
+ return np.concatenate([emb_sin, emb_cos], axis=1)
SiT-XL-2-512-diffusers/vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 256,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }