realantonvoronov commited on
Commit
55ca09f
β€’
1 Parent(s): 385f11a

init commit

Browse files
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: Switti
3
- emoji: πŸ–Ό
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
 
10
  short_description: Generate images with Switti
 
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Switti
 
 
 
3
  sdk: gradio
4
+ emoji: πŸš€
5
+ colorFrom: red
6
+ colorTo: red
7
+ pinned: true
8
  short_description: Generate images with Switti
9
+ preload_from_hub:
10
+ - yresearch/Switti
11
+ - yresearch/VQVAE-Switti
12
  ---
 
 
app.py CHANGED
@@ -2,59 +2,67 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
 
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
  image = pipe(
42
  prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
 
 
50
 
51
  return image, seed
52
 
53
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
 
 
 
 
 
 
58
  ]
59
 
60
  css = """
@@ -66,8 +74,8 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
  with gr.Row():
72
  prompt = gr.Text(
73
  label="Prompt",
@@ -81,59 +89,66 @@ with gr.Blocks(css=css) as demo:
81
 
82
  result = gr.Image(label="Result", show_label=False)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
  with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
-
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
 
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
@@ -142,10 +157,12 @@ with gr.Blocks(css=css) as demo:
142
  negative_prompt,
143
  seed,
144
  randomize_seed,
145
- width,
146
- height,
147
  guidance_scale,
148
- num_inference_steps,
 
 
 
 
149
  ],
150
  outputs=[result, seed],
151
  )
 
2
  import numpy as np
3
  import random
4
 
5
+ import spaces
6
+ from models import SwittiPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "yresearch/Switti"
11
 
 
 
 
 
12
 
13
+ pipe = SwittiPipeline.from_pretrained(model_repo_id, device=device)
 
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
 
16
 
17
 
18
+ @spaces.GPU(duration=65)
19
  def infer(
20
  prompt,
21
+ negative_prompt="",
22
+ seed=42,
23
+ randomize_seed=False,
24
+ guidance_scale=4.0,
25
+ top_k=400,
26
+ top_p=0.95,
27
+ more_smooth=True,
28
+ smooth_start_si=2,
29
+ turn_off_cfg_start_si=10,
30
  progress=gr.Progress(track_tqdm=True),
31
  ):
32
  if randomize_seed:
33
  seed = random.randint(0, MAX_SEED)
34
 
 
 
35
  image = pipe(
36
  prompt=prompt,
37
+ null_prompt=negative_prompt,
38
+ cfg=guidance_scale,
39
+ top_p=top_p,
40
+ top_k=top_k,
41
+ more_smooth=more_smooth,
42
+ smooth_start_si=smooth_start_si,
43
+ turn_off_cfg_start_si=turn_off_cfg_start_si,
44
+ seed=seed,
45
+ )[0]
46
 
47
  return image, seed
48
 
49
 
50
  examples = [
51
+ "Cute winter dragon baby, kawaii, Pixar, ultra detailed, glacial background, extremely realistic.",
52
+ "Cat as a wizard",
53
+ ("An ancient ruined archway on the moon, fantasy, ruins of an alien civilization, "
54
+ "concept art, blue sky, reflectionin water pool, large white planet rising behind it"),
55
+ ("A lizard that looks very much like a man, with developed muscles, leather armor "
56
+ "with metal elements, in the hands of a large trident decorated with ancient runes,"
57
+ " against the background of a small lake, everything is well drawn in the style of fantasy"),
58
+ ("The Mandalorian by masamune shirow, fighting stance, in the snow, "
59
+ "cinematic lighting, intricate detail, character design"),
60
+ "Phoenix woman brown skin asian eyes silver scales, full body, high detail",
61
+ ("Portrait of an alien family from the 1970’s, futuristic clothes, "
62
+ "absurd alien helmet, straight line, surreal, strange, absurd, photorealistic, "
63
+ "Hasselblad, Kodak, portra 800, 35mm lens, F 2.8, photo studio."),
64
+ ("32 – bit pixelated future Hiphop producer in glowing power street ware, "
65
+ "noriyoshi ohrai, in the style of minecraft tomer hanuka."),
66
  ]
67
 
68
  css = """
 
74
 
75
  with gr.Blocks(css=css) as demo:
76
  with gr.Column(elem_id="col-container"):
77
+ gr.Markdown(" # [Switti](https://yandex-research.github.io/switti)")
78
+ gr.Markdown("[Learn more](https://yandex-research.github.io/switti) about Switti.")
79
  with gr.Row():
80
  prompt = gr.Text(
81
  label="Prompt",
 
89
 
90
  result = gr.Image(label="Result", show_label=False)
91
 
92
+ seed = gr.Number(
93
+ label="Seed",
94
+ minimum=0,
95
+ maximum=MAX_SEED,
96
+ value=0,
97
+ )
98
+
99
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
100
+
101
+ guidance_scale = gr.Slider(
102
+ label="Guidance scale",
103
+ minimum=0.0,
104
+ maximum=10.,
105
+ step=0.5,
106
+ value=4.,
107
+ )
108
+
109
  with gr.Accordion("Advanced Settings", open=False):
110
  negative_prompt = gr.Text(
111
  label="Negative prompt",
112
  max_lines=1,
113
  placeholder="Enter a negative prompt",
114
+ visible=True,
 
 
 
 
 
 
 
 
115
  )
116
 
 
 
117
  with gr.Row():
118
+ top_k = gr.Slider(
119
+ label="Sampling top k",
120
+ minimum=10,
121
+ maximum=1000,
122
+ step=10,
123
+ value=400,
124
  )
125
+ top_p = gr.Slider(
126
+ label="Sampling top p",
127
+ minimum=0.0,
128
+ maximum=1.,
129
+ step=0.01,
130
+ value=0.95,
 
131
  )
132
+
133
  with gr.Row():
134
+ more_smooth = gr.Checkbox(label="Smoothing with Gumbel softmax sampling", value=True)
135
+ smooth_start_si = gr.Slider(
136
+ label="Smoothing starting scale",
137
+ minimum=0,
138
+ maximum=10,
139
+ step=1,
140
+ value=2,
141
  )
142
+ turn_off_cfg_start_si = gr.Slider(
143
+ label="Disable CFG from scale",
144
+ minimum=0,
145
+ maximum=10,
 
146
  step=1,
147
+ value=8,
148
  )
149
 
150
+
151
+ gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
152
  gr.on(
153
  triggers=[run_button.click, prompt.submit],
154
  fn=infer,
 
157
  negative_prompt,
158
  seed,
159
  randomize_seed,
 
 
160
  guidance_scale,
161
+ top_k,
162
+ top_p,
163
+ more_smooth,
164
+ smooth_start_si,
165
+ turn_off_cfg_start_si,
166
  ],
167
  outputs=[result, seed],
168
  )
models/__init__.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .clip import FrozenCLIPEmbedder
4
+ from .switti import Switti
5
+ from .vqvae import VQVAE
6
+ from .pipeline import SwittiPipeline
7
+
8
+
9
+ def build_models(
10
+ # Shared args
11
+ device,
12
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
13
+ # VQVAE args
14
+ V=4096,
15
+ Cvae=32,
16
+ ch=160,
17
+ share_quant_resi=4,
18
+ # Switti args
19
+ depth=16,
20
+ rope=True,
21
+ rope_theta=10000,
22
+ rope_size=128,
23
+ use_swiglu_ffn=True,
24
+ use_ar=False,
25
+ use_crop_cond=True,
26
+ attn_l2_norm=True,
27
+ init_adaln=0.5,
28
+ init_adaln_gamma=1e-5,
29
+ init_head=0.02,
30
+ init_std=-1, # init_std < 0: automated
31
+ drop_rate=0.0,
32
+ attn_drop_rate=0.0,
33
+ dpr=0,
34
+ norm_eps=1e-6,
35
+ # pipeline args
36
+ text_encoder_path="openai/clip-vit-large-patch14",
37
+ text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
38
+ ) -> tuple[VQVAE, Switti]:
39
+ heads = depth
40
+ width = depth * 64
41
+ if dpr > 0:
42
+ dpr = dpr * depth / 24
43
+
44
+ # disable built-in initialization for speed
45
+ for clz in (
46
+ nn.Linear,
47
+ nn.LayerNorm,
48
+ nn.BatchNorm2d,
49
+ nn.SyncBatchNorm,
50
+ nn.Conv1d,
51
+ nn.Conv2d,
52
+ nn.ConvTranspose1d,
53
+ nn.ConvTranspose2d,
54
+ ):
55
+ setattr(clz, "reset_parameters", lambda self: None)
56
+
57
+ # build models
58
+ vae_local = VQVAE(
59
+ vocab_size=V,
60
+ z_channels=Cvae,
61
+ ch=ch,
62
+ test_mode=True,
63
+ share_quant_resi=share_quant_resi,
64
+ v_patch_nums=patch_nums,
65
+ ).to(device)
66
+
67
+ switti_wo_ddp = Switti(
68
+ depth=depth,
69
+ embed_dim=width,
70
+ num_heads=heads,
71
+ drop_rate=drop_rate,
72
+ attn_drop_rate=attn_drop_rate,
73
+ drop_path_rate=dpr,
74
+ norm_eps=norm_eps,
75
+ attn_l2_norm=attn_l2_norm,
76
+ patch_nums=patch_nums,
77
+ rope=rope,
78
+ rope_theta=rope_theta,
79
+ rope_size=rope_size,
80
+ use_swiglu_ffn=use_swiglu_ffn,
81
+ use_ar=use_ar,
82
+ use_crop_cond=use_crop_cond,
83
+ ).to(device)
84
+
85
+ switti_wo_ddp.init_weights(
86
+ init_adaln=init_adaln,
87
+ init_adaln_gamma=init_adaln_gamma,
88
+ init_head=init_head,
89
+ init_std=init_std,
90
+ )
91
+ text_encoder = FrozenCLIPEmbedder(text_encoder_path)
92
+ text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
93
+ pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
94
+
95
+ return vae_local, switti_wo_ddp, pipe
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.9 kB). View file
 
models/__pycache__/basic_switti.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
models/__pycache__/basic_vae.cpython-311.pyc ADDED
Binary file (15.8 kB). View file
 
models/__pycache__/clip.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
models/__pycache__/helpers.cpython-311.pyc ADDED
Binary file (5.29 kB). View file
 
models/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
models/__pycache__/quant.cpython-311.pyc ADDED
Binary file (24.6 kB). View file
 
models/__pycache__/rope.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
models/__pycache__/switti.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
models/__pycache__/vqvae.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
models/basic_switti.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn.functional import scaled_dot_product_attention # q, k, v: BHLc
9
+
10
+ from models.helpers import DropPath
11
+ from models.rope import apply_rotary_emb
12
+
13
+ try:
14
+ from flash_attn.ops.fused_dense import fused_mlp_func
15
+ except ImportError:
16
+ fused_mlp_func = None
17
+
18
+ # this file only provides the blocks used in Switti transformer
19
+ __all__ = ["FFN", "SwiGLUFFN", "RMSNorm", "AdaLNSelfCrossAttn", "AdaLNBeforeHead"]
20
+
21
+
22
+ try:
23
+ from apex.normalization import FusedRMSNorm as RMSNorm
24
+ except ImportError:
25
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
26
+
27
+ class RMSNorm(torch.nn.Module):
28
+ def __init__(self, dim: int, eps: float = 1e-6):
29
+ """
30
+ Initialize the RMSNorm normalization layer.
31
+
32
+ Args:
33
+ dim (int): The dimension of the input tensor.
34
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
35
+
36
+ Attributes:
37
+ eps (float): A small value added to the denominator for numerical stability.
38
+ weight (nn.Parameter): Learnable scaling parameter.
39
+
40
+ """
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+
45
+ def _norm(self, x):
46
+ """
47
+ Apply the RMSNorm normalization to the input tensor.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The normalized tensor.
54
+
55
+ """
56
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
57
+
58
+ def forward(self, x):
59
+ """
60
+ Forward pass through the RMSNorm layer.
61
+
62
+ Args:
63
+ x (torch.Tensor): The input tensor.
64
+
65
+ Returns:
66
+ torch.Tensor: The output tensor after applying RMSNorm.
67
+
68
+ """
69
+ output = self._norm(x.float()).type_as(x)
70
+ return output * self.weight
71
+
72
+
73
+ class FFN(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_features,
77
+ hidden_features=None,
78
+ out_features=None,
79
+ drop=0.0,
80
+ fused_if_available=True,
81
+ ):
82
+ super().__init__()
83
+ self.fused_mlp_func = fused_mlp_func if fused_if_available else None
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+ self.fc1 = nn.Linear(in_features, hidden_features)
87
+ self.act = nn.GELU(approximate="tanh")
88
+ self.fc2 = nn.Linear(hidden_features, out_features)
89
+ self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
90
+
91
+ def forward(self, x):
92
+ if self.fused_mlp_func is not None:
93
+ return self.drop(
94
+ self.fused_mlp_func(
95
+ x=x,
96
+ weight1=self.fc1.weight,
97
+ weight2=self.fc2.weight,
98
+ bias1=self.fc1.bias,
99
+ bias2=self.fc2.bias,
100
+ activation="gelu_approx",
101
+ save_pre_act=self.training,
102
+ return_residual=False,
103
+ checkpoint_lvl=0,
104
+ heuristic=0,
105
+ process_group=None,
106
+ )
107
+ )
108
+ else:
109
+ return self.drop(self.fc2(self.act(self.fc1(x))))
110
+
111
+ def extra_repr(self) -> str:
112
+ return f"fused_mlp_func={self.fused_mlp_func is not None}"
113
+
114
+
115
+ class SwiGLUFFN(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim: int,
119
+ ff_mult: float = 8 / 3,
120
+ ):
121
+ """
122
+ Initialize the FeedForward module.
123
+
124
+ Args:
125
+ dim (int): Input dimension.
126
+ ff_mult (float, optional): Custom multiplier for hidden dimension. Defaults to 4.
127
+ """
128
+ super().__init__()
129
+ hidden_dim = int(dim * ff_mult)
130
+
131
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
132
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
133
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
134
+ self.fused_mlp_func = None
135
+ self._init()
136
+
137
+ def _init(self):
138
+ for module in self.modules():
139
+ if isinstance(module, nn.Linear):
140
+ nn.init.xavier_uniform_(module.weight)
141
+ if module.bias is not None:
142
+ nn.init.zeros_(module.bias)
143
+
144
+ # @torch.compile
145
+ def _forward_silu_gating(self, x_gate: torch.Tensor, x_up: torch.Tensor):
146
+ return F.silu(x_gate) * x_up
147
+
148
+ def forward(self, x: torch.Tensor):
149
+ return self.down_proj(
150
+ self._forward_silu_gating(self.gate_proj(x), self.up_proj(x))
151
+ )
152
+
153
+ def extra_repr(self) -> str:
154
+ return f"fused_mlp_func={self.fused_mlp_func is not None}"
155
+
156
+
157
+ class CrossAttention(nn.Module):
158
+ def __init__(
159
+ self,
160
+ embed_dim: int = 768,
161
+ context_dim: int = 2048,
162
+ num_heads: int = 12,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ qk_norm: bool = False,
166
+ ):
167
+ super().__init__()
168
+ assert embed_dim % num_heads == 0
169
+ assert attn_drop == 0.0
170
+
171
+ self.num_heads, self.head_dim = (
172
+ num_heads,
173
+ embed_dim // num_heads,
174
+ )
175
+ self.qk_norm = qk_norm
176
+ self.scale = 1 / math.sqrt(self.head_dim)
177
+
178
+ self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
179
+ self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
180
+
181
+ self.to_q = nn.Linear(embed_dim, embed_dim, bias=True)
182
+ self.to_kv = nn.Linear(context_dim, embed_dim * 2, bias=True)
183
+
184
+ self.proj = nn.Linear(embed_dim, embed_dim)
185
+ self.proj_drop = (
186
+ nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
187
+ )
188
+ self.attn_drop = attn_drop
189
+
190
+ # only used during inference
191
+ self.caching, self.cached_k, self.cached_v = False, None, None
192
+
193
+ def kv_caching(self, enable: bool):
194
+ self.caching, self.cached_k, self.cached_v = enable, None, None
195
+
196
+ def forward(self, x, context, context_attn_bias=None, freqs_cis=None):
197
+ B, L, C = x.shape
198
+ context_B, context_L, context_C = context.shape
199
+ assert B == context_B
200
+
201
+ q = self.to_q(x).view(B, L, -1) # BLD , self.num_heads, self.head_dim)
202
+ if self.qk_norm:
203
+ q = self.q_norm(q)
204
+
205
+ q = q.view(B, L, self.num_heads, self.head_dim)
206
+ q = q.permute(0, 2, 1, 3) # BHLc
207
+
208
+ if self.cached_k is None:
209
+ # not using caches or first scale inference
210
+ kv = self.to_kv(context).view(B, context_L, 2, -1) # qkv: BL3D
211
+ k, v = kv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLHD
212
+
213
+ if self.qk_norm:
214
+ k = self.k_norm(k)
215
+
216
+ k = k.view(B, context_L, self.num_heads, self.head_dim)
217
+ k = k.permute(0, 2, 1, 3) # BHLc
218
+
219
+ v = v.view(B, context_L, self.num_heads, self.head_dim)
220
+ v = v.permute(0, 2, 1, 3) # BHLc
221
+
222
+ if self.caching:
223
+ self.cached_k = k
224
+ self.cached_v = v
225
+ else:
226
+ k = self.cached_k
227
+ v = self.cached_v
228
+
229
+ if context_attn_bias is not None:
230
+ context_attn_bias = rearrange(context_attn_bias, "b j -> b 1 1 j")
231
+
232
+ dropout_p = self.attn_drop if self.training else 0.0
233
+ out = (
234
+ scaled_dot_product_attention(
235
+ query=q,
236
+ key=k,
237
+ value=v,
238
+ scale=self.scale,
239
+ attn_mask=context_attn_bias,
240
+ dropout_p=dropout_p,
241
+ )
242
+ .transpose(1, 2)
243
+ .reshape(B, L, C)
244
+ )
245
+
246
+ return self.proj_drop(self.proj(out))
247
+
248
+
249
+ class SelfAttention(nn.Module):
250
+ def __init__(
251
+ self,
252
+ block_idx: int,
253
+ embed_dim: int = 768,
254
+ num_heads: int = 12,
255
+ attn_drop: float = 0.0,
256
+ proj_drop: float = 0.0,
257
+ qk_norm: bool = False,
258
+ ):
259
+ super().__init__()
260
+ assert embed_dim % num_heads == 0
261
+ self.block_idx, self.num_heads, self.head_dim = (
262
+ block_idx,
263
+ num_heads,
264
+ embed_dim // num_heads,
265
+ )
266
+ self.qk_norm = qk_norm
267
+ self.scale = 1 / math.sqrt(self.head_dim)
268
+
269
+ self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
270
+ self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
271
+
272
+ self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
273
+ self.proj = nn.Linear(embed_dim, embed_dim)
274
+ self.proj_drop = (
275
+ nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
276
+ )
277
+ self.attn_drop = attn_drop
278
+
279
+ # only used during inference
280
+ self.caching, self.cached_k, self.cached_v = False, None, None
281
+
282
+ def kv_caching(self, enable: bool):
283
+ self.caching, self.cached_k, self.cached_v = enable, None, None
284
+
285
+ # NOTE: attn_bias is None during inference because kv cache is enabled
286
+ def forward(self, x, attn_bias, freqs_cis: torch.Tensor = None):
287
+ B, L, C = x.shape
288
+
289
+ qkv = self.to_qkv(x).view(B, L, 3, -1)
290
+ q, k, v = qkv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLD
291
+
292
+ if self.qk_norm:
293
+ q = self.q_norm(q)
294
+ k = self.k_norm(k)
295
+
296
+ q = q.view(B, L, self.num_heads, self.head_dim)
297
+ q = q.permute(0, 2, 1, 3) # BHLc
298
+ k = k.view(B, L, self.num_heads, self.head_dim)
299
+ k = k.permute(0, 2, 1, 3) # BHLc
300
+ v = v.view(B, L, self.num_heads, self.head_dim)
301
+ v = v.permute(0, 2, 1, 3) # BHLc
302
+ dim_cat = 2
303
+
304
+ if freqs_cis is not None:
305
+ q = apply_rotary_emb(q, freqs_cis=freqs_cis)
306
+ k = apply_rotary_emb(k, freqs_cis=freqs_cis)
307
+
308
+ if self.caching:
309
+ if self.cached_k is None:
310
+ self.cached_k = k
311
+ self.cached_v = v
312
+ else:
313
+ k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat)
314
+ v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
315
+
316
+ dropout_p = self.attn_drop if self.training else 0.0
317
+ out = (
318
+ scaled_dot_product_attention(
319
+ query=q,
320
+ key=k,
321
+ value=v,
322
+ scale=self.scale,
323
+ attn_mask=attn_bias,
324
+ dropout_p=dropout_p,
325
+ )
326
+ .transpose(1, 2)
327
+ .reshape(B, L, C)
328
+ )
329
+
330
+ return self.proj_drop(self.proj(out))
331
+
332
+ def extra_repr(self) -> str:
333
+ return f"attn_l2_norm={self.qk_norm}"
334
+
335
+
336
+ class AdaLNSelfCrossAttn(nn.Module):
337
+ def __init__(
338
+ self,
339
+ block_idx,
340
+ last_drop_p,
341
+ embed_dim,
342
+ cond_dim,
343
+ num_heads,
344
+ mlp_ratio=4.0,
345
+ drop=0.0,
346
+ attn_drop=0.0,
347
+ drop_path=0.0,
348
+ qk_norm=False,
349
+ context_dim=None,
350
+ use_swiglu_ffn=False,
351
+ norm_eps=1e-6,
352
+ use_crop_cond=False,
353
+ ):
354
+ super().__init__()
355
+ assert attn_drop == 0.0
356
+ assert qk_norm
357
+
358
+ self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
359
+ self.C, self.D = embed_dim, cond_dim
360
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
361
+ self.attn = SelfAttention(
362
+ block_idx=block_idx,
363
+ embed_dim=embed_dim,
364
+ num_heads=num_heads,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop,
367
+ qk_norm=qk_norm,
368
+ )
369
+
370
+ if context_dim:
371
+ self.cross_attn = CrossAttention(
372
+ embed_dim=embed_dim,
373
+ context_dim=context_dim,
374
+ num_heads=num_heads,
375
+ attn_drop=attn_drop,
376
+ proj_drop=drop,
377
+ qk_norm=qk_norm,
378
+ )
379
+ else:
380
+ self.cross_attn = None
381
+
382
+ if use_swiglu_ffn:
383
+ self.ffn = SwiGLUFFN(dim=embed_dim)
384
+ else:
385
+ self.ffn = FFN(
386
+ in_features=embed_dim,
387
+ hidden_features=round(embed_dim * mlp_ratio),
388
+ drop=drop,
389
+ )
390
+
391
+ self.self_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
392
+ self.self_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
393
+ self.cross_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
394
+ self.cross_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
395
+
396
+ self.ffn_norm1 = RMSNorm(embed_dim, eps=norm_eps)
397
+ self.ffn_norm2 = RMSNorm(embed_dim, eps=norm_eps)
398
+
399
+ self.attention_y_norm = RMSNorm(context_dim, eps=norm_eps)
400
+
401
+ # AdaLN
402
+ lin = nn.Linear(cond_dim, 6 * embed_dim)
403
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
404
+
405
+ self.fused_add_norm_fn = None
406
+
407
+ self.use_crop_cond = use_crop_cond
408
+ if use_crop_cond:
409
+ self.crop_cond_scales = nn.Parameter(torch.zeros(1, cond_dim))
410
+
411
+ # NOTE: attn_bias is None during inference because kv cache is enabled
412
+ def forward(
413
+ self,
414
+ x,
415
+ cond_BD,
416
+ attn_bias,
417
+ crop_cond=None,
418
+ context=None,
419
+ context_attn_bias=None,
420
+ freqs_cis=None,
421
+ ): # C: embed_dim, D: cond_dim
422
+
423
+ if self.use_crop_cond:
424
+ assert crop_cond is not None
425
+ cond_BD = cond_BD + self.crop_cond_scales * crop_cond
426
+
427
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (
428
+ self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
429
+ )
430
+ x = x + self.self_attention_norm2(
431
+ self.attn(
432
+ self.self_attention_norm1(x).mul(scale1.add(1)).add(shift1),
433
+ attn_bias=attn_bias,
434
+ freqs_cis=freqs_cis,
435
+ )
436
+ ).mul(gamma1)
437
+ if context is not None:
438
+ x = x + self.cross_attention_norm2(
439
+ self.cross_attn(
440
+ self.cross_attention_norm1(x),
441
+ self.attention_y_norm(context),
442
+ context_attn_bias=context_attn_bias,
443
+ freqs_cis=freqs_cis,
444
+ )
445
+ )
446
+ x = x + self.ffn_norm2(
447
+ self.ffn(self.ffn_norm1(x).mul(scale2.add(1)).add(shift2))
448
+ ).mul(gamma2)
449
+ return x
450
+
451
+
452
+ class AdaLNBeforeHead(nn.Module):
453
+ def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
454
+ super().__init__()
455
+ self.C, self.D = C, D
456
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
457
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2 * C))
458
+
459
+ def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
460
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
461
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
models/basic_vae.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # this file only provides the 2 modules used in VQVAE
6
+ __all__ = [ "Encoder", "Decoder"]
7
+
8
+
9
+ """
10
+ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
11
+ """
12
+
13
+
14
+ # swish
15
+ def nonlinearity(x):
16
+ return x * torch.sigmoid(x)
17
+
18
+
19
+ def Normalize(in_channels, num_groups=32):
20
+ return torch.nn.GroupNorm(
21
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
22
+ )
23
+
24
+
25
+ class Upsample2x(nn.Module):
26
+ def __init__(self, in_channels):
27
+ super().__init__()
28
+ self.conv = torch.nn.Conv2d(
29
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
30
+ )
31
+
32
+ def forward(self, x):
33
+ return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
34
+
35
+
36
+ class Downsample2x(nn.Module):
37
+ def __init__(self, in_channels):
38
+ super().__init__()
39
+ self.conv = torch.nn.Conv2d(
40
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode="constant", value=0))
45
+
46
+
47
+ class ResnetBlock(nn.Module):
48
+ def __init__(
49
+ self, *, in_channels, out_channels=None, dropout
50
+ ): # conv_shortcut=False, # conv_shortcut: always False in VAE
51
+ super().__init__()
52
+ self.in_channels = in_channels
53
+ out_channels = in_channels if out_channels is None else out_channels
54
+ self.out_channels = out_channels
55
+
56
+ self.norm1 = Normalize(in_channels)
57
+ self.conv1 = torch.nn.Conv2d(
58
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
59
+ )
60
+ self.norm2 = Normalize(out_channels)
61
+ self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
62
+ self.conv2 = torch.nn.Conv2d(
63
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
64
+ )
65
+ if self.in_channels != self.out_channels:
66
+ self.nin_shortcut = torch.nn.Conv2d(
67
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
68
+ )
69
+ else:
70
+ self.nin_shortcut = nn.Identity()
71
+
72
+ def forward(self, x):
73
+ h = self.conv1(F.silu(self.norm1(x), inplace=True))
74
+ h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
75
+ return self.nin_shortcut(x) + h
76
+
77
+
78
+ class AttnBlock(nn.Module):
79
+ def __init__(self, in_channels):
80
+ super().__init__()
81
+ self.C = in_channels
82
+
83
+ self.norm = Normalize(in_channels)
84
+ self.qkv = torch.nn.Conv2d(
85
+ in_channels, 3 * in_channels, kernel_size=1, stride=1, padding=0
86
+ )
87
+ self.w_ratio = int(in_channels) ** (-0.5)
88
+ self.proj_out = torch.nn.Conv2d(
89
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
90
+ )
91
+
92
+ def forward(self, x):
93
+ qkv = self.qkv(self.norm(x))
94
+ B, _, H, W = qkv.shape # should be B,3C,H,W
95
+ C = self.C
96
+ q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
97
+
98
+ # compute attention
99
+ q = q.view(B, C, H * W).contiguous()
100
+ q = q.permute(0, 2, 1).contiguous() # B,HW,C
101
+ k = k.view(B, C, H * W).contiguous() # B,C,HW
102
+ w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW
103
+ # w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
104
+ w = F.softmax(w, dim=2)
105
+
106
+ # attend to values
107
+ v = v.view(B, C, H * W).contiguous()
108
+ w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
109
+ h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
110
+ h = h.view(B, C, H, W).contiguous()
111
+
112
+ return x + self.proj_out(h)
113
+
114
+
115
+ def make_attn(in_channels, using_sa=True):
116
+ return AttnBlock(in_channels) if using_sa else nn.Identity()
117
+
118
+
119
+ class Encoder(nn.Module):
120
+ def __init__(
121
+ self,
122
+ *,
123
+ ch=128,
124
+ ch_mult=(1, 2, 4, 8),
125
+ num_res_blocks=2,
126
+ dropout=0.0,
127
+ in_channels=3,
128
+ z_channels,
129
+ double_z=False,
130
+ using_sa=True,
131
+ using_mid_sa=True,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.downsample_ratio = 2 ** (self.num_resolutions - 1)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.in_channels = in_channels
139
+
140
+ # downsampling
141
+ self.conv_in = torch.nn.Conv2d(
142
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
143
+ )
144
+
145
+ in_ch_mult = (1,) + tuple(ch_mult)
146
+ self.down = nn.ModuleList()
147
+ for i_level in range(self.num_resolutions):
148
+ block = nn.ModuleList()
149
+ attn = nn.ModuleList()
150
+ block_in = ch * in_ch_mult[i_level]
151
+ block_out = ch * ch_mult[i_level]
152
+ for i_block in range(self.num_res_blocks):
153
+ block.append(
154
+ ResnetBlock(
155
+ in_channels=block_in, out_channels=block_out, dropout=dropout
156
+ )
157
+ )
158
+ block_in = block_out
159
+ if i_level == self.num_resolutions - 1 and using_sa:
160
+ attn.append(make_attn(block_in, using_sa=True))
161
+ down = nn.Module()
162
+ down.block = block
163
+ down.attn = attn
164
+ if i_level != self.num_resolutions - 1:
165
+ down.downsample = Downsample2x(block_in)
166
+ self.down.append(down)
167
+
168
+ # middle
169
+ self.mid = nn.Module()
170
+ self.mid.block_1 = ResnetBlock(
171
+ in_channels=block_in, out_channels=block_in, dropout=dropout
172
+ )
173
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
174
+ self.mid.block_2 = ResnetBlock(
175
+ in_channels=block_in, out_channels=block_in, dropout=dropout
176
+ )
177
+
178
+ # end
179
+ self.norm_out = Normalize(block_in)
180
+ self.conv_out = torch.nn.Conv2d(
181
+ block_in,
182
+ (2 * z_channels if double_z else z_channels),
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1,
186
+ )
187
+
188
+ def forward(self, x):
189
+ # downsampling
190
+ h = self.conv_in(x)
191
+ for i_level in range(self.num_resolutions):
192
+ for i_block in range(self.num_res_blocks):
193
+ h = self.down[i_level].block[i_block](h)
194
+ if len(self.down[i_level].attn) > 0:
195
+ h = self.down[i_level].attn[i_block](h)
196
+ if i_level != self.num_resolutions - 1:
197
+ h = self.down[i_level].downsample(h)
198
+
199
+ # middle
200
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
201
+
202
+ # end
203
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
204
+ return h
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ def __init__(
209
+ self,
210
+ *,
211
+ ch=128,
212
+ ch_mult=(1, 2, 4, 8),
213
+ num_res_blocks=2,
214
+ dropout=0.0,
215
+ in_channels=3, # in_channels: raw img channels
216
+ z_channels,
217
+ using_sa=True,
218
+ using_mid_sa=True,
219
+ ):
220
+ super().__init__()
221
+ self.ch = ch
222
+ self.num_resolutions = len(ch_mult)
223
+ self.num_res_blocks = num_res_blocks
224
+ self.in_channels = in_channels
225
+
226
+ # compute in_ch_mult, block_in and curr_res at lowest res
227
+ in_ch_mult = (1,) + tuple(ch_mult)
228
+ block_in = ch * ch_mult[self.num_resolutions - 1]
229
+
230
+ # z to block_in
231
+ self.conv_in = torch.nn.Conv2d(
232
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
233
+ )
234
+
235
+ # middle
236
+ self.mid = nn.Module()
237
+ self.mid.block_1 = ResnetBlock(
238
+ in_channels=block_in, out_channels=block_in, dropout=dropout
239
+ )
240
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
241
+ self.mid.block_2 = ResnetBlock(
242
+ in_channels=block_in, out_channels=block_in, dropout=dropout
243
+ )
244
+
245
+ # upsampling
246
+ self.up = nn.ModuleList()
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ block = nn.ModuleList()
249
+ attn = nn.ModuleList()
250
+ block_out = ch * ch_mult[i_level]
251
+ for i_block in range(self.num_res_blocks + 1):
252
+ block.append(
253
+ ResnetBlock(
254
+ in_channels=block_in, out_channels=block_out, dropout=dropout
255
+ )
256
+ )
257
+ block_in = block_out
258
+ if i_level == self.num_resolutions - 1 and using_sa:
259
+ attn.append(make_attn(block_in, using_sa=True))
260
+ up = nn.Module()
261
+ up.block = block
262
+ up.attn = attn
263
+ if i_level != 0:
264
+ up.upsample = Upsample2x(block_in)
265
+ self.up.insert(0, up) # prepend to get consistent order
266
+
267
+ # end
268
+ self.norm_out = Normalize(block_in)
269
+ self.conv_out = torch.nn.Conv2d(
270
+ block_in, in_channels, kernel_size=3, stride=1, padding=1
271
+ )
272
+
273
+ def forward(self, z):
274
+ # z to block_in
275
+ # middle
276
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
277
+
278
+ # upsampling
279
+ for i_level in reversed(range(self.num_resolutions)):
280
+ for i_block in range(self.num_res_blocks + 1):
281
+ h = self.up[i_level].block[i_block](h)
282
+ if len(self.up[i_level].attn) > 0:
283
+ h = self.up[i_level].attn[i_block](h)
284
+ if i_level != 0:
285
+ h = self.up[i_level].upsample(h)
286
+
287
+ # end
288
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
289
+ return h
models/clip.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+
5
+
6
+ class FrozenCLIPEmbedder(nn.Module):
7
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
8
+
9
+ def __init__(
10
+ self,
11
+ version="openai/clip-vit-large-patch14",
12
+ device="cuda",
13
+ max_length=77,
14
+ freeze=True,
15
+ ):
16
+ super().__init__()
17
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
18
+ self.transformer = CLIPTextModel.from_pretrained(version).to(device)
19
+ self.device = device
20
+ self.hidden_size = self.transformer.config.hidden_size
21
+ self.max_length = max_length
22
+ if freeze:
23
+ self.freeze()
24
+
25
+ def freeze(self):
26
+ self.transformer = self.transformer.eval()
27
+ for param in self.parameters():
28
+ param.requires_grad = False
29
+
30
+ def forward(self, text):
31
+ batch_encoding = self.tokenizer(
32
+ text,
33
+ truncation=True,
34
+ max_length=self.max_length,
35
+ return_overflowing_tokens=False,
36
+ padding="max_length",
37
+ return_tensors="pt",
38
+ ).to(self.device)
39
+
40
+ outputs = self.transformer(**batch_encoding)
41
+
42
+ attn_bias = batch_encoding["attention_mask"].to(outputs["last_hidden_state"].dtype)
43
+ attn_bias[attn_bias == 0] = -float("inf")
44
+ attn_bias[attn_bias == 1] = 0.0
45
+ outputs["attn_bias"] = attn_bias
46
+ return outputs
47
+
48
+ @torch.no_grad()
49
+ def encode(self, text):
50
+ return self(text)
models/helpers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def sample_with_top_k_top_p_(
7
+ logits_BlV: torch.Tensor,
8
+ top_k: int = 0,
9
+ top_p: float = 0.0,
10
+ rng=None,
11
+ num_samples=1,
12
+ ) -> torch.Tensor: # return idx, shaped (B, l)
13
+ B, l, V = logits_BlV.shape
14
+ if top_k > 0:
15
+ idx_to_remove = logits_BlV < logits_BlV.topk(
16
+ top_k, largest=True, sorted=False, dim=-1
17
+ )[0].amin(dim=-1, keepdim=True)
18
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
19
+ if top_p > 0:
20
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
21
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
22
+ sorted_idx_to_remove[..., -1:] = False
23
+ logits_BlV.masked_fill_(
24
+ sorted_idx_to_remove.scatter(
25
+ sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove
26
+ ),
27
+ -torch.inf,
28
+ )
29
+ # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
30
+ replacement = num_samples >= 0
31
+ num_samples = abs(num_samples)
32
+ return torch.multinomial(
33
+ logits_BlV.softmax(dim=-1).view(-1, V),
34
+ num_samples=num_samples,
35
+ replacement=replacement,
36
+ generator=rng,
37
+ ).view(B, l, num_samples)
38
+
39
+
40
+ def gumbel_softmax_with_rng(
41
+ logits: torch.Tensor,
42
+ tau: float = 1,
43
+ hard: bool = False,
44
+ eps: float = 1e-10,
45
+ dim: int = -1,
46
+ rng: torch.Generator | None = None,
47
+ ) -> torch.Tensor:
48
+ if rng is None:
49
+ return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
50
+
51
+ gumbels = (
52
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
53
+ .exponential_(generator=rng)
54
+ .log()
55
+ )
56
+ gumbels = (logits + gumbels) / tau
57
+ y_soft = gumbels.softmax(dim)
58
+
59
+ if hard:
60
+ index = y_soft.max(dim, keepdim=True)[1]
61
+ y_hard = torch.zeros_like(
62
+ logits, memory_format=torch.legacy_contiguous_format
63
+ ).scatter_(dim, index, 1.0)
64
+ ret = y_hard - y_soft.detach() + y_soft
65
+ else:
66
+ ret = y_soft
67
+ return ret
68
+
69
+
70
+ def drop_path(
71
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
72
+ ): # taken from timm
73
+ if drop_prob == 0.0 or not training:
74
+ return x
75
+ keep_prob = 1 - drop_prob
76
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
77
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
78
+ if keep_prob > 0.0 and scale_by_keep:
79
+ random_tensor.div_(keep_prob)
80
+ return x * random_tensor
81
+
82
+
83
+ class DropPath(nn.Module): # taken from timm
84
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
85
+ super(DropPath, self).__init__()
86
+ self.drop_prob = drop_prob
87
+ self.scale_by_keep = scale_by_keep
88
+
89
+ def forward(self, x):
90
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
91
+
92
+ def extra_repr(self):
93
+ return f"(drop_prob=...)"
models/pipeline.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import ToPILImage
3
+ from PIL.Image import Image as PILImage
4
+
5
+ from models.vqvae import VQVAEHF
6
+ from models.clip import FrozenCLIPEmbedder
7
+ from models.switti import SwittiHF, get_crop_condition
8
+ from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
9
+
10
+
11
+ class SwittiPipeline:
12
+ vae_path = "yresearch/VQVAE-Switti"
13
+ text_encoder_path = "openai/clip-vit-large-patch14"
14
+ text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
+
16
+ def __init__(self, switti, vae, text_encoder, text_encoder_2, device):
17
+ self.switti = switti
18
+ self.vae = vae
19
+ self.text_encoder = text_encoder
20
+ self.text_encoder_2 = text_encoder_2
21
+
22
+ self.switti.eval()
23
+ self.vae.eval()
24
+
25
+ self.device = device
26
+
27
+ @classmethod
28
+ def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
29
+ switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device)
30
+ vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
31
+ text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
32
+ text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
33
+
34
+ return cls(switti, vae, text_encoder, text_encoder_2, device)
35
+
36
+ @staticmethod
37
+ def to_image(tensor):
38
+ return [ToPILImage()(
39
+ (255 * img.cpu().detach()).to(torch.uint8))
40
+ for img in tensor]
41
+
42
+ def _encode_prompt(self, prompt: str | list[str]):
43
+ prompt = [prompt] if isinstance(prompt, str) else prompt
44
+ encodings = [
45
+ self.text_encoder.encode(prompt),
46
+ self.text_encoder_2.encode(prompt),
47
+ ]
48
+ prompt_embeds = torch.concat(
49
+ [encoding.last_hidden_state for encoding in encodings], dim=-1
50
+ )
51
+ pooled_prompt_embeds = encodings[-1].pooler_output
52
+ attn_bias = encodings[-1].attn_bias
53
+
54
+ return prompt_embeds, pooled_prompt_embeds, attn_bias
55
+
56
+ def encode_prompt(
57
+ self,
58
+ prompt: str | list[str],
59
+ null_prompt: str = "",
60
+ encode_null: bool = True,
61
+ ):
62
+ prompt_embeds, pooled_prompt_embeds, attn_bias = self._encode_prompt(prompt)
63
+ if encode_null:
64
+ B, L, hidden_dim = prompt_embeds.shape
65
+ pooled_dim = pooled_prompt_embeds.shape[1]
66
+
67
+ null_embeds, null_pooled_embeds, null_attn_bias = self._encode_prompt(null_prompt)
68
+
69
+ null_embeds = null_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device)
70
+ null_pooled_embeds = null_pooled_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device)
71
+ null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device)
72
+
73
+ prompt_embeds = torch.cat([prompt_embeds, null_embeds], dim=0)
74
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_embeds], dim=0)
75
+ attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0)
76
+
77
+ return prompt_embeds, pooled_prompt_embeds, attn_bias
78
+
79
+ @torch.inference_mode()
80
+ def __call__(
81
+ self,
82
+ prompt: str | list[str],
83
+ null_prompt: str = "",
84
+ seed: int | None = None,
85
+ cfg: float = 4.0,
86
+ top_k: int = 400,
87
+ top_p: float = 0.95,
88
+ more_smooth: bool = False,
89
+ return_pil: bool = True,
90
+ smooth_start_si: int = 0,
91
+ turn_off_cfg_start_si: int = 10,
92
+ image_size: tuple[int, int] = (512, 512),
93
+ ) -> torch.Tensor | list[PILImage]:
94
+ """
95
+ only used for inference, on autoregressive mode
96
+ :param prompt: text prompt to generate an image
97
+ :param null_prompt: negative prompt for CFG
98
+ :param seed: random seed
99
+ :param cfg: classifier-free guidance ratio
100
+ :param top_k: top-k sampling
101
+ :param top_p: top-p sampling
102
+ :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
103
+ :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
104
+ """
105
+ assert not self.switti.training
106
+ switti = self.switti
107
+ vae = self.vae
108
+ vae_quant = self.vae.quantize
109
+ if seed is None:
110
+ rng = None
111
+ else:
112
+ switti.rng.manual_seed(seed)
113
+ rng = switti.rng
114
+
115
+ context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
116
+
117
+ B = context.shape[0] // 2
118
+
119
+ cond_vector = switti.text_pooler(cond_vector)
120
+
121
+ if switti.use_crop_cond:
122
+ crop_coords = get_crop_condition(2 * B * [image_size[0]],
123
+ 2 * B * [image_size[1]],
124
+ ).to(cond_vector.device)
125
+ crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
126
+ crop_cond = switti.crop_proj(crop_embed)
127
+ else:
128
+ crop_cond = None
129
+
130
+ sos = cond_BD = cond_vector
131
+
132
+ lvl_pos = switti.lvl_embed(switti.lvl_1L)
133
+ if not switti.rope:
134
+ lvl_pos += switti.pos_1LC
135
+ next_token_map = (
136
+ sos.unsqueeze(1)
137
+ + switti.pos_start.expand(2 * B, switti.first_l, -1)
138
+ + lvl_pos[:, : switti.first_l]
139
+ )
140
+ cur_L = 0
141
+ f_hat = sos.new_zeros(B, switti.Cvae, switti.patch_nums[-1], switti.patch_nums[-1])
142
+
143
+ for b in switti.blocks:
144
+ b.attn.kv_caching(switti.use_ar) # Use KV caching if switti is in the AR mode
145
+ b.cross_attn.kv_caching(True)
146
+
147
+ for si, pn in enumerate(switti.patch_nums): # si: i-th segment
148
+ ratio = si / switti.num_stages_minus_1
149
+ x_BLC = next_token_map
150
+
151
+ if switti.rope:
152
+ freqs_cis = switti.freqs_cis[:, cur_L : cur_L + pn * pn]
153
+ else:
154
+ freqs_cis = switti.freqs_cis
155
+
156
+ if si >= turn_off_cfg_start_si:
157
+ x_BLC = x_BLC[:B]
158
+ context = context[:B]
159
+ context_attn_bias = context_attn_bias[:B]
160
+ freqs_cis = freqs_cis[:B]
161
+ cond_BD = cond_BD[:B]
162
+ if crop_cond is not None:
163
+ crop_cond = crop_cond[:B]
164
+ for b in switti.blocks:
165
+ if b.attn.caching:
166
+ b.attn.cached_k = b.attn.cached_k[:B]
167
+ b.attn.cached_v = b.attn.cached_v[:B]
168
+ if b.cross_attn.caching:
169
+ b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
170
+ b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
171
+
172
+ for block in switti.blocks:
173
+ x_BLC = block(
174
+ x=x_BLC,
175
+ cond_BD=cond_BD,
176
+ attn_bias=None,
177
+ context=context,
178
+ context_attn_bias=context_attn_bias,
179
+ freqs_cis=freqs_cis,
180
+ crop_cond=crop_cond,
181
+ )
182
+ cur_L += pn * pn
183
+
184
+ logits_BlV = switti.get_logits(x_BLC, cond_BD)
185
+
186
+ # Guidance
187
+ if si < turn_off_cfg_start_si:
188
+ t = cfg * ratio
189
+ logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
190
+
191
+ if more_smooth and si >= smooth_start_si:
192
+ # not used when evaluating FID/IS/Precision/Recall
193
+ gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
194
+ idx_Bl = gumbel_softmax_with_rng(
195
+ logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng,
196
+ )
197
+ h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
198
+ else:
199
+ # defaul nucleus sampling
200
+ idx_Bl = sample_with_top_k_top_p_(
201
+ logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
202
+ )[:, :, 0]
203
+ h_BChw = vae_quant.embedding(idx_Bl)
204
+
205
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, switti.Cvae, pn, pn)
206
+ f_hat, next_token_map = vae_quant.get_next_autoregressive_input(
207
+ si, len(switti.patch_nums), f_hat, h_BChw,
208
+ )
209
+ if si != switti.num_stages_minus_1: # prepare for next stage
210
+ next_token_map = next_token_map.view(B, switti.Cvae, -1).transpose(1, 2)
211
+ next_token_map = (
212
+ switti.word_embed(next_token_map)
213
+ + lvl_pos[:, cur_L : cur_L + switti.patch_nums[si + 1] ** 2]
214
+ )
215
+ # double the batch sizes due to CFG
216
+ next_token_map = next_token_map.repeat(2, 1, 1)
217
+
218
+ for b in switti.blocks:
219
+ b.attn.kv_caching(False)
220
+ b.cross_attn.kv_caching(False)
221
+
222
+ # de-normalize, from [-1, 1] to [0, 1]
223
+ img = vae.fhat_to_img(f_hat).add(1).mul(0.5)
224
+ if return_pil:
225
+ img = self.to_image(img)
226
+
227
+ return img
models/quant.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import distributed as tdist
7
+ from torch import nn as nn
8
+ from torch.nn import functional as F
9
+
10
+ # this file only provides the VectorQuantizer2 used in VQVAE
11
+ __all__ = ["VectorQuantizer2"]
12
+
13
+
14
+ class VectorQuantizer2(nn.Module):
15
+ # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
16
+ def __init__(
17
+ self,
18
+ vocab_size,
19
+ Cvae,
20
+ using_znorm,
21
+ beta: float = 0.25,
22
+ default_qresi_counts=0,
23
+ v_patch_nums=None,
24
+ quant_resi=0.5,
25
+ share_quant_resi=4, # share_quant_resi: args.qsr
26
+ ):
27
+ super().__init__()
28
+ self.vocab_size: int = vocab_size
29
+ self.Cvae: int = Cvae
30
+ self.using_znorm: bool = using_znorm
31
+ self.v_patch_nums: Tuple[int] = v_patch_nums
32
+
33
+ self.quant_resi_ratio = quant_resi
34
+ if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
35
+ self.quant_resi = PhiNonShared(
36
+ [
37
+ (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
38
+ for _ in range(default_qresi_counts or len(self.v_patch_nums))
39
+ ]
40
+ )
41
+ elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
42
+ self.quant_resi = PhiShared(
43
+ Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()
44
+ )
45
+ else: # partially shared: \phi_{1 to share_quant_resi} for K scales
46
+ self.quant_resi = PhiPartiallyShared(
47
+ nn.ModuleList([(
48
+ Phi(Cvae, quant_resi)
49
+ if abs(quant_resi) > 1e-6
50
+ else nn.Identity()
51
+ ) for _ in range(share_quant_resi)])
52
+ )
53
+
54
+ self.register_buffer(
55
+ "ema_vocab_hit_SV",
56
+ torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0),
57
+ )
58
+ self.record_hit = 0
59
+
60
+ self.beta: float = beta
61
+ self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
62
+
63
+ def eini(self, eini):
64
+ if eini > 0:
65
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
66
+ elif eini < 0:
67
+ self.embedding.weight.data.uniform_(
68
+ -abs(eini) / self.vocab_size, abs(eini) / self.vocab_size
69
+ )
70
+
71
+ def extra_repr(self) -> str:
72
+ return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}"
73
+
74
+ # ===================== `forward` is only used in VAE training =====================
75
+ def forward(
76
+ self, f_BChw: torch.Tensor, ret_usages=False
77
+ ) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
78
+ dtype = f_BChw.dtype
79
+ if dtype != torch.float32:
80
+ f_BChw = f_BChw.float()
81
+ B, C, H, W = f_BChw.shape
82
+ f_no_grad = f_BChw.detach()
83
+
84
+ f_rest = f_no_grad.clone()
85
+ f_hat = torch.zeros_like(f_rest)
86
+
87
+ with torch.cuda.amp.autocast(enabled=False):
88
+ mean_vq_loss: torch.Tensor = 0.0
89
+ vocab_hit_V = torch.zeros(
90
+ self.vocab_size, dtype=torch.float, device=f_BChw.device
91
+ )
92
+ SN = len(self.v_patch_nums)
93
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
94
+ # find the nearest embedding
95
+ if self.using_znorm:
96
+ rest_NC = (
97
+ F.interpolate(f_rest, size=(pn, pn), mode="area")
98
+ .permute(0, 2, 3, 1)
99
+ .reshape(-1, C)
100
+ if (si != SN - 1)
101
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
102
+ )
103
+ rest_NC = F.normalize(rest_NC, dim=-1)
104
+ idx_N = torch.argmax(
105
+ rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0),
106
+ dim=1,
107
+ )
108
+ else:
109
+ rest_NC = (
110
+ F.interpolate(f_rest, size=(pn, pn), mode="area")
111
+ .permute(0, 2, 3, 1)
112
+ .reshape(-1, C)
113
+ if (si != SN - 1)
114
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
115
+ )
116
+ d_no_grad = torch.sum(
117
+ rest_NC.square(), dim=1, keepdim=True
118
+ ) + torch.sum(
119
+ self.embedding.weight.data.square(), dim=1, keepdim=False
120
+ )
121
+ d_no_grad.addmm_(
122
+ rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1
123
+ ) # (B*h*w, vocab_size)
124
+ idx_N = torch.argmin(d_no_grad, dim=1)
125
+
126
+ hit_V = idx_N.bincount(minlength=self.vocab_size).float()
127
+ if self.training:
128
+ # if dist.initialized():
129
+ handler = tdist.all_reduce(hit_V, async_op=True)
130
+
131
+ # calc loss
132
+ idx_Bhw = idx_N.view(B, pn, pn)
133
+ h_BChw = (
134
+ F.interpolate(
135
+ self.embedding(idx_Bhw).permute(0, 3, 1, 2),
136
+ size=(H, W),
137
+ mode="bicubic",
138
+ ).contiguous()
139
+ if (si != SN - 1)
140
+ else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
141
+ )
142
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
143
+ f_hat = f_hat + h_BChw
144
+ f_rest -= h_BChw
145
+
146
+ if self.training: # and dist.initialized():
147
+ handler.wait()
148
+ if self.record_hit == 0:
149
+ self.ema_vocab_hit_SV[si].copy_(hit_V)
150
+ elif self.record_hit < 100:
151
+ self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
152
+ else:
153
+ self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
154
+ self.record_hit += 1
155
+ vocab_hit_V.add_(hit_V)
156
+ mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
157
+
158
+ mean_vq_loss *= 1.0 / SN
159
+ f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
160
+
161
+ margin = (
162
+ tdist.get_world_size()
163
+ * (f_BChw.numel() / f_BChw.shape[1])
164
+ / self.vocab_size
165
+ * 0.08
166
+ )
167
+ # margin = pn*pn / 100
168
+ if ret_usages:
169
+ usages = [
170
+ (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100
171
+ for si, pn in enumerate(self.v_patch_nums)
172
+ ]
173
+ else:
174
+ usages = None
175
+ return f_hat, usages, mean_vq_loss
176
+
177
+ # ===================== `forward` is only used in VAE training =====================
178
+
179
+ def embed_to_fhat(
180
+ self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False
181
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
182
+ ls_f_hat_BChw = []
183
+ B = ms_h_BChw[0].shape[0]
184
+ H = W = self.v_patch_nums[-1]
185
+ SN = len(self.v_patch_nums)
186
+ if all_to_max_scale:
187
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
188
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
189
+ h_BChw = ms_h_BChw[si]
190
+ if si < len(self.v_patch_nums) - 1:
191
+ h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic")
192
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
193
+ f_hat.add_(h_BChw)
194
+ if last_one:
195
+ ls_f_hat_BChw = f_hat
196
+ else:
197
+ ls_f_hat_BChw.append(f_hat.clone())
198
+ else:
199
+ # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
200
+ # WARNING: this should only be used for experimental purpose
201
+ f_hat = ms_h_BChw[0].new_zeros(
202
+ B,
203
+ self.Cvae,
204
+ self.v_patch_nums[0],
205
+ self.v_patch_nums[0],
206
+ dtype=torch.float32,
207
+ )
208
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
209
+ f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic")
210
+ h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
211
+ f_hat.add_(h_BChw)
212
+ if last_one:
213
+ ls_f_hat_BChw = f_hat
214
+ else:
215
+ ls_f_hat_BChw.append(f_hat)
216
+
217
+ return ls_f_hat_BChw
218
+
219
+ def f_to_idxBl_or_fhat(
220
+ self,
221
+ f_BChw: torch.Tensor,
222
+ to_fhat: bool,
223
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
224
+ noise_std: Optional[float] = None,
225
+ ) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
226
+ B, C, H, W = f_BChw.shape
227
+ f_no_grad = f_BChw.detach()
228
+ f_rest = f_no_grad.clone()
229
+ f_hat = torch.zeros_like(f_rest)
230
+
231
+ f_hat_or_idx_Bl: List[torch.Tensor] = []
232
+
233
+ patch_hws = [
234
+ (pn, pn) if isinstance(pn, int) else (pn[0], pn[1])
235
+ for pn in (v_patch_nums or self.v_patch_nums)
236
+ ] # from small to large
237
+ assert (
238
+ patch_hws[-1][0] == H and patch_hws[-1][1] == W
239
+ ), f"{patch_hws[-1]=} != ({H=}, {W=})"
240
+
241
+ SN = len(patch_hws)
242
+ for si, (ph, pw) in enumerate(patch_hws): # from small to large
243
+ # find the nearest embedding
244
+ z_NC = (
245
+ F.interpolate(f_rest, size=(ph, pw), mode="area")
246
+ .permute(0, 2, 3, 1)
247
+ .reshape(-1, C)
248
+ if (si != SN - 1)
249
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
250
+ )
251
+ if noise_std is not None:
252
+ z_NC = math.sqrt(1 - noise_std ** 2) * z_NC + torch.randn_like(z_NC) * noise_std
253
+
254
+ if self.using_znorm:
255
+ z_NC = F.normalize(z_NC, dim=-1)
256
+ idx_N = torch.argmax(
257
+ z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1
258
+ )
259
+ else:
260
+ d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
261
+ self.embedding.weight.data.square(), dim=1, keepdim=False
262
+ )
263
+ d_no_grad.addmm_(
264
+ z_NC, self.embedding.weight.data.T, alpha=-2, beta=1
265
+ ) # (B*h*w, vocab_size)
266
+ idx_N = torch.argmin(d_no_grad, dim=1)
267
+
268
+ idx_Bhw = idx_N.view(B, ph, pw)
269
+ h_BChw = (
270
+ F.interpolate(
271
+ self.embedding(idx_Bhw).permute(0, 3, 1, 2),
272
+ size=(H, W),
273
+ mode="bicubic",
274
+ ).contiguous()
275
+ if (si != SN - 1)
276
+ else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
277
+ )
278
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
279
+ f_hat.add_(h_BChw)
280
+ f_rest.sub_(h_BChw)
281
+ f_hat_or_idx_Bl.append(
282
+ f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw)
283
+ )
284
+
285
+ return f_hat_or_idx_Bl
286
+
287
+ # ===================== idxBl_to_switti_input: only used in Switti training, for getting teacher-forcing input =====================
288
+ def idxBl_to_switti_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
289
+ next_scales = []
290
+ B = gt_ms_idx_Bl[0].shape[0]
291
+ C = self.Cvae
292
+ H = W = self.v_patch_nums[-1]
293
+ SN = len(self.v_patch_nums)
294
+
295
+ f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
296
+ pn_next: int = self.v_patch_nums[0]
297
+ for si in range(SN - 1):
298
+ h_BChw = F.interpolate(
299
+ self.embedding(gt_ms_idx_Bl[si])
300
+ .transpose_(1, 2)
301
+ .view(B, C, pn_next, pn_next),
302
+ size=(H, W),
303
+ mode="bicubic",
304
+ )
305
+ f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
306
+ pn_next = self.v_patch_nums[si + 1]
307
+ next_scales.append(
308
+ F.interpolate(f_hat, size=(pn_next, pn_next), mode="area")
309
+ .view(B, C, -1)
310
+ .transpose(1, 2)
311
+ )
312
+ # cat BlCs to BLC, this should be float32
313
+ return torch.cat(next_scales, dim=1) if len(next_scales) else None
314
+
315
+ # ===================== get_next_autoregressive_input: only used in Switti inference, for getting next step's input =====================
316
+ def get_next_autoregressive_input(
317
+ self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor
318
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in Switti inference
319
+ HW = self.v_patch_nums[-1]
320
+ if si != SN - 1:
321
+ h = self.quant_resi[si / (SN - 1)](
322
+ F.interpolate(h_BChw, size=(HW, HW), mode="bicubic")
323
+ ) # conv after upsample
324
+ f_hat.add_(h)
325
+ return f_hat, F.interpolate(
326
+ f_hat,
327
+ size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
328
+ mode="area",
329
+ )
330
+ else:
331
+ h = self.quant_resi[si / (SN - 1)](h_BChw)
332
+ f_hat.add_(h)
333
+ return f_hat, f_hat
334
+
335
+
336
+ class Phi(nn.Conv2d):
337
+ def __init__(self, embed_dim, quant_resi):
338
+ ks = 3
339
+ super().__init__(
340
+ in_channels=embed_dim,
341
+ out_channels=embed_dim,
342
+ kernel_size=ks,
343
+ stride=1,
344
+ padding=ks // 2,
345
+ )
346
+ self.resi_ratio = abs(quant_resi)
347
+
348
+ def forward(self, h_BChw):
349
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(
350
+ self.resi_ratio
351
+ )
352
+
353
+
354
+ class PhiShared(nn.Module):
355
+ def __init__(self, qresi: Phi):
356
+ super().__init__()
357
+ self.qresi: Phi = qresi
358
+
359
+ def __getitem__(self, _) -> Phi:
360
+ return self.qresi
361
+
362
+
363
+ class PhiPartiallyShared(nn.Module):
364
+ def __init__(self, qresi_ls: nn.ModuleList):
365
+ super().__init__()
366
+ self.qresi_ls = qresi_ls
367
+ K = len(qresi_ls)
368
+ self.ticks = (
369
+ np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
370
+ if K == 4
371
+ else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
372
+ )
373
+
374
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
375
+ return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
376
+
377
+ def extra_repr(self) -> str:
378
+ return f"ticks={self.ticks}"
379
+
380
+
381
+ class PhiNonShared(nn.ModuleList):
382
+ def __init__(self, qresi: List):
383
+ super().__init__(qresi)
384
+ # self.qresi = qresi
385
+ K = len(qresi)
386
+ self.ticks = (
387
+ np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
388
+ if K == 4
389
+ else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
390
+ )
391
+
392
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
393
+ return super().__getitem__(
394
+ np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()
395
+ )
396
+
397
+ def extra_repr(self) -> str:
398
+ return f"ticks={self.ticks}"
models/rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def init_t_xy(end_x: int, end_y: int):
5
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
6
+ t_x = (t % end_x).float()
7
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
8
+ return t_x, t_y
9
+
10
+
11
+ def compute_axial_cis(
12
+ dim: int, end_x: int, end_y: int, theta: float = 100.0, norm_coeff: int = 1
13
+ ):
14
+ freqs_x = (
15
+ 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
16
+ * norm_coeff
17
+ )
18
+ freqs_y = (
19
+ 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
20
+ * norm_coeff
21
+ )
22
+
23
+ t_x, t_y = init_t_xy(end_x, end_y)
24
+ freqs_x = torch.outer(t_x, freqs_x)
25
+ freqs_y = torch.outer(t_y, freqs_y)
26
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
27
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
28
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
29
+
30
+
31
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32
+ ndim = x.ndim
33
+ assert 0 <= 1 < ndim
34
+ freqs_cis = freqs_cis[:, x.shape[1], ...]
35
+ if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
36
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
37
+ elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
38
+ shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
39
+ return freqs_cis.view(*shape)
40
+
41
+
42
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor):
43
+ with torch.cuda.amp.autocast(enabled=False):
44
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
45
+ # freqs_cis = reshape_for_broadcast(freqs_cis, x).to(x_in.device)
46
+ freqs_cis = freqs_cis[None, :, : x.shape[2], ...].to(x_in.device)
47
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
48
+ return x_out.type_as(x_in)
models/switti.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from diffusers.models.embeddings import GaussianFourierProjection
9
+
10
+ from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
11
+ from models.rope import compute_axial_cis
12
+
13
+
14
+ def get_crop_condition(
15
+ heights: list,
16
+ widths: list,
17
+ base_size=512
18
+ ):
19
+ if type(heights[0]) == type(widths[0]) == str:
20
+ heights = [int(h) for h in heights]
21
+ widths = [int(w) for w in widths]
22
+ h = torch.tensor(heights, dtype=torch.int).unsqueeze(1)
23
+ w = torch.tensor(widths, dtype=torch.int).unsqueeze(1)
24
+ hw = torch.cat([h, w], dim=1)
25
+
26
+ ratio = base_size / hw.min(-1)[0]
27
+ orig_size = (hw * ratio[:, None]).to(torch.int)
28
+ crop_coords = ((orig_size - base_size) // 2).clamp(min=0)
29
+ crop_cond = torch.cat([orig_size, crop_coords], dim=1)
30
+
31
+ return crop_cond
32
+
33
+
34
+ class Switti(nn.Module):
35
+ def __init__(
36
+ self,
37
+ Cvae=32,
38
+ V=4096,
39
+ rope=True,
40
+ rope_theta=10000,
41
+ rope_size=128,
42
+ depth=16,
43
+ embed_dim=1024,
44
+ num_heads=16,
45
+ mlp_ratio=4.0,
46
+ drop_rate=0.0,
47
+ attn_drop_rate=0.0,
48
+ drop_path_rate=0.0,
49
+ norm_eps=1e-6,
50
+ attn_l2_norm=True,
51
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
52
+ fused_if_available=True,
53
+ use_swiglu_ffn=True,
54
+ use_ar=False,
55
+ use_crop_cond=True,
56
+ ):
57
+ super().__init__()
58
+ # 0. hyperparameters
59
+ assert embed_dim % num_heads == 0
60
+ self.depth, self.C, self.D, self.num_heads = (
61
+ depth,
62
+ embed_dim,
63
+ embed_dim,
64
+ num_heads,
65
+ )
66
+ self.Cvae, self.V = Cvae, V
67
+
68
+ self.patch_nums: Tuple[int] = patch_nums
69
+ self.L = sum(pn**2 for pn in self.patch_nums)
70
+ self.first_l = self.patch_nums[0] ** 2
71
+ self.rope = rope
72
+
73
+ self.num_stages_minus_1 = len(self.patch_nums) - 1
74
+ self.rng = torch.Generator(device="cuda")
75
+
76
+ # 1. input (word) embedding
77
+ self.word_embed = nn.Linear(self.Cvae, self.C)
78
+
79
+ # 2. text embedding
80
+ self.pooled_embed_size = 1280
81
+ self.context_dim = 1280 + 768
82
+ self.text_pooler = nn.Linear(self.pooled_embed_size, self.D)
83
+
84
+ init_std = math.sqrt(1 / self.C / 3)
85
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
86
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
87
+
88
+ # 3. position embedding
89
+ if not self.rope:
90
+ # absolute position embedding
91
+ pos_1LC = []
92
+ for i, pn in enumerate(self.patch_nums):
93
+ pe = torch.empty(1, pn * pn, self.C)
94
+ nn.init.trunc_normal_(pe, mean=0, std=init_std)
95
+ pos_1LC.append(pe)
96
+ pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
97
+ assert tuple(pos_1LC.shape) == (1, self.L, self.C)
98
+ self.pos_1LC = nn.Parameter(pos_1LC)
99
+ self.freqs_cis = None
100
+ else:
101
+ # RoPE position embedding
102
+ assert (
103
+ self.C // self.num_heads
104
+ ) % 4 == 0, "2d rope needs head dim to be divisible by 4"
105
+ patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums)
106
+ self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads)
107
+ freqs_cis = []
108
+ for i, pn in enumerate(self.patch_nums):
109
+ norm_coeff = rope_size / patch_nums_m1[i]
110
+ cur_freqs = self.compute_cis(
111
+ end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff
112
+ )
113
+ freqs_cis.append(cur_freqs[None, ...])
114
+ self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex
115
+
116
+ # level embedding (similar to GPT's segment embedding,
117
+ # used to distinguish different levels of token pyramid)
118
+ self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
119
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
120
+
121
+ # 4. backbone blocks
122
+ self.drop_path_rate = drop_path_rate
123
+ # stochastic depth decay rule (linearly increasing)
124
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
125
+ self.blocks = nn.ModuleList([])
126
+ for block_idx in range(depth):
127
+ self.blocks.append(
128
+ AdaLNSelfCrossAttn(
129
+ cond_dim=self.D,
130
+ block_idx=block_idx,
131
+ embed_dim=self.C,
132
+ num_heads=num_heads,
133
+ mlp_ratio=mlp_ratio,
134
+ drop=drop_rate,
135
+ attn_drop=attn_drop_rate,
136
+ drop_path=dpr[block_idx],
137
+ last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
138
+ qk_norm=attn_l2_norm,
139
+ context_dim=self.context_dim,
140
+ use_swiglu_ffn=use_swiglu_ffn,
141
+ norm_eps=norm_eps,
142
+ use_crop_cond=use_crop_cond,
143
+ )
144
+ )
145
+
146
+ fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
147
+ self.using_fused_add_norm_fn = any(fused_add_norm_fns)
148
+ print(
149
+ f"\n[constructor] ==== fused_if_available={fused_if_available} "
150
+ f"(fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, "
151
+ f"fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n"
152
+ f" [Switti config ] embed_dim={embed_dim}, num_heads={num_heads}, "
153
+ f"depth={depth}, mlp_ratio={mlp_ratio}\n"
154
+ f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, "
155
+ f"drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})",
156
+ end="\n\n",
157
+ flush=True,
158
+ )
159
+
160
+ # Prepare crop condition embedder
161
+ self.use_crop_cond = use_crop_cond
162
+ if use_crop_cond:
163
+ # crop condition is repredsented with 4 int values. each is embeded to self.D // 4 dim
164
+ assert self.D % 8 == 0
165
+ self.crop_embed = GaussianFourierProjection(
166
+ self.D // 2 // 4, set_W_to_weight=False, log=False, flip_sin_to_cos=False
167
+ )
168
+ self.crop_proj = nn.Linear(self.D, self.D)
169
+
170
+ # 5. attention mask used in training (for masking out the future)
171
+ # it won't be used in inference, since kv cache is enabled
172
+ self.use_ar = use_ar
173
+ d: torch.Tensor = torch.cat(
174
+ [torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]
175
+ ).view(1, self.L, 1)
176
+ dT = d.transpose(1, 2) # dT: 11L
177
+ lvl_1L = dT[:, 0].contiguous()
178
+ self.register_buffer("lvl_1L", lvl_1L)
179
+
180
+ if self.use_ar:
181
+ attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf)
182
+ else:
183
+ attn_bias_for_masking = torch.where(d == dT, 0.0, -torch.inf)
184
+
185
+ attn_bias_for_masking = attn_bias_for_masking.reshape(1, 1, self.L, self.L)
186
+ self.register_buffer(
187
+ "attn_bias_for_masking", attn_bias_for_masking.contiguous()
188
+ )
189
+
190
+ # 6. classifier head
191
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
192
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
193
+ self.head = nn.Linear(self.C, self.V)
194
+
195
+ # By default disable gradient checkpointing
196
+ self.use_gradient_checkpointing = False
197
+
198
+ def enable_gradient_checkpointing(self):
199
+ self.use_gradient_checkpointing = True
200
+
201
+ def disable_gradient_checkpointing(self):
202
+ self.use_gradient_checkpointing = False
203
+
204
+ def get_logits(
205
+ self,
206
+ h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
207
+ cond_BD: Optional[torch.Tensor],
208
+ ):
209
+ if not isinstance(h_or_h_and_residual, torch.Tensor):
210
+ h, resi = h_or_h_and_residual # fused_add_norm must be used
211
+ h = resi + self.blocks[-1].drop_path(h)
212
+ else: # fused_add_norm is not used
213
+ h = h_or_h_and_residual
214
+ return self.head(self.head_nm(h, cond_BD))
215
+
216
+
217
+ def forward(
218
+ self,
219
+ x_BLCv_wo_first_l: torch.Tensor,
220
+ prompt_embeds: torch.Tensor,
221
+ pooled_prompt_embeds: torch.Tensor,
222
+ prompt_attn_bias: torch.Tensor,
223
+ batch_height: list[int] | None = None,
224
+ batch_width: list[int] | None = None,
225
+ ) -> torch.Tensor: # returns logits_BLV
226
+ """
227
+ :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
228
+ :param prompt_embeds (B, context_len, self.context_dim):
229
+ text features from pipe.text_encoder and pipe.text_encoder_2,
230
+ concatenated along dim=-1, padded to longest along dim=1
231
+ :param pooled_prompt_embeds (B, self.pooled_embed_size):
232
+ pooled text features from pipe.text_encoder_2
233
+ :param prompt_attn_bias (B, context_len):
234
+ boolean mask to specify which tokens are not padding
235
+ :param batch_height (B,): original height of images in a batch.
236
+ :param batch_width (B,): original width of images in a batch.
237
+ Only used when self.use_crop_cond = True
238
+ :return: logits BLV, V is vocab_size
239
+ """
240
+ bg, ed = 0, self.L
241
+ B = x_BLCv_wo_first_l.shape[0]
242
+ with torch.amp.autocast('cuda', enabled=False):
243
+ pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds)
244
+
245
+ sos = cond_BD = pooled_prompt_embeds
246
+ sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(
247
+ B, self.first_l, -1
248
+ )
249
+
250
+ x_BLC = torch.cat(
251
+ (sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1
252
+ )
253
+ x_BLC += self.lvl_embed(
254
+ self.lvl_1L[:, :ed].expand(B, -1)
255
+ ) # lvl: BLC; pos: 1LC
256
+ if not self.rope:
257
+ x_BLC += self.pos_1LC[:, :ed]
258
+ attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
259
+
260
+ if self.use_crop_cond:
261
+ crop_coords = get_crop_condition(batch_height, batch_width).to(cond_BD.device)
262
+ crop_embed = self.crop_embed(crop_coords.view(-1)).reshape(B, self.D)
263
+ crop_cond = self.crop_proj(crop_embed)
264
+ else:
265
+ crop_cond = None
266
+
267
+ # hack: get the dtype if mixed precision is used
268
+ temp = x_BLC.new_ones(8, 8)
269
+ main_type = torch.matmul(temp, temp).dtype
270
+
271
+ x_BLC = x_BLC.to(dtype=main_type)
272
+ cond_BD = cond_BD.to(dtype=main_type)
273
+ attn_bias = attn_bias.to(dtype=main_type)
274
+
275
+ for block in self.blocks:
276
+ if self.use_gradient_checkpointing:
277
+ x_BLC = torch.utils.checkpoint.checkpoint(
278
+ block,
279
+ x=x_BLC,
280
+ cond_BD=cond_BD,
281
+ attn_bias=attn_bias,
282
+ context=prompt_embeds,
283
+ freqs_cis=self.freqs_cis,
284
+ context_attn_bias=prompt_attn_bias,
285
+ crop_cond=crop_cond,
286
+ use_reentrant=False,
287
+ )
288
+ else:
289
+ x_BLC = block(
290
+ x=x_BLC,
291
+ cond_BD=cond_BD,
292
+ attn_bias=attn_bias,
293
+ context=prompt_embeds,
294
+ freqs_cis=self.freqs_cis,
295
+ context_attn_bias=prompt_attn_bias,
296
+ crop_cond=crop_cond,
297
+ )
298
+
299
+ with torch.amp.autocast('cuda', enabled=not self.training):
300
+ x_BLC = self.get_logits(x_BLC, cond_BD.float())
301
+
302
+ return x_BLC # logits BLV, V is vocab_size
303
+
304
+ def init_weights(
305
+ self,
306
+ init_adaln=0.5,
307
+ init_adaln_gamma=1e-5,
308
+ init_head=0.02,
309
+ init_std=0.02,
310
+ ):
311
+ if init_std < 0:
312
+ init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
313
+
314
+ print(f"[init_weights] {type(self).__name__} with {init_std=:g}")
315
+ for m in self.modules():
316
+ with_weight = hasattr(m, "weight") and m.weight is not None
317
+ with_bias = hasattr(m, "bias") and m.bias is not None
318
+ if isinstance(m, nn.Linear):
319
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
320
+ if with_bias:
321
+ m.bias.data.zero_()
322
+ elif isinstance(m, nn.Embedding):
323
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
324
+ if m.padding_idx is not None:
325
+ m.weight.data[m.padding_idx].zero_()
326
+ elif isinstance(
327
+ m,
328
+ (
329
+ nn.LayerNorm,
330
+ nn.BatchNorm1d,
331
+ nn.BatchNorm2d,
332
+ nn.BatchNorm3d,
333
+ nn.SyncBatchNorm,
334
+ nn.GroupNorm,
335
+ nn.InstanceNorm1d,
336
+ nn.InstanceNorm2d,
337
+ nn.InstanceNorm3d,
338
+ ),
339
+ ):
340
+ if with_weight:
341
+ m.weight.data.fill_(1.0)
342
+ if with_bias:
343
+ m.bias.data.zero_()
344
+
345
+ if init_head >= 0:
346
+ if isinstance(self.head, nn.Linear):
347
+ self.head.weight.data.mul_(init_head)
348
+ self.head.bias.data.zero_()
349
+ elif isinstance(self.head, nn.Sequential):
350
+ self.head[-1].weight.data.mul_(init_head)
351
+ self.head[-1].bias.data.zero_()
352
+
353
+ if isinstance(self.head_nm, AdaLNBeforeHead):
354
+ self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
355
+ if (
356
+ hasattr(self.head_nm.ada_lin[-1], "bias")
357
+ and self.head_nm.ada_lin[-1].bias is not None
358
+ ):
359
+ self.head_nm.ada_lin[-1].bias.data.zero_()
360
+
361
+ depth = len(self.blocks)
362
+ for block in self.blocks:
363
+ block.attn.proj.weight.data.div_(math.sqrt(2 * depth))
364
+ block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth))
365
+ if hasattr(block.ffn, "fc2"):
366
+ block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
367
+
368
+ if hasattr(block, "ada_lin"):
369
+ block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln)
370
+ block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma)
371
+ if (
372
+ hasattr(block.ada_lin[-1], "bias")
373
+ and block.ada_lin[-1].bias is not None
374
+ ):
375
+ block.ada_lin[-1].bias.data.zero_()
376
+ elif hasattr(block, "ada_gss"):
377
+ block.ada_gss.data[:, :, 2:].mul_(init_adaln)
378
+ block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
379
+
380
+ def extra_repr(self):
381
+ return f"drop_path_rate={self.drop_path_rate:g}"
382
+
383
+
384
+ class SwittiHF(Switti, PyTorchModelHubMixin):
385
+ # tags=["image-generation"]):
386
+ def __init__(
387
+ self,
388
+ depth=30,
389
+ rope=True,
390
+ rope_theta=10000,
391
+ rope_size=128,
392
+ use_swiglu_ffn=True,
393
+ use_ar=False,
394
+ use_crop_cond=True,
395
+ ):
396
+ heads = depth
397
+ width = depth * 64
398
+ super().__init__(
399
+ depth=depth,
400
+ embed_dim=width,
401
+ num_heads=heads,
402
+ patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
403
+ rope=rope,
404
+ rope_theta=rope_theta,
405
+ rope_size=rope_size,
406
+ use_swiglu_ffn=use_swiglu_ffn,
407
+ use_ar=use_ar,
408
+ use_crop_cond=use_crop_cond,
409
+ )
models/vqvae.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
4
+ - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
5
+ - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+
14
+ from .basic_vae import Decoder, Encoder
15
+ from .quant import VectorQuantizer2
16
+
17
+
18
+
19
+ class VQVAE(nn.Module):
20
+ def __init__(
21
+ self,
22
+ vocab_size=4096,
23
+ z_channels=32,
24
+ ch=160,
25
+ dropout=0.0,
26
+ beta=0.25, # commitment loss weight
27
+ using_znorm=False, # whether to normalize when computing the nearest neighbors
28
+ quant_conv_ks=3, # quant conv kernel size
29
+ quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
30
+ share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
31
+ default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
32
+ # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
33
+ v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
34
+ test_mode=True,
35
+ ):
36
+ super().__init__()
37
+ self.test_mode = test_mode
38
+ self.V, self.Cvae = vocab_size, z_channels
39
+ # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
40
+ ddconfig = dict(
41
+ dropout=dropout,
42
+ ch=ch,
43
+ z_channels=z_channels,
44
+ in_channels=3,
45
+ ch_mult=(1, 1, 2, 2, 4),
46
+ num_res_blocks=2, # from vq-f16/config.yaml above
47
+ using_sa=True,
48
+ using_mid_sa=True, # from vq-f16/config.yaml above
49
+ # resamp_with_conv=True, # always True, removed.
50
+ )
51
+ ddconfig.pop("double_z", None) # only KL-VAE should use double_z=True
52
+ self.encoder = Encoder(double_z=False, **ddconfig)
53
+ self.decoder = Decoder(**ddconfig)
54
+
55
+ self.vocab_size = vocab_size
56
+ self.downsample = 2 ** (len(ddconfig["ch_mult"]) - 1)
57
+ self.quantize: VectorQuantizer2 = VectorQuantizer2(
58
+ vocab_size=vocab_size,
59
+ Cvae=self.Cvae,
60
+ using_znorm=using_znorm,
61
+ beta=beta,
62
+ default_qresi_counts=default_qresi_counts,
63
+ v_patch_nums=v_patch_nums,
64
+ quant_resi=quant_resi,
65
+ share_quant_resi=share_quant_resi,
66
+ )
67
+ self.quant_conv = torch.nn.Conv2d(
68
+ self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
69
+ )
70
+ self.post_quant_conv = torch.nn.Conv2d(
71
+ self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
72
+ )
73
+
74
+ if self.test_mode:
75
+ self.eval()
76
+ [p.requires_grad_(False) for p in self.parameters()]
77
+
78
+ # ===================== `forward` is only used in VAE training =====================
79
+ def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
80
+ VectorQuantizer2.forward
81
+ f_hat, usages, vq_loss = self.quantize(
82
+ self.quant_conv(self.encoder(inp)), ret_usages=ret_usages
83
+ )
84
+ return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
85
+
86
+ # ===================== `forward` is only used in VAE training =====================
87
+
88
+ def fhat_to_img(self, f_hat: torch.Tensor):
89
+ return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
90
+
91
+ def img_to_idxBl(
92
+ self,
93
+ inp_img_no_grad: torch.Tensor,
94
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
95
+ noise_std: Optional[float] = None,
96
+ ) -> List[torch.LongTensor]: # return List[Bl]
97
+ f = self.quant_conv(self.encoder(inp_img_no_grad))
98
+ return self.quantize.f_to_idxBl_or_fhat(
99
+ f, to_fhat=False, v_patch_nums=v_patch_nums, noise_std=noise_std,
100
+ )
101
+
102
+ def idxBl_to_img(
103
+ self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False
104
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
105
+ B = ms_idx_Bl[0].shape[0]
106
+ ms_h_BChw = []
107
+ for idx_Bl in ms_idx_Bl:
108
+ l = idx_Bl.shape[1]
109
+ pn = round(l**0.5)
110
+ ms_h_BChw.append(
111
+ self.quantize.embedding(idx_Bl)
112
+ .transpose(1, 2)
113
+ .view(B, self.Cvae, pn, pn)
114
+ )
115
+ return self.embed_to_img(
116
+ ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one
117
+ )
118
+
119
+ def embed_to_img(
120
+ self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False
121
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
122
+ if last_one:
123
+ return self.decoder(
124
+ self.post_quant_conv(
125
+ self.quantize.embed_to_fhat(
126
+ ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True
127
+ )
128
+ )
129
+ ).clamp_(-1, 1)
130
+ else:
131
+ return [
132
+ self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
133
+ for f_hat in self.quantize.embed_to_fhat(
134
+ ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False
135
+ )
136
+ ]
137
+
138
+ def img_to_reconstructed_img(
139
+ self,
140
+ x,
141
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
142
+ last_one=False,
143
+ ) -> List[torch.Tensor]:
144
+ f = self.quant_conv(self.encoder(x))
145
+ ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(
146
+ f, to_fhat=True, v_patch_nums=v_patch_nums
147
+ )
148
+ if last_one:
149
+ return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
150
+ else:
151
+ return [
152
+ self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
153
+ for f_hat in ls_f_hat_BChw
154
+ ]
155
+
156
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
157
+ if (
158
+ "quantize.ema_vocab_hit_SV" in state_dict
159
+ and state_dict["quantize.ema_vocab_hit_SV"].shape[0]
160
+ != self.quantize.ema_vocab_hit_SV.shape[0]
161
+ ):
162
+ state_dict["quantize.ema_vocab_hit_SV"] = self.quantize.ema_vocab_hit_SV
163
+ return super().load_state_dict(
164
+ state_dict=state_dict, strict=strict, assign=assign
165
+ )
166
+
167
+ class VQVAEHF(VQVAE, PyTorchModelHubMixin):
168
+ def __init__(
169
+ self,
170
+ vocab_size=4096,
171
+ z_channels=32,
172
+ ch=160,
173
+ test_mode=True,
174
+ share_quant_resi=4,
175
+ v_patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
176
+ ):
177
+ super().__init__(
178
+ vocab_size=vocab_size,
179
+ z_channels=z_channels,
180
+ ch=ch,
181
+ test_mode=test_mode,
182
+ share_quant_resi=share_quant_resi,
183
+ v_patch_nums=v_patch_nums,
184
+ )
requirements.txt CHANGED
@@ -1,6 +1,16 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.26.2
2
+ transformers==4.45.2
3
+ diffusers==0.31.0
4
+ einops==0.8.0
5
+ pytz==2024.2
6
+ wandb==0.18.7
7
+ torch==2.4.1
8
+ decord==0.6.0
9
+ numpy==2.1.2
10
+ Pillow==11.0.0
11
+ pytz==2024.2
12
+ scipy==1.14.1
13
+ torchvision==0.19.1
14
+ tqdm==4.66.5
15
+ gradio==5.7.1
16
+ spaces==0.30.4