Add files using upload-large-folder tool
Browse files- README.md +132 -0
- SiT-B-2-256-diffusers/README.md +42 -0
- SiT-B-2-256-diffusers/model_index.json +19 -0
- SiT-B-2-256-diffusers/pipeline.py +82 -0
- SiT-B-2-256-diffusers/scheduler/scheduler_config.json +9 -0
- SiT-B-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
- SiT-B-2-256-diffusers/transformer/config.json +14 -0
- SiT-B-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
- SiT-B-2-256-diffusers/transformer/transformer_sit.py +224 -0
- SiT-B-2-256-diffusers/vae/config.json +38 -0
- SiT-B-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
- SiT-L-2-256-diffusers/README.md +42 -0
- SiT-L-2-256-diffusers/model_index.json +19 -0
- SiT-L-2-256-diffusers/pipeline.py +82 -0
- SiT-L-2-256-diffusers/scheduler/scheduler_config.json +9 -0
- SiT-L-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
- SiT-L-2-256-diffusers/transformer/config.json +14 -0
- SiT-L-2-256-diffusers/transformer/transformer_sit.py +224 -0
- SiT-L-2-256-diffusers/vae/config.json +38 -0
- SiT-L-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
- SiT-S-2-256-diffusers/README.md +42 -0
- SiT-S-2-256-diffusers/model_index.json +19 -0
- SiT-S-2-256-diffusers/pipeline.py +82 -0
- SiT-S-2-256-diffusers/scheduler/scheduler_config.json +9 -0
- SiT-S-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
- SiT-S-2-256-diffusers/transformer/config.json +14 -0
- SiT-S-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
- SiT-S-2-256-diffusers/transformer/transformer_sit.py +224 -0
- SiT-S-2-256-diffusers/vae/config.json +38 -0
- SiT-S-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
- SiT-XL-2-256-diffusers/README.md +42 -0
- SiT-XL-2-256-diffusers/demo_50steps.png +0 -0
- SiT-XL-2-256-diffusers/model_index.json +19 -0
- SiT-XL-2-256-diffusers/pipeline.py +82 -0
- SiT-XL-2-256-diffusers/scheduler/scheduler_config.json +9 -0
- SiT-XL-2-256-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
- SiT-XL-2-256-diffusers/transformer/config.json +14 -0
- SiT-XL-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
- SiT-XL-2-256-diffusers/transformer/transformer_sit.py +224 -0
- SiT-XL-2-256-diffusers/vae/config.json +38 -0
- SiT-XL-2-256-diffusers/vae/diffusion_pytorch_model.safetensors +3 -0
- SiT-XL-2-512-diffusers/README.md +42 -0
- SiT-XL-2-512-diffusers/model_index.json +19 -0
- SiT-XL-2-512-diffusers/pipeline.py +82 -0
- SiT-XL-2-512-diffusers/scheduler/scheduler_config.json +9 -0
- SiT-XL-2-512-diffusers/scheduler/scheduling_flow_match_sit.py +98 -0
- SiT-XL-2-512-diffusers/transformer/config.json +14 -0
- SiT-XL-2-512-diffusers/transformer/diffusion_pytorch_model.safetensors +3 -0
- SiT-XL-2-512-diffusers/transformer/transformer_sit.py +224 -0
- SiT-XL-2-512-diffusers/vae/config.json +38 -0
README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
- imagenet
|
| 10 |
+
license: mit
|
| 11 |
+
inference: true
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SiT-diffusers
|
| 15 |
+
|
| 16 |
+
Diffusers-ready checkpoints for **Scalable Interpolant Transformers (SiT)**, converted for local/offline use.
|
| 17 |
+
|
| 18 |
+
This root folder is a model collection that contains:
|
| 19 |
+
|
| 20 |
+
- `SiT-S-2-256-diffusers`
|
| 21 |
+
- `SiT-B-2-256-diffusers`
|
| 22 |
+
- `SiT-L-2-256-diffusers`
|
| 23 |
+
- `SiT-XL-2-256-diffusers`
|
| 24 |
+
- `SiT-XL-2-512-diffusers`
|
| 25 |
+
|
| 26 |
+
Each subfolder is a self-contained Diffusers model repo with:
|
| 27 |
+
|
| 28 |
+
- `pipeline.py`
|
| 29 |
+
- `transformer/transformer_sit.py`
|
| 30 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 31 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 32 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
| 33 |
+
|
| 34 |
+
## Model Paths
|
| 35 |
+
|
| 36 |
+
Use paths relative to this root README:
|
| 37 |
+
|
| 38 |
+
| Model | Resolution | Local path |
|
| 39 |
+
|---|---:|---|
|
| 40 |
+
| SiT-S/2 | 256x256 | `./SiT-S-2-256-diffusers` |
|
| 41 |
+
| SiT-B/2 | 256x256 | `./SiT-B-2-256-diffusers` |
|
| 42 |
+
| SiT-L/2 | 256x256 | `./SiT-L-2-256-diffusers` |
|
| 43 |
+
| SiT-XL/2 | 256x256 | `./SiT-XL-2-256-diffusers` |
|
| 44 |
+
| SiT-XL/2 | 512x512 | `./SiT-XL-2-512-diffusers` |
|
| 45 |
+
|
| 46 |
+
## Inference Demo (Diffusers)
|
| 47 |
+
|
| 48 |
+
### 1) Load a local subfolder checkpoint
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
import torch
|
| 52 |
+
from diffusers import DiffusionPipeline
|
| 53 |
+
|
| 54 |
+
model_path = "./SiT-XL-2-512-diffusers" # change to any path in the table above
|
| 55 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 56 |
+
|
| 57 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 58 |
+
model_path,
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
).to(device)
|
| 61 |
+
|
| 62 |
+
generator = torch.Generator(device=device).manual_seed(0)
|
| 63 |
+
|
| 64 |
+
# ImageNet class example: 207 = golden retriever
|
| 65 |
+
result = pipe(
|
| 66 |
+
class_labels=207,
|
| 67 |
+
height=512,
|
| 68 |
+
width=512,
|
| 69 |
+
num_inference_steps=250, # official SiT comparisons commonly use 250 steps
|
| 70 |
+
guidance_scale=4.0,
|
| 71 |
+
generator=generator,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
image = result.images[0]
|
| 75 |
+
image.save("sit_xl_512_demo.png")
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 2) Quick variant switch (256 models)
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
model_path = "./SiT-S-2-256-diffusers"
|
| 82 |
+
# model_path = "./SiT-B-2-256-diffusers"
|
| 83 |
+
# model_path = "./SiT-L-2-256-diffusers"
|
| 84 |
+
# model_path = "./SiT-XL-2-256-diffusers"
|
| 85 |
+
|
| 86 |
+
pipe = DiffusionPipeline.from_pretrained(model_path, trust_remote_code=True).to(device)
|
| 87 |
+
image = pipe(
|
| 88 |
+
class_labels=207,
|
| 89 |
+
height=256,
|
| 90 |
+
width=256,
|
| 91 |
+
num_inference_steps=250,
|
| 92 |
+
guidance_scale=4.0,
|
| 93 |
+
generator=generator,
|
| 94 |
+
).images[0]
|
| 95 |
+
image.save("sit_256_demo.png")
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## FID Reference (from Official SiT Results)
|
| 99 |
+
|
| 100 |
+
The table below summarizes widely cited SiT numbers from the official project materials for class-conditional ImageNet generation.
|
| 101 |
+
|
| 102 |
+
| Model / setting | Resolution | FID-50K (lower is better) |
|
| 103 |
+
|---|---:|---:|
|
| 104 |
+
| SiT-S (400K steps) | 256x256 | 57.6 |
|
| 105 |
+
| SiT-B (400K steps) | 256x256 | 33.5 |
|
| 106 |
+
| SiT-L (400K steps) | 256x256 | 17.2 |
|
| 107 |
+
| SiT-XL (400K steps) | 256x256 | 8.6 |
|
| 108 |
+
| SiT-XL (cfg=1.5, ODE) | 256x256 | 2.15 |
|
| 109 |
+
| SiT-XL (cfg=1.5, SDE, `w(t)=sigma_t`) | 256x256 | 2.06 |
|
| 110 |
+
| SiT-XL (sample showcase) | 512x512 | Not reported in the same benchmark table |
|
| 111 |
+
|
| 112 |
+
> Note: FID depends on training recipe, sampler choice (ODE/SDE), guidance scale, and evaluation protocol. Treat this table as a reference to official SiT reports, not as guaranteed reproducibility for every conversion/export.
|
| 113 |
+
|
| 114 |
+
## Source and Paper
|
| 115 |
+
|
| 116 |
+
- Official SiT code: [willisma/SiT](https://github.com/willisma/SiT)
|
| 117 |
+
- Project page: [scalable-interpolant.github.io](https://scalable-interpolant.github.io/)
|
| 118 |
+
- Paper (arXiv): [SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers](https://arxiv.org/abs/2401.08740)
|
| 119 |
+
|
| 120 |
+
## Citation
|
| 121 |
+
|
| 122 |
+
If you use SiT in your work, please cite:
|
| 123 |
+
|
| 124 |
+
```bibtex
|
| 125 |
+
@inproceedings{ma2024sit,
|
| 126 |
+
title={SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers},
|
| 127 |
+
author={Ma, Nanye and Goldstein, Mark and Albergo, Michael S. and Boffi, Nicholas M. and Vanden-Eijnden, Eric and Xie, Saining},
|
| 128 |
+
booktitle={European Conference on Computer Vision (ECCV)},
|
| 129 |
+
year={2024},
|
| 130 |
+
note={Accepted to ECCV 2024}
|
| 131 |
+
}
|
| 132 |
+
```
|
SiT-B-2-256-diffusers/README.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SiT-B-2-256-diffusers
|
| 13 |
+
|
| 14 |
+
Self-contained Diffusers checkpoint repo for SiT.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
+
|
| 22 |
+
pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
generator = torch.Generator(device=pipe.device).manual_seed(0)
|
| 24 |
+
|
| 25 |
+
image = pipe(
|
| 26 |
+
class_labels=207,
|
| 27 |
+
height=256,
|
| 28 |
+
width=256,
|
| 29 |
+
num_inference_steps=250,
|
| 30 |
+
guidance_scale=4.0,
|
| 31 |
+
generator=generator,
|
| 32 |
+
).images[0]
|
| 33 |
+
image.save("demo.png")
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Components
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `transformer/transformer_sit.py`
|
| 40 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 41 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
SiT-B-2-256-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"SiTPipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_flow_match_sit",
|
| 9 |
+
"SiTFlowMatchScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"transformer_sit",
|
| 13 |
+
"SiTTransformer2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
SiT-B-2-256-diffusers/pipeline.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 6 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiTPipeline(DiffusionPipeline):
|
| 11 |
+
model_cpu_offload_seq = "transformer->vae"
|
| 12 |
+
|
| 13 |
+
def __init__(self, transformer, scheduler, vae):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
|
| 16 |
+
self.vae_scale_factor = 8
|
| 17 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
class_labels: Union[int, List[int]] = 207,
|
| 23 |
+
height: int = 256,
|
| 24 |
+
width: int = 256,
|
| 25 |
+
num_inference_steps: int = 250,
|
| 26 |
+
guidance_scale: float = 4.0,
|
| 27 |
+
generator: Optional[torch.Generator] = None,
|
| 28 |
+
output_type: str = "pil",
|
| 29 |
+
return_dict: bool = True,
|
| 30 |
+
):
|
| 31 |
+
device = self._execution_device
|
| 32 |
+
if isinstance(class_labels, int):
|
| 33 |
+
class_labels = [class_labels]
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
|
| 36 |
+
latent_h = height // self.vae_scale_factor
|
| 37 |
+
latent_w = width // self.vae_scale_factor
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
(batch_size, self.transformer.config.in_channels, latent_h, latent_w),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=self.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
labels = torch.tensor(class_labels, device=device, dtype=torch.long)
|
| 46 |
+
do_cfg = guidance_scale is not None and guidance_scale > 1.0
|
| 47 |
+
if do_cfg:
|
| 48 |
+
null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
|
| 49 |
+
labels = torch.cat([labels, null_label], dim=0)
|
| 50 |
+
|
| 51 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 52 |
+
timesteps = self.scheduler.timesteps
|
| 53 |
+
|
| 54 |
+
for t in self.progress_bar(timesteps):
|
| 55 |
+
t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
|
| 56 |
+
model_input = latents
|
| 57 |
+
if do_cfg:
|
| 58 |
+
model_input = torch.cat([latents, latents], dim=0)
|
| 59 |
+
t_batch = torch.cat([t_batch, t_batch], dim=0)
|
| 60 |
+
|
| 61 |
+
model_pred = self.transformer(
|
| 62 |
+
hidden_states=model_input,
|
| 63 |
+
timestep=t_batch,
|
| 64 |
+
class_labels=labels,
|
| 65 |
+
).sample
|
| 66 |
+
|
| 67 |
+
if do_cfg:
|
| 68 |
+
cond, uncond = model_pred.chunk(2, dim=0)
|
| 69 |
+
model_pred = uncond + guidance_scale * (cond - uncond)
|
| 70 |
+
|
| 71 |
+
latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
|
| 72 |
+
|
| 73 |
+
image = self.vae.decode(latents / 0.18215).sample
|
| 74 |
+
# Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
|
| 75 |
+
if output_type == "pt":
|
| 76 |
+
image = image
|
| 77 |
+
else:
|
| 78 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 79 |
+
|
| 80 |
+
if not return_dict:
|
| 81 |
+
return (image,)
|
| 82 |
+
return ImagePipelineOutput(images=image)
|
SiT-B-2-256-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTFlowMatchScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"diffusion_form": "sigma",
|
| 5 |
+
"diffusion_norm": 1.0,
|
| 6 |
+
"mode": "ode",
|
| 7 |
+
"num_train_timesteps": 1000,
|
| 8 |
+
"shift": 1.0
|
| 9 |
+
}
|
SiT-B-2-256-diffusers/scheduler/scheduling_flow_match_sit.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SiTFlowMatchSchedulerOutput(BaseOutput):
|
| 13 |
+
prev_sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
| 17 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 18 |
+
order = 1
|
| 19 |
+
|
| 20 |
+
@register_to_config
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mode: str = "ode",
|
| 24 |
+
num_train_timesteps: int = 1000,
|
| 25 |
+
shift: float = 1.0,
|
| 26 |
+
diffusion_form: str = "sigma",
|
| 27 |
+
diffusion_norm: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
self.timesteps = None
|
| 30 |
+
self.sigmas = None
|
| 31 |
+
self._step_index = None
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
|
| 34 |
+
# Flow matching integrates from noise (t=0) to data (t=1).
|
| 35 |
+
ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 36 |
+
self.timesteps = ts[:-1]
|
| 37 |
+
self.sigmas = 1.0 - self.timesteps
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
return self.timesteps
|
| 40 |
+
|
| 41 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
form = self.config.diffusion_form
|
| 46 |
+
norm = self.config.diffusion_norm
|
| 47 |
+
if form == "constant":
|
| 48 |
+
return torch.full_like(t, norm)
|
| 49 |
+
if form == "sigma":
|
| 50 |
+
return norm * (1.0 - t)
|
| 51 |
+
if form == "linear":
|
| 52 |
+
return norm * (1.0 - t)
|
| 53 |
+
if form == "decreasing":
|
| 54 |
+
return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
|
| 55 |
+
if form == "increasing-decreasing":
|
| 56 |
+
return norm * torch.sin(torch.pi * t) ** 2
|
| 57 |
+
# "SBDM" approximated with sigma-based schedule for compatibility.
|
| 58 |
+
return norm * (1.0 - t)
|
| 59 |
+
|
| 60 |
+
def step(
|
| 61 |
+
self,
|
| 62 |
+
model_output: torch.Tensor,
|
| 63 |
+
timestep: Union[float, torch.Tensor],
|
| 64 |
+
sample: torch.Tensor,
|
| 65 |
+
generator: Optional[torch.Generator] = None,
|
| 66 |
+
return_dict: bool = True,
|
| 67 |
+
) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
|
| 68 |
+
if self.timesteps is None:
|
| 69 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 70 |
+
if self._step_index is None:
|
| 71 |
+
self._step_index = 0
|
| 72 |
+
|
| 73 |
+
step_index = min(self._step_index, len(self.timesteps) - 1)
|
| 74 |
+
t = self.timesteps[step_index].to(sample.device)
|
| 75 |
+
next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
|
| 76 |
+
dt = next_t - t
|
| 77 |
+
|
| 78 |
+
prev_sample = sample + model_output * dt
|
| 79 |
+
if self.config.mode.lower() == "sde":
|
| 80 |
+
diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
|
| 81 |
+
while diffusion.dim() < sample.dim():
|
| 82 |
+
diffusion = diffusion.unsqueeze(-1)
|
| 83 |
+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 84 |
+
prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
|
| 85 |
+
|
| 86 |
+
self._step_index += 1
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
|
| 90 |
+
|
| 91 |
+
def add_noise(
|
| 92 |
+
self,
|
| 93 |
+
original_samples: torch.Tensor,
|
| 94 |
+
noise: torch.Tensor,
|
| 95 |
+
timesteps: torch.Tensor,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
|
| 98 |
+
return (1 - sigma) * original_samples + sigma * noise
|
SiT-B-2-256-diffusers/transformer/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"class_dropout_prob": 0.1,
|
| 5 |
+
"depth": 12,
|
| 6 |
+
"hidden_size": 768,
|
| 7 |
+
"in_channels": 4,
|
| 8 |
+
"input_size": 32,
|
| 9 |
+
"learn_sigma": true,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"num_classes": 1000,
|
| 12 |
+
"num_heads": 12,
|
| 13 |
+
"patch_size": 2
|
| 14 |
+
}
|
SiT-B-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be5318b2de818389f53e49ce39495394a94e20b0449041e0f9a6ced1ccc64f6c
|
| 3 |
+
size 522062536
|
SiT-B-2-256-diffusers/transformer/transformer_sit.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
| 9 |
+
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from diffusers.utils import BaseOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SiTTransformer2DModelOutput(BaseOutput):
|
| 21 |
+
sample: torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 36 |
+
half = dim // 2
|
| 37 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 38 |
+
device=t.device
|
| 39 |
+
)
|
| 40 |
+
args = t[:, None].float() * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
if dim % 2:
|
| 43 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 44 |
+
return embedding
|
| 45 |
+
|
| 46 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
|
| 52 |
+
super().__init__()
|
| 53 |
+
use_cfg_embedding = dropout_prob > 0
|
| 54 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_prob = dropout_prob
|
| 57 |
+
|
| 58 |
+
def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 59 |
+
if force_drop_ids is None:
|
| 60 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 61 |
+
else:
|
| 62 |
+
drop_ids = force_drop_ids == 1
|
| 63 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
labels: torch.Tensor,
|
| 69 |
+
train: bool,
|
| 70 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
use_dropout = self.dropout_prob > 0
|
| 73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 75 |
+
return self.embedding_table(labels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SiTBlock(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 82 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 83 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 84 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 85 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 86 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 87 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 91 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 92 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FinalLayer(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 100 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 101 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 105 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 106 |
+
return self.linear(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 110 |
+
@register_to_config
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
input_size: int = 32,
|
| 114 |
+
patch_size: int = 2,
|
| 115 |
+
in_channels: int = 4,
|
| 116 |
+
hidden_size: int = 1152,
|
| 117 |
+
depth: int = 28,
|
| 118 |
+
num_heads: int = 16,
|
| 119 |
+
mlp_ratio: float = 4.0,
|
| 120 |
+
class_dropout_prob: float = 0.1,
|
| 121 |
+
num_classes: int = 1000,
|
| 122 |
+
learn_sigma: bool = True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.learn_sigma = learn_sigma
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_classes = num_classes
|
| 130 |
+
|
| 131 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 132 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 133 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 134 |
+
num_patches = self.x_embedder.num_patches
|
| 135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 136 |
+
|
| 137 |
+
self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 138 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
|
| 141 |
+
def initialize_weights(self) -> None:
|
| 142 |
+
def _basic_init(module: nn.Module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
|
| 150 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 151 |
+
|
| 152 |
+
w = self.x_embedder.proj.weight.data
|
| 153 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 154 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 155 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 156 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 157 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 158 |
+
for block in self.blocks:
|
| 159 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 160 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 162 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 163 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 164 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 165 |
+
|
| 166 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
c = self.out_channels
|
| 168 |
+
p = self.x_embedder.patch_size[0]
|
| 169 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 170 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 171 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 172 |
+
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
timestep: torch.Tensor,
|
| 178 |
+
class_labels: torch.Tensor,
|
| 179 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> SiTTransformer2DModelOutput:
|
| 182 |
+
x = self.x_embedder(hidden_states) + self.pos_embed
|
| 183 |
+
t = self.t_embedder(timestep)
|
| 184 |
+
y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
|
| 185 |
+
c = t + y
|
| 186 |
+
for block in self.blocks:
|
| 187 |
+
x = block(x, c)
|
| 188 |
+
x = self.final_layer(x, c)
|
| 189 |
+
x = self.unpatchify(x)
|
| 190 |
+
if self.learn_sigma:
|
| 191 |
+
x, _ = x.chunk(2, dim=1)
|
| 192 |
+
if not return_dict:
|
| 193 |
+
return (x,)
|
| 194 |
+
return SiTTransformer2DModelOutput(sample=x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
|
| 198 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 199 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 200 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 201 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 202 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 203 |
+
if cls_token and extra_tokens > 0:
|
| 204 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 205 |
+
return pos_embed
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
|
| 209 |
+
assert embed_dim % 2 == 0
|
| 210 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 211 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 212 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
|
| 216 |
+
assert embed_dim % 2 == 0
|
| 217 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 218 |
+
omega /= embed_dim / 2.0
|
| 219 |
+
omega = 1.0 / 10000**omega
|
| 220 |
+
pos = pos.reshape(-1)
|
| 221 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 222 |
+
emb_sin = np.sin(out)
|
| 223 |
+
emb_cos = np.cos(out)
|
| 224 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
SiT-B-2-256-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
SiT-B-2-256-diffusers/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
|
| 3 |
+
size 334643268
|
SiT-L-2-256-diffusers/README.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SiT-L-2-256-diffusers
|
| 13 |
+
|
| 14 |
+
Self-contained Diffusers checkpoint repo for SiT.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
+
|
| 22 |
+
pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
generator = torch.Generator(device=pipe.device).manual_seed(0)
|
| 24 |
+
|
| 25 |
+
image = pipe(
|
| 26 |
+
class_labels=207,
|
| 27 |
+
height=256,
|
| 28 |
+
width=256,
|
| 29 |
+
num_inference_steps=250,
|
| 30 |
+
guidance_scale=4.0,
|
| 31 |
+
generator=generator,
|
| 32 |
+
).images[0]
|
| 33 |
+
image.save("demo.png")
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Components
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `transformer/transformer_sit.py`
|
| 40 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 41 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
SiT-L-2-256-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"SiTPipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_flow_match_sit",
|
| 9 |
+
"SiTFlowMatchScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"transformer_sit",
|
| 13 |
+
"SiTTransformer2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
SiT-L-2-256-diffusers/pipeline.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 6 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiTPipeline(DiffusionPipeline):
|
| 11 |
+
model_cpu_offload_seq = "transformer->vae"
|
| 12 |
+
|
| 13 |
+
def __init__(self, transformer, scheduler, vae):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
|
| 16 |
+
self.vae_scale_factor = 8
|
| 17 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
class_labels: Union[int, List[int]] = 207,
|
| 23 |
+
height: int = 256,
|
| 24 |
+
width: int = 256,
|
| 25 |
+
num_inference_steps: int = 250,
|
| 26 |
+
guidance_scale: float = 4.0,
|
| 27 |
+
generator: Optional[torch.Generator] = None,
|
| 28 |
+
output_type: str = "pil",
|
| 29 |
+
return_dict: bool = True,
|
| 30 |
+
):
|
| 31 |
+
device = self._execution_device
|
| 32 |
+
if isinstance(class_labels, int):
|
| 33 |
+
class_labels = [class_labels]
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
|
| 36 |
+
latent_h = height // self.vae_scale_factor
|
| 37 |
+
latent_w = width // self.vae_scale_factor
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
(batch_size, self.transformer.config.in_channels, latent_h, latent_w),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=self.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
labels = torch.tensor(class_labels, device=device, dtype=torch.long)
|
| 46 |
+
do_cfg = guidance_scale is not None and guidance_scale > 1.0
|
| 47 |
+
if do_cfg:
|
| 48 |
+
null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
|
| 49 |
+
labels = torch.cat([labels, null_label], dim=0)
|
| 50 |
+
|
| 51 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 52 |
+
timesteps = self.scheduler.timesteps
|
| 53 |
+
|
| 54 |
+
for t in self.progress_bar(timesteps):
|
| 55 |
+
t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
|
| 56 |
+
model_input = latents
|
| 57 |
+
if do_cfg:
|
| 58 |
+
model_input = torch.cat([latents, latents], dim=0)
|
| 59 |
+
t_batch = torch.cat([t_batch, t_batch], dim=0)
|
| 60 |
+
|
| 61 |
+
model_pred = self.transformer(
|
| 62 |
+
hidden_states=model_input,
|
| 63 |
+
timestep=t_batch,
|
| 64 |
+
class_labels=labels,
|
| 65 |
+
).sample
|
| 66 |
+
|
| 67 |
+
if do_cfg:
|
| 68 |
+
cond, uncond = model_pred.chunk(2, dim=0)
|
| 69 |
+
model_pred = uncond + guidance_scale * (cond - uncond)
|
| 70 |
+
|
| 71 |
+
latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
|
| 72 |
+
|
| 73 |
+
image = self.vae.decode(latents / 0.18215).sample
|
| 74 |
+
# Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
|
| 75 |
+
if output_type == "pt":
|
| 76 |
+
image = image
|
| 77 |
+
else:
|
| 78 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 79 |
+
|
| 80 |
+
if not return_dict:
|
| 81 |
+
return (image,)
|
| 82 |
+
return ImagePipelineOutput(images=image)
|
SiT-L-2-256-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTFlowMatchScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"diffusion_form": "sigma",
|
| 5 |
+
"diffusion_norm": 1.0,
|
| 6 |
+
"mode": "ode",
|
| 7 |
+
"num_train_timesteps": 1000,
|
| 8 |
+
"shift": 1.0
|
| 9 |
+
}
|
SiT-L-2-256-diffusers/scheduler/scheduling_flow_match_sit.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SiTFlowMatchSchedulerOutput(BaseOutput):
|
| 13 |
+
prev_sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
| 17 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 18 |
+
order = 1
|
| 19 |
+
|
| 20 |
+
@register_to_config
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mode: str = "ode",
|
| 24 |
+
num_train_timesteps: int = 1000,
|
| 25 |
+
shift: float = 1.0,
|
| 26 |
+
diffusion_form: str = "sigma",
|
| 27 |
+
diffusion_norm: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
self.timesteps = None
|
| 30 |
+
self.sigmas = None
|
| 31 |
+
self._step_index = None
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
|
| 34 |
+
# Flow matching integrates from noise (t=0) to data (t=1).
|
| 35 |
+
ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 36 |
+
self.timesteps = ts[:-1]
|
| 37 |
+
self.sigmas = 1.0 - self.timesteps
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
return self.timesteps
|
| 40 |
+
|
| 41 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
form = self.config.diffusion_form
|
| 46 |
+
norm = self.config.diffusion_norm
|
| 47 |
+
if form == "constant":
|
| 48 |
+
return torch.full_like(t, norm)
|
| 49 |
+
if form == "sigma":
|
| 50 |
+
return norm * (1.0 - t)
|
| 51 |
+
if form == "linear":
|
| 52 |
+
return norm * (1.0 - t)
|
| 53 |
+
if form == "decreasing":
|
| 54 |
+
return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
|
| 55 |
+
if form == "increasing-decreasing":
|
| 56 |
+
return norm * torch.sin(torch.pi * t) ** 2
|
| 57 |
+
# "SBDM" approximated with sigma-based schedule for compatibility.
|
| 58 |
+
return norm * (1.0 - t)
|
| 59 |
+
|
| 60 |
+
def step(
|
| 61 |
+
self,
|
| 62 |
+
model_output: torch.Tensor,
|
| 63 |
+
timestep: Union[float, torch.Tensor],
|
| 64 |
+
sample: torch.Tensor,
|
| 65 |
+
generator: Optional[torch.Generator] = None,
|
| 66 |
+
return_dict: bool = True,
|
| 67 |
+
) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
|
| 68 |
+
if self.timesteps is None:
|
| 69 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 70 |
+
if self._step_index is None:
|
| 71 |
+
self._step_index = 0
|
| 72 |
+
|
| 73 |
+
step_index = min(self._step_index, len(self.timesteps) - 1)
|
| 74 |
+
t = self.timesteps[step_index].to(sample.device)
|
| 75 |
+
next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
|
| 76 |
+
dt = next_t - t
|
| 77 |
+
|
| 78 |
+
prev_sample = sample + model_output * dt
|
| 79 |
+
if self.config.mode.lower() == "sde":
|
| 80 |
+
diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
|
| 81 |
+
while diffusion.dim() < sample.dim():
|
| 82 |
+
diffusion = diffusion.unsqueeze(-1)
|
| 83 |
+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 84 |
+
prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
|
| 85 |
+
|
| 86 |
+
self._step_index += 1
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
|
| 90 |
+
|
| 91 |
+
def add_noise(
|
| 92 |
+
self,
|
| 93 |
+
original_samples: torch.Tensor,
|
| 94 |
+
noise: torch.Tensor,
|
| 95 |
+
timesteps: torch.Tensor,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
|
| 98 |
+
return (1 - sigma) * original_samples + sigma * noise
|
SiT-L-2-256-diffusers/transformer/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"class_dropout_prob": 0.1,
|
| 5 |
+
"depth": 24,
|
| 6 |
+
"hidden_size": 1024,
|
| 7 |
+
"in_channels": 4,
|
| 8 |
+
"input_size": 32,
|
| 9 |
+
"learn_sigma": true,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"num_classes": 1000,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"patch_size": 2
|
| 14 |
+
}
|
SiT-L-2-256-diffusers/transformer/transformer_sit.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
| 9 |
+
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from diffusers.utils import BaseOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SiTTransformer2DModelOutput(BaseOutput):
|
| 21 |
+
sample: torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 36 |
+
half = dim // 2
|
| 37 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 38 |
+
device=t.device
|
| 39 |
+
)
|
| 40 |
+
args = t[:, None].float() * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
if dim % 2:
|
| 43 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 44 |
+
return embedding
|
| 45 |
+
|
| 46 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
|
| 52 |
+
super().__init__()
|
| 53 |
+
use_cfg_embedding = dropout_prob > 0
|
| 54 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_prob = dropout_prob
|
| 57 |
+
|
| 58 |
+
def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 59 |
+
if force_drop_ids is None:
|
| 60 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 61 |
+
else:
|
| 62 |
+
drop_ids = force_drop_ids == 1
|
| 63 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
labels: torch.Tensor,
|
| 69 |
+
train: bool,
|
| 70 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
use_dropout = self.dropout_prob > 0
|
| 73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 75 |
+
return self.embedding_table(labels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SiTBlock(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 82 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 83 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 84 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 85 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 86 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 87 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 91 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 92 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FinalLayer(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 100 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 101 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 105 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 106 |
+
return self.linear(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 110 |
+
@register_to_config
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
input_size: int = 32,
|
| 114 |
+
patch_size: int = 2,
|
| 115 |
+
in_channels: int = 4,
|
| 116 |
+
hidden_size: int = 1152,
|
| 117 |
+
depth: int = 28,
|
| 118 |
+
num_heads: int = 16,
|
| 119 |
+
mlp_ratio: float = 4.0,
|
| 120 |
+
class_dropout_prob: float = 0.1,
|
| 121 |
+
num_classes: int = 1000,
|
| 122 |
+
learn_sigma: bool = True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.learn_sigma = learn_sigma
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_classes = num_classes
|
| 130 |
+
|
| 131 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 132 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 133 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 134 |
+
num_patches = self.x_embedder.num_patches
|
| 135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 136 |
+
|
| 137 |
+
self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 138 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
|
| 141 |
+
def initialize_weights(self) -> None:
|
| 142 |
+
def _basic_init(module: nn.Module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
|
| 150 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 151 |
+
|
| 152 |
+
w = self.x_embedder.proj.weight.data
|
| 153 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 154 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 155 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 156 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 157 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 158 |
+
for block in self.blocks:
|
| 159 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 160 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 162 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 163 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 164 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 165 |
+
|
| 166 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
c = self.out_channels
|
| 168 |
+
p = self.x_embedder.patch_size[0]
|
| 169 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 170 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 171 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 172 |
+
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
timestep: torch.Tensor,
|
| 178 |
+
class_labels: torch.Tensor,
|
| 179 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> SiTTransformer2DModelOutput:
|
| 182 |
+
x = self.x_embedder(hidden_states) + self.pos_embed
|
| 183 |
+
t = self.t_embedder(timestep)
|
| 184 |
+
y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
|
| 185 |
+
c = t + y
|
| 186 |
+
for block in self.blocks:
|
| 187 |
+
x = block(x, c)
|
| 188 |
+
x = self.final_layer(x, c)
|
| 189 |
+
x = self.unpatchify(x)
|
| 190 |
+
if self.learn_sigma:
|
| 191 |
+
x, _ = x.chunk(2, dim=1)
|
| 192 |
+
if not return_dict:
|
| 193 |
+
return (x,)
|
| 194 |
+
return SiTTransformer2DModelOutput(sample=x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
|
| 198 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 199 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 200 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 201 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 202 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 203 |
+
if cls_token and extra_tokens > 0:
|
| 204 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 205 |
+
return pos_embed
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
|
| 209 |
+
assert embed_dim % 2 == 0
|
| 210 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 211 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 212 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
|
| 216 |
+
assert embed_dim % 2 == 0
|
| 217 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 218 |
+
omega /= embed_dim / 2.0
|
| 219 |
+
omega = 1.0 / 10000**omega
|
| 220 |
+
pos = pos.reshape(-1)
|
| 221 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 222 |
+
emb_sin = np.sin(out)
|
| 223 |
+
emb_cos = np.cos(out)
|
| 224 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
SiT-L-2-256-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
SiT-L-2-256-diffusers/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
|
| 3 |
+
size 334643268
|
SiT-S-2-256-diffusers/README.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SiT-S-2-256-diffusers
|
| 13 |
+
|
| 14 |
+
Self-contained Diffusers checkpoint repo for SiT.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
+
|
| 22 |
+
pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
generator = torch.Generator(device=pipe.device).manual_seed(0)
|
| 24 |
+
|
| 25 |
+
image = pipe(
|
| 26 |
+
class_labels=207,
|
| 27 |
+
height=256,
|
| 28 |
+
width=256,
|
| 29 |
+
num_inference_steps=250,
|
| 30 |
+
guidance_scale=4.0,
|
| 31 |
+
generator=generator,
|
| 32 |
+
).images[0]
|
| 33 |
+
image.save("demo.png")
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Components
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `transformer/transformer_sit.py`
|
| 40 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 41 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
SiT-S-2-256-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"SiTPipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_flow_match_sit",
|
| 9 |
+
"SiTFlowMatchScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"transformer_sit",
|
| 13 |
+
"SiTTransformer2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
SiT-S-2-256-diffusers/pipeline.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 6 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiTPipeline(DiffusionPipeline):
|
| 11 |
+
model_cpu_offload_seq = "transformer->vae"
|
| 12 |
+
|
| 13 |
+
def __init__(self, transformer, scheduler, vae):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
|
| 16 |
+
self.vae_scale_factor = 8
|
| 17 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
class_labels: Union[int, List[int]] = 207,
|
| 23 |
+
height: int = 256,
|
| 24 |
+
width: int = 256,
|
| 25 |
+
num_inference_steps: int = 250,
|
| 26 |
+
guidance_scale: float = 4.0,
|
| 27 |
+
generator: Optional[torch.Generator] = None,
|
| 28 |
+
output_type: str = "pil",
|
| 29 |
+
return_dict: bool = True,
|
| 30 |
+
):
|
| 31 |
+
device = self._execution_device
|
| 32 |
+
if isinstance(class_labels, int):
|
| 33 |
+
class_labels = [class_labels]
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
|
| 36 |
+
latent_h = height // self.vae_scale_factor
|
| 37 |
+
latent_w = width // self.vae_scale_factor
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
(batch_size, self.transformer.config.in_channels, latent_h, latent_w),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=self.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
labels = torch.tensor(class_labels, device=device, dtype=torch.long)
|
| 46 |
+
do_cfg = guidance_scale is not None and guidance_scale > 1.0
|
| 47 |
+
if do_cfg:
|
| 48 |
+
null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
|
| 49 |
+
labels = torch.cat([labels, null_label], dim=0)
|
| 50 |
+
|
| 51 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 52 |
+
timesteps = self.scheduler.timesteps
|
| 53 |
+
|
| 54 |
+
for t in self.progress_bar(timesteps):
|
| 55 |
+
t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
|
| 56 |
+
model_input = latents
|
| 57 |
+
if do_cfg:
|
| 58 |
+
model_input = torch.cat([latents, latents], dim=0)
|
| 59 |
+
t_batch = torch.cat([t_batch, t_batch], dim=0)
|
| 60 |
+
|
| 61 |
+
model_pred = self.transformer(
|
| 62 |
+
hidden_states=model_input,
|
| 63 |
+
timestep=t_batch,
|
| 64 |
+
class_labels=labels,
|
| 65 |
+
).sample
|
| 66 |
+
|
| 67 |
+
if do_cfg:
|
| 68 |
+
cond, uncond = model_pred.chunk(2, dim=0)
|
| 69 |
+
model_pred = uncond + guidance_scale * (cond - uncond)
|
| 70 |
+
|
| 71 |
+
latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
|
| 72 |
+
|
| 73 |
+
image = self.vae.decode(latents / 0.18215).sample
|
| 74 |
+
# Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
|
| 75 |
+
if output_type == "pt":
|
| 76 |
+
image = image
|
| 77 |
+
else:
|
| 78 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 79 |
+
|
| 80 |
+
if not return_dict:
|
| 81 |
+
return (image,)
|
| 82 |
+
return ImagePipelineOutput(images=image)
|
SiT-S-2-256-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTFlowMatchScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"diffusion_form": "sigma",
|
| 5 |
+
"diffusion_norm": 1.0,
|
| 6 |
+
"mode": "ode",
|
| 7 |
+
"num_train_timesteps": 1000,
|
| 8 |
+
"shift": 1.0
|
| 9 |
+
}
|
SiT-S-2-256-diffusers/scheduler/scheduling_flow_match_sit.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SiTFlowMatchSchedulerOutput(BaseOutput):
|
| 13 |
+
prev_sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
| 17 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 18 |
+
order = 1
|
| 19 |
+
|
| 20 |
+
@register_to_config
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mode: str = "ode",
|
| 24 |
+
num_train_timesteps: int = 1000,
|
| 25 |
+
shift: float = 1.0,
|
| 26 |
+
diffusion_form: str = "sigma",
|
| 27 |
+
diffusion_norm: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
self.timesteps = None
|
| 30 |
+
self.sigmas = None
|
| 31 |
+
self._step_index = None
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
|
| 34 |
+
# Flow matching integrates from noise (t=0) to data (t=1).
|
| 35 |
+
ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 36 |
+
self.timesteps = ts[:-1]
|
| 37 |
+
self.sigmas = 1.0 - self.timesteps
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
return self.timesteps
|
| 40 |
+
|
| 41 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
form = self.config.diffusion_form
|
| 46 |
+
norm = self.config.diffusion_norm
|
| 47 |
+
if form == "constant":
|
| 48 |
+
return torch.full_like(t, norm)
|
| 49 |
+
if form == "sigma":
|
| 50 |
+
return norm * (1.0 - t)
|
| 51 |
+
if form == "linear":
|
| 52 |
+
return norm * (1.0 - t)
|
| 53 |
+
if form == "decreasing":
|
| 54 |
+
return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
|
| 55 |
+
if form == "increasing-decreasing":
|
| 56 |
+
return norm * torch.sin(torch.pi * t) ** 2
|
| 57 |
+
# "SBDM" approximated with sigma-based schedule for compatibility.
|
| 58 |
+
return norm * (1.0 - t)
|
| 59 |
+
|
| 60 |
+
def step(
|
| 61 |
+
self,
|
| 62 |
+
model_output: torch.Tensor,
|
| 63 |
+
timestep: Union[float, torch.Tensor],
|
| 64 |
+
sample: torch.Tensor,
|
| 65 |
+
generator: Optional[torch.Generator] = None,
|
| 66 |
+
return_dict: bool = True,
|
| 67 |
+
) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
|
| 68 |
+
if self.timesteps is None:
|
| 69 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 70 |
+
if self._step_index is None:
|
| 71 |
+
self._step_index = 0
|
| 72 |
+
|
| 73 |
+
step_index = min(self._step_index, len(self.timesteps) - 1)
|
| 74 |
+
t = self.timesteps[step_index].to(sample.device)
|
| 75 |
+
next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
|
| 76 |
+
dt = next_t - t
|
| 77 |
+
|
| 78 |
+
prev_sample = sample + model_output * dt
|
| 79 |
+
if self.config.mode.lower() == "sde":
|
| 80 |
+
diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
|
| 81 |
+
while diffusion.dim() < sample.dim():
|
| 82 |
+
diffusion = diffusion.unsqueeze(-1)
|
| 83 |
+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 84 |
+
prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
|
| 85 |
+
|
| 86 |
+
self._step_index += 1
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
|
| 90 |
+
|
| 91 |
+
def add_noise(
|
| 92 |
+
self,
|
| 93 |
+
original_samples: torch.Tensor,
|
| 94 |
+
noise: torch.Tensor,
|
| 95 |
+
timesteps: torch.Tensor,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
|
| 98 |
+
return (1 - sigma) * original_samples + sigma * noise
|
SiT-S-2-256-diffusers/transformer/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"class_dropout_prob": 0.1,
|
| 5 |
+
"depth": 12,
|
| 6 |
+
"hidden_size": 384,
|
| 7 |
+
"in_channels": 4,
|
| 8 |
+
"input_size": 32,
|
| 9 |
+
"learn_sigma": true,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"num_classes": 1000,
|
| 12 |
+
"num_heads": 6,
|
| 13 |
+
"patch_size": 2
|
| 14 |
+
}
|
SiT-S-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b0754c57c2b6e2e4e74b181d1730a8bd824a30eafeaae9eaf8bc4015e8e4f39
|
| 3 |
+
size 131866144
|
SiT-S-2-256-diffusers/transformer/transformer_sit.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
| 9 |
+
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from diffusers.utils import BaseOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SiTTransformer2DModelOutput(BaseOutput):
|
| 21 |
+
sample: torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 36 |
+
half = dim // 2
|
| 37 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 38 |
+
device=t.device
|
| 39 |
+
)
|
| 40 |
+
args = t[:, None].float() * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
if dim % 2:
|
| 43 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 44 |
+
return embedding
|
| 45 |
+
|
| 46 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
|
| 52 |
+
super().__init__()
|
| 53 |
+
use_cfg_embedding = dropout_prob > 0
|
| 54 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_prob = dropout_prob
|
| 57 |
+
|
| 58 |
+
def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 59 |
+
if force_drop_ids is None:
|
| 60 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 61 |
+
else:
|
| 62 |
+
drop_ids = force_drop_ids == 1
|
| 63 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
labels: torch.Tensor,
|
| 69 |
+
train: bool,
|
| 70 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
use_dropout = self.dropout_prob > 0
|
| 73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 75 |
+
return self.embedding_table(labels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SiTBlock(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 82 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 83 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 84 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 85 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 86 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 87 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 91 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 92 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FinalLayer(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 100 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 101 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 105 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 106 |
+
return self.linear(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 110 |
+
@register_to_config
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
input_size: int = 32,
|
| 114 |
+
patch_size: int = 2,
|
| 115 |
+
in_channels: int = 4,
|
| 116 |
+
hidden_size: int = 1152,
|
| 117 |
+
depth: int = 28,
|
| 118 |
+
num_heads: int = 16,
|
| 119 |
+
mlp_ratio: float = 4.0,
|
| 120 |
+
class_dropout_prob: float = 0.1,
|
| 121 |
+
num_classes: int = 1000,
|
| 122 |
+
learn_sigma: bool = True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.learn_sigma = learn_sigma
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_classes = num_classes
|
| 130 |
+
|
| 131 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 132 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 133 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 134 |
+
num_patches = self.x_embedder.num_patches
|
| 135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 136 |
+
|
| 137 |
+
self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 138 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
|
| 141 |
+
def initialize_weights(self) -> None:
|
| 142 |
+
def _basic_init(module: nn.Module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
|
| 150 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 151 |
+
|
| 152 |
+
w = self.x_embedder.proj.weight.data
|
| 153 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 154 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 155 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 156 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 157 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 158 |
+
for block in self.blocks:
|
| 159 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 160 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 162 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 163 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 164 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 165 |
+
|
| 166 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
c = self.out_channels
|
| 168 |
+
p = self.x_embedder.patch_size[0]
|
| 169 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 170 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 171 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 172 |
+
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
timestep: torch.Tensor,
|
| 178 |
+
class_labels: torch.Tensor,
|
| 179 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> SiTTransformer2DModelOutput:
|
| 182 |
+
x = self.x_embedder(hidden_states) + self.pos_embed
|
| 183 |
+
t = self.t_embedder(timestep)
|
| 184 |
+
y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
|
| 185 |
+
c = t + y
|
| 186 |
+
for block in self.blocks:
|
| 187 |
+
x = block(x, c)
|
| 188 |
+
x = self.final_layer(x, c)
|
| 189 |
+
x = self.unpatchify(x)
|
| 190 |
+
if self.learn_sigma:
|
| 191 |
+
x, _ = x.chunk(2, dim=1)
|
| 192 |
+
if not return_dict:
|
| 193 |
+
return (x,)
|
| 194 |
+
return SiTTransformer2DModelOutput(sample=x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
|
| 198 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 199 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 200 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 201 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 202 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 203 |
+
if cls_token and extra_tokens > 0:
|
| 204 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 205 |
+
return pos_embed
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
|
| 209 |
+
assert embed_dim % 2 == 0
|
| 210 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 211 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 212 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
|
| 216 |
+
assert embed_dim % 2 == 0
|
| 217 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 218 |
+
omega /= embed_dim / 2.0
|
| 219 |
+
omega = 1.0 / 10000**omega
|
| 220 |
+
pos = pos.reshape(-1)
|
| 221 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 222 |
+
emb_sin = np.sin(out)
|
| 223 |
+
emb_cos = np.cos(out)
|
| 224 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
SiT-S-2-256-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
SiT-S-2-256-diffusers/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
|
| 3 |
+
size 334643268
|
SiT-XL-2-256-diffusers/README.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SiT-XL-2-256-diffusers
|
| 13 |
+
|
| 14 |
+
Self-contained Diffusers checkpoint repo for SiT.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
+
|
| 22 |
+
pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
generator = torch.Generator(device=pipe.device).manual_seed(0)
|
| 24 |
+
|
| 25 |
+
image = pipe(
|
| 26 |
+
class_labels=207,
|
| 27 |
+
height=256,
|
| 28 |
+
width=256,
|
| 29 |
+
num_inference_steps=250,
|
| 30 |
+
guidance_scale=4.0,
|
| 31 |
+
generator=generator,
|
| 32 |
+
).images[0]
|
| 33 |
+
image.save("demo.png")
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Components
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `transformer/transformer_sit.py`
|
| 40 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 41 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
SiT-XL-2-256-diffusers/demo_50steps.png
ADDED
|
SiT-XL-2-256-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"SiTPipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_flow_match_sit",
|
| 9 |
+
"SiTFlowMatchScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"transformer_sit",
|
| 13 |
+
"SiTTransformer2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
SiT-XL-2-256-diffusers/pipeline.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 6 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiTPipeline(DiffusionPipeline):
|
| 11 |
+
model_cpu_offload_seq = "transformer->vae"
|
| 12 |
+
|
| 13 |
+
def __init__(self, transformer, scheduler, vae):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
|
| 16 |
+
self.vae_scale_factor = 8
|
| 17 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
class_labels: Union[int, List[int]] = 207,
|
| 23 |
+
height: int = 256,
|
| 24 |
+
width: int = 256,
|
| 25 |
+
num_inference_steps: int = 250,
|
| 26 |
+
guidance_scale: float = 4.0,
|
| 27 |
+
generator: Optional[torch.Generator] = None,
|
| 28 |
+
output_type: str = "pil",
|
| 29 |
+
return_dict: bool = True,
|
| 30 |
+
):
|
| 31 |
+
device = self._execution_device
|
| 32 |
+
if isinstance(class_labels, int):
|
| 33 |
+
class_labels = [class_labels]
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
|
| 36 |
+
latent_h = height // self.vae_scale_factor
|
| 37 |
+
latent_w = width // self.vae_scale_factor
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
(batch_size, self.transformer.config.in_channels, latent_h, latent_w),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=self.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
labels = torch.tensor(class_labels, device=device, dtype=torch.long)
|
| 46 |
+
do_cfg = guidance_scale is not None and guidance_scale > 1.0
|
| 47 |
+
if do_cfg:
|
| 48 |
+
null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
|
| 49 |
+
labels = torch.cat([labels, null_label], dim=0)
|
| 50 |
+
|
| 51 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 52 |
+
timesteps = self.scheduler.timesteps
|
| 53 |
+
|
| 54 |
+
for t in self.progress_bar(timesteps):
|
| 55 |
+
t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
|
| 56 |
+
model_input = latents
|
| 57 |
+
if do_cfg:
|
| 58 |
+
model_input = torch.cat([latents, latents], dim=0)
|
| 59 |
+
t_batch = torch.cat([t_batch, t_batch], dim=0)
|
| 60 |
+
|
| 61 |
+
model_pred = self.transformer(
|
| 62 |
+
hidden_states=model_input,
|
| 63 |
+
timestep=t_batch,
|
| 64 |
+
class_labels=labels,
|
| 65 |
+
).sample
|
| 66 |
+
|
| 67 |
+
if do_cfg:
|
| 68 |
+
cond, uncond = model_pred.chunk(2, dim=0)
|
| 69 |
+
model_pred = uncond + guidance_scale * (cond - uncond)
|
| 70 |
+
|
| 71 |
+
latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
|
| 72 |
+
|
| 73 |
+
image = self.vae.decode(latents / 0.18215).sample
|
| 74 |
+
# Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
|
| 75 |
+
if output_type == "pt":
|
| 76 |
+
image = image
|
| 77 |
+
else:
|
| 78 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 79 |
+
|
| 80 |
+
if not return_dict:
|
| 81 |
+
return (image,)
|
| 82 |
+
return ImagePipelineOutput(images=image)
|
SiT-XL-2-256-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTFlowMatchScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"diffusion_form": "sigma",
|
| 5 |
+
"diffusion_norm": 1.0,
|
| 6 |
+
"mode": "ode",
|
| 7 |
+
"num_train_timesteps": 1000,
|
| 8 |
+
"shift": 1.0
|
| 9 |
+
}
|
SiT-XL-2-256-diffusers/scheduler/scheduling_flow_match_sit.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SiTFlowMatchSchedulerOutput(BaseOutput):
|
| 13 |
+
prev_sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
| 17 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 18 |
+
order = 1
|
| 19 |
+
|
| 20 |
+
@register_to_config
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mode: str = "ode",
|
| 24 |
+
num_train_timesteps: int = 1000,
|
| 25 |
+
shift: float = 1.0,
|
| 26 |
+
diffusion_form: str = "sigma",
|
| 27 |
+
diffusion_norm: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
self.timesteps = None
|
| 30 |
+
self.sigmas = None
|
| 31 |
+
self._step_index = None
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
|
| 34 |
+
# Flow matching integrates from noise (t=0) to data (t=1).
|
| 35 |
+
ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 36 |
+
self.timesteps = ts[:-1]
|
| 37 |
+
self.sigmas = 1.0 - self.timesteps
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
return self.timesteps
|
| 40 |
+
|
| 41 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
form = self.config.diffusion_form
|
| 46 |
+
norm = self.config.diffusion_norm
|
| 47 |
+
if form == "constant":
|
| 48 |
+
return torch.full_like(t, norm)
|
| 49 |
+
if form == "sigma":
|
| 50 |
+
return norm * (1.0 - t)
|
| 51 |
+
if form == "linear":
|
| 52 |
+
return norm * (1.0 - t)
|
| 53 |
+
if form == "decreasing":
|
| 54 |
+
return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
|
| 55 |
+
if form == "increasing-decreasing":
|
| 56 |
+
return norm * torch.sin(torch.pi * t) ** 2
|
| 57 |
+
# "SBDM" approximated with sigma-based schedule for compatibility.
|
| 58 |
+
return norm * (1.0 - t)
|
| 59 |
+
|
| 60 |
+
def step(
|
| 61 |
+
self,
|
| 62 |
+
model_output: torch.Tensor,
|
| 63 |
+
timestep: Union[float, torch.Tensor],
|
| 64 |
+
sample: torch.Tensor,
|
| 65 |
+
generator: Optional[torch.Generator] = None,
|
| 66 |
+
return_dict: bool = True,
|
| 67 |
+
) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
|
| 68 |
+
if self.timesteps is None:
|
| 69 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 70 |
+
if self._step_index is None:
|
| 71 |
+
self._step_index = 0
|
| 72 |
+
|
| 73 |
+
step_index = min(self._step_index, len(self.timesteps) - 1)
|
| 74 |
+
t = self.timesteps[step_index].to(sample.device)
|
| 75 |
+
next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
|
| 76 |
+
dt = next_t - t
|
| 77 |
+
|
| 78 |
+
prev_sample = sample + model_output * dt
|
| 79 |
+
if self.config.mode.lower() == "sde":
|
| 80 |
+
diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
|
| 81 |
+
while diffusion.dim() < sample.dim():
|
| 82 |
+
diffusion = diffusion.unsqueeze(-1)
|
| 83 |
+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 84 |
+
prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
|
| 85 |
+
|
| 86 |
+
self._step_index += 1
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
|
| 90 |
+
|
| 91 |
+
def add_noise(
|
| 92 |
+
self,
|
| 93 |
+
original_samples: torch.Tensor,
|
| 94 |
+
noise: torch.Tensor,
|
| 95 |
+
timesteps: torch.Tensor,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
|
| 98 |
+
return (1 - sigma) * original_samples + sigma * noise
|
SiT-XL-2-256-diffusers/transformer/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"class_dropout_prob": 0.1,
|
| 5 |
+
"depth": 28,
|
| 6 |
+
"hidden_size": 1152,
|
| 7 |
+
"in_channels": 4,
|
| 8 |
+
"input_size": 32,
|
| 9 |
+
"learn_sigma": true,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"num_classes": 1000,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"patch_size": 2
|
| 14 |
+
}
|
SiT-XL-2-256-diffusers/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fc19454ff52c2f741194974b567a8679709dac950a2f74b766eaaeae22bdc09
|
| 3 |
+
size 2700547792
|
SiT-XL-2-256-diffusers/transformer/transformer_sit.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
| 9 |
+
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from diffusers.utils import BaseOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SiTTransformer2DModelOutput(BaseOutput):
|
| 21 |
+
sample: torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 36 |
+
half = dim // 2
|
| 37 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 38 |
+
device=t.device
|
| 39 |
+
)
|
| 40 |
+
args = t[:, None].float() * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
if dim % 2:
|
| 43 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 44 |
+
return embedding
|
| 45 |
+
|
| 46 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
|
| 52 |
+
super().__init__()
|
| 53 |
+
use_cfg_embedding = dropout_prob > 0
|
| 54 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_prob = dropout_prob
|
| 57 |
+
|
| 58 |
+
def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 59 |
+
if force_drop_ids is None:
|
| 60 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 61 |
+
else:
|
| 62 |
+
drop_ids = force_drop_ids == 1
|
| 63 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
labels: torch.Tensor,
|
| 69 |
+
train: bool,
|
| 70 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
use_dropout = self.dropout_prob > 0
|
| 73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 75 |
+
return self.embedding_table(labels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SiTBlock(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 82 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 83 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 84 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 85 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 86 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 87 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 91 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 92 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FinalLayer(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 100 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 101 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 105 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 106 |
+
return self.linear(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 110 |
+
@register_to_config
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
input_size: int = 32,
|
| 114 |
+
patch_size: int = 2,
|
| 115 |
+
in_channels: int = 4,
|
| 116 |
+
hidden_size: int = 1152,
|
| 117 |
+
depth: int = 28,
|
| 118 |
+
num_heads: int = 16,
|
| 119 |
+
mlp_ratio: float = 4.0,
|
| 120 |
+
class_dropout_prob: float = 0.1,
|
| 121 |
+
num_classes: int = 1000,
|
| 122 |
+
learn_sigma: bool = True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.learn_sigma = learn_sigma
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_classes = num_classes
|
| 130 |
+
|
| 131 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 132 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 133 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 134 |
+
num_patches = self.x_embedder.num_patches
|
| 135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 136 |
+
|
| 137 |
+
self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 138 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
|
| 141 |
+
def initialize_weights(self) -> None:
|
| 142 |
+
def _basic_init(module: nn.Module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
|
| 150 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 151 |
+
|
| 152 |
+
w = self.x_embedder.proj.weight.data
|
| 153 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 154 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 155 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 156 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 157 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 158 |
+
for block in self.blocks:
|
| 159 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 160 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 162 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 163 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 164 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 165 |
+
|
| 166 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
c = self.out_channels
|
| 168 |
+
p = self.x_embedder.patch_size[0]
|
| 169 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 170 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 171 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 172 |
+
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
timestep: torch.Tensor,
|
| 178 |
+
class_labels: torch.Tensor,
|
| 179 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> SiTTransformer2DModelOutput:
|
| 182 |
+
x = self.x_embedder(hidden_states) + self.pos_embed
|
| 183 |
+
t = self.t_embedder(timestep)
|
| 184 |
+
y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
|
| 185 |
+
c = t + y
|
| 186 |
+
for block in self.blocks:
|
| 187 |
+
x = block(x, c)
|
| 188 |
+
x = self.final_layer(x, c)
|
| 189 |
+
x = self.unpatchify(x)
|
| 190 |
+
if self.learn_sigma:
|
| 191 |
+
x, _ = x.chunk(2, dim=1)
|
| 192 |
+
if not return_dict:
|
| 193 |
+
return (x,)
|
| 194 |
+
return SiTTransformer2DModelOutput(sample=x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
|
| 198 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 199 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 200 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 201 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 202 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 203 |
+
if cls_token and extra_tokens > 0:
|
| 204 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 205 |
+
return pos_embed
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
|
| 209 |
+
assert embed_dim % 2 == 0
|
| 210 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 211 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 212 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
|
| 216 |
+
assert embed_dim % 2 == 0
|
| 217 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 218 |
+
omega /= embed_dim / 2.0
|
| 219 |
+
omega = 1.0 / 10000**omega
|
| 220 |
+
pos = pos.reshape(-1)
|
| 221 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 222 |
+
emb_sin = np.sin(out)
|
| 223 |
+
emb_cos = np.cos(out)
|
| 224 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
SiT-XL-2-256-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|
SiT-XL-2-256-diffusers/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
|
| 3 |
+
size 334643268
|
SiT-XL-2-512-diffusers/README.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
pipeline_tag: unconditional-image-generation
|
| 4 |
+
tags:
|
| 5 |
+
- diffusers
|
| 6 |
+
- sit
|
| 7 |
+
- image-generation
|
| 8 |
+
- class-conditional
|
| 9 |
+
inference: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SiT-XL-2-512-diffusers
|
| 13 |
+
|
| 14 |
+
Self-contained Diffusers checkpoint repo for SiT.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
+
|
| 22 |
+
pipe = DiffusionPipeline.from_pretrained("./").to("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
generator = torch.Generator(device=pipe.device).manual_seed(0)
|
| 24 |
+
|
| 25 |
+
image = pipe(
|
| 26 |
+
class_labels=207,
|
| 27 |
+
height=512,
|
| 28 |
+
width=512,
|
| 29 |
+
num_inference_steps=250,
|
| 30 |
+
guidance_scale=4.0,
|
| 31 |
+
generator=generator,
|
| 32 |
+
).images[0]
|
| 33 |
+
image.save("demo.png")
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Components
|
| 37 |
+
|
| 38 |
+
- `pipeline.py`
|
| 39 |
+
- `transformer/transformer_sit.py`
|
| 40 |
+
- `scheduler/scheduling_flow_match_sit.py`
|
| 41 |
+
- `transformer/diffusion_pytorch_model.safetensors`
|
| 42 |
+
- `vae/diffusion_pytorch_model.safetensors`
|
SiT-XL-2-512-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"SiTPipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.36.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"scheduling_flow_match_sit",
|
| 9 |
+
"SiTFlowMatchScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"transformer_sit",
|
| 13 |
+
"SiTTransformer2DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"diffusers",
|
| 17 |
+
"AutoencoderKL"
|
| 18 |
+
]
|
| 19 |
+
}
|
SiT-XL-2-512-diffusers/pipeline.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 6 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 7 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiTPipeline(DiffusionPipeline):
|
| 11 |
+
model_cpu_offload_seq = "transformer->vae"
|
| 12 |
+
|
| 13 |
+
def __init__(self, transformer, scheduler, vae):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
|
| 16 |
+
self.vae_scale_factor = 8
|
| 17 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def __call__(
|
| 21 |
+
self,
|
| 22 |
+
class_labels: Union[int, List[int]] = 207,
|
| 23 |
+
height: int = 256,
|
| 24 |
+
width: int = 256,
|
| 25 |
+
num_inference_steps: int = 250,
|
| 26 |
+
guidance_scale: float = 4.0,
|
| 27 |
+
generator: Optional[torch.Generator] = None,
|
| 28 |
+
output_type: str = "pil",
|
| 29 |
+
return_dict: bool = True,
|
| 30 |
+
):
|
| 31 |
+
device = self._execution_device
|
| 32 |
+
if isinstance(class_labels, int):
|
| 33 |
+
class_labels = [class_labels]
|
| 34 |
+
batch_size = len(class_labels)
|
| 35 |
+
|
| 36 |
+
latent_h = height // self.vae_scale_factor
|
| 37 |
+
latent_w = width // self.vae_scale_factor
|
| 38 |
+
latents = randn_tensor(
|
| 39 |
+
(batch_size, self.transformer.config.in_channels, latent_h, latent_w),
|
| 40 |
+
generator=generator,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=self.transformer.dtype,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
labels = torch.tensor(class_labels, device=device, dtype=torch.long)
|
| 46 |
+
do_cfg = guidance_scale is not None and guidance_scale > 1.0
|
| 47 |
+
if do_cfg:
|
| 48 |
+
null_label = torch.full((batch_size,), self.transformer.config.num_classes, device=device, dtype=torch.long)
|
| 49 |
+
labels = torch.cat([labels, null_label], dim=0)
|
| 50 |
+
|
| 51 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 52 |
+
timesteps = self.scheduler.timesteps
|
| 53 |
+
|
| 54 |
+
for t in self.progress_bar(timesteps):
|
| 55 |
+
t_batch = torch.full((batch_size,), t, device=device, dtype=latents.dtype)
|
| 56 |
+
model_input = latents
|
| 57 |
+
if do_cfg:
|
| 58 |
+
model_input = torch.cat([latents, latents], dim=0)
|
| 59 |
+
t_batch = torch.cat([t_batch, t_batch], dim=0)
|
| 60 |
+
|
| 61 |
+
model_pred = self.transformer(
|
| 62 |
+
hidden_states=model_input,
|
| 63 |
+
timestep=t_batch,
|
| 64 |
+
class_labels=labels,
|
| 65 |
+
).sample
|
| 66 |
+
|
| 67 |
+
if do_cfg:
|
| 68 |
+
cond, uncond = model_pred.chunk(2, dim=0)
|
| 69 |
+
model_pred = uncond + guidance_scale * (cond - uncond)
|
| 70 |
+
|
| 71 |
+
latents = self.scheduler.step(model_pred, t, latents, generator=generator).prev_sample
|
| 72 |
+
|
| 73 |
+
image = self.vae.decode(latents / 0.18215).sample
|
| 74 |
+
# Keep PyTorch outputs in raw VAE range [-1, 1] to match original SiT scripts.
|
| 75 |
+
if output_type == "pt":
|
| 76 |
+
image = image
|
| 77 |
+
else:
|
| 78 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 79 |
+
|
| 80 |
+
if not return_dict:
|
| 81 |
+
return (image,)
|
| 82 |
+
return ImagePipelineOutput(images=image)
|
SiT-XL-2-512-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTFlowMatchScheduler",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"diffusion_form": "sigma",
|
| 5 |
+
"diffusion_norm": 1.0,
|
| 6 |
+
"mode": "ode",
|
| 7 |
+
"num_train_timesteps": 1000,
|
| 8 |
+
"shift": 1.0
|
| 9 |
+
}
|
SiT-XL-2-512-diffusers/scheduler/scheduling_flow_match_sit.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SiTFlowMatchSchedulerOutput(BaseOutput):
|
| 13 |
+
prev_sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
| 17 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 18 |
+
order = 1
|
| 19 |
+
|
| 20 |
+
@register_to_config
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mode: str = "ode",
|
| 24 |
+
num_train_timesteps: int = 1000,
|
| 25 |
+
shift: float = 1.0,
|
| 26 |
+
diffusion_form: str = "sigma",
|
| 27 |
+
diffusion_norm: float = 1.0,
|
| 28 |
+
):
|
| 29 |
+
self.timesteps = None
|
| 30 |
+
self.sigmas = None
|
| 31 |
+
self._step_index = None
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None):
|
| 34 |
+
# Flow matching integrates from noise (t=0) to data (t=1).
|
| 35 |
+
ts = torch.linspace(0.0, 1.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 36 |
+
self.timesteps = ts[:-1]
|
| 37 |
+
self.sigmas = 1.0 - self.timesteps
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
return self.timesteps
|
| 40 |
+
|
| 41 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _diffusion(self, t: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
form = self.config.diffusion_form
|
| 46 |
+
norm = self.config.diffusion_norm
|
| 47 |
+
if form == "constant":
|
| 48 |
+
return torch.full_like(t, norm)
|
| 49 |
+
if form == "sigma":
|
| 50 |
+
return norm * (1.0 - t)
|
| 51 |
+
if form == "linear":
|
| 52 |
+
return norm * (1.0 - t)
|
| 53 |
+
if form == "decreasing":
|
| 54 |
+
return 0.25 * (norm * torch.cos(torch.pi * t) + 1) ** 2
|
| 55 |
+
if form == "increasing-decreasing":
|
| 56 |
+
return norm * torch.sin(torch.pi * t) ** 2
|
| 57 |
+
# "SBDM" approximated with sigma-based schedule for compatibility.
|
| 58 |
+
return norm * (1.0 - t)
|
| 59 |
+
|
| 60 |
+
def step(
|
| 61 |
+
self,
|
| 62 |
+
model_output: torch.Tensor,
|
| 63 |
+
timestep: Union[float, torch.Tensor],
|
| 64 |
+
sample: torch.Tensor,
|
| 65 |
+
generator: Optional[torch.Generator] = None,
|
| 66 |
+
return_dict: bool = True,
|
| 67 |
+
) -> Union[SiTFlowMatchSchedulerOutput, Tuple[torch.Tensor]]:
|
| 68 |
+
if self.timesteps is None:
|
| 69 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 70 |
+
if self._step_index is None:
|
| 71 |
+
self._step_index = 0
|
| 72 |
+
|
| 73 |
+
step_index = min(self._step_index, len(self.timesteps) - 1)
|
| 74 |
+
t = self.timesteps[step_index].to(sample.device)
|
| 75 |
+
next_t = 1.0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1].to(sample.device)
|
| 76 |
+
dt = next_t - t
|
| 77 |
+
|
| 78 |
+
prev_sample = sample + model_output * dt
|
| 79 |
+
if self.config.mode.lower() == "sde":
|
| 80 |
+
diffusion = self._diffusion(torch.full((sample.shape[0],), t, device=sample.device, dtype=sample.dtype))
|
| 81 |
+
while diffusion.dim() < sample.dim():
|
| 82 |
+
diffusion = diffusion.unsqueeze(-1)
|
| 83 |
+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 84 |
+
prev_sample = prev_sample + torch.sqrt(torch.clamp(2.0 * diffusion * torch.abs(dt), min=0.0)) * noise
|
| 85 |
+
|
| 86 |
+
self._step_index += 1
|
| 87 |
+
if not return_dict:
|
| 88 |
+
return (prev_sample,)
|
| 89 |
+
return SiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
|
| 90 |
+
|
| 91 |
+
def add_noise(
|
| 92 |
+
self,
|
| 93 |
+
original_samples: torch.Tensor,
|
| 94 |
+
noise: torch.Tensor,
|
| 95 |
+
timesteps: torch.Tensor,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
sigma = (1.0 - timesteps).view(-1, *([1] * (original_samples.ndim - 1)))
|
| 98 |
+
return (1 - sigma) * original_samples + sigma * noise
|
SiT-XL-2-512-diffusers/transformer/config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SiTTransformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"class_dropout_prob": 0.1,
|
| 5 |
+
"depth": 28,
|
| 6 |
+
"hidden_size": 1152,
|
| 7 |
+
"in_channels": 4,
|
| 8 |
+
"input_size": 64,
|
| 9 |
+
"learn_sigma": false,
|
| 10 |
+
"mlp_ratio": 4.0,
|
| 11 |
+
"num_classes": 1000,
|
| 12 |
+
"num_heads": 16,
|
| 13 |
+
"patch_size": 2
|
| 14 |
+
}
|
SiT-XL-2-512-diffusers/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d830587730010dae36882e38da966a784396e4e1b8f7f0997685f23fd63063f
|
| 3 |
+
size 2704012944
|
SiT-XL-2-512-diffusers/transformer/transformer_sit.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
|
| 9 |
+
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
from diffusers.utils import BaseOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class SiTTransformer2DModelOutput(BaseOutput):
|
| 21 |
+
sample: torch.Tensor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TimestepEmbedder(nn.Module):
|
| 25 |
+
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
|
| 36 |
+
half = dim // 2
|
| 37 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 38 |
+
device=t.device
|
| 39 |
+
)
|
| 40 |
+
args = t[:, None].float() * freqs[None]
|
| 41 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 42 |
+
if dim % 2:
|
| 43 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 44 |
+
return embedding
|
| 45 |
+
|
| 46 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
|
| 52 |
+
super().__init__()
|
| 53 |
+
use_cfg_embedding = dropout_prob > 0
|
| 54 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_prob = dropout_prob
|
| 57 |
+
|
| 58 |
+
def token_drop(self, labels: torch.Tensor, force_drop_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 59 |
+
if force_drop_ids is None:
|
| 60 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 61 |
+
else:
|
| 62 |
+
drop_ids = force_drop_ids == 1
|
| 63 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
def forward(
|
| 67 |
+
self,
|
| 68 |
+
labels: torch.Tensor,
|
| 69 |
+
train: bool,
|
| 70 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
use_dropout = self.dropout_prob > 0
|
| 73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 75 |
+
return self.embedding_table(labels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SiTBlock(nn.Module):
|
| 79 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, **block_kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 82 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 83 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 84 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 85 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 86 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 87 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 91 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 92 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FinalLayer(nn.Module):
|
| 97 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 100 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 101 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 105 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 106 |
+
return self.linear(x)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 110 |
+
@register_to_config
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
input_size: int = 32,
|
| 114 |
+
patch_size: int = 2,
|
| 115 |
+
in_channels: int = 4,
|
| 116 |
+
hidden_size: int = 1152,
|
| 117 |
+
depth: int = 28,
|
| 118 |
+
num_heads: int = 16,
|
| 119 |
+
mlp_ratio: float = 4.0,
|
| 120 |
+
class_dropout_prob: float = 0.1,
|
| 121 |
+
num_classes: int = 1000,
|
| 122 |
+
learn_sigma: bool = True,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.learn_sigma = learn_sigma
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 128 |
+
self.patch_size = patch_size
|
| 129 |
+
self.num_classes = num_classes
|
| 130 |
+
|
| 131 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 132 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 133 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 134 |
+
num_patches = self.x_embedder.num_patches
|
| 135 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 136 |
+
|
| 137 |
+
self.blocks = nn.ModuleList([SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 138 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 139 |
+
self.initialize_weights()
|
| 140 |
+
|
| 141 |
+
def initialize_weights(self) -> None:
|
| 142 |
+
def _basic_init(module: nn.Module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
|
| 150 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 151 |
+
|
| 152 |
+
w = self.x_embedder.proj.weight.data
|
| 153 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 154 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 155 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 156 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 157 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 158 |
+
for block in self.blocks:
|
| 159 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 160 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 161 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 162 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 163 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 164 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 165 |
+
|
| 166 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
c = self.out_channels
|
| 168 |
+
p = self.x_embedder.patch_size[0]
|
| 169 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 170 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 171 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 172 |
+
return x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
hidden_states: torch.Tensor,
|
| 177 |
+
timestep: torch.Tensor,
|
| 178 |
+
class_labels: torch.Tensor,
|
| 179 |
+
force_drop_ids: Optional[torch.Tensor] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> SiTTransformer2DModelOutput:
|
| 182 |
+
x = self.x_embedder(hidden_states) + self.pos_embed
|
| 183 |
+
t = self.t_embedder(timestep)
|
| 184 |
+
y = self.y_embedder(class_labels, self.training, force_drop_ids=force_drop_ids)
|
| 185 |
+
c = t + y
|
| 186 |
+
for block in self.blocks:
|
| 187 |
+
x = block(x, c)
|
| 188 |
+
x = self.final_layer(x, c)
|
| 189 |
+
x = self.unpatchify(x)
|
| 190 |
+
if self.learn_sigma:
|
| 191 |
+
x, _ = x.chunk(2, dim=1)
|
| 192 |
+
if not return_dict:
|
| 193 |
+
return (x,)
|
| 194 |
+
return SiTTransformer2DModelOutput(sample=x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False, extra_tokens: int = 0):
|
| 198 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 199 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 200 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 201 |
+
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
|
| 202 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 203 |
+
if cls_token and extra_tokens > 0:
|
| 204 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 205 |
+
return pos_embed
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray):
|
| 209 |
+
assert embed_dim % 2 == 0
|
| 210 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| 211 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| 212 |
+
return np.concatenate([emb_h, emb_w], axis=1)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray):
|
| 216 |
+
assert embed_dim % 2 == 0
|
| 217 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 218 |
+
omega /= embed_dim / 2.0
|
| 219 |
+
omega = 1.0 / 10000**omega
|
| 220 |
+
pos = pos.reshape(-1)
|
| 221 |
+
out = np.einsum("m,d->md", pos, omega)
|
| 222 |
+
emb_sin = np.sin(out)
|
| 223 |
+
emb_cos = np.cos(out)
|
| 224 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
SiT-XL-2-512-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"force_upcast": true,
|
| 19 |
+
"in_channels": 3,
|
| 20 |
+
"latent_channels": 4,
|
| 21 |
+
"latents_mean": null,
|
| 22 |
+
"latents_std": null,
|
| 23 |
+
"layers_per_block": 2,
|
| 24 |
+
"mid_block_add_attention": true,
|
| 25 |
+
"norm_num_groups": 32,
|
| 26 |
+
"out_channels": 3,
|
| 27 |
+
"sample_size": 256,
|
| 28 |
+
"scaling_factor": 0.18215,
|
| 29 |
+
"shift_factor": null,
|
| 30 |
+
"up_block_types": [
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D",
|
| 34 |
+
"UpDecoderBlock2D"
|
| 35 |
+
],
|
| 36 |
+
"use_post_quant_conv": true,
|
| 37 |
+
"use_quant_conv": true
|
| 38 |
+
}
|