Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
β’
55ca09f
1
Parent(s):
385f11a
init commit
Browse files- README.md +7 -8
- app.py +90 -73
- models/__init__.py +95 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/basic_switti.cpython-311.pyc +0 -0
- models/__pycache__/basic_vae.cpython-311.pyc +0 -0
- models/__pycache__/clip.cpython-311.pyc +0 -0
- models/__pycache__/helpers.cpython-311.pyc +0 -0
- models/__pycache__/pipeline.cpython-311.pyc +0 -0
- models/__pycache__/quant.cpython-311.pyc +0 -0
- models/__pycache__/rope.cpython-311.pyc +0 -0
- models/__pycache__/switti.cpython-311.pyc +0 -0
- models/__pycache__/vqvae.cpython-311.pyc +0 -0
- models/basic_switti.py +461 -0
- models/basic_vae.py +289 -0
- models/clip.py +50 -0
- models/helpers.py +93 -0
- models/pipeline.py +227 -0
- models/quant.py +398 -0
- models/rope.py +48 -0
- models/switti.py +409 -0
- models/vqvae.py +184 -0
- requirements.txt +16 -6
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: Switti
|
3 |
-
emoji: πΌ
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
10 |
short_description: Generate images with Switti
|
|
|
|
|
|
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Switti
|
|
|
|
|
|
|
3 |
sdk: gradio
|
4 |
+
emoji: π
|
5 |
+
colorFrom: red
|
6 |
+
colorTo: red
|
7 |
+
pinned: true
|
8 |
short_description: Generate images with Switti
|
9 |
+
preload_from_hub:
|
10 |
+
- yresearch/Switti
|
11 |
+
- yresearch/VQVAE-Switti
|
12 |
---
|
|
|
|
app.py
CHANGED
@@ -2,59 +2,67 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
5 |
-
|
6 |
-
from
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
-
model_repo_id = "
|
11 |
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch_dtype = torch.float16
|
14 |
-
else:
|
15 |
-
torch_dtype = torch.float32
|
16 |
|
17 |
-
pipe =
|
18 |
-
pipe = pipe.to(device)
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
-
MAX_IMAGE_SIZE = 1024
|
22 |
|
23 |
|
24 |
-
|
25 |
def infer(
|
26 |
prompt,
|
27 |
-
negative_prompt,
|
28 |
-
seed,
|
29 |
-
randomize_seed,
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
progress=gr.Progress(track_tqdm=True),
|
35 |
):
|
36 |
if randomize_seed:
|
37 |
seed = random.randint(0, MAX_SEED)
|
38 |
|
39 |
-
generator = torch.Generator().manual_seed(seed)
|
40 |
-
|
41 |
image = pipe(
|
42 |
prompt=prompt,
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
|
51 |
return image, seed
|
52 |
|
53 |
|
54 |
examples = [
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
]
|
59 |
|
60 |
css = """
|
@@ -66,8 +74,8 @@ css = """
|
|
66 |
|
67 |
with gr.Blocks(css=css) as demo:
|
68 |
with gr.Column(elem_id="col-container"):
|
69 |
-
gr.Markdown(" #
|
70 |
-
|
71 |
with gr.Row():
|
72 |
prompt = gr.Text(
|
73 |
label="Prompt",
|
@@ -81,59 +89,66 @@ with gr.Blocks(css=css) as demo:
|
|
81 |
|
82 |
result = gr.Image(label="Result", show_label=False)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
with gr.Accordion("Advanced Settings", open=False):
|
85 |
negative_prompt = gr.Text(
|
86 |
label="Negative prompt",
|
87 |
max_lines=1,
|
88 |
placeholder="Enter a negative prompt",
|
89 |
-
visible=
|
90 |
-
)
|
91 |
-
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
)
|
99 |
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
with gr.Row():
|
103 |
-
|
104 |
-
label="
|
105 |
-
minimum=
|
106 |
-
maximum=
|
107 |
-
step=
|
108 |
-
value=
|
109 |
)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
)
|
118 |
-
|
119 |
with gr.Row():
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
maximum=50,
|
132 |
step=1,
|
133 |
-
value=
|
134 |
)
|
135 |
|
136 |
-
|
|
|
137 |
gr.on(
|
138 |
triggers=[run_button.click, prompt.submit],
|
139 |
fn=infer,
|
@@ -142,10 +157,12 @@ with gr.Blocks(css=css) as demo:
|
|
142 |
negative_prompt,
|
143 |
seed,
|
144 |
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
guidance_scale,
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
],
|
150 |
outputs=[result, seed],
|
151 |
)
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
5 |
+
import spaces
|
6 |
+
from models import SwittiPipeline
|
7 |
import torch
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
model_repo_id = "yresearch/Switti"
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
pipe = SwittiPipeline.from_pretrained(model_repo_id, device=device)
|
|
|
14 |
|
15 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
16 |
|
17 |
|
18 |
+
@spaces.GPU(duration=65)
|
19 |
def infer(
|
20 |
prompt,
|
21 |
+
negative_prompt="",
|
22 |
+
seed=42,
|
23 |
+
randomize_seed=False,
|
24 |
+
guidance_scale=4.0,
|
25 |
+
top_k=400,
|
26 |
+
top_p=0.95,
|
27 |
+
more_smooth=True,
|
28 |
+
smooth_start_si=2,
|
29 |
+
turn_off_cfg_start_si=10,
|
30 |
progress=gr.Progress(track_tqdm=True),
|
31 |
):
|
32 |
if randomize_seed:
|
33 |
seed = random.randint(0, MAX_SEED)
|
34 |
|
|
|
|
|
35 |
image = pipe(
|
36 |
prompt=prompt,
|
37 |
+
null_prompt=negative_prompt,
|
38 |
+
cfg=guidance_scale,
|
39 |
+
top_p=top_p,
|
40 |
+
top_k=top_k,
|
41 |
+
more_smooth=more_smooth,
|
42 |
+
smooth_start_si=smooth_start_si,
|
43 |
+
turn_off_cfg_start_si=turn_off_cfg_start_si,
|
44 |
+
seed=seed,
|
45 |
+
)[0]
|
46 |
|
47 |
return image, seed
|
48 |
|
49 |
|
50 |
examples = [
|
51 |
+
"Cute winter dragon baby, kawaii, Pixar, ultra detailed, glacial background, extremely realistic.",
|
52 |
+
"Cat as a wizard",
|
53 |
+
("An ancient ruined archway on the moon, fantasy, ruins of an alien civilization, "
|
54 |
+
"concept art, blue sky, reflectionin water pool, large white planet rising behind it"),
|
55 |
+
("A lizard that looks very much like a man, with developed muscles, leather armor "
|
56 |
+
"with metal elements, in the hands of a large trident decorated with ancient runes,"
|
57 |
+
" against the background of a small lake, everything is well drawn in the style of fantasy"),
|
58 |
+
("The Mandalorian by masamune shirow, fighting stance, in the snow, "
|
59 |
+
"cinematic lighting, intricate detail, character design"),
|
60 |
+
"Phoenix woman brown skin asian eyes silver scales, full body, high detail",
|
61 |
+
("Portrait of an alien family from the 1970βs, futuristic clothes, "
|
62 |
+
"absurd alien helmet, straight line, surreal, strange, absurd, photorealistic, "
|
63 |
+
"Hasselblad, Kodak, portra 800, 35mm lens, F 2.8, photo studio."),
|
64 |
+
("32 β bit pixelated future Hiphop producer in glowing power street ware, "
|
65 |
+
"noriyoshi ohrai, in the style of minecraft tomer hanuka."),
|
66 |
]
|
67 |
|
68 |
css = """
|
|
|
74 |
|
75 |
with gr.Blocks(css=css) as demo:
|
76 |
with gr.Column(elem_id="col-container"):
|
77 |
+
gr.Markdown(" # [Switti](https://yandex-research.github.io/switti)")
|
78 |
+
gr.Markdown("[Learn more](https://yandex-research.github.io/switti) about Switti.")
|
79 |
with gr.Row():
|
80 |
prompt = gr.Text(
|
81 |
label="Prompt",
|
|
|
89 |
|
90 |
result = gr.Image(label="Result", show_label=False)
|
91 |
|
92 |
+
seed = gr.Number(
|
93 |
+
label="Seed",
|
94 |
+
minimum=0,
|
95 |
+
maximum=MAX_SEED,
|
96 |
+
value=0,
|
97 |
+
)
|
98 |
+
|
99 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
100 |
+
|
101 |
+
guidance_scale = gr.Slider(
|
102 |
+
label="Guidance scale",
|
103 |
+
minimum=0.0,
|
104 |
+
maximum=10.,
|
105 |
+
step=0.5,
|
106 |
+
value=4.,
|
107 |
+
)
|
108 |
+
|
109 |
with gr.Accordion("Advanced Settings", open=False):
|
110 |
negative_prompt = gr.Text(
|
111 |
label="Negative prompt",
|
112 |
max_lines=1,
|
113 |
placeholder="Enter a negative prompt",
|
114 |
+
visible=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
)
|
116 |
|
|
|
|
|
117 |
with gr.Row():
|
118 |
+
top_k = gr.Slider(
|
119 |
+
label="Sampling top k",
|
120 |
+
minimum=10,
|
121 |
+
maximum=1000,
|
122 |
+
step=10,
|
123 |
+
value=400,
|
124 |
)
|
125 |
+
top_p = gr.Slider(
|
126 |
+
label="Sampling top p",
|
127 |
+
minimum=0.0,
|
128 |
+
maximum=1.,
|
129 |
+
step=0.01,
|
130 |
+
value=0.95,
|
|
|
131 |
)
|
132 |
+
|
133 |
with gr.Row():
|
134 |
+
more_smooth = gr.Checkbox(label="Smoothing with Gumbel softmax sampling", value=True)
|
135 |
+
smooth_start_si = gr.Slider(
|
136 |
+
label="Smoothing starting scale",
|
137 |
+
minimum=0,
|
138 |
+
maximum=10,
|
139 |
+
step=1,
|
140 |
+
value=2,
|
141 |
)
|
142 |
+
turn_off_cfg_start_si = gr.Slider(
|
143 |
+
label="Disable CFG from scale",
|
144 |
+
minimum=0,
|
145 |
+
maximum=10,
|
|
|
146 |
step=1,
|
147 |
+
value=8,
|
148 |
)
|
149 |
|
150 |
+
|
151 |
+
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
|
152 |
gr.on(
|
153 |
triggers=[run_button.click, prompt.submit],
|
154 |
fn=infer,
|
|
|
157 |
negative_prompt,
|
158 |
seed,
|
159 |
randomize_seed,
|
|
|
|
|
160 |
guidance_scale,
|
161 |
+
top_k,
|
162 |
+
top_p,
|
163 |
+
more_smooth,
|
164 |
+
smooth_start_si,
|
165 |
+
turn_off_cfg_start_si,
|
166 |
],
|
167 |
outputs=[result, seed],
|
168 |
)
|
models/__init__.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .clip import FrozenCLIPEmbedder
|
4 |
+
from .switti import Switti
|
5 |
+
from .vqvae import VQVAE
|
6 |
+
from .pipeline import SwittiPipeline
|
7 |
+
|
8 |
+
|
9 |
+
def build_models(
|
10 |
+
# Shared args
|
11 |
+
device,
|
12 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
13 |
+
# VQVAE args
|
14 |
+
V=4096,
|
15 |
+
Cvae=32,
|
16 |
+
ch=160,
|
17 |
+
share_quant_resi=4,
|
18 |
+
# Switti args
|
19 |
+
depth=16,
|
20 |
+
rope=True,
|
21 |
+
rope_theta=10000,
|
22 |
+
rope_size=128,
|
23 |
+
use_swiglu_ffn=True,
|
24 |
+
use_ar=False,
|
25 |
+
use_crop_cond=True,
|
26 |
+
attn_l2_norm=True,
|
27 |
+
init_adaln=0.5,
|
28 |
+
init_adaln_gamma=1e-5,
|
29 |
+
init_head=0.02,
|
30 |
+
init_std=-1, # init_std < 0: automated
|
31 |
+
drop_rate=0.0,
|
32 |
+
attn_drop_rate=0.0,
|
33 |
+
dpr=0,
|
34 |
+
norm_eps=1e-6,
|
35 |
+
# pipeline args
|
36 |
+
text_encoder_path="openai/clip-vit-large-patch14",
|
37 |
+
text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
38 |
+
) -> tuple[VQVAE, Switti]:
|
39 |
+
heads = depth
|
40 |
+
width = depth * 64
|
41 |
+
if dpr > 0:
|
42 |
+
dpr = dpr * depth / 24
|
43 |
+
|
44 |
+
# disable built-in initialization for speed
|
45 |
+
for clz in (
|
46 |
+
nn.Linear,
|
47 |
+
nn.LayerNorm,
|
48 |
+
nn.BatchNorm2d,
|
49 |
+
nn.SyncBatchNorm,
|
50 |
+
nn.Conv1d,
|
51 |
+
nn.Conv2d,
|
52 |
+
nn.ConvTranspose1d,
|
53 |
+
nn.ConvTranspose2d,
|
54 |
+
):
|
55 |
+
setattr(clz, "reset_parameters", lambda self: None)
|
56 |
+
|
57 |
+
# build models
|
58 |
+
vae_local = VQVAE(
|
59 |
+
vocab_size=V,
|
60 |
+
z_channels=Cvae,
|
61 |
+
ch=ch,
|
62 |
+
test_mode=True,
|
63 |
+
share_quant_resi=share_quant_resi,
|
64 |
+
v_patch_nums=patch_nums,
|
65 |
+
).to(device)
|
66 |
+
|
67 |
+
switti_wo_ddp = Switti(
|
68 |
+
depth=depth,
|
69 |
+
embed_dim=width,
|
70 |
+
num_heads=heads,
|
71 |
+
drop_rate=drop_rate,
|
72 |
+
attn_drop_rate=attn_drop_rate,
|
73 |
+
drop_path_rate=dpr,
|
74 |
+
norm_eps=norm_eps,
|
75 |
+
attn_l2_norm=attn_l2_norm,
|
76 |
+
patch_nums=patch_nums,
|
77 |
+
rope=rope,
|
78 |
+
rope_theta=rope_theta,
|
79 |
+
rope_size=rope_size,
|
80 |
+
use_swiglu_ffn=use_swiglu_ffn,
|
81 |
+
use_ar=use_ar,
|
82 |
+
use_crop_cond=use_crop_cond,
|
83 |
+
).to(device)
|
84 |
+
|
85 |
+
switti_wo_ddp.init_weights(
|
86 |
+
init_adaln=init_adaln,
|
87 |
+
init_adaln_gamma=init_adaln_gamma,
|
88 |
+
init_head=init_head,
|
89 |
+
init_std=init_std,
|
90 |
+
)
|
91 |
+
text_encoder = FrozenCLIPEmbedder(text_encoder_path)
|
92 |
+
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
|
93 |
+
pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
|
94 |
+
|
95 |
+
return vae_local, switti_wo_ddp, pipe
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (2.9 kB). View file
|
|
models/__pycache__/basic_switti.cpython-311.pyc
ADDED
Binary file (23.3 kB). View file
|
|
models/__pycache__/basic_vae.cpython-311.pyc
ADDED
Binary file (15.8 kB). View file
|
|
models/__pycache__/clip.cpython-311.pyc
ADDED
Binary file (3.01 kB). View file
|
|
models/__pycache__/helpers.cpython-311.pyc
ADDED
Binary file (5.29 kB). View file
|
|
models/__pycache__/pipeline.cpython-311.pyc
ADDED
Binary file (12.6 kB). View file
|
|
models/__pycache__/quant.cpython-311.pyc
ADDED
Binary file (24.6 kB). View file
|
|
models/__pycache__/rope.cpython-311.pyc
ADDED
Binary file (4.45 kB). View file
|
|
models/__pycache__/switti.cpython-311.pyc
ADDED
Binary file (23.3 kB). View file
|
|
models/__pycache__/vqvae.cpython-311.pyc
ADDED
Binary file (10.2 kB). View file
|
|
models/basic_switti.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.functional import scaled_dot_product_attention # q, k, v: BHLc
|
9 |
+
|
10 |
+
from models.helpers import DropPath
|
11 |
+
from models.rope import apply_rotary_emb
|
12 |
+
|
13 |
+
try:
|
14 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
15 |
+
except ImportError:
|
16 |
+
fused_mlp_func = None
|
17 |
+
|
18 |
+
# this file only provides the blocks used in Switti transformer
|
19 |
+
__all__ = ["FFN", "SwiGLUFFN", "RMSNorm", "AdaLNSelfCrossAttn", "AdaLNBeforeHead"]
|
20 |
+
|
21 |
+
|
22 |
+
try:
|
23 |
+
from apex.normalization import FusedRMSNorm as RMSNorm
|
24 |
+
except ImportError:
|
25 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
26 |
+
|
27 |
+
class RMSNorm(torch.nn.Module):
|
28 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
29 |
+
"""
|
30 |
+
Initialize the RMSNorm normalization layer.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dim (int): The dimension of the input tensor.
|
34 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
35 |
+
|
36 |
+
Attributes:
|
37 |
+
eps (float): A small value added to the denominator for numerical stability.
|
38 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
39 |
+
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
self.eps = eps
|
43 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
44 |
+
|
45 |
+
def _norm(self, x):
|
46 |
+
"""
|
47 |
+
Apply the RMSNorm normalization to the input tensor.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): The input tensor.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: The normalized tensor.
|
54 |
+
|
55 |
+
"""
|
56 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
"""
|
60 |
+
Forward pass through the RMSNorm layer.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
x (torch.Tensor): The input tensor.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
67 |
+
|
68 |
+
"""
|
69 |
+
output = self._norm(x.float()).type_as(x)
|
70 |
+
return output * self.weight
|
71 |
+
|
72 |
+
|
73 |
+
class FFN(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
in_features,
|
77 |
+
hidden_features=None,
|
78 |
+
out_features=None,
|
79 |
+
drop=0.0,
|
80 |
+
fused_if_available=True,
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
|
84 |
+
out_features = out_features or in_features
|
85 |
+
hidden_features = hidden_features or in_features
|
86 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
87 |
+
self.act = nn.GELU(approximate="tanh")
|
88 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
89 |
+
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if self.fused_mlp_func is not None:
|
93 |
+
return self.drop(
|
94 |
+
self.fused_mlp_func(
|
95 |
+
x=x,
|
96 |
+
weight1=self.fc1.weight,
|
97 |
+
weight2=self.fc2.weight,
|
98 |
+
bias1=self.fc1.bias,
|
99 |
+
bias2=self.fc2.bias,
|
100 |
+
activation="gelu_approx",
|
101 |
+
save_pre_act=self.training,
|
102 |
+
return_residual=False,
|
103 |
+
checkpoint_lvl=0,
|
104 |
+
heuristic=0,
|
105 |
+
process_group=None,
|
106 |
+
)
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
return self.drop(self.fc2(self.act(self.fc1(x))))
|
110 |
+
|
111 |
+
def extra_repr(self) -> str:
|
112 |
+
return f"fused_mlp_func={self.fused_mlp_func is not None}"
|
113 |
+
|
114 |
+
|
115 |
+
class SwiGLUFFN(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
dim: int,
|
119 |
+
ff_mult: float = 8 / 3,
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Initialize the FeedForward module.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
dim (int): Input dimension.
|
126 |
+
ff_mult (float, optional): Custom multiplier for hidden dimension. Defaults to 4.
|
127 |
+
"""
|
128 |
+
super().__init__()
|
129 |
+
hidden_dim = int(dim * ff_mult)
|
130 |
+
|
131 |
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
132 |
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
133 |
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
134 |
+
self.fused_mlp_func = None
|
135 |
+
self._init()
|
136 |
+
|
137 |
+
def _init(self):
|
138 |
+
for module in self.modules():
|
139 |
+
if isinstance(module, nn.Linear):
|
140 |
+
nn.init.xavier_uniform_(module.weight)
|
141 |
+
if module.bias is not None:
|
142 |
+
nn.init.zeros_(module.bias)
|
143 |
+
|
144 |
+
# @torch.compile
|
145 |
+
def _forward_silu_gating(self, x_gate: torch.Tensor, x_up: torch.Tensor):
|
146 |
+
return F.silu(x_gate) * x_up
|
147 |
+
|
148 |
+
def forward(self, x: torch.Tensor):
|
149 |
+
return self.down_proj(
|
150 |
+
self._forward_silu_gating(self.gate_proj(x), self.up_proj(x))
|
151 |
+
)
|
152 |
+
|
153 |
+
def extra_repr(self) -> str:
|
154 |
+
return f"fused_mlp_func={self.fused_mlp_func is not None}"
|
155 |
+
|
156 |
+
|
157 |
+
class CrossAttention(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
embed_dim: int = 768,
|
161 |
+
context_dim: int = 2048,
|
162 |
+
num_heads: int = 12,
|
163 |
+
attn_drop: float = 0.0,
|
164 |
+
proj_drop: float = 0.0,
|
165 |
+
qk_norm: bool = False,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
assert embed_dim % num_heads == 0
|
169 |
+
assert attn_drop == 0.0
|
170 |
+
|
171 |
+
self.num_heads, self.head_dim = (
|
172 |
+
num_heads,
|
173 |
+
embed_dim // num_heads,
|
174 |
+
)
|
175 |
+
self.qk_norm = qk_norm
|
176 |
+
self.scale = 1 / math.sqrt(self.head_dim)
|
177 |
+
|
178 |
+
self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
|
179 |
+
self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
|
180 |
+
|
181 |
+
self.to_q = nn.Linear(embed_dim, embed_dim, bias=True)
|
182 |
+
self.to_kv = nn.Linear(context_dim, embed_dim * 2, bias=True)
|
183 |
+
|
184 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
185 |
+
self.proj_drop = (
|
186 |
+
nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
|
187 |
+
)
|
188 |
+
self.attn_drop = attn_drop
|
189 |
+
|
190 |
+
# only used during inference
|
191 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
192 |
+
|
193 |
+
def kv_caching(self, enable: bool):
|
194 |
+
self.caching, self.cached_k, self.cached_v = enable, None, None
|
195 |
+
|
196 |
+
def forward(self, x, context, context_attn_bias=None, freqs_cis=None):
|
197 |
+
B, L, C = x.shape
|
198 |
+
context_B, context_L, context_C = context.shape
|
199 |
+
assert B == context_B
|
200 |
+
|
201 |
+
q = self.to_q(x).view(B, L, -1) # BLD , self.num_heads, self.head_dim)
|
202 |
+
if self.qk_norm:
|
203 |
+
q = self.q_norm(q)
|
204 |
+
|
205 |
+
q = q.view(B, L, self.num_heads, self.head_dim)
|
206 |
+
q = q.permute(0, 2, 1, 3) # BHLc
|
207 |
+
|
208 |
+
if self.cached_k is None:
|
209 |
+
# not using caches or first scale inference
|
210 |
+
kv = self.to_kv(context).view(B, context_L, 2, -1) # qkv: BL3D
|
211 |
+
k, v = kv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLHD
|
212 |
+
|
213 |
+
if self.qk_norm:
|
214 |
+
k = self.k_norm(k)
|
215 |
+
|
216 |
+
k = k.view(B, context_L, self.num_heads, self.head_dim)
|
217 |
+
k = k.permute(0, 2, 1, 3) # BHLc
|
218 |
+
|
219 |
+
v = v.view(B, context_L, self.num_heads, self.head_dim)
|
220 |
+
v = v.permute(0, 2, 1, 3) # BHLc
|
221 |
+
|
222 |
+
if self.caching:
|
223 |
+
self.cached_k = k
|
224 |
+
self.cached_v = v
|
225 |
+
else:
|
226 |
+
k = self.cached_k
|
227 |
+
v = self.cached_v
|
228 |
+
|
229 |
+
if context_attn_bias is not None:
|
230 |
+
context_attn_bias = rearrange(context_attn_bias, "b j -> b 1 1 j")
|
231 |
+
|
232 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
233 |
+
out = (
|
234 |
+
scaled_dot_product_attention(
|
235 |
+
query=q,
|
236 |
+
key=k,
|
237 |
+
value=v,
|
238 |
+
scale=self.scale,
|
239 |
+
attn_mask=context_attn_bias,
|
240 |
+
dropout_p=dropout_p,
|
241 |
+
)
|
242 |
+
.transpose(1, 2)
|
243 |
+
.reshape(B, L, C)
|
244 |
+
)
|
245 |
+
|
246 |
+
return self.proj_drop(self.proj(out))
|
247 |
+
|
248 |
+
|
249 |
+
class SelfAttention(nn.Module):
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
block_idx: int,
|
253 |
+
embed_dim: int = 768,
|
254 |
+
num_heads: int = 12,
|
255 |
+
attn_drop: float = 0.0,
|
256 |
+
proj_drop: float = 0.0,
|
257 |
+
qk_norm: bool = False,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
assert embed_dim % num_heads == 0
|
261 |
+
self.block_idx, self.num_heads, self.head_dim = (
|
262 |
+
block_idx,
|
263 |
+
num_heads,
|
264 |
+
embed_dim // num_heads,
|
265 |
+
)
|
266 |
+
self.qk_norm = qk_norm
|
267 |
+
self.scale = 1 / math.sqrt(self.head_dim)
|
268 |
+
|
269 |
+
self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
|
270 |
+
self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6, elementwise_affine=False)
|
271 |
+
|
272 |
+
self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
|
273 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
274 |
+
self.proj_drop = (
|
275 |
+
nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
|
276 |
+
)
|
277 |
+
self.attn_drop = attn_drop
|
278 |
+
|
279 |
+
# only used during inference
|
280 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
281 |
+
|
282 |
+
def kv_caching(self, enable: bool):
|
283 |
+
self.caching, self.cached_k, self.cached_v = enable, None, None
|
284 |
+
|
285 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
286 |
+
def forward(self, x, attn_bias, freqs_cis: torch.Tensor = None):
|
287 |
+
B, L, C = x.shape
|
288 |
+
|
289 |
+
qkv = self.to_qkv(x).view(B, L, 3, -1)
|
290 |
+
q, k, v = qkv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLD
|
291 |
+
|
292 |
+
if self.qk_norm:
|
293 |
+
q = self.q_norm(q)
|
294 |
+
k = self.k_norm(k)
|
295 |
+
|
296 |
+
q = q.view(B, L, self.num_heads, self.head_dim)
|
297 |
+
q = q.permute(0, 2, 1, 3) # BHLc
|
298 |
+
k = k.view(B, L, self.num_heads, self.head_dim)
|
299 |
+
k = k.permute(0, 2, 1, 3) # BHLc
|
300 |
+
v = v.view(B, L, self.num_heads, self.head_dim)
|
301 |
+
v = v.permute(0, 2, 1, 3) # BHLc
|
302 |
+
dim_cat = 2
|
303 |
+
|
304 |
+
if freqs_cis is not None:
|
305 |
+
q = apply_rotary_emb(q, freqs_cis=freqs_cis)
|
306 |
+
k = apply_rotary_emb(k, freqs_cis=freqs_cis)
|
307 |
+
|
308 |
+
if self.caching:
|
309 |
+
if self.cached_k is None:
|
310 |
+
self.cached_k = k
|
311 |
+
self.cached_v = v
|
312 |
+
else:
|
313 |
+
k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat)
|
314 |
+
v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
|
315 |
+
|
316 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
317 |
+
out = (
|
318 |
+
scaled_dot_product_attention(
|
319 |
+
query=q,
|
320 |
+
key=k,
|
321 |
+
value=v,
|
322 |
+
scale=self.scale,
|
323 |
+
attn_mask=attn_bias,
|
324 |
+
dropout_p=dropout_p,
|
325 |
+
)
|
326 |
+
.transpose(1, 2)
|
327 |
+
.reshape(B, L, C)
|
328 |
+
)
|
329 |
+
|
330 |
+
return self.proj_drop(self.proj(out))
|
331 |
+
|
332 |
+
def extra_repr(self) -> str:
|
333 |
+
return f"attn_l2_norm={self.qk_norm}"
|
334 |
+
|
335 |
+
|
336 |
+
class AdaLNSelfCrossAttn(nn.Module):
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
block_idx,
|
340 |
+
last_drop_p,
|
341 |
+
embed_dim,
|
342 |
+
cond_dim,
|
343 |
+
num_heads,
|
344 |
+
mlp_ratio=4.0,
|
345 |
+
drop=0.0,
|
346 |
+
attn_drop=0.0,
|
347 |
+
drop_path=0.0,
|
348 |
+
qk_norm=False,
|
349 |
+
context_dim=None,
|
350 |
+
use_swiglu_ffn=False,
|
351 |
+
norm_eps=1e-6,
|
352 |
+
use_crop_cond=False,
|
353 |
+
):
|
354 |
+
super().__init__()
|
355 |
+
assert attn_drop == 0.0
|
356 |
+
assert qk_norm
|
357 |
+
|
358 |
+
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
|
359 |
+
self.C, self.D = embed_dim, cond_dim
|
360 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
361 |
+
self.attn = SelfAttention(
|
362 |
+
block_idx=block_idx,
|
363 |
+
embed_dim=embed_dim,
|
364 |
+
num_heads=num_heads,
|
365 |
+
attn_drop=attn_drop,
|
366 |
+
proj_drop=drop,
|
367 |
+
qk_norm=qk_norm,
|
368 |
+
)
|
369 |
+
|
370 |
+
if context_dim:
|
371 |
+
self.cross_attn = CrossAttention(
|
372 |
+
embed_dim=embed_dim,
|
373 |
+
context_dim=context_dim,
|
374 |
+
num_heads=num_heads,
|
375 |
+
attn_drop=attn_drop,
|
376 |
+
proj_drop=drop,
|
377 |
+
qk_norm=qk_norm,
|
378 |
+
)
|
379 |
+
else:
|
380 |
+
self.cross_attn = None
|
381 |
+
|
382 |
+
if use_swiglu_ffn:
|
383 |
+
self.ffn = SwiGLUFFN(dim=embed_dim)
|
384 |
+
else:
|
385 |
+
self.ffn = FFN(
|
386 |
+
in_features=embed_dim,
|
387 |
+
hidden_features=round(embed_dim * mlp_ratio),
|
388 |
+
drop=drop,
|
389 |
+
)
|
390 |
+
|
391 |
+
self.self_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
|
392 |
+
self.self_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
|
393 |
+
self.cross_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
|
394 |
+
self.cross_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
|
395 |
+
|
396 |
+
self.ffn_norm1 = RMSNorm(embed_dim, eps=norm_eps)
|
397 |
+
self.ffn_norm2 = RMSNorm(embed_dim, eps=norm_eps)
|
398 |
+
|
399 |
+
self.attention_y_norm = RMSNorm(context_dim, eps=norm_eps)
|
400 |
+
|
401 |
+
# AdaLN
|
402 |
+
lin = nn.Linear(cond_dim, 6 * embed_dim)
|
403 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
|
404 |
+
|
405 |
+
self.fused_add_norm_fn = None
|
406 |
+
|
407 |
+
self.use_crop_cond = use_crop_cond
|
408 |
+
if use_crop_cond:
|
409 |
+
self.crop_cond_scales = nn.Parameter(torch.zeros(1, cond_dim))
|
410 |
+
|
411 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
412 |
+
def forward(
|
413 |
+
self,
|
414 |
+
x,
|
415 |
+
cond_BD,
|
416 |
+
attn_bias,
|
417 |
+
crop_cond=None,
|
418 |
+
context=None,
|
419 |
+
context_attn_bias=None,
|
420 |
+
freqs_cis=None,
|
421 |
+
): # C: embed_dim, D: cond_dim
|
422 |
+
|
423 |
+
if self.use_crop_cond:
|
424 |
+
assert crop_cond is not None
|
425 |
+
cond_BD = cond_BD + self.crop_cond_scales * crop_cond
|
426 |
+
|
427 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (
|
428 |
+
self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
429 |
+
)
|
430 |
+
x = x + self.self_attention_norm2(
|
431 |
+
self.attn(
|
432 |
+
self.self_attention_norm1(x).mul(scale1.add(1)).add(shift1),
|
433 |
+
attn_bias=attn_bias,
|
434 |
+
freqs_cis=freqs_cis,
|
435 |
+
)
|
436 |
+
).mul(gamma1)
|
437 |
+
if context is not None:
|
438 |
+
x = x + self.cross_attention_norm2(
|
439 |
+
self.cross_attn(
|
440 |
+
self.cross_attention_norm1(x),
|
441 |
+
self.attention_y_norm(context),
|
442 |
+
context_attn_bias=context_attn_bias,
|
443 |
+
freqs_cis=freqs_cis,
|
444 |
+
)
|
445 |
+
)
|
446 |
+
x = x + self.ffn_norm2(
|
447 |
+
self.ffn(self.ffn_norm1(x).mul(scale2.add(1)).add(shift2))
|
448 |
+
).mul(gamma2)
|
449 |
+
return x
|
450 |
+
|
451 |
+
|
452 |
+
class AdaLNBeforeHead(nn.Module):
|
453 |
+
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
|
454 |
+
super().__init__()
|
455 |
+
self.C, self.D = C, D
|
456 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
457 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2 * C))
|
458 |
+
|
459 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
|
460 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
461 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
models/basic_vae.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# this file only provides the 2 modules used in VQVAE
|
6 |
+
__all__ = [ "Encoder", "Decoder"]
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
# swish
|
15 |
+
def nonlinearity(x):
|
16 |
+
return x * torch.sigmoid(x)
|
17 |
+
|
18 |
+
|
19 |
+
def Normalize(in_channels, num_groups=32):
|
20 |
+
return torch.nn.GroupNorm(
|
21 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class Upsample2x(nn.Module):
|
26 |
+
def __init__(self, in_channels):
|
27 |
+
super().__init__()
|
28 |
+
self.conv = torch.nn.Conv2d(
|
29 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
|
34 |
+
|
35 |
+
|
36 |
+
class Downsample2x(nn.Module):
|
37 |
+
def __init__(self, in_channels):
|
38 |
+
super().__init__()
|
39 |
+
self.conv = torch.nn.Conv2d(
|
40 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode="constant", value=0))
|
45 |
+
|
46 |
+
|
47 |
+
class ResnetBlock(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self, *, in_channels, out_channels=None, dropout
|
50 |
+
): # conv_shortcut=False, # conv_shortcut: always False in VAE
|
51 |
+
super().__init__()
|
52 |
+
self.in_channels = in_channels
|
53 |
+
out_channels = in_channels if out_channels is None else out_channels
|
54 |
+
self.out_channels = out_channels
|
55 |
+
|
56 |
+
self.norm1 = Normalize(in_channels)
|
57 |
+
self.conv1 = torch.nn.Conv2d(
|
58 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
59 |
+
)
|
60 |
+
self.norm2 = Normalize(out_channels)
|
61 |
+
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
|
62 |
+
self.conv2 = torch.nn.Conv2d(
|
63 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
64 |
+
)
|
65 |
+
if self.in_channels != self.out_channels:
|
66 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
67 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
self.nin_shortcut = nn.Identity()
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
h = self.conv1(F.silu(self.norm1(x), inplace=True))
|
74 |
+
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
|
75 |
+
return self.nin_shortcut(x) + h
|
76 |
+
|
77 |
+
|
78 |
+
class AttnBlock(nn.Module):
|
79 |
+
def __init__(self, in_channels):
|
80 |
+
super().__init__()
|
81 |
+
self.C = in_channels
|
82 |
+
|
83 |
+
self.norm = Normalize(in_channels)
|
84 |
+
self.qkv = torch.nn.Conv2d(
|
85 |
+
in_channels, 3 * in_channels, kernel_size=1, stride=1, padding=0
|
86 |
+
)
|
87 |
+
self.w_ratio = int(in_channels) ** (-0.5)
|
88 |
+
self.proj_out = torch.nn.Conv2d(
|
89 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
qkv = self.qkv(self.norm(x))
|
94 |
+
B, _, H, W = qkv.shape # should be B,3C,H,W
|
95 |
+
C = self.C
|
96 |
+
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
|
97 |
+
|
98 |
+
# compute attention
|
99 |
+
q = q.view(B, C, H * W).contiguous()
|
100 |
+
q = q.permute(0, 2, 1).contiguous() # B,HW,C
|
101 |
+
k = k.view(B, C, H * W).contiguous() # B,C,HW
|
102 |
+
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW
|
103 |
+
# w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
|
104 |
+
w = F.softmax(w, dim=2)
|
105 |
+
|
106 |
+
# attend to values
|
107 |
+
v = v.view(B, C, H * W).contiguous()
|
108 |
+
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
|
109 |
+
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
|
110 |
+
h = h.view(B, C, H, W).contiguous()
|
111 |
+
|
112 |
+
return x + self.proj_out(h)
|
113 |
+
|
114 |
+
|
115 |
+
def make_attn(in_channels, using_sa=True):
|
116 |
+
return AttnBlock(in_channels) if using_sa else nn.Identity()
|
117 |
+
|
118 |
+
|
119 |
+
class Encoder(nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
*,
|
123 |
+
ch=128,
|
124 |
+
ch_mult=(1, 2, 4, 8),
|
125 |
+
num_res_blocks=2,
|
126 |
+
dropout=0.0,
|
127 |
+
in_channels=3,
|
128 |
+
z_channels,
|
129 |
+
double_z=False,
|
130 |
+
using_sa=True,
|
131 |
+
using_mid_sa=True,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.ch = ch
|
135 |
+
self.num_resolutions = len(ch_mult)
|
136 |
+
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
|
137 |
+
self.num_res_blocks = num_res_blocks
|
138 |
+
self.in_channels = in_channels
|
139 |
+
|
140 |
+
# downsampling
|
141 |
+
self.conv_in = torch.nn.Conv2d(
|
142 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
143 |
+
)
|
144 |
+
|
145 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
146 |
+
self.down = nn.ModuleList()
|
147 |
+
for i_level in range(self.num_resolutions):
|
148 |
+
block = nn.ModuleList()
|
149 |
+
attn = nn.ModuleList()
|
150 |
+
block_in = ch * in_ch_mult[i_level]
|
151 |
+
block_out = ch * ch_mult[i_level]
|
152 |
+
for i_block in range(self.num_res_blocks):
|
153 |
+
block.append(
|
154 |
+
ResnetBlock(
|
155 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout
|
156 |
+
)
|
157 |
+
)
|
158 |
+
block_in = block_out
|
159 |
+
if i_level == self.num_resolutions - 1 and using_sa:
|
160 |
+
attn.append(make_attn(block_in, using_sa=True))
|
161 |
+
down = nn.Module()
|
162 |
+
down.block = block
|
163 |
+
down.attn = attn
|
164 |
+
if i_level != self.num_resolutions - 1:
|
165 |
+
down.downsample = Downsample2x(block_in)
|
166 |
+
self.down.append(down)
|
167 |
+
|
168 |
+
# middle
|
169 |
+
self.mid = nn.Module()
|
170 |
+
self.mid.block_1 = ResnetBlock(
|
171 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
172 |
+
)
|
173 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
174 |
+
self.mid.block_2 = ResnetBlock(
|
175 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
176 |
+
)
|
177 |
+
|
178 |
+
# end
|
179 |
+
self.norm_out = Normalize(block_in)
|
180 |
+
self.conv_out = torch.nn.Conv2d(
|
181 |
+
block_in,
|
182 |
+
(2 * z_channels if double_z else z_channels),
|
183 |
+
kernel_size=3,
|
184 |
+
stride=1,
|
185 |
+
padding=1,
|
186 |
+
)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
# downsampling
|
190 |
+
h = self.conv_in(x)
|
191 |
+
for i_level in range(self.num_resolutions):
|
192 |
+
for i_block in range(self.num_res_blocks):
|
193 |
+
h = self.down[i_level].block[i_block](h)
|
194 |
+
if len(self.down[i_level].attn) > 0:
|
195 |
+
h = self.down[i_level].attn[i_block](h)
|
196 |
+
if i_level != self.num_resolutions - 1:
|
197 |
+
h = self.down[i_level].downsample(h)
|
198 |
+
|
199 |
+
# middle
|
200 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
|
201 |
+
|
202 |
+
# end
|
203 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
204 |
+
return h
|
205 |
+
|
206 |
+
|
207 |
+
class Decoder(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
*,
|
211 |
+
ch=128,
|
212 |
+
ch_mult=(1, 2, 4, 8),
|
213 |
+
num_res_blocks=2,
|
214 |
+
dropout=0.0,
|
215 |
+
in_channels=3, # in_channels: raw img channels
|
216 |
+
z_channels,
|
217 |
+
using_sa=True,
|
218 |
+
using_mid_sa=True,
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
self.ch = ch
|
222 |
+
self.num_resolutions = len(ch_mult)
|
223 |
+
self.num_res_blocks = num_res_blocks
|
224 |
+
self.in_channels = in_channels
|
225 |
+
|
226 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
227 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
228 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
229 |
+
|
230 |
+
# z to block_in
|
231 |
+
self.conv_in = torch.nn.Conv2d(
|
232 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
233 |
+
)
|
234 |
+
|
235 |
+
# middle
|
236 |
+
self.mid = nn.Module()
|
237 |
+
self.mid.block_1 = ResnetBlock(
|
238 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
239 |
+
)
|
240 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
241 |
+
self.mid.block_2 = ResnetBlock(
|
242 |
+
in_channels=block_in, out_channels=block_in, dropout=dropout
|
243 |
+
)
|
244 |
+
|
245 |
+
# upsampling
|
246 |
+
self.up = nn.ModuleList()
|
247 |
+
for i_level in reversed(range(self.num_resolutions)):
|
248 |
+
block = nn.ModuleList()
|
249 |
+
attn = nn.ModuleList()
|
250 |
+
block_out = ch * ch_mult[i_level]
|
251 |
+
for i_block in range(self.num_res_blocks + 1):
|
252 |
+
block.append(
|
253 |
+
ResnetBlock(
|
254 |
+
in_channels=block_in, out_channels=block_out, dropout=dropout
|
255 |
+
)
|
256 |
+
)
|
257 |
+
block_in = block_out
|
258 |
+
if i_level == self.num_resolutions - 1 and using_sa:
|
259 |
+
attn.append(make_attn(block_in, using_sa=True))
|
260 |
+
up = nn.Module()
|
261 |
+
up.block = block
|
262 |
+
up.attn = attn
|
263 |
+
if i_level != 0:
|
264 |
+
up.upsample = Upsample2x(block_in)
|
265 |
+
self.up.insert(0, up) # prepend to get consistent order
|
266 |
+
|
267 |
+
# end
|
268 |
+
self.norm_out = Normalize(block_in)
|
269 |
+
self.conv_out = torch.nn.Conv2d(
|
270 |
+
block_in, in_channels, kernel_size=3, stride=1, padding=1
|
271 |
+
)
|
272 |
+
|
273 |
+
def forward(self, z):
|
274 |
+
# z to block_in
|
275 |
+
# middle
|
276 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
|
277 |
+
|
278 |
+
# upsampling
|
279 |
+
for i_level in reversed(range(self.num_resolutions)):
|
280 |
+
for i_block in range(self.num_res_blocks + 1):
|
281 |
+
h = self.up[i_level].block[i_block](h)
|
282 |
+
if len(self.up[i_level].attn) > 0:
|
283 |
+
h = self.up[i_level].attn[i_block](h)
|
284 |
+
if i_level != 0:
|
285 |
+
h = self.up[i_level].upsample(h)
|
286 |
+
|
287 |
+
# end
|
288 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
289 |
+
return h
|
models/clip.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
class FrozenCLIPEmbedder(nn.Module):
|
7 |
+
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
version="openai/clip-vit-large-patch14",
|
12 |
+
device="cuda",
|
13 |
+
max_length=77,
|
14 |
+
freeze=True,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
18 |
+
self.transformer = CLIPTextModel.from_pretrained(version).to(device)
|
19 |
+
self.device = device
|
20 |
+
self.hidden_size = self.transformer.config.hidden_size
|
21 |
+
self.max_length = max_length
|
22 |
+
if freeze:
|
23 |
+
self.freeze()
|
24 |
+
|
25 |
+
def freeze(self):
|
26 |
+
self.transformer = self.transformer.eval()
|
27 |
+
for param in self.parameters():
|
28 |
+
param.requires_grad = False
|
29 |
+
|
30 |
+
def forward(self, text):
|
31 |
+
batch_encoding = self.tokenizer(
|
32 |
+
text,
|
33 |
+
truncation=True,
|
34 |
+
max_length=self.max_length,
|
35 |
+
return_overflowing_tokens=False,
|
36 |
+
padding="max_length",
|
37 |
+
return_tensors="pt",
|
38 |
+
).to(self.device)
|
39 |
+
|
40 |
+
outputs = self.transformer(**batch_encoding)
|
41 |
+
|
42 |
+
attn_bias = batch_encoding["attention_mask"].to(outputs["last_hidden_state"].dtype)
|
43 |
+
attn_bias[attn_bias == 0] = -float("inf")
|
44 |
+
attn_bias[attn_bias == 1] = 0.0
|
45 |
+
outputs["attn_bias"] = attn_bias
|
46 |
+
return outputs
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def encode(self, text):
|
50 |
+
return self(text)
|
models/helpers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def sample_with_top_k_top_p_(
|
7 |
+
logits_BlV: torch.Tensor,
|
8 |
+
top_k: int = 0,
|
9 |
+
top_p: float = 0.0,
|
10 |
+
rng=None,
|
11 |
+
num_samples=1,
|
12 |
+
) -> torch.Tensor: # return idx, shaped (B, l)
|
13 |
+
B, l, V = logits_BlV.shape
|
14 |
+
if top_k > 0:
|
15 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(
|
16 |
+
top_k, largest=True, sorted=False, dim=-1
|
17 |
+
)[0].amin(dim=-1, keepdim=True)
|
18 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
19 |
+
if top_p > 0:
|
20 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
21 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
22 |
+
sorted_idx_to_remove[..., -1:] = False
|
23 |
+
logits_BlV.masked_fill_(
|
24 |
+
sorted_idx_to_remove.scatter(
|
25 |
+
sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove
|
26 |
+
),
|
27 |
+
-torch.inf,
|
28 |
+
)
|
29 |
+
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
|
30 |
+
replacement = num_samples >= 0
|
31 |
+
num_samples = abs(num_samples)
|
32 |
+
return torch.multinomial(
|
33 |
+
logits_BlV.softmax(dim=-1).view(-1, V),
|
34 |
+
num_samples=num_samples,
|
35 |
+
replacement=replacement,
|
36 |
+
generator=rng,
|
37 |
+
).view(B, l, num_samples)
|
38 |
+
|
39 |
+
|
40 |
+
def gumbel_softmax_with_rng(
|
41 |
+
logits: torch.Tensor,
|
42 |
+
tau: float = 1,
|
43 |
+
hard: bool = False,
|
44 |
+
eps: float = 1e-10,
|
45 |
+
dim: int = -1,
|
46 |
+
rng: torch.Generator | None = None,
|
47 |
+
) -> torch.Tensor:
|
48 |
+
if rng is None:
|
49 |
+
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
|
50 |
+
|
51 |
+
gumbels = (
|
52 |
+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
53 |
+
.exponential_(generator=rng)
|
54 |
+
.log()
|
55 |
+
)
|
56 |
+
gumbels = (logits + gumbels) / tau
|
57 |
+
y_soft = gumbels.softmax(dim)
|
58 |
+
|
59 |
+
if hard:
|
60 |
+
index = y_soft.max(dim, keepdim=True)[1]
|
61 |
+
y_hard = torch.zeros_like(
|
62 |
+
logits, memory_format=torch.legacy_contiguous_format
|
63 |
+
).scatter_(dim, index, 1.0)
|
64 |
+
ret = y_hard - y_soft.detach() + y_soft
|
65 |
+
else:
|
66 |
+
ret = y_soft
|
67 |
+
return ret
|
68 |
+
|
69 |
+
|
70 |
+
def drop_path(
|
71 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
72 |
+
): # taken from timm
|
73 |
+
if drop_prob == 0.0 or not training:
|
74 |
+
return x
|
75 |
+
keep_prob = 1 - drop_prob
|
76 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
77 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
78 |
+
if keep_prob > 0.0 and scale_by_keep:
|
79 |
+
random_tensor.div_(keep_prob)
|
80 |
+
return x * random_tensor
|
81 |
+
|
82 |
+
|
83 |
+
class DropPath(nn.Module): # taken from timm
|
84 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
85 |
+
super(DropPath, self).__init__()
|
86 |
+
self.drop_prob = drop_prob
|
87 |
+
self.scale_by_keep = scale_by_keep
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
91 |
+
|
92 |
+
def extra_repr(self):
|
93 |
+
return f"(drop_prob=...)"
|
models/pipeline.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.transforms import ToPILImage
|
3 |
+
from PIL.Image import Image as PILImage
|
4 |
+
|
5 |
+
from models.vqvae import VQVAEHF
|
6 |
+
from models.clip import FrozenCLIPEmbedder
|
7 |
+
from models.switti import SwittiHF, get_crop_condition
|
8 |
+
from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
|
9 |
+
|
10 |
+
|
11 |
+
class SwittiPipeline:
|
12 |
+
vae_path = "yresearch/VQVAE-Switti"
|
13 |
+
text_encoder_path = "openai/clip-vit-large-patch14"
|
14 |
+
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
15 |
+
|
16 |
+
def __init__(self, switti, vae, text_encoder, text_encoder_2, device):
|
17 |
+
self.switti = switti
|
18 |
+
self.vae = vae
|
19 |
+
self.text_encoder = text_encoder
|
20 |
+
self.text_encoder_2 = text_encoder_2
|
21 |
+
|
22 |
+
self.switti.eval()
|
23 |
+
self.vae.eval()
|
24 |
+
|
25 |
+
self.device = device
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
|
29 |
+
switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device)
|
30 |
+
vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
|
31 |
+
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
|
32 |
+
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
|
33 |
+
|
34 |
+
return cls(switti, vae, text_encoder, text_encoder_2, device)
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def to_image(tensor):
|
38 |
+
return [ToPILImage()(
|
39 |
+
(255 * img.cpu().detach()).to(torch.uint8))
|
40 |
+
for img in tensor]
|
41 |
+
|
42 |
+
def _encode_prompt(self, prompt: str | list[str]):
|
43 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
44 |
+
encodings = [
|
45 |
+
self.text_encoder.encode(prompt),
|
46 |
+
self.text_encoder_2.encode(prompt),
|
47 |
+
]
|
48 |
+
prompt_embeds = torch.concat(
|
49 |
+
[encoding.last_hidden_state for encoding in encodings], dim=-1
|
50 |
+
)
|
51 |
+
pooled_prompt_embeds = encodings[-1].pooler_output
|
52 |
+
attn_bias = encodings[-1].attn_bias
|
53 |
+
|
54 |
+
return prompt_embeds, pooled_prompt_embeds, attn_bias
|
55 |
+
|
56 |
+
def encode_prompt(
|
57 |
+
self,
|
58 |
+
prompt: str | list[str],
|
59 |
+
null_prompt: str = "",
|
60 |
+
encode_null: bool = True,
|
61 |
+
):
|
62 |
+
prompt_embeds, pooled_prompt_embeds, attn_bias = self._encode_prompt(prompt)
|
63 |
+
if encode_null:
|
64 |
+
B, L, hidden_dim = prompt_embeds.shape
|
65 |
+
pooled_dim = pooled_prompt_embeds.shape[1]
|
66 |
+
|
67 |
+
null_embeds, null_pooled_embeds, null_attn_bias = self._encode_prompt(null_prompt)
|
68 |
+
|
69 |
+
null_embeds = null_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device)
|
70 |
+
null_pooled_embeds = null_pooled_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device)
|
71 |
+
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device)
|
72 |
+
|
73 |
+
prompt_embeds = torch.cat([prompt_embeds, null_embeds], dim=0)
|
74 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_embeds], dim=0)
|
75 |
+
attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0)
|
76 |
+
|
77 |
+
return prompt_embeds, pooled_prompt_embeds, attn_bias
|
78 |
+
|
79 |
+
@torch.inference_mode()
|
80 |
+
def __call__(
|
81 |
+
self,
|
82 |
+
prompt: str | list[str],
|
83 |
+
null_prompt: str = "",
|
84 |
+
seed: int | None = None,
|
85 |
+
cfg: float = 4.0,
|
86 |
+
top_k: int = 400,
|
87 |
+
top_p: float = 0.95,
|
88 |
+
more_smooth: bool = False,
|
89 |
+
return_pil: bool = True,
|
90 |
+
smooth_start_si: int = 0,
|
91 |
+
turn_off_cfg_start_si: int = 10,
|
92 |
+
image_size: tuple[int, int] = (512, 512),
|
93 |
+
) -> torch.Tensor | list[PILImage]:
|
94 |
+
"""
|
95 |
+
only used for inference, on autoregressive mode
|
96 |
+
:param prompt: text prompt to generate an image
|
97 |
+
:param null_prompt: negative prompt for CFG
|
98 |
+
:param seed: random seed
|
99 |
+
:param cfg: classifier-free guidance ratio
|
100 |
+
:param top_k: top-k sampling
|
101 |
+
:param top_p: top-p sampling
|
102 |
+
:param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
103 |
+
:return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
|
104 |
+
"""
|
105 |
+
assert not self.switti.training
|
106 |
+
switti = self.switti
|
107 |
+
vae = self.vae
|
108 |
+
vae_quant = self.vae.quantize
|
109 |
+
if seed is None:
|
110 |
+
rng = None
|
111 |
+
else:
|
112 |
+
switti.rng.manual_seed(seed)
|
113 |
+
rng = switti.rng
|
114 |
+
|
115 |
+
context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
|
116 |
+
|
117 |
+
B = context.shape[0] // 2
|
118 |
+
|
119 |
+
cond_vector = switti.text_pooler(cond_vector)
|
120 |
+
|
121 |
+
if switti.use_crop_cond:
|
122 |
+
crop_coords = get_crop_condition(2 * B * [image_size[0]],
|
123 |
+
2 * B * [image_size[1]],
|
124 |
+
).to(cond_vector.device)
|
125 |
+
crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
|
126 |
+
crop_cond = switti.crop_proj(crop_embed)
|
127 |
+
else:
|
128 |
+
crop_cond = None
|
129 |
+
|
130 |
+
sos = cond_BD = cond_vector
|
131 |
+
|
132 |
+
lvl_pos = switti.lvl_embed(switti.lvl_1L)
|
133 |
+
if not switti.rope:
|
134 |
+
lvl_pos += switti.pos_1LC
|
135 |
+
next_token_map = (
|
136 |
+
sos.unsqueeze(1)
|
137 |
+
+ switti.pos_start.expand(2 * B, switti.first_l, -1)
|
138 |
+
+ lvl_pos[:, : switti.first_l]
|
139 |
+
)
|
140 |
+
cur_L = 0
|
141 |
+
f_hat = sos.new_zeros(B, switti.Cvae, switti.patch_nums[-1], switti.patch_nums[-1])
|
142 |
+
|
143 |
+
for b in switti.blocks:
|
144 |
+
b.attn.kv_caching(switti.use_ar) # Use KV caching if switti is in the AR mode
|
145 |
+
b.cross_attn.kv_caching(True)
|
146 |
+
|
147 |
+
for si, pn in enumerate(switti.patch_nums): # si: i-th segment
|
148 |
+
ratio = si / switti.num_stages_minus_1
|
149 |
+
x_BLC = next_token_map
|
150 |
+
|
151 |
+
if switti.rope:
|
152 |
+
freqs_cis = switti.freqs_cis[:, cur_L : cur_L + pn * pn]
|
153 |
+
else:
|
154 |
+
freqs_cis = switti.freqs_cis
|
155 |
+
|
156 |
+
if si >= turn_off_cfg_start_si:
|
157 |
+
x_BLC = x_BLC[:B]
|
158 |
+
context = context[:B]
|
159 |
+
context_attn_bias = context_attn_bias[:B]
|
160 |
+
freqs_cis = freqs_cis[:B]
|
161 |
+
cond_BD = cond_BD[:B]
|
162 |
+
if crop_cond is not None:
|
163 |
+
crop_cond = crop_cond[:B]
|
164 |
+
for b in switti.blocks:
|
165 |
+
if b.attn.caching:
|
166 |
+
b.attn.cached_k = b.attn.cached_k[:B]
|
167 |
+
b.attn.cached_v = b.attn.cached_v[:B]
|
168 |
+
if b.cross_attn.caching:
|
169 |
+
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
|
170 |
+
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
|
171 |
+
|
172 |
+
for block in switti.blocks:
|
173 |
+
x_BLC = block(
|
174 |
+
x=x_BLC,
|
175 |
+
cond_BD=cond_BD,
|
176 |
+
attn_bias=None,
|
177 |
+
context=context,
|
178 |
+
context_attn_bias=context_attn_bias,
|
179 |
+
freqs_cis=freqs_cis,
|
180 |
+
crop_cond=crop_cond,
|
181 |
+
)
|
182 |
+
cur_L += pn * pn
|
183 |
+
|
184 |
+
logits_BlV = switti.get_logits(x_BLC, cond_BD)
|
185 |
+
|
186 |
+
# Guidance
|
187 |
+
if si < turn_off_cfg_start_si:
|
188 |
+
t = cfg * ratio
|
189 |
+
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
190 |
+
|
191 |
+
if more_smooth and si >= smooth_start_si:
|
192 |
+
# not used when evaluating FID/IS/Precision/Recall
|
193 |
+
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
194 |
+
idx_Bl = gumbel_softmax_with_rng(
|
195 |
+
logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng,
|
196 |
+
)
|
197 |
+
h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
|
198 |
+
else:
|
199 |
+
# defaul nucleus sampling
|
200 |
+
idx_Bl = sample_with_top_k_top_p_(
|
201 |
+
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
|
202 |
+
)[:, :, 0]
|
203 |
+
h_BChw = vae_quant.embedding(idx_Bl)
|
204 |
+
|
205 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, switti.Cvae, pn, pn)
|
206 |
+
f_hat, next_token_map = vae_quant.get_next_autoregressive_input(
|
207 |
+
si, len(switti.patch_nums), f_hat, h_BChw,
|
208 |
+
)
|
209 |
+
if si != switti.num_stages_minus_1: # prepare for next stage
|
210 |
+
next_token_map = next_token_map.view(B, switti.Cvae, -1).transpose(1, 2)
|
211 |
+
next_token_map = (
|
212 |
+
switti.word_embed(next_token_map)
|
213 |
+
+ lvl_pos[:, cur_L : cur_L + switti.patch_nums[si + 1] ** 2]
|
214 |
+
)
|
215 |
+
# double the batch sizes due to CFG
|
216 |
+
next_token_map = next_token_map.repeat(2, 1, 1)
|
217 |
+
|
218 |
+
for b in switti.blocks:
|
219 |
+
b.attn.kv_caching(False)
|
220 |
+
b.cross_attn.kv_caching(False)
|
221 |
+
|
222 |
+
# de-normalize, from [-1, 1] to [0, 1]
|
223 |
+
img = vae.fhat_to_img(f_hat).add(1).mul(0.5)
|
224 |
+
if return_pil:
|
225 |
+
img = self.to_image(img)
|
226 |
+
|
227 |
+
return img
|
models/quant.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import distributed as tdist
|
7 |
+
from torch import nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
# this file only provides the VectorQuantizer2 used in VQVAE
|
11 |
+
__all__ = ["VectorQuantizer2"]
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer2(nn.Module):
|
15 |
+
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
vocab_size,
|
19 |
+
Cvae,
|
20 |
+
using_znorm,
|
21 |
+
beta: float = 0.25,
|
22 |
+
default_qresi_counts=0,
|
23 |
+
v_patch_nums=None,
|
24 |
+
quant_resi=0.5,
|
25 |
+
share_quant_resi=4, # share_quant_resi: args.qsr
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.vocab_size: int = vocab_size
|
29 |
+
self.Cvae: int = Cvae
|
30 |
+
self.using_znorm: bool = using_znorm
|
31 |
+
self.v_patch_nums: Tuple[int] = v_patch_nums
|
32 |
+
|
33 |
+
self.quant_resi_ratio = quant_resi
|
34 |
+
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
|
35 |
+
self.quant_resi = PhiNonShared(
|
36 |
+
[
|
37 |
+
(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
|
38 |
+
for _ in range(default_qresi_counts or len(self.v_patch_nums))
|
39 |
+
]
|
40 |
+
)
|
41 |
+
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
|
42 |
+
self.quant_resi = PhiShared(
|
43 |
+
Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()
|
44 |
+
)
|
45 |
+
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
|
46 |
+
self.quant_resi = PhiPartiallyShared(
|
47 |
+
nn.ModuleList([(
|
48 |
+
Phi(Cvae, quant_resi)
|
49 |
+
if abs(quant_resi) > 1e-6
|
50 |
+
else nn.Identity()
|
51 |
+
) for _ in range(share_quant_resi)])
|
52 |
+
)
|
53 |
+
|
54 |
+
self.register_buffer(
|
55 |
+
"ema_vocab_hit_SV",
|
56 |
+
torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0),
|
57 |
+
)
|
58 |
+
self.record_hit = 0
|
59 |
+
|
60 |
+
self.beta: float = beta
|
61 |
+
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
|
62 |
+
|
63 |
+
def eini(self, eini):
|
64 |
+
if eini > 0:
|
65 |
+
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
66 |
+
elif eini < 0:
|
67 |
+
self.embedding.weight.data.uniform_(
|
68 |
+
-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size
|
69 |
+
)
|
70 |
+
|
71 |
+
def extra_repr(self) -> str:
|
72 |
+
return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}"
|
73 |
+
|
74 |
+
# ===================== `forward` is only used in VAE training =====================
|
75 |
+
def forward(
|
76 |
+
self, f_BChw: torch.Tensor, ret_usages=False
|
77 |
+
) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
|
78 |
+
dtype = f_BChw.dtype
|
79 |
+
if dtype != torch.float32:
|
80 |
+
f_BChw = f_BChw.float()
|
81 |
+
B, C, H, W = f_BChw.shape
|
82 |
+
f_no_grad = f_BChw.detach()
|
83 |
+
|
84 |
+
f_rest = f_no_grad.clone()
|
85 |
+
f_hat = torch.zeros_like(f_rest)
|
86 |
+
|
87 |
+
with torch.cuda.amp.autocast(enabled=False):
|
88 |
+
mean_vq_loss: torch.Tensor = 0.0
|
89 |
+
vocab_hit_V = torch.zeros(
|
90 |
+
self.vocab_size, dtype=torch.float, device=f_BChw.device
|
91 |
+
)
|
92 |
+
SN = len(self.v_patch_nums)
|
93 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
94 |
+
# find the nearest embedding
|
95 |
+
if self.using_znorm:
|
96 |
+
rest_NC = (
|
97 |
+
F.interpolate(f_rest, size=(pn, pn), mode="area")
|
98 |
+
.permute(0, 2, 3, 1)
|
99 |
+
.reshape(-1, C)
|
100 |
+
if (si != SN - 1)
|
101 |
+
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
102 |
+
)
|
103 |
+
rest_NC = F.normalize(rest_NC, dim=-1)
|
104 |
+
idx_N = torch.argmax(
|
105 |
+
rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0),
|
106 |
+
dim=1,
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
rest_NC = (
|
110 |
+
F.interpolate(f_rest, size=(pn, pn), mode="area")
|
111 |
+
.permute(0, 2, 3, 1)
|
112 |
+
.reshape(-1, C)
|
113 |
+
if (si != SN - 1)
|
114 |
+
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
115 |
+
)
|
116 |
+
d_no_grad = torch.sum(
|
117 |
+
rest_NC.square(), dim=1, keepdim=True
|
118 |
+
) + torch.sum(
|
119 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False
|
120 |
+
)
|
121 |
+
d_no_grad.addmm_(
|
122 |
+
rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1
|
123 |
+
) # (B*h*w, vocab_size)
|
124 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
125 |
+
|
126 |
+
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
|
127 |
+
if self.training:
|
128 |
+
# if dist.initialized():
|
129 |
+
handler = tdist.all_reduce(hit_V, async_op=True)
|
130 |
+
|
131 |
+
# calc loss
|
132 |
+
idx_Bhw = idx_N.view(B, pn, pn)
|
133 |
+
h_BChw = (
|
134 |
+
F.interpolate(
|
135 |
+
self.embedding(idx_Bhw).permute(0, 3, 1, 2),
|
136 |
+
size=(H, W),
|
137 |
+
mode="bicubic",
|
138 |
+
).contiguous()
|
139 |
+
if (si != SN - 1)
|
140 |
+
else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
141 |
+
)
|
142 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
143 |
+
f_hat = f_hat + h_BChw
|
144 |
+
f_rest -= h_BChw
|
145 |
+
|
146 |
+
if self.training: # and dist.initialized():
|
147 |
+
handler.wait()
|
148 |
+
if self.record_hit == 0:
|
149 |
+
self.ema_vocab_hit_SV[si].copy_(hit_V)
|
150 |
+
elif self.record_hit < 100:
|
151 |
+
self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
|
152 |
+
else:
|
153 |
+
self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
|
154 |
+
self.record_hit += 1
|
155 |
+
vocab_hit_V.add_(hit_V)
|
156 |
+
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
|
157 |
+
|
158 |
+
mean_vq_loss *= 1.0 / SN
|
159 |
+
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
|
160 |
+
|
161 |
+
margin = (
|
162 |
+
tdist.get_world_size()
|
163 |
+
* (f_BChw.numel() / f_BChw.shape[1])
|
164 |
+
/ self.vocab_size
|
165 |
+
* 0.08
|
166 |
+
)
|
167 |
+
# margin = pn*pn / 100
|
168 |
+
if ret_usages:
|
169 |
+
usages = [
|
170 |
+
(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100
|
171 |
+
for si, pn in enumerate(self.v_patch_nums)
|
172 |
+
]
|
173 |
+
else:
|
174 |
+
usages = None
|
175 |
+
return f_hat, usages, mean_vq_loss
|
176 |
+
|
177 |
+
# ===================== `forward` is only used in VAE training =====================
|
178 |
+
|
179 |
+
def embed_to_fhat(
|
180 |
+
self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False
|
181 |
+
) -> Union[List[torch.Tensor], torch.Tensor]:
|
182 |
+
ls_f_hat_BChw = []
|
183 |
+
B = ms_h_BChw[0].shape[0]
|
184 |
+
H = W = self.v_patch_nums[-1]
|
185 |
+
SN = len(self.v_patch_nums)
|
186 |
+
if all_to_max_scale:
|
187 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
|
188 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
189 |
+
h_BChw = ms_h_BChw[si]
|
190 |
+
if si < len(self.v_patch_nums) - 1:
|
191 |
+
h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic")
|
192 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
193 |
+
f_hat.add_(h_BChw)
|
194 |
+
if last_one:
|
195 |
+
ls_f_hat_BChw = f_hat
|
196 |
+
else:
|
197 |
+
ls_f_hat_BChw.append(f_hat.clone())
|
198 |
+
else:
|
199 |
+
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
|
200 |
+
# WARNING: this should only be used for experimental purpose
|
201 |
+
f_hat = ms_h_BChw[0].new_zeros(
|
202 |
+
B,
|
203 |
+
self.Cvae,
|
204 |
+
self.v_patch_nums[0],
|
205 |
+
self.v_patch_nums[0],
|
206 |
+
dtype=torch.float32,
|
207 |
+
)
|
208 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
209 |
+
f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic")
|
210 |
+
h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
|
211 |
+
f_hat.add_(h_BChw)
|
212 |
+
if last_one:
|
213 |
+
ls_f_hat_BChw = f_hat
|
214 |
+
else:
|
215 |
+
ls_f_hat_BChw.append(f_hat)
|
216 |
+
|
217 |
+
return ls_f_hat_BChw
|
218 |
+
|
219 |
+
def f_to_idxBl_or_fhat(
|
220 |
+
self,
|
221 |
+
f_BChw: torch.Tensor,
|
222 |
+
to_fhat: bool,
|
223 |
+
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
|
224 |
+
noise_std: Optional[float] = None,
|
225 |
+
) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
|
226 |
+
B, C, H, W = f_BChw.shape
|
227 |
+
f_no_grad = f_BChw.detach()
|
228 |
+
f_rest = f_no_grad.clone()
|
229 |
+
f_hat = torch.zeros_like(f_rest)
|
230 |
+
|
231 |
+
f_hat_or_idx_Bl: List[torch.Tensor] = []
|
232 |
+
|
233 |
+
patch_hws = [
|
234 |
+
(pn, pn) if isinstance(pn, int) else (pn[0], pn[1])
|
235 |
+
for pn in (v_patch_nums or self.v_patch_nums)
|
236 |
+
] # from small to large
|
237 |
+
assert (
|
238 |
+
patch_hws[-1][0] == H and patch_hws[-1][1] == W
|
239 |
+
), f"{patch_hws[-1]=} != ({H=}, {W=})"
|
240 |
+
|
241 |
+
SN = len(patch_hws)
|
242 |
+
for si, (ph, pw) in enumerate(patch_hws): # from small to large
|
243 |
+
# find the nearest embedding
|
244 |
+
z_NC = (
|
245 |
+
F.interpolate(f_rest, size=(ph, pw), mode="area")
|
246 |
+
.permute(0, 2, 3, 1)
|
247 |
+
.reshape(-1, C)
|
248 |
+
if (si != SN - 1)
|
249 |
+
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
250 |
+
)
|
251 |
+
if noise_std is not None:
|
252 |
+
z_NC = math.sqrt(1 - noise_std ** 2) * z_NC + torch.randn_like(z_NC) * noise_std
|
253 |
+
|
254 |
+
if self.using_znorm:
|
255 |
+
z_NC = F.normalize(z_NC, dim=-1)
|
256 |
+
idx_N = torch.argmax(
|
257 |
+
z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
|
261 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False
|
262 |
+
)
|
263 |
+
d_no_grad.addmm_(
|
264 |
+
z_NC, self.embedding.weight.data.T, alpha=-2, beta=1
|
265 |
+
) # (B*h*w, vocab_size)
|
266 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
267 |
+
|
268 |
+
idx_Bhw = idx_N.view(B, ph, pw)
|
269 |
+
h_BChw = (
|
270 |
+
F.interpolate(
|
271 |
+
self.embedding(idx_Bhw).permute(0, 3, 1, 2),
|
272 |
+
size=(H, W),
|
273 |
+
mode="bicubic",
|
274 |
+
).contiguous()
|
275 |
+
if (si != SN - 1)
|
276 |
+
else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
277 |
+
)
|
278 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
279 |
+
f_hat.add_(h_BChw)
|
280 |
+
f_rest.sub_(h_BChw)
|
281 |
+
f_hat_or_idx_Bl.append(
|
282 |
+
f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw)
|
283 |
+
)
|
284 |
+
|
285 |
+
return f_hat_or_idx_Bl
|
286 |
+
|
287 |
+
# ===================== idxBl_to_switti_input: only used in Switti training, for getting teacher-forcing input =====================
|
288 |
+
def idxBl_to_switti_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
|
289 |
+
next_scales = []
|
290 |
+
B = gt_ms_idx_Bl[0].shape[0]
|
291 |
+
C = self.Cvae
|
292 |
+
H = W = self.v_patch_nums[-1]
|
293 |
+
SN = len(self.v_patch_nums)
|
294 |
+
|
295 |
+
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
|
296 |
+
pn_next: int = self.v_patch_nums[0]
|
297 |
+
for si in range(SN - 1):
|
298 |
+
h_BChw = F.interpolate(
|
299 |
+
self.embedding(gt_ms_idx_Bl[si])
|
300 |
+
.transpose_(1, 2)
|
301 |
+
.view(B, C, pn_next, pn_next),
|
302 |
+
size=(H, W),
|
303 |
+
mode="bicubic",
|
304 |
+
)
|
305 |
+
f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
|
306 |
+
pn_next = self.v_patch_nums[si + 1]
|
307 |
+
next_scales.append(
|
308 |
+
F.interpolate(f_hat, size=(pn_next, pn_next), mode="area")
|
309 |
+
.view(B, C, -1)
|
310 |
+
.transpose(1, 2)
|
311 |
+
)
|
312 |
+
# cat BlCs to BLC, this should be float32
|
313 |
+
return torch.cat(next_scales, dim=1) if len(next_scales) else None
|
314 |
+
|
315 |
+
# ===================== get_next_autoregressive_input: only used in Switti inference, for getting next step's input =====================
|
316 |
+
def get_next_autoregressive_input(
|
317 |
+
self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor
|
318 |
+
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in Switti inference
|
319 |
+
HW = self.v_patch_nums[-1]
|
320 |
+
if si != SN - 1:
|
321 |
+
h = self.quant_resi[si / (SN - 1)](
|
322 |
+
F.interpolate(h_BChw, size=(HW, HW), mode="bicubic")
|
323 |
+
) # conv after upsample
|
324 |
+
f_hat.add_(h)
|
325 |
+
return f_hat, F.interpolate(
|
326 |
+
f_hat,
|
327 |
+
size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
|
328 |
+
mode="area",
|
329 |
+
)
|
330 |
+
else:
|
331 |
+
h = self.quant_resi[si / (SN - 1)](h_BChw)
|
332 |
+
f_hat.add_(h)
|
333 |
+
return f_hat, f_hat
|
334 |
+
|
335 |
+
|
336 |
+
class Phi(nn.Conv2d):
|
337 |
+
def __init__(self, embed_dim, quant_resi):
|
338 |
+
ks = 3
|
339 |
+
super().__init__(
|
340 |
+
in_channels=embed_dim,
|
341 |
+
out_channels=embed_dim,
|
342 |
+
kernel_size=ks,
|
343 |
+
stride=1,
|
344 |
+
padding=ks // 2,
|
345 |
+
)
|
346 |
+
self.resi_ratio = abs(quant_resi)
|
347 |
+
|
348 |
+
def forward(self, h_BChw):
|
349 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(
|
350 |
+
self.resi_ratio
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
class PhiShared(nn.Module):
|
355 |
+
def __init__(self, qresi: Phi):
|
356 |
+
super().__init__()
|
357 |
+
self.qresi: Phi = qresi
|
358 |
+
|
359 |
+
def __getitem__(self, _) -> Phi:
|
360 |
+
return self.qresi
|
361 |
+
|
362 |
+
|
363 |
+
class PhiPartiallyShared(nn.Module):
|
364 |
+
def __init__(self, qresi_ls: nn.ModuleList):
|
365 |
+
super().__init__()
|
366 |
+
self.qresi_ls = qresi_ls
|
367 |
+
K = len(qresi_ls)
|
368 |
+
self.ticks = (
|
369 |
+
np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
|
370 |
+
if K == 4
|
371 |
+
else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
372 |
+
)
|
373 |
+
|
374 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
375 |
+
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
|
376 |
+
|
377 |
+
def extra_repr(self) -> str:
|
378 |
+
return f"ticks={self.ticks}"
|
379 |
+
|
380 |
+
|
381 |
+
class PhiNonShared(nn.ModuleList):
|
382 |
+
def __init__(self, qresi: List):
|
383 |
+
super().__init__(qresi)
|
384 |
+
# self.qresi = qresi
|
385 |
+
K = len(qresi)
|
386 |
+
self.ticks = (
|
387 |
+
np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
|
388 |
+
if K == 4
|
389 |
+
else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
390 |
+
)
|
391 |
+
|
392 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
393 |
+
return super().__getitem__(
|
394 |
+
np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()
|
395 |
+
)
|
396 |
+
|
397 |
+
def extra_repr(self) -> str:
|
398 |
+
return f"ticks={self.ticks}"
|
models/rope.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def init_t_xy(end_x: int, end_y: int):
|
5 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
6 |
+
t_x = (t % end_x).float()
|
7 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
8 |
+
return t_x, t_y
|
9 |
+
|
10 |
+
|
11 |
+
def compute_axial_cis(
|
12 |
+
dim: int, end_x: int, end_y: int, theta: float = 100.0, norm_coeff: int = 1
|
13 |
+
):
|
14 |
+
freqs_x = (
|
15 |
+
1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
16 |
+
* norm_coeff
|
17 |
+
)
|
18 |
+
freqs_y = (
|
19 |
+
1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
20 |
+
* norm_coeff
|
21 |
+
)
|
22 |
+
|
23 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
24 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
25 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
26 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
27 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
28 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
29 |
+
|
30 |
+
|
31 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
32 |
+
ndim = x.ndim
|
33 |
+
assert 0 <= 1 < ndim
|
34 |
+
freqs_cis = freqs_cis[:, x.shape[1], ...]
|
35 |
+
if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
|
36 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
37 |
+
elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
|
38 |
+
shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
|
39 |
+
return freqs_cis.view(*shape)
|
40 |
+
|
41 |
+
|
42 |
+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor):
|
43 |
+
with torch.cuda.amp.autocast(enabled=False):
|
44 |
+
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
45 |
+
# freqs_cis = reshape_for_broadcast(freqs_cis, x).to(x_in.device)
|
46 |
+
freqs_cis = freqs_cis[None, :, : x.shape[2], ...].to(x_in.device)
|
47 |
+
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
48 |
+
return x_out.type_as(x_in)
|
models/switti.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
8 |
+
from diffusers.models.embeddings import GaussianFourierProjection
|
9 |
+
|
10 |
+
from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
|
11 |
+
from models.rope import compute_axial_cis
|
12 |
+
|
13 |
+
|
14 |
+
def get_crop_condition(
|
15 |
+
heights: list,
|
16 |
+
widths: list,
|
17 |
+
base_size=512
|
18 |
+
):
|
19 |
+
if type(heights[0]) == type(widths[0]) == str:
|
20 |
+
heights = [int(h) for h in heights]
|
21 |
+
widths = [int(w) for w in widths]
|
22 |
+
h = torch.tensor(heights, dtype=torch.int).unsqueeze(1)
|
23 |
+
w = torch.tensor(widths, dtype=torch.int).unsqueeze(1)
|
24 |
+
hw = torch.cat([h, w], dim=1)
|
25 |
+
|
26 |
+
ratio = base_size / hw.min(-1)[0]
|
27 |
+
orig_size = (hw * ratio[:, None]).to(torch.int)
|
28 |
+
crop_coords = ((orig_size - base_size) // 2).clamp(min=0)
|
29 |
+
crop_cond = torch.cat([orig_size, crop_coords], dim=1)
|
30 |
+
|
31 |
+
return crop_cond
|
32 |
+
|
33 |
+
|
34 |
+
class Switti(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
Cvae=32,
|
38 |
+
V=4096,
|
39 |
+
rope=True,
|
40 |
+
rope_theta=10000,
|
41 |
+
rope_size=128,
|
42 |
+
depth=16,
|
43 |
+
embed_dim=1024,
|
44 |
+
num_heads=16,
|
45 |
+
mlp_ratio=4.0,
|
46 |
+
drop_rate=0.0,
|
47 |
+
attn_drop_rate=0.0,
|
48 |
+
drop_path_rate=0.0,
|
49 |
+
norm_eps=1e-6,
|
50 |
+
attn_l2_norm=True,
|
51 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
52 |
+
fused_if_available=True,
|
53 |
+
use_swiglu_ffn=True,
|
54 |
+
use_ar=False,
|
55 |
+
use_crop_cond=True,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
# 0. hyperparameters
|
59 |
+
assert embed_dim % num_heads == 0
|
60 |
+
self.depth, self.C, self.D, self.num_heads = (
|
61 |
+
depth,
|
62 |
+
embed_dim,
|
63 |
+
embed_dim,
|
64 |
+
num_heads,
|
65 |
+
)
|
66 |
+
self.Cvae, self.V = Cvae, V
|
67 |
+
|
68 |
+
self.patch_nums: Tuple[int] = patch_nums
|
69 |
+
self.L = sum(pn**2 for pn in self.patch_nums)
|
70 |
+
self.first_l = self.patch_nums[0] ** 2
|
71 |
+
self.rope = rope
|
72 |
+
|
73 |
+
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
74 |
+
self.rng = torch.Generator(device="cuda")
|
75 |
+
|
76 |
+
# 1. input (word) embedding
|
77 |
+
self.word_embed = nn.Linear(self.Cvae, self.C)
|
78 |
+
|
79 |
+
# 2. text embedding
|
80 |
+
self.pooled_embed_size = 1280
|
81 |
+
self.context_dim = 1280 + 768
|
82 |
+
self.text_pooler = nn.Linear(self.pooled_embed_size, self.D)
|
83 |
+
|
84 |
+
init_std = math.sqrt(1 / self.C / 3)
|
85 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
86 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
87 |
+
|
88 |
+
# 3. position embedding
|
89 |
+
if not self.rope:
|
90 |
+
# absolute position embedding
|
91 |
+
pos_1LC = []
|
92 |
+
for i, pn in enumerate(self.patch_nums):
|
93 |
+
pe = torch.empty(1, pn * pn, self.C)
|
94 |
+
nn.init.trunc_normal_(pe, mean=0, std=init_std)
|
95 |
+
pos_1LC.append(pe)
|
96 |
+
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
|
97 |
+
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
|
98 |
+
self.pos_1LC = nn.Parameter(pos_1LC)
|
99 |
+
self.freqs_cis = None
|
100 |
+
else:
|
101 |
+
# RoPE position embedding
|
102 |
+
assert (
|
103 |
+
self.C // self.num_heads
|
104 |
+
) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
105 |
+
patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums)
|
106 |
+
self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads)
|
107 |
+
freqs_cis = []
|
108 |
+
for i, pn in enumerate(self.patch_nums):
|
109 |
+
norm_coeff = rope_size / patch_nums_m1[i]
|
110 |
+
cur_freqs = self.compute_cis(
|
111 |
+
end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff
|
112 |
+
)
|
113 |
+
freqs_cis.append(cur_freqs[None, ...])
|
114 |
+
self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex
|
115 |
+
|
116 |
+
# level embedding (similar to GPT's segment embedding,
|
117 |
+
# used to distinguish different levels of token pyramid)
|
118 |
+
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
|
119 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
120 |
+
|
121 |
+
# 4. backbone blocks
|
122 |
+
self.drop_path_rate = drop_path_rate
|
123 |
+
# stochastic depth decay rule (linearly increasing)
|
124 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
125 |
+
self.blocks = nn.ModuleList([])
|
126 |
+
for block_idx in range(depth):
|
127 |
+
self.blocks.append(
|
128 |
+
AdaLNSelfCrossAttn(
|
129 |
+
cond_dim=self.D,
|
130 |
+
block_idx=block_idx,
|
131 |
+
embed_dim=self.C,
|
132 |
+
num_heads=num_heads,
|
133 |
+
mlp_ratio=mlp_ratio,
|
134 |
+
drop=drop_rate,
|
135 |
+
attn_drop=attn_drop_rate,
|
136 |
+
drop_path=dpr[block_idx],
|
137 |
+
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
|
138 |
+
qk_norm=attn_l2_norm,
|
139 |
+
context_dim=self.context_dim,
|
140 |
+
use_swiglu_ffn=use_swiglu_ffn,
|
141 |
+
norm_eps=norm_eps,
|
142 |
+
use_crop_cond=use_crop_cond,
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
|
147 |
+
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
|
148 |
+
print(
|
149 |
+
f"\n[constructor] ==== fused_if_available={fused_if_available} "
|
150 |
+
f"(fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, "
|
151 |
+
f"fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n"
|
152 |
+
f" [Switti config ] embed_dim={embed_dim}, num_heads={num_heads}, "
|
153 |
+
f"depth={depth}, mlp_ratio={mlp_ratio}\n"
|
154 |
+
f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, "
|
155 |
+
f"drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})",
|
156 |
+
end="\n\n",
|
157 |
+
flush=True,
|
158 |
+
)
|
159 |
+
|
160 |
+
# Prepare crop condition embedder
|
161 |
+
self.use_crop_cond = use_crop_cond
|
162 |
+
if use_crop_cond:
|
163 |
+
# crop condition is repredsented with 4 int values. each is embeded to self.D // 4 dim
|
164 |
+
assert self.D % 8 == 0
|
165 |
+
self.crop_embed = GaussianFourierProjection(
|
166 |
+
self.D // 2 // 4, set_W_to_weight=False, log=False, flip_sin_to_cos=False
|
167 |
+
)
|
168 |
+
self.crop_proj = nn.Linear(self.D, self.D)
|
169 |
+
|
170 |
+
# 5. attention mask used in training (for masking out the future)
|
171 |
+
# it won't be used in inference, since kv cache is enabled
|
172 |
+
self.use_ar = use_ar
|
173 |
+
d: torch.Tensor = torch.cat(
|
174 |
+
[torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]
|
175 |
+
).view(1, self.L, 1)
|
176 |
+
dT = d.transpose(1, 2) # dT: 11L
|
177 |
+
lvl_1L = dT[:, 0].contiguous()
|
178 |
+
self.register_buffer("lvl_1L", lvl_1L)
|
179 |
+
|
180 |
+
if self.use_ar:
|
181 |
+
attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf)
|
182 |
+
else:
|
183 |
+
attn_bias_for_masking = torch.where(d == dT, 0.0, -torch.inf)
|
184 |
+
|
185 |
+
attn_bias_for_masking = attn_bias_for_masking.reshape(1, 1, self.L, self.L)
|
186 |
+
self.register_buffer(
|
187 |
+
"attn_bias_for_masking", attn_bias_for_masking.contiguous()
|
188 |
+
)
|
189 |
+
|
190 |
+
# 6. classifier head
|
191 |
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
192 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
|
193 |
+
self.head = nn.Linear(self.C, self.V)
|
194 |
+
|
195 |
+
# By default disable gradient checkpointing
|
196 |
+
self.use_gradient_checkpointing = False
|
197 |
+
|
198 |
+
def enable_gradient_checkpointing(self):
|
199 |
+
self.use_gradient_checkpointing = True
|
200 |
+
|
201 |
+
def disable_gradient_checkpointing(self):
|
202 |
+
self.use_gradient_checkpointing = False
|
203 |
+
|
204 |
+
def get_logits(
|
205 |
+
self,
|
206 |
+
h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
207 |
+
cond_BD: Optional[torch.Tensor],
|
208 |
+
):
|
209 |
+
if not isinstance(h_or_h_and_residual, torch.Tensor):
|
210 |
+
h, resi = h_or_h_and_residual # fused_add_norm must be used
|
211 |
+
h = resi + self.blocks[-1].drop_path(h)
|
212 |
+
else: # fused_add_norm is not used
|
213 |
+
h = h_or_h_and_residual
|
214 |
+
return self.head(self.head_nm(h, cond_BD))
|
215 |
+
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
x_BLCv_wo_first_l: torch.Tensor,
|
220 |
+
prompt_embeds: torch.Tensor,
|
221 |
+
pooled_prompt_embeds: torch.Tensor,
|
222 |
+
prompt_attn_bias: torch.Tensor,
|
223 |
+
batch_height: list[int] | None = None,
|
224 |
+
batch_width: list[int] | None = None,
|
225 |
+
) -> torch.Tensor: # returns logits_BLV
|
226 |
+
"""
|
227 |
+
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
|
228 |
+
:param prompt_embeds (B, context_len, self.context_dim):
|
229 |
+
text features from pipe.text_encoder and pipe.text_encoder_2,
|
230 |
+
concatenated along dim=-1, padded to longest along dim=1
|
231 |
+
:param pooled_prompt_embeds (B, self.pooled_embed_size):
|
232 |
+
pooled text features from pipe.text_encoder_2
|
233 |
+
:param prompt_attn_bias (B, context_len):
|
234 |
+
boolean mask to specify which tokens are not padding
|
235 |
+
:param batch_height (B,): original height of images in a batch.
|
236 |
+
:param batch_width (B,): original width of images in a batch.
|
237 |
+
Only used when self.use_crop_cond = True
|
238 |
+
:return: logits BLV, V is vocab_size
|
239 |
+
"""
|
240 |
+
bg, ed = 0, self.L
|
241 |
+
B = x_BLCv_wo_first_l.shape[0]
|
242 |
+
with torch.amp.autocast('cuda', enabled=False):
|
243 |
+
pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds)
|
244 |
+
|
245 |
+
sos = cond_BD = pooled_prompt_embeds
|
246 |
+
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(
|
247 |
+
B, self.first_l, -1
|
248 |
+
)
|
249 |
+
|
250 |
+
x_BLC = torch.cat(
|
251 |
+
(sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1
|
252 |
+
)
|
253 |
+
x_BLC += self.lvl_embed(
|
254 |
+
self.lvl_1L[:, :ed].expand(B, -1)
|
255 |
+
) # lvl: BLC; pos: 1LC
|
256 |
+
if not self.rope:
|
257 |
+
x_BLC += self.pos_1LC[:, :ed]
|
258 |
+
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
|
259 |
+
|
260 |
+
if self.use_crop_cond:
|
261 |
+
crop_coords = get_crop_condition(batch_height, batch_width).to(cond_BD.device)
|
262 |
+
crop_embed = self.crop_embed(crop_coords.view(-1)).reshape(B, self.D)
|
263 |
+
crop_cond = self.crop_proj(crop_embed)
|
264 |
+
else:
|
265 |
+
crop_cond = None
|
266 |
+
|
267 |
+
# hack: get the dtype if mixed precision is used
|
268 |
+
temp = x_BLC.new_ones(8, 8)
|
269 |
+
main_type = torch.matmul(temp, temp).dtype
|
270 |
+
|
271 |
+
x_BLC = x_BLC.to(dtype=main_type)
|
272 |
+
cond_BD = cond_BD.to(dtype=main_type)
|
273 |
+
attn_bias = attn_bias.to(dtype=main_type)
|
274 |
+
|
275 |
+
for block in self.blocks:
|
276 |
+
if self.use_gradient_checkpointing:
|
277 |
+
x_BLC = torch.utils.checkpoint.checkpoint(
|
278 |
+
block,
|
279 |
+
x=x_BLC,
|
280 |
+
cond_BD=cond_BD,
|
281 |
+
attn_bias=attn_bias,
|
282 |
+
context=prompt_embeds,
|
283 |
+
freqs_cis=self.freqs_cis,
|
284 |
+
context_attn_bias=prompt_attn_bias,
|
285 |
+
crop_cond=crop_cond,
|
286 |
+
use_reentrant=False,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
x_BLC = block(
|
290 |
+
x=x_BLC,
|
291 |
+
cond_BD=cond_BD,
|
292 |
+
attn_bias=attn_bias,
|
293 |
+
context=prompt_embeds,
|
294 |
+
freqs_cis=self.freqs_cis,
|
295 |
+
context_attn_bias=prompt_attn_bias,
|
296 |
+
crop_cond=crop_cond,
|
297 |
+
)
|
298 |
+
|
299 |
+
with torch.amp.autocast('cuda', enabled=not self.training):
|
300 |
+
x_BLC = self.get_logits(x_BLC, cond_BD.float())
|
301 |
+
|
302 |
+
return x_BLC # logits BLV, V is vocab_size
|
303 |
+
|
304 |
+
def init_weights(
|
305 |
+
self,
|
306 |
+
init_adaln=0.5,
|
307 |
+
init_adaln_gamma=1e-5,
|
308 |
+
init_head=0.02,
|
309 |
+
init_std=0.02,
|
310 |
+
):
|
311 |
+
if init_std < 0:
|
312 |
+
init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
|
313 |
+
|
314 |
+
print(f"[init_weights] {type(self).__name__} with {init_std=:g}")
|
315 |
+
for m in self.modules():
|
316 |
+
with_weight = hasattr(m, "weight") and m.weight is not None
|
317 |
+
with_bias = hasattr(m, "bias") and m.bias is not None
|
318 |
+
if isinstance(m, nn.Linear):
|
319 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
320 |
+
if with_bias:
|
321 |
+
m.bias.data.zero_()
|
322 |
+
elif isinstance(m, nn.Embedding):
|
323 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
324 |
+
if m.padding_idx is not None:
|
325 |
+
m.weight.data[m.padding_idx].zero_()
|
326 |
+
elif isinstance(
|
327 |
+
m,
|
328 |
+
(
|
329 |
+
nn.LayerNorm,
|
330 |
+
nn.BatchNorm1d,
|
331 |
+
nn.BatchNorm2d,
|
332 |
+
nn.BatchNorm3d,
|
333 |
+
nn.SyncBatchNorm,
|
334 |
+
nn.GroupNorm,
|
335 |
+
nn.InstanceNorm1d,
|
336 |
+
nn.InstanceNorm2d,
|
337 |
+
nn.InstanceNorm3d,
|
338 |
+
),
|
339 |
+
):
|
340 |
+
if with_weight:
|
341 |
+
m.weight.data.fill_(1.0)
|
342 |
+
if with_bias:
|
343 |
+
m.bias.data.zero_()
|
344 |
+
|
345 |
+
if init_head >= 0:
|
346 |
+
if isinstance(self.head, nn.Linear):
|
347 |
+
self.head.weight.data.mul_(init_head)
|
348 |
+
self.head.bias.data.zero_()
|
349 |
+
elif isinstance(self.head, nn.Sequential):
|
350 |
+
self.head[-1].weight.data.mul_(init_head)
|
351 |
+
self.head[-1].bias.data.zero_()
|
352 |
+
|
353 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
354 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
|
355 |
+
if (
|
356 |
+
hasattr(self.head_nm.ada_lin[-1], "bias")
|
357 |
+
and self.head_nm.ada_lin[-1].bias is not None
|
358 |
+
):
|
359 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
360 |
+
|
361 |
+
depth = len(self.blocks)
|
362 |
+
for block in self.blocks:
|
363 |
+
block.attn.proj.weight.data.div_(math.sqrt(2 * depth))
|
364 |
+
block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth))
|
365 |
+
if hasattr(block.ffn, "fc2"):
|
366 |
+
block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
|
367 |
+
|
368 |
+
if hasattr(block, "ada_lin"):
|
369 |
+
block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln)
|
370 |
+
block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma)
|
371 |
+
if (
|
372 |
+
hasattr(block.ada_lin[-1], "bias")
|
373 |
+
and block.ada_lin[-1].bias is not None
|
374 |
+
):
|
375 |
+
block.ada_lin[-1].bias.data.zero_()
|
376 |
+
elif hasattr(block, "ada_gss"):
|
377 |
+
block.ada_gss.data[:, :, 2:].mul_(init_adaln)
|
378 |
+
block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
|
379 |
+
|
380 |
+
def extra_repr(self):
|
381 |
+
return f"drop_path_rate={self.drop_path_rate:g}"
|
382 |
+
|
383 |
+
|
384 |
+
class SwittiHF(Switti, PyTorchModelHubMixin):
|
385 |
+
# tags=["image-generation"]):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
depth=30,
|
389 |
+
rope=True,
|
390 |
+
rope_theta=10000,
|
391 |
+
rope_size=128,
|
392 |
+
use_swiglu_ffn=True,
|
393 |
+
use_ar=False,
|
394 |
+
use_crop_cond=True,
|
395 |
+
):
|
396 |
+
heads = depth
|
397 |
+
width = depth * 64
|
398 |
+
super().__init__(
|
399 |
+
depth=depth,
|
400 |
+
embed_dim=width,
|
401 |
+
num_heads=heads,
|
402 |
+
patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
|
403 |
+
rope=rope,
|
404 |
+
rope_theta=rope_theta,
|
405 |
+
rope_size=rope_size,
|
406 |
+
use_swiglu_ffn=use_swiglu_ffn,
|
407 |
+
use_ar=use_ar,
|
408 |
+
use_crop_cond=use_crop_cond,
|
409 |
+
)
|
models/vqvae.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
|
4 |
+
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
|
5 |
+
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from huggingface_hub import PyTorchModelHubMixin
|
13 |
+
|
14 |
+
from .basic_vae import Decoder, Encoder
|
15 |
+
from .quant import VectorQuantizer2
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class VQVAE(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
vocab_size=4096,
|
23 |
+
z_channels=32,
|
24 |
+
ch=160,
|
25 |
+
dropout=0.0,
|
26 |
+
beta=0.25, # commitment loss weight
|
27 |
+
using_znorm=False, # whether to normalize when computing the nearest neighbors
|
28 |
+
quant_conv_ks=3, # quant conv kernel size
|
29 |
+
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
|
30 |
+
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
|
31 |
+
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
|
32 |
+
# number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
|
33 |
+
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
|
34 |
+
test_mode=True,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.test_mode = test_mode
|
38 |
+
self.V, self.Cvae = vocab_size, z_channels
|
39 |
+
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
|
40 |
+
ddconfig = dict(
|
41 |
+
dropout=dropout,
|
42 |
+
ch=ch,
|
43 |
+
z_channels=z_channels,
|
44 |
+
in_channels=3,
|
45 |
+
ch_mult=(1, 1, 2, 2, 4),
|
46 |
+
num_res_blocks=2, # from vq-f16/config.yaml above
|
47 |
+
using_sa=True,
|
48 |
+
using_mid_sa=True, # from vq-f16/config.yaml above
|
49 |
+
# resamp_with_conv=True, # always True, removed.
|
50 |
+
)
|
51 |
+
ddconfig.pop("double_z", None) # only KL-VAE should use double_z=True
|
52 |
+
self.encoder = Encoder(double_z=False, **ddconfig)
|
53 |
+
self.decoder = Decoder(**ddconfig)
|
54 |
+
|
55 |
+
self.vocab_size = vocab_size
|
56 |
+
self.downsample = 2 ** (len(ddconfig["ch_mult"]) - 1)
|
57 |
+
self.quantize: VectorQuantizer2 = VectorQuantizer2(
|
58 |
+
vocab_size=vocab_size,
|
59 |
+
Cvae=self.Cvae,
|
60 |
+
using_znorm=using_znorm,
|
61 |
+
beta=beta,
|
62 |
+
default_qresi_counts=default_qresi_counts,
|
63 |
+
v_patch_nums=v_patch_nums,
|
64 |
+
quant_resi=quant_resi,
|
65 |
+
share_quant_resi=share_quant_resi,
|
66 |
+
)
|
67 |
+
self.quant_conv = torch.nn.Conv2d(
|
68 |
+
self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
|
69 |
+
)
|
70 |
+
self.post_quant_conv = torch.nn.Conv2d(
|
71 |
+
self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
|
72 |
+
)
|
73 |
+
|
74 |
+
if self.test_mode:
|
75 |
+
self.eval()
|
76 |
+
[p.requires_grad_(False) for p in self.parameters()]
|
77 |
+
|
78 |
+
# ===================== `forward` is only used in VAE training =====================
|
79 |
+
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
|
80 |
+
VectorQuantizer2.forward
|
81 |
+
f_hat, usages, vq_loss = self.quantize(
|
82 |
+
self.quant_conv(self.encoder(inp)), ret_usages=ret_usages
|
83 |
+
)
|
84 |
+
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
|
85 |
+
|
86 |
+
# ===================== `forward` is only used in VAE training =====================
|
87 |
+
|
88 |
+
def fhat_to_img(self, f_hat: torch.Tensor):
|
89 |
+
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
90 |
+
|
91 |
+
def img_to_idxBl(
|
92 |
+
self,
|
93 |
+
inp_img_no_grad: torch.Tensor,
|
94 |
+
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
|
95 |
+
noise_std: Optional[float] = None,
|
96 |
+
) -> List[torch.LongTensor]: # return List[Bl]
|
97 |
+
f = self.quant_conv(self.encoder(inp_img_no_grad))
|
98 |
+
return self.quantize.f_to_idxBl_or_fhat(
|
99 |
+
f, to_fhat=False, v_patch_nums=v_patch_nums, noise_std=noise_std,
|
100 |
+
)
|
101 |
+
|
102 |
+
def idxBl_to_img(
|
103 |
+
self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False
|
104 |
+
) -> Union[List[torch.Tensor], torch.Tensor]:
|
105 |
+
B = ms_idx_Bl[0].shape[0]
|
106 |
+
ms_h_BChw = []
|
107 |
+
for idx_Bl in ms_idx_Bl:
|
108 |
+
l = idx_Bl.shape[1]
|
109 |
+
pn = round(l**0.5)
|
110 |
+
ms_h_BChw.append(
|
111 |
+
self.quantize.embedding(idx_Bl)
|
112 |
+
.transpose(1, 2)
|
113 |
+
.view(B, self.Cvae, pn, pn)
|
114 |
+
)
|
115 |
+
return self.embed_to_img(
|
116 |
+
ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one
|
117 |
+
)
|
118 |
+
|
119 |
+
def embed_to_img(
|
120 |
+
self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False
|
121 |
+
) -> Union[List[torch.Tensor], torch.Tensor]:
|
122 |
+
if last_one:
|
123 |
+
return self.decoder(
|
124 |
+
self.post_quant_conv(
|
125 |
+
self.quantize.embed_to_fhat(
|
126 |
+
ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True
|
127 |
+
)
|
128 |
+
)
|
129 |
+
).clamp_(-1, 1)
|
130 |
+
else:
|
131 |
+
return [
|
132 |
+
self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
133 |
+
for f_hat in self.quantize.embed_to_fhat(
|
134 |
+
ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False
|
135 |
+
)
|
136 |
+
]
|
137 |
+
|
138 |
+
def img_to_reconstructed_img(
|
139 |
+
self,
|
140 |
+
x,
|
141 |
+
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
|
142 |
+
last_one=False,
|
143 |
+
) -> List[torch.Tensor]:
|
144 |
+
f = self.quant_conv(self.encoder(x))
|
145 |
+
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(
|
146 |
+
f, to_fhat=True, v_patch_nums=v_patch_nums
|
147 |
+
)
|
148 |
+
if last_one:
|
149 |
+
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
|
150 |
+
else:
|
151 |
+
return [
|
152 |
+
self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
153 |
+
for f_hat in ls_f_hat_BChw
|
154 |
+
]
|
155 |
+
|
156 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
|
157 |
+
if (
|
158 |
+
"quantize.ema_vocab_hit_SV" in state_dict
|
159 |
+
and state_dict["quantize.ema_vocab_hit_SV"].shape[0]
|
160 |
+
!= self.quantize.ema_vocab_hit_SV.shape[0]
|
161 |
+
):
|
162 |
+
state_dict["quantize.ema_vocab_hit_SV"] = self.quantize.ema_vocab_hit_SV
|
163 |
+
return super().load_state_dict(
|
164 |
+
state_dict=state_dict, strict=strict, assign=assign
|
165 |
+
)
|
166 |
+
|
167 |
+
class VQVAEHF(VQVAE, PyTorchModelHubMixin):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
vocab_size=4096,
|
171 |
+
z_channels=32,
|
172 |
+
ch=160,
|
173 |
+
test_mode=True,
|
174 |
+
share_quant_resi=4,
|
175 |
+
v_patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
|
176 |
+
):
|
177 |
+
super().__init__(
|
178 |
+
vocab_size=vocab_size,
|
179 |
+
z_channels=z_channels,
|
180 |
+
ch=ch,
|
181 |
+
test_mode=test_mode,
|
182 |
+
share_quant_resi=share_quant_resi,
|
183 |
+
v_patch_nums=v_patch_nums,
|
184 |
+
)
|
requirements.txt
CHANGED
@@ -1,6 +1,16 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.26.2
|
2 |
+
transformers==4.45.2
|
3 |
+
diffusers==0.31.0
|
4 |
+
einops==0.8.0
|
5 |
+
pytz==2024.2
|
6 |
+
wandb==0.18.7
|
7 |
+
torch==2.4.1
|
8 |
+
decord==0.6.0
|
9 |
+
numpy==2.1.2
|
10 |
+
Pillow==11.0.0
|
11 |
+
pytz==2024.2
|
12 |
+
scipy==1.14.1
|
13 |
+
torchvision==0.19.1
|
14 |
+
tqdm==4.66.5
|
15 |
+
gradio==5.7.1
|
16 |
+
spaces==0.30.4
|