dmitriitochilkin commited on
Commit
ff49a48
1 Parent(s): a2337f4

add dependencies

Browse files
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
5
+ transformers==4.35.0
6
+ trimesh==4.0.5
7
+ rembg
8
+ huggingface-hub
tsr/__pycache__/system.cpython-310.pyc ADDED
Binary file (5.41 kB). View file
 
tsr/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
tsr/models/__pycache__/camera.cpython-310.pyc ADDED
Binary file (1.48 kB). View file
 
tsr/models/__pycache__/isosurface.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
tsr/models/__pycache__/nerf_renderer.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
tsr/models/__pycache__/network_utils.cpython-310.pyc ADDED
Binary file (3.42 kB). View file
 
tsr/models/isosurface.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchmcubes import marching_cubes
7
+
8
+
9
+ class IsosurfaceHelper(nn.Module):
10
+ points_range: Tuple[float, float] = (0, 1)
11
+
12
+ @property
13
+ def grid_vertices(self) -> torch.FloatTensor:
14
+ raise NotImplementedError
15
+
16
+
17
+ class MarchingCubeHelper(IsosurfaceHelper):
18
+ def __init__(self, resolution: int) -> None:
19
+ super().__init__()
20
+ self.resolution = resolution
21
+ self.mc_func: Callable = marching_cubes
22
+ self._grid_vertices: Optional[torch.FloatTensor] = None
23
+
24
+ @property
25
+ def grid_vertices(self) -> torch.FloatTensor:
26
+ if self._grid_vertices is None:
27
+ # keep the vertices on CPU so that we can support very large resolution
28
+ x, y, z = (
29
+ torch.linspace(*self.points_range, self.resolution),
30
+ torch.linspace(*self.points_range, self.resolution),
31
+ torch.linspace(*self.points_range, self.resolution),
32
+ )
33
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
34
+ verts = torch.cat(
35
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
36
+ ).reshape(-1, 3)
37
+ self._grid_vertices = verts
38
+ return self._grid_vertices
39
+
40
+ def forward(
41
+ self,
42
+ level: torch.FloatTensor,
43
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
+ level = -level.view(self.resolution, self.resolution, self.resolution)
45
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
46
+ v_pos = v_pos[..., [2, 1, 0]]
47
+ v_pos = v_pos / (self.resolution - 1.0)
48
+ return v_pos, t_pos_idx
tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()
tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out
tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
tsr/models/tokenizers/__pycache__/image.cpython-310.pyc ADDED
Binary file (2.42 kB). View file
 
tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers.models.vit.modeling_vit import ViTModel
9
+
10
+ from ...utils import BaseModule
11
+
12
+
13
+ class DINOSingleImageTokenizer(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
17
+ enable_gradient_checkpointing: bool = False
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.model: ViTModel = ViTModel(
23
+ ViTModel.config_class.from_pretrained(
24
+ hf_hub_download(
25
+ repo_id=self.cfg.pretrained_model_name_or_path,
26
+ filename="config.json",
27
+ )
28
+ )
29
+ )
30
+
31
+ if self.cfg.enable_gradient_checkpointing:
32
+ self.model.encoder.gradient_checkpointing = True
33
+
34
+ self.register_buffer(
35
+ "image_mean",
36
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
37
+ persistent=False,
38
+ )
39
+ self.register_buffer(
40
+ "image_std",
41
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
42
+ persistent=False,
43
+ )
44
+
45
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
46
+ packed = False
47
+ if images.ndim == 4:
48
+ packed = True
49
+ images = images.unsqueeze(1)
50
+
51
+ batch_size, n_input_views = images.shape[:2]
52
+ images = (images - self.image_mean) / self.image_std
53
+ out = self.model(
54
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
55
+ )
56
+ local_features, global_features = out.last_hidden_state, out.pooler_output
57
+ local_features = local_features.permute(0, 2, 1)
58
+ local_features = rearrange(
59
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
60
+ )
61
+ if packed:
62
+ local_features = local_features.squeeze(1)
63
+
64
+ return local_features
65
+
66
+ def detokenize(self, *args, **kwargs):
67
+ raise NotImplementedError
tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )
tsr/models/transformer/__pycache__/attention.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc ADDED
Binary file (9.96 kB). View file
 
tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc ADDED
Binary file (7.47 kB). View file
 
tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+
21
+ class Attention(nn.Module):
22
+ r"""
23
+ A cross attention layer.
24
+
25
+ Parameters:
26
+ query_dim (`int`):
27
+ The number of channels in the query.
28
+ cross_attention_dim (`int`, *optional*):
29
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
30
+ heads (`int`, *optional*, defaults to 8):
31
+ The number of heads to use for multi-head attention.
32
+ dim_head (`int`, *optional*, defaults to 64):
33
+ The number of channels in each head.
34
+ dropout (`float`, *optional*, defaults to 0.0):
35
+ The dropout probability to use.
36
+ bias (`bool`, *optional*, defaults to False):
37
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
38
+ upcast_attention (`bool`, *optional*, defaults to False):
39
+ Set to `True` to upcast the attention computation to `float32`.
40
+ upcast_softmax (`bool`, *optional*, defaults to False):
41
+ Set to `True` to upcast the softmax computation to `float32`.
42
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
43
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
44
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
45
+ The number of groups to use for the group norm in the cross attention.
46
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
47
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
48
+ norm_num_groups (`int`, *optional*, defaults to `None`):
49
+ The number of groups to use for the group norm in the attention.
50
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
51
+ The number of channels to use for the spatial normalization.
52
+ out_bias (`bool`, *optional*, defaults to `True`):
53
+ Set to `True` to use a bias in the output linear layer.
54
+ scale_qk (`bool`, *optional*, defaults to `True`):
55
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
56
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
57
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
58
+ `added_kv_proj_dim` is not `None`.
59
+ eps (`float`, *optional*, defaults to 1e-5):
60
+ An additional value added to the denominator in group normalization that is used for numerical stability.
61
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
62
+ A factor to rescale the output by dividing it with this value.
63
+ residual_connection (`bool`, *optional*, defaults to `False`):
64
+ Set to `True` to add the residual connection to the output.
65
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
66
+ Set to `True` if the attention block is loaded from a deprecated state dict.
67
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
68
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
69
+ `AttnProcessor` otherwise.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ query_dim: int,
75
+ cross_attention_dim: Optional[int] = None,
76
+ heads: int = 8,
77
+ dim_head: int = 64,
78
+ dropout: float = 0.0,
79
+ bias: bool = False,
80
+ upcast_attention: bool = False,
81
+ upcast_softmax: bool = False,
82
+ cross_attention_norm: Optional[str] = None,
83
+ cross_attention_norm_num_groups: int = 32,
84
+ added_kv_proj_dim: Optional[int] = None,
85
+ norm_num_groups: Optional[int] = None,
86
+ out_bias: bool = True,
87
+ scale_qk: bool = True,
88
+ only_cross_attention: bool = False,
89
+ eps: float = 1e-5,
90
+ rescale_output_factor: float = 1.0,
91
+ residual_connection: bool = False,
92
+ _from_deprecated_attn_block: bool = False,
93
+ processor: Optional["AttnProcessor"] = None,
94
+ out_dim: int = None,
95
+ ):
96
+ super().__init__()
97
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
98
+ self.query_dim = query_dim
99
+ self.cross_attention_dim = (
100
+ cross_attention_dim if cross_attention_dim is not None else query_dim
101
+ )
102
+ self.upcast_attention = upcast_attention
103
+ self.upcast_softmax = upcast_softmax
104
+ self.rescale_output_factor = rescale_output_factor
105
+ self.residual_connection = residual_connection
106
+ self.dropout = dropout
107
+ self.fused_projections = False
108
+ self.out_dim = out_dim if out_dim is not None else query_dim
109
+
110
+ # we make use of this private variable to know whether this class is loaded
111
+ # with an deprecated state dict so that we can convert it on the fly
112
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
113
+
114
+ self.scale_qk = scale_qk
115
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
116
+
117
+ self.heads = out_dim // dim_head if out_dim is not None else heads
118
+ # for slice_size > 0 the attention score computation
119
+ # is split across the batch axis to save memory
120
+ # You can set slice_size with `set_attention_slice`
121
+ self.sliceable_head_dim = heads
122
+
123
+ self.added_kv_proj_dim = added_kv_proj_dim
124
+ self.only_cross_attention = only_cross_attention
125
+
126
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
127
+ raise ValueError(
128
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
129
+ )
130
+
131
+ if norm_num_groups is not None:
132
+ self.group_norm = nn.GroupNorm(
133
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
134
+ )
135
+ else:
136
+ self.group_norm = None
137
+
138
+ self.spatial_norm = None
139
+
140
+ if cross_attention_norm is None:
141
+ self.norm_cross = None
142
+ elif cross_attention_norm == "layer_norm":
143
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
144
+ elif cross_attention_norm == "group_norm":
145
+ if self.added_kv_proj_dim is not None:
146
+ # The given `encoder_hidden_states` are initially of shape
147
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
148
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
149
+ # before the projection, so we need to use `added_kv_proj_dim` as
150
+ # the number of channels for the group norm.
151
+ norm_cross_num_channels = added_kv_proj_dim
152
+ else:
153
+ norm_cross_num_channels = self.cross_attention_dim
154
+
155
+ self.norm_cross = nn.GroupNorm(
156
+ num_channels=norm_cross_num_channels,
157
+ num_groups=cross_attention_norm_num_groups,
158
+ eps=1e-5,
159
+ affine=True,
160
+ )
161
+ else:
162
+ raise ValueError(
163
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
164
+ )
165
+
166
+ linear_cls = nn.Linear
167
+
168
+ self.linear_cls = linear_cls
169
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
170
+
171
+ if not self.only_cross_attention:
172
+ # only relevant for the `AddedKVProcessor` classes
173
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
174
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
175
+ else:
176
+ self.to_k = None
177
+ self.to_v = None
178
+
179
+ if self.added_kv_proj_dim is not None:
180
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
181
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
182
+
183
+ self.to_out = nn.ModuleList([])
184
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
185
+ self.to_out.append(nn.Dropout(dropout))
186
+
187
+ # set attention processor
188
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
189
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
190
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
191
+ if processor is None:
192
+ processor = (
193
+ AttnProcessor2_0()
194
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
195
+ else AttnProcessor()
196
+ )
197
+ self.set_processor(processor)
198
+
199
+ def set_processor(self, processor: "AttnProcessor") -> None:
200
+ self.processor = processor
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.FloatTensor,
205
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
206
+ attention_mask: Optional[torch.FloatTensor] = None,
207
+ **cross_attention_kwargs,
208
+ ) -> torch.Tensor:
209
+ r"""
210
+ The forward method of the `Attention` class.
211
+
212
+ Args:
213
+ hidden_states (`torch.Tensor`):
214
+ The hidden states of the query.
215
+ encoder_hidden_states (`torch.Tensor`, *optional*):
216
+ The hidden states of the encoder.
217
+ attention_mask (`torch.Tensor`, *optional*):
218
+ The attention mask to use. If `None`, no mask is applied.
219
+ **cross_attention_kwargs:
220
+ Additional keyword arguments to pass along to the cross attention.
221
+
222
+ Returns:
223
+ `torch.Tensor`: The output of the attention layer.
224
+ """
225
+ # The `Attention` class can call different attention processors / attention functions
226
+ # here we simply pass along all tensors to the selected processor class
227
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
228
+ return self.processor(
229
+ self,
230
+ hidden_states,
231
+ encoder_hidden_states=encoder_hidden_states,
232
+ attention_mask=attention_mask,
233
+ **cross_attention_kwargs,
234
+ )
235
+
236
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
237
+ r"""
238
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
239
+ is the number of heads initialized while constructing the `Attention` class.
240
+
241
+ Args:
242
+ tensor (`torch.Tensor`): The tensor to reshape.
243
+
244
+ Returns:
245
+ `torch.Tensor`: The reshaped tensor.
246
+ """
247
+ head_size = self.heads
248
+ batch_size, seq_len, dim = tensor.shape
249
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
250
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
251
+ batch_size // head_size, seq_len, dim * head_size
252
+ )
253
+ return tensor
254
+
255
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
256
+ r"""
257
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
258
+ the number of heads initialized while constructing the `Attention` class.
259
+
260
+ Args:
261
+ tensor (`torch.Tensor`): The tensor to reshape.
262
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
263
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
264
+
265
+ Returns:
266
+ `torch.Tensor`: The reshaped tensor.
267
+ """
268
+ head_size = self.heads
269
+ batch_size, seq_len, dim = tensor.shape
270
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
271
+ tensor = tensor.permute(0, 2, 1, 3)
272
+
273
+ if out_dim == 3:
274
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
275
+
276
+ return tensor
277
+
278
+ def get_attention_scores(
279
+ self,
280
+ query: torch.Tensor,
281
+ key: torch.Tensor,
282
+ attention_mask: torch.Tensor = None,
283
+ ) -> torch.Tensor:
284
+ r"""
285
+ Compute the attention scores.
286
+
287
+ Args:
288
+ query (`torch.Tensor`): The query tensor.
289
+ key (`torch.Tensor`): The key tensor.
290
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
291
+
292
+ Returns:
293
+ `torch.Tensor`: The attention probabilities/scores.
294
+ """
295
+ dtype = query.dtype
296
+ if self.upcast_attention:
297
+ query = query.float()
298
+ key = key.float()
299
+
300
+ if attention_mask is None:
301
+ baddbmm_input = torch.empty(
302
+ query.shape[0],
303
+ query.shape[1],
304
+ key.shape[1],
305
+ dtype=query.dtype,
306
+ device=query.device,
307
+ )
308
+ beta = 0
309
+ else:
310
+ baddbmm_input = attention_mask
311
+ beta = 1
312
+
313
+ attention_scores = torch.baddbmm(
314
+ baddbmm_input,
315
+ query,
316
+ key.transpose(-1, -2),
317
+ beta=beta,
318
+ alpha=self.scale,
319
+ )
320
+ del baddbmm_input
321
+
322
+ if self.upcast_softmax:
323
+ attention_scores = attention_scores.float()
324
+
325
+ attention_probs = attention_scores.softmax(dim=-1)
326
+ del attention_scores
327
+
328
+ attention_probs = attention_probs.to(dtype)
329
+
330
+ return attention_probs
331
+
332
+ def prepare_attention_mask(
333
+ self,
334
+ attention_mask: torch.Tensor,
335
+ target_length: int,
336
+ batch_size: int,
337
+ out_dim: int = 3,
338
+ ) -> torch.Tensor:
339
+ r"""
340
+ Prepare the attention mask for the attention computation.
341
+
342
+ Args:
343
+ attention_mask (`torch.Tensor`):
344
+ The attention mask to prepare.
345
+ target_length (`int`):
346
+ The target length of the attention mask. This is the length of the attention mask after padding.
347
+ batch_size (`int`):
348
+ The batch size, which is used to repeat the attention mask.
349
+ out_dim (`int`, *optional*, defaults to `3`):
350
+ The output dimension of the attention mask. Can be either `3` or `4`.
351
+
352
+ Returns:
353
+ `torch.Tensor`: The prepared attention mask.
354
+ """
355
+ head_size = self.heads
356
+ if attention_mask is None:
357
+ return attention_mask
358
+
359
+ current_length: int = attention_mask.shape[-1]
360
+ if current_length != target_length:
361
+ if attention_mask.device.type == "mps":
362
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
363
+ # Instead, we can manually construct the padding tensor.
364
+ padding_shape = (
365
+ attention_mask.shape[0],
366
+ attention_mask.shape[1],
367
+ target_length,
368
+ )
369
+ padding = torch.zeros(
370
+ padding_shape,
371
+ dtype=attention_mask.dtype,
372
+ device=attention_mask.device,
373
+ )
374
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
375
+ else:
376
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
377
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
378
+ # remaining_length: int = target_length - current_length
379
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
380
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
381
+
382
+ if out_dim == 3:
383
+ if attention_mask.shape[0] < batch_size * head_size:
384
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
385
+ elif out_dim == 4:
386
+ attention_mask = attention_mask.unsqueeze(1)
387
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
388
+
389
+ return attention_mask
390
+
391
+ def norm_encoder_hidden_states(
392
+ self, encoder_hidden_states: torch.Tensor
393
+ ) -> torch.Tensor:
394
+ r"""
395
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
396
+ `Attention` class.
397
+
398
+ Args:
399
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
400
+
401
+ Returns:
402
+ `torch.Tensor`: The normalized encoder hidden states.
403
+ """
404
+ assert (
405
+ self.norm_cross is not None
406
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
407
+
408
+ if isinstance(self.norm_cross, nn.LayerNorm):
409
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
410
+ elif isinstance(self.norm_cross, nn.GroupNorm):
411
+ # Group norm norms along the channels dimension and expects
412
+ # input to be in the shape of (N, C, *). In this case, we want
413
+ # to norm along the hidden dimension, so we need to move
414
+ # (batch_size, sequence_length, hidden_size) ->
415
+ # (batch_size, hidden_size, sequence_length)
416
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
417
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
418
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
419
+ else:
420
+ assert False
421
+
422
+ return encoder_hidden_states
423
+
424
+ @torch.no_grad()
425
+ def fuse_projections(self, fuse=True):
426
+ is_cross_attention = self.cross_attention_dim != self.query_dim
427
+ device = self.to_q.weight.data.device
428
+ dtype = self.to_q.weight.data.dtype
429
+
430
+ if not is_cross_attention:
431
+ # fetch weight matrices.
432
+ concatenated_weights = torch.cat(
433
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
434
+ )
435
+ in_features = concatenated_weights.shape[1]
436
+ out_features = concatenated_weights.shape[0]
437
+
438
+ # create a new single projection layer and copy over the weights.
439
+ self.to_qkv = self.linear_cls(
440
+ in_features, out_features, bias=False, device=device, dtype=dtype
441
+ )
442
+ self.to_qkv.weight.copy_(concatenated_weights)
443
+
444
+ else:
445
+ concatenated_weights = torch.cat(
446
+ [self.to_k.weight.data, self.to_v.weight.data]
447
+ )
448
+ in_features = concatenated_weights.shape[1]
449
+ out_features = concatenated_weights.shape[0]
450
+
451
+ self.to_kv = self.linear_cls(
452
+ in_features, out_features, bias=False, device=device, dtype=dtype
453
+ )
454
+ self.to_kv.weight.copy_(concatenated_weights)
455
+
456
+ self.fused_projections = fuse
457
+
458
+
459
+ class AttnProcessor:
460
+ r"""
461
+ Default processor for performing attention-related computations.
462
+ """
463
+
464
+ def __call__(
465
+ self,
466
+ attn: Attention,
467
+ hidden_states: torch.FloatTensor,
468
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
469
+ attention_mask: Optional[torch.FloatTensor] = None,
470
+ ) -> torch.Tensor:
471
+ residual = hidden_states
472
+
473
+ input_ndim = hidden_states.ndim
474
+
475
+ if input_ndim == 4:
476
+ batch_size, channel, height, width = hidden_states.shape
477
+ hidden_states = hidden_states.view(
478
+ batch_size, channel, height * width
479
+ ).transpose(1, 2)
480
+
481
+ batch_size, sequence_length, _ = (
482
+ hidden_states.shape
483
+ if encoder_hidden_states is None
484
+ else encoder_hidden_states.shape
485
+ )
486
+ attention_mask = attn.prepare_attention_mask(
487
+ attention_mask, sequence_length, batch_size
488
+ )
489
+
490
+ if attn.group_norm is not None:
491
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
492
+ 1, 2
493
+ )
494
+
495
+ query = attn.to_q(hidden_states)
496
+
497
+ if encoder_hidden_states is None:
498
+ encoder_hidden_states = hidden_states
499
+ elif attn.norm_cross:
500
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
501
+ encoder_hidden_states
502
+ )
503
+
504
+ key = attn.to_k(encoder_hidden_states)
505
+ value = attn.to_v(encoder_hidden_states)
506
+
507
+ query = attn.head_to_batch_dim(query)
508
+ key = attn.head_to_batch_dim(key)
509
+ value = attn.head_to_batch_dim(value)
510
+
511
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
512
+ hidden_states = torch.bmm(attention_probs, value)
513
+ hidden_states = attn.batch_to_head_dim(hidden_states)
514
+
515
+ # linear proj
516
+ hidden_states = attn.to_out[0](hidden_states)
517
+ # dropout
518
+ hidden_states = attn.to_out[1](hidden_states)
519
+
520
+ if input_ndim == 4:
521
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
522
+ batch_size, channel, height, width
523
+ )
524
+
525
+ if attn.residual_connection:
526
+ hidden_states = hidden_states + residual
527
+
528
+ hidden_states = hidden_states / attn.rescale_output_factor
529
+
530
+ return hidden_states
531
+
532
+
533
+ class AttnProcessor2_0:
534
+ r"""
535
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
536
+ """
537
+
538
+ def __init__(self):
539
+ if not hasattr(F, "scaled_dot_product_attention"):
540
+ raise ImportError(
541
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
542
+ )
543
+
544
+ def __call__(
545
+ self,
546
+ attn: Attention,
547
+ hidden_states: torch.FloatTensor,
548
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
549
+ attention_mask: Optional[torch.FloatTensor] = None,
550
+ ) -> torch.FloatTensor:
551
+ residual = hidden_states
552
+
553
+ input_ndim = hidden_states.ndim
554
+
555
+ if input_ndim == 4:
556
+ batch_size, channel, height, width = hidden_states.shape
557
+ hidden_states = hidden_states.view(
558
+ batch_size, channel, height * width
559
+ ).transpose(1, 2)
560
+
561
+ batch_size, sequence_length, _ = (
562
+ hidden_states.shape
563
+ if encoder_hidden_states is None
564
+ else encoder_hidden_states.shape
565
+ )
566
+
567
+ if attention_mask is not None:
568
+ attention_mask = attn.prepare_attention_mask(
569
+ attention_mask, sequence_length, batch_size
570
+ )
571
+ # scaled_dot_product_attention expects attention_mask shape to be
572
+ # (batch, heads, source_length, target_length)
573
+ attention_mask = attention_mask.view(
574
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
575
+ )
576
+
577
+ if attn.group_norm is not None:
578
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
579
+ 1, 2
580
+ )
581
+
582
+ query = attn.to_q(hidden_states)
583
+
584
+ if encoder_hidden_states is None:
585
+ encoder_hidden_states = hidden_states
586
+ elif attn.norm_cross:
587
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
588
+ encoder_hidden_states
589
+ )
590
+
591
+ key = attn.to_k(encoder_hidden_states)
592
+ value = attn.to_v(encoder_hidden_states)
593
+
594
+ inner_dim = key.shape[-1]
595
+ head_dim = inner_dim // attn.heads
596
+
597
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
598
+
599
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
600
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
601
+
602
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
603
+ # TODO: add support for attn.scale when we move to Torch 2.1
604
+ hidden_states = F.scaled_dot_product_attention(
605
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
606
+ )
607
+
608
+ hidden_states = hidden_states.transpose(1, 2).reshape(
609
+ batch_size, -1, attn.heads * head_dim
610
+ )
611
+ hidden_states = hidden_states.to(query.dtype)
612
+
613
+ # linear proj
614
+ hidden_states = attn.to_out[0](hidden_states)
615
+ # dropout
616
+ hidden_states = attn.to_out[1](hidden_states)
617
+
618
+ if input_ndim == 4:
619
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
620
+ batch_size, channel, height, width
621
+ )
622
+
623
+ if attn.residual_connection:
624
+ hidden_states = hidden_states + residual
625
+
626
+ hidden_states = hidden_states / attn.rescale_output_factor
627
+
628
+ return hidden_states
tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from .attention import Attention
22
+
23
+
24
+ class BasicTransformerBlock(nn.Module):
25
+ r"""
26
+ A basic Transformer block.
27
+
28
+ Parameters:
29
+ dim (`int`): The number of channels in the input and output.
30
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
31
+ attention_head_dim (`int`): The number of channels in each head.
32
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
33
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
34
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
35
+ num_embeds_ada_norm (:
36
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
37
+ attention_bias (:
38
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
39
+ only_cross_attention (`bool`, *optional*):
40
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
41
+ double_self_attention (`bool`, *optional*):
42
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
43
+ upcast_attention (`bool`, *optional*):
44
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
45
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
46
+ Whether to use learnable elementwise affine parameters for normalization.
47
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
48
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
49
+ final_dropout (`bool` *optional*, defaults to False):
50
+ Whether to apply a final dropout after the last feed-forward layer.
51
+ attention_type (`str`, *optional*, defaults to `"default"`):
52
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ num_attention_heads: int,
59
+ attention_head_dim: int,
60
+ dropout=0.0,
61
+ cross_attention_dim: Optional[int] = None,
62
+ activation_fn: str = "geglu",
63
+ attention_bias: bool = False,
64
+ only_cross_attention: bool = False,
65
+ double_self_attention: bool = False,
66
+ upcast_attention: bool = False,
67
+ norm_elementwise_affine: bool = True,
68
+ norm_type: str = "layer_norm",
69
+ final_dropout: bool = False,
70
+ ):
71
+ super().__init__()
72
+ self.only_cross_attention = only_cross_attention
73
+
74
+ assert norm_type == "layer_norm"
75
+
76
+ # Define 3 blocks. Each block has its own normalization layer.
77
+ # 1. Self-Attn
78
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
79
+ self.attn1 = Attention(
80
+ query_dim=dim,
81
+ heads=num_attention_heads,
82
+ dim_head=attention_head_dim,
83
+ dropout=dropout,
84
+ bias=attention_bias,
85
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
86
+ upcast_attention=upcast_attention,
87
+ )
88
+
89
+ # 2. Cross-Attn
90
+ if cross_attention_dim is not None or double_self_attention:
91
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
92
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
93
+ # the second cross attention block.
94
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
95
+
96
+ self.attn2 = Attention(
97
+ query_dim=dim,
98
+ cross_attention_dim=cross_attention_dim
99
+ if not double_self_attention
100
+ else None,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ upcast_attention=upcast_attention,
106
+ ) # is self-attn if encoder_hidden_states is none
107
+ else:
108
+ self.norm2 = None
109
+ self.attn2 = None
110
+
111
+ # 3. Feed-forward
112
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
113
+ self.ff = FeedForward(
114
+ dim,
115
+ dropout=dropout,
116
+ activation_fn=activation_fn,
117
+ final_dropout=final_dropout,
118
+ )
119
+
120
+ # let chunk size default to None
121
+ self._chunk_size = None
122
+ self._chunk_dim = 0
123
+
124
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
125
+ # Sets chunk feed-forward
126
+ self._chunk_size = chunk_size
127
+ self._chunk_dim = dim
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.FloatTensor,
132
+ attention_mask: Optional[torch.FloatTensor] = None,
133
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
134
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
135
+ ) -> torch.FloatTensor:
136
+ # Notice that normalization is always applied before the real computation in the following blocks.
137
+ # 0. Self-Attention
138
+ norm_hidden_states = self.norm1(hidden_states)
139
+
140
+ attn_output = self.attn1(
141
+ norm_hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states
143
+ if self.only_cross_attention
144
+ else None,
145
+ attention_mask=attention_mask,
146
+ )
147
+
148
+ hidden_states = attn_output + hidden_states
149
+
150
+ # 3. Cross-Attention
151
+ if self.attn2 is not None:
152
+ norm_hidden_states = self.norm2(hidden_states)
153
+
154
+ attn_output = self.attn2(
155
+ norm_hidden_states,
156
+ encoder_hidden_states=encoder_hidden_states,
157
+ attention_mask=encoder_attention_mask,
158
+ )
159
+ hidden_states = attn_output + hidden_states
160
+
161
+ # 4. Feed-forward
162
+ norm_hidden_states = self.norm3(hidden_states)
163
+
164
+ if self._chunk_size is not None:
165
+ # "feed_forward_chunk_size" can be used to save memory
166
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
167
+ raise ValueError(
168
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
169
+ )
170
+
171
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
172
+ ff_output = torch.cat(
173
+ [
174
+ self.ff(hid_slice)
175
+ for hid_slice in norm_hidden_states.chunk(
176
+ num_chunks, dim=self._chunk_dim
177
+ )
178
+ ],
179
+ dim=self._chunk_dim,
180
+ )
181
+ else:
182
+ ff_output = self.ff(norm_hidden_states)
183
+
184
+ hidden_states = ff_output + hidden_states
185
+
186
+ return hidden_states
187
+
188
+
189
+ class FeedForward(nn.Module):
190
+ r"""
191
+ A feed-forward layer.
192
+
193
+ Parameters:
194
+ dim (`int`): The number of channels in the input.
195
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
196
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
197
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
198
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
199
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ dim: int,
205
+ dim_out: Optional[int] = None,
206
+ mult: int = 4,
207
+ dropout: float = 0.0,
208
+ activation_fn: str = "geglu",
209
+ final_dropout: bool = False,
210
+ ):
211
+ super().__init__()
212
+ inner_dim = int(dim * mult)
213
+ dim_out = dim_out if dim_out is not None else dim
214
+ linear_cls = nn.Linear
215
+
216
+ if activation_fn == "gelu":
217
+ act_fn = GELU(dim, inner_dim)
218
+ if activation_fn == "gelu-approximate":
219
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
220
+ elif activation_fn == "geglu":
221
+ act_fn = GEGLU(dim, inner_dim)
222
+ elif activation_fn == "geglu-approximate":
223
+ act_fn = ApproximateGELU(dim, inner_dim)
224
+
225
+ self.net = nn.ModuleList([])
226
+ # project in
227
+ self.net.append(act_fn)
228
+ # project dropout
229
+ self.net.append(nn.Dropout(dropout))
230
+ # project out
231
+ self.net.append(linear_cls(inner_dim, dim_out))
232
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
233
+ if final_dropout:
234
+ self.net.append(nn.Dropout(dropout))
235
+
236
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
237
+ for module in self.net:
238
+ hidden_states = module(hidden_states)
239
+ return hidden_states
240
+
241
+
242
+ class GELU(nn.Module):
243
+ r"""
244
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
245
+
246
+ Parameters:
247
+ dim_in (`int`): The number of channels in the input.
248
+ dim_out (`int`): The number of channels in the output.
249
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
250
+ """
251
+
252
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
253
+ super().__init__()
254
+ self.proj = nn.Linear(dim_in, dim_out)
255
+ self.approximate = approximate
256
+
257
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
258
+ if gate.device.type != "mps":
259
+ return F.gelu(gate, approximate=self.approximate)
260
+ # mps: gelu is not implemented for float16
261
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
262
+ dtype=gate.dtype
263
+ )
264
+
265
+ def forward(self, hidden_states):
266
+ hidden_states = self.proj(hidden_states)
267
+ hidden_states = self.gelu(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class GEGLU(nn.Module):
272
+ r"""
273
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
274
+
275
+ Parameters:
276
+ dim_in (`int`): The number of channels in the input.
277
+ dim_out (`int`): The number of channels in the output.
278
+ """
279
+
280
+ def __init__(self, dim_in: int, dim_out: int):
281
+ super().__init__()
282
+ linear_cls = nn.Linear
283
+
284
+ self.proj = linear_cls(dim_in, dim_out * 2)
285
+
286
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
287
+ if gate.device.type != "mps":
288
+ return F.gelu(gate)
289
+ # mps: gelu is not implemented for float16
290
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
291
+
292
+ def forward(self, hidden_states, scale: float = 1.0):
293
+ args = ()
294
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
295
+ return hidden_states * self.gelu(gate)
296
+
297
+
298
+ class ApproximateGELU(nn.Module):
299
+ r"""
300
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
301
+ https://arxiv.org/abs/1606.08415.
302
+
303
+ Parameters:
304
+ dim_in (`int`): The number of channels in the input.
305
+ dim_out (`int`): The number of channels in the output.
306
+ """
307
+
308
+ def __init__(self, dim_in: int, dim_out: int):
309
+ super().__init__()
310
+ self.proj = nn.Linear(dim_in, dim_out)
311
+
312
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
313
+ x = self.proj(x)
314
+ return x * torch.sigmoid(1.702 * x)
tsr/models/transformer/transformer_1d.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ...utils import BaseModule
9
+ from .basic_transformer_block import BasicTransformerBlock
10
+
11
+
12
+ class Transformer1D(BaseModule):
13
+ """
14
+ A 1D Transformer model for sequence data.
15
+
16
+ Parameters:
17
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
18
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
19
+ in_channels (`int`, *optional*):
20
+ The number of channels in the input and output (specify if the input is **continuous**).
21
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
22
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
23
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
24
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
25
+ num_embeds_ada_norm ( `int`, *optional*):
26
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
27
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
28
+ added to the hidden states.
29
+
30
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
31
+ attention_bias (`bool`, *optional*):
32
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
33
+ """
34
+
35
+ @dataclass
36
+ class Config(BaseModule.Config):
37
+ num_attention_heads: int = 16
38
+ attention_head_dim: int = 88
39
+ in_channels: Optional[int] = None
40
+ out_channels: Optional[int] = None
41
+ num_layers: int = 1
42
+ dropout: float = 0.0
43
+ norm_num_groups: int = 32
44
+ cross_attention_dim: Optional[int] = None
45
+ attention_bias: bool = False
46
+ activation_fn: str = "geglu"
47
+ only_cross_attention: bool = False
48
+ double_self_attention: bool = False
49
+ upcast_attention: bool = False
50
+ norm_type: str = "layer_norm"
51
+ norm_elementwise_affine: bool = True
52
+ gradient_checkpointing: bool = False
53
+
54
+ cfg: Config
55
+
56
+ def configure(self) -> None:
57
+ self.num_attention_heads = self.cfg.num_attention_heads
58
+ self.attention_head_dim = self.cfg.attention_head_dim
59
+ inner_dim = self.num_attention_heads * self.attention_head_dim
60
+
61
+ linear_cls = nn.Linear
62
+
63
+ # 2. Define input layers
64
+ self.in_channels = self.cfg.in_channels
65
+
66
+ self.norm = torch.nn.GroupNorm(
67
+ num_groups=self.cfg.norm_num_groups,
68
+ num_channels=self.cfg.in_channels,
69
+ eps=1e-6,
70
+ affine=True,
71
+ )
72
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
73
+
74
+ # 3. Define transformers blocks
75
+ self.transformer_blocks = nn.ModuleList(
76
+ [
77
+ BasicTransformerBlock(
78
+ inner_dim,
79
+ self.num_attention_heads,
80
+ self.attention_head_dim,
81
+ dropout=self.cfg.dropout,
82
+ cross_attention_dim=self.cfg.cross_attention_dim,
83
+ activation_fn=self.cfg.activation_fn,
84
+ attention_bias=self.cfg.attention_bias,
85
+ only_cross_attention=self.cfg.only_cross_attention,
86
+ double_self_attention=self.cfg.double_self_attention,
87
+ upcast_attention=self.cfg.upcast_attention,
88
+ norm_type=self.cfg.norm_type,
89
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
90
+ )
91
+ for d in range(self.cfg.num_layers)
92
+ ]
93
+ )
94
+
95
+ # 4. Define output layers
96
+ self.out_channels = (
97
+ self.cfg.in_channels
98
+ if self.cfg.out_channels is None
99
+ else self.cfg.out_channels
100
+ )
101
+
102
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
103
+
104
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ encoder_hidden_states: Optional[torch.Tensor] = None,
110
+ attention_mask: Optional[torch.Tensor] = None,
111
+ encoder_attention_mask: Optional[torch.Tensor] = None,
112
+ ):
113
+ """
114
+ The [`Transformer1DModel`] forward method.
115
+
116
+ Args:
117
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
118
+ Input `hidden_states`.
119
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
120
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
121
+ self-attention.
122
+ timestep ( `torch.LongTensor`, *optional*):
123
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
124
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
125
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
126
+ `AdaLayerZeroNorm`.
127
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
128
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
129
+ `self.processor` in
130
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
131
+ attention_mask ( `torch.Tensor`, *optional*):
132
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
133
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
134
+ negative values to the attention scores corresponding to "discard" tokens.
135
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
136
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
137
+
138
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
139
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
140
+
141
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
142
+ above. This bias will be added to the cross-attention scores.
143
+ return_dict (`bool`, *optional*, defaults to `True`):
144
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
145
+ tuple.
146
+
147
+ Returns:
148
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
149
+ `tuple` where the first element is the sample tensor.
150
+ """
151
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
152
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
153
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
154
+ # expects mask of shape:
155
+ # [batch, key_tokens]
156
+ # adds singleton query_tokens dimension:
157
+ # [batch, 1, key_tokens]
158
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
159
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
160
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
161
+ if attention_mask is not None and attention_mask.ndim == 2:
162
+ # assume that mask is expressed as:
163
+ # (1 = keep, 0 = discard)
164
+ # convert mask into a bias that can be added to attention scores:
165
+ # (keep = +0, discard = -10000.0)
166
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
167
+ attention_mask = attention_mask.unsqueeze(1)
168
+
169
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
170
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
171
+ encoder_attention_mask = (
172
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
173
+ ) * -10000.0
174
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
175
+
176
+ # 1. Input
177
+ batch, _, seq_len = hidden_states.shape
178
+ residual = hidden_states
179
+
180
+ hidden_states = self.norm(hidden_states)
181
+ inner_dim = hidden_states.shape[1]
182
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
183
+ batch, seq_len, inner_dim
184
+ )
185
+ hidden_states = self.proj_in(hidden_states)
186
+
187
+ # 2. Blocks
188
+ for block in self.transformer_blocks:
189
+ if self.training and self.gradient_checkpointing:
190
+ hidden_states = torch.utils.checkpoint.checkpoint(
191
+ block,
192
+ hidden_states,
193
+ attention_mask,
194
+ encoder_hidden_states,
195
+ encoder_attention_mask,
196
+ use_reentrant=False,
197
+ )
198
+ else:
199
+ hidden_states = block(
200
+ hidden_states,
201
+ attention_mask=attention_mask,
202
+ encoder_hidden_states=encoder_hidden_states,
203
+ encoder_attention_mask=encoder_attention_mask,
204
+ )
205
+
206
+ # 3. Output
207
+ hidden_states = self.proj_out(hidden_states)
208
+ hidden_states = (
209
+ hidden_states.reshape(batch, seq_len, inner_dim)
210
+ .permute(0, 2, 1)
211
+ .contiguous()
212
+ )
213
+
214
+ output = hidden_states + residual
215
+
216
+ return output
tsr/system.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from .models.isosurface import MarchingCubeHelper
17
+ from .utils import (
18
+ BaseModule,
19
+ ImagePreprocessor,
20
+ find_class,
21
+ get_spherical_cameras,
22
+ scale_tensor,
23
+ )
24
+
25
+
26
+ class TSR(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ cond_image_size: int
30
+
31
+ image_tokenizer_cls: str
32
+ image_tokenizer: dict
33
+
34
+ tokenizer_cls: str
35
+ tokenizer: dict
36
+
37
+ backbone_cls: str
38
+ backbone: dict
39
+
40
+ post_processor_cls: str
41
+ post_processor: dict
42
+
43
+ decoder_cls: str
44
+ decoder: dict
45
+
46
+ renderer_cls: str
47
+ renderer: dict
48
+
49
+ cfg: Config
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54
+ ):
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
+ else:
59
+ config_path = hf_hub_download(
60
+ repo_id=pretrained_model_name_or_path, filename=config_name
61
+ )
62
+ weight_path = hf_hub_download(
63
+ repo_id=pretrained_model_name_or_path, filename=weight_name
64
+ )
65
+
66
+ cfg = OmegaConf.load(config_path)
67
+ OmegaConf.resolve(cfg)
68
+ model = cls(cfg)
69
+ ckpt = torch.load(weight_path, map_location="cpu")
70
+ model.load_state_dict(ckpt)
71
+ return model
72
+
73
+ def configure(self):
74
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
75
+ self.cfg.image_tokenizer
76
+ )
77
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
78
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
79
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
80
+ self.cfg.post_processor
81
+ )
82
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
83
+ self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
84
+ self.image_processor = ImagePreprocessor()
85
+ self.isosurface_helper = None
86
+
87
+ def forward(
88
+ self,
89
+ image: Union[
90
+ PIL.Image.Image,
91
+ np.ndarray,
92
+ torch.FloatTensor,
93
+ List[PIL.Image.Image],
94
+ List[np.ndarray],
95
+ List[torch.FloatTensor],
96
+ ],
97
+ device: str,
98
+ ) -> torch.FloatTensor:
99
+ rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
100
+ device
101
+ )
102
+ batch_size = rgb_cond.shape[0]
103
+
104
+ input_image_tokens: torch.Tensor = self.image_tokenizer(
105
+ rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
106
+ )
107
+
108
+ input_image_tokens = rearrange(
109
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
110
+ )
111
+
112
+ tokens: torch.Tensor = self.tokenizer(batch_size)
113
+
114
+ tokens = self.backbone(
115
+ tokens,
116
+ encoder_hidden_states=input_image_tokens,
117
+ )
118
+
119
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
120
+ return scene_codes
121
+
122
+ def render(
123
+ self,
124
+ scene_codes,
125
+ n_views: int,
126
+ elevation_deg: float = 0.0,
127
+ camera_distance: float = 1.9,
128
+ fovy_deg: float = 40.0,
129
+ height: int = 256,
130
+ width: int = 256,
131
+ return_type: str = "pil",
132
+ ):
133
+ rays_o, rays_d = get_spherical_cameras(
134
+ n_views, elevation_deg, camera_distance, fovy_deg, height, width
135
+ )
136
+ rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
137
+
138
+ def process_output(image: torch.FloatTensor):
139
+ if return_type == "pt":
140
+ return image
141
+ elif return_type == "np":
142
+ return image.detach().cpu().numpy()
143
+ elif return_type == "pil":
144
+ return Image.fromarray(
145
+ (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
146
+ )
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ images = []
151
+ for scene_code in scene_codes:
152
+ images_ = []
153
+ for i in range(n_views):
154
+ with torch.no_grad():
155
+ image = self.renderer(
156
+ self.decoder, scene_code, rays_o[i], rays_d[i]
157
+ )
158
+ images_.append(process_output(image))
159
+ images.append(images_)
160
+
161
+ return images
162
+
163
+ def set_marching_cubes_resolution(self, resolution: int):
164
+ if (
165
+ self.isosurface_helper is not None
166
+ and self.isosurface_helper.resolution == resolution
167
+ ):
168
+ return
169
+ self.isosurface_helper = MarchingCubeHelper(resolution)
170
+
171
+ def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 20.0):
172
+ self.set_marching_cubes_resolution(resolution)
173
+ meshes = []
174
+ for scene_code in scene_codes:
175
+ with torch.no_grad():
176
+ density = self.renderer.query_triplane(
177
+ self.decoder,
178
+ scale_tensor(
179
+ self.isosurface_helper.grid_vertices.to(scene_codes.device),
180
+ self.isosurface_helper.points_range,
181
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
182
+ ),
183
+ scene_code,
184
+ )["density_act"]
185
+ v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
186
+ v_pos = scale_tensor(
187
+ v_pos,
188
+ self.isosurface_helper.points_range,
189
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
190
+ )
191
+ with torch.no_grad():
192
+ color = self.renderer.query_triplane(
193
+ self.decoder,
194
+ v_pos,
195
+ scene_code,
196
+ )["color"]
197
+ mesh = trimesh.Trimesh(
198
+ vertices=v_pos.cpu().numpy(),
199
+ faces=t_pos_idx.cpu().numpy(),
200
+ vertex_colors=color.cpu().numpy(),
201
+ )
202
+ meshes.append(mesh)
203
+ return meshes
tsr/utils.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import imageio
8
+ import numpy as np
9
+ import PIL.Image
10
+ import rembg
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from PIL import Image
16
+
17
+
18
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
19
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
20
+ return scfg
21
+
22
+
23
+ def find_class(cls_string):
24
+ module_string = ".".join(cls_string.split(".")[:-1])
25
+ cls_name = cls_string.split(".")[-1]
26
+ module = importlib.import_module(module_string, package=None)
27
+ cls = getattr(module, cls_name)
28
+ return cls
29
+
30
+
31
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
32
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
33
+ intrinsic = np.identity(3, dtype=np.float32)
34
+ intrinsic[0, 0] = focal_length
35
+ intrinsic[1, 1] = focal_length
36
+ intrinsic[0, 2] = W / 2.0
37
+ intrinsic[1, 2] = H / 2.0
38
+
39
+ if bs > 0:
40
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
41
+
42
+ return torch.from_numpy(intrinsic)
43
+
44
+
45
+ class BaseModule(nn.Module):
46
+ @dataclass
47
+ class Config:
48
+ pass
49
+
50
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
51
+
52
+ def __init__(
53
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
54
+ ) -> None:
55
+ super().__init__()
56
+ self.cfg = parse_structured(self.Config, cfg)
57
+ self.configure(*args, **kwargs)
58
+
59
+ def configure(self, *args, **kwargs) -> None:
60
+ raise NotImplementedError
61
+
62
+
63
+ class ImagePreprocessor:
64
+ def convert_and_resize(
65
+ self,
66
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
67
+ size: int,
68
+ ):
69
+ if isinstance(image, PIL.Image.Image):
70
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
71
+ elif isinstance(image, np.ndarray):
72
+ if image.dtype == np.uint8:
73
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
74
+ else:
75
+ image = torch.from_numpy(image)
76
+ elif isinstance(image, torch.Tensor):
77
+ pass
78
+
79
+ batched = image.ndim == 4
80
+
81
+ if not batched:
82
+ image = image[None, ...]
83
+ image = F.interpolate(
84
+ image.permute(0, 3, 1, 2),
85
+ (size, size),
86
+ mode="bilinear",
87
+ align_corners=False,
88
+ antialias=True,
89
+ ).permute(0, 2, 3, 1)
90
+ if not batched:
91
+ image = image[0]
92
+ return image
93
+
94
+ def __call__(
95
+ self,
96
+ image: Union[
97
+ PIL.Image.Image,
98
+ np.ndarray,
99
+ torch.FloatTensor,
100
+ List[PIL.Image.Image],
101
+ List[np.ndarray],
102
+ List[torch.FloatTensor],
103
+ ],
104
+ size: int,
105
+ ) -> Any:
106
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
107
+ image = self.convert_and_resize(image, size)
108
+ else:
109
+ if not isinstance(image, list):
110
+ image = [image]
111
+ image = [self.convert_and_resize(im, size) for im in image]
112
+ image = torch.stack(image, dim=0)
113
+ return image
114
+
115
+
116
+ def rays_intersect_bbox(
117
+ rays_o: torch.Tensor,
118
+ rays_d: torch.Tensor,
119
+ radius: float,
120
+ near: float = 0.0,
121
+ valid_thresh: float = 0.01,
122
+ ):
123
+ input_shape = rays_o.shape[:-1]
124
+ rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
125
+ rays_d_valid = torch.where(
126
+ rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
127
+ )
128
+ if type(radius) in [int, float]:
129
+ radius = torch.FloatTensor(
130
+ [[-radius, radius], [-radius, radius], [-radius, radius]]
131
+ ).to(rays_o.device)
132
+ radius = (
133
+ 1.0 - 1.0e-3
134
+ ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
135
+ interx0 = (radius[..., 1] - rays_o) / rays_d_valid
136
+ interx1 = (radius[..., 0] - rays_o) / rays_d_valid
137
+ t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
138
+ t_far = torch.maximum(interx0, interx1).amin(dim=-1)
139
+
140
+ # check wheter a ray intersects the bbox or not
141
+ rays_valid = t_far - t_near > valid_thresh
142
+
143
+ t_near[torch.where(~rays_valid)] = 0.0
144
+ t_far[torch.where(~rays_valid)] = 0.0
145
+
146
+ t_near = t_near.view(*input_shape, 1)
147
+ t_far = t_far.view(*input_shape, 1)
148
+ rays_valid = rays_valid.view(*input_shape)
149
+
150
+ return t_near, t_far, rays_valid
151
+
152
+
153
+ def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
154
+ if chunk_size <= 0:
155
+ return func(*args, **kwargs)
156
+ B = None
157
+ for arg in list(args) + list(kwargs.values()):
158
+ if isinstance(arg, torch.Tensor):
159
+ B = arg.shape[0]
160
+ break
161
+ assert (
162
+ B is not None
163
+ ), "No tensor found in args or kwargs, cannot determine batch size."
164
+ out = defaultdict(list)
165
+ out_type = None
166
+ # max(1, B) to support B == 0
167
+ for i in range(0, max(1, B), chunk_size):
168
+ out_chunk = func(
169
+ *[
170
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
171
+ for arg in args
172
+ ],
173
+ **{
174
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
175
+ for k, arg in kwargs.items()
176
+ },
177
+ )
178
+ if out_chunk is None:
179
+ continue
180
+ out_type = type(out_chunk)
181
+ if isinstance(out_chunk, torch.Tensor):
182
+ out_chunk = {0: out_chunk}
183
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
184
+ chunk_length = len(out_chunk)
185
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
186
+ elif isinstance(out_chunk, dict):
187
+ pass
188
+ else:
189
+ print(
190
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
191
+ )
192
+ exit(1)
193
+ for k, v in out_chunk.items():
194
+ v = v if torch.is_grad_enabled() else v.detach()
195
+ out[k].append(v)
196
+
197
+ if out_type is None:
198
+ return None
199
+
200
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
201
+ for k, v in out.items():
202
+ if all([vv is None for vv in v]):
203
+ # allow None in return value
204
+ out_merged[k] = None
205
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
206
+ out_merged[k] = torch.cat(v, dim=0)
207
+ else:
208
+ raise TypeError(
209
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
210
+ )
211
+
212
+ if out_type is torch.Tensor:
213
+ return out_merged[0]
214
+ elif out_type in [tuple, list]:
215
+ return out_type([out_merged[i] for i in range(chunk_length)])
216
+ elif out_type is dict:
217
+ return out_merged
218
+
219
+
220
+ ValidScale = Union[Tuple[float, float], torch.FloatTensor]
221
+
222
+
223
+ def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
224
+ if inp_scale is None:
225
+ inp_scale = (0, 1)
226
+ if tgt_scale is None:
227
+ tgt_scale = (0, 1)
228
+ if isinstance(tgt_scale, torch.FloatTensor):
229
+ assert dat.shape[-1] == tgt_scale.shape[-1]
230
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
231
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
232
+ return dat
233
+
234
+
235
+ def get_activation(name) -> Callable:
236
+ if name is None:
237
+ return lambda x: x
238
+ name = name.lower()
239
+ if name == "none":
240
+ return lambda x: x
241
+ elif name == "exp":
242
+ return lambda x: torch.exp(x)
243
+ elif name == "sigmoid":
244
+ return lambda x: torch.sigmoid(x)
245
+ elif name == "tanh":
246
+ return lambda x: torch.tanh(x)
247
+ elif name == "softplus":
248
+ return lambda x: F.softplus(x)
249
+ else:
250
+ try:
251
+ return getattr(F, name)
252
+ except AttributeError:
253
+ raise ValueError(f"Unknown activation function: {name}")
254
+
255
+
256
+ def get_ray_directions(
257
+ H: int,
258
+ W: int,
259
+ focal: Union[float, Tuple[float, float]],
260
+ principal: Optional[Tuple[float, float]] = None,
261
+ use_pixel_centers: bool = True,
262
+ normalize: bool = True,
263
+ ) -> torch.FloatTensor:
264
+ """
265
+ Get ray directions for all pixels in camera coordinate.
266
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
267
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
268
+
269
+ Inputs:
270
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
271
+ Outputs:
272
+ directions: (H, W, 3), the direction of the rays in camera coordinate
273
+ """
274
+ pixel_center = 0.5 if use_pixel_centers else 0
275
+
276
+ if isinstance(focal, float):
277
+ fx, fy = focal, focal
278
+ cx, cy = W / 2, H / 2
279
+ else:
280
+ fx, fy = focal
281
+ assert principal is not None
282
+ cx, cy = principal
283
+
284
+ i, j = torch.meshgrid(
285
+ torch.arange(W, dtype=torch.float32) + pixel_center,
286
+ torch.arange(H, dtype=torch.float32) + pixel_center,
287
+ indexing="xy",
288
+ )
289
+
290
+ directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
291
+
292
+ if normalize:
293
+ directions = F.normalize(directions, dim=-1)
294
+
295
+ return directions
296
+
297
+
298
+ def get_rays(
299
+ directions,
300
+ c2w,
301
+ keepdim=False,
302
+ noise_scale=0.0,
303
+ normalize=False,
304
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305
+ # Rotate ray directions from camera coordinate to the world coordinate
306
+ assert directions.shape[-1] == 3
307
+
308
+ if directions.ndim == 2: # (N_rays, 3)
309
+ if c2w.ndim == 2: # (4, 4)
310
+ c2w = c2w[None, :, :]
311
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
312
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
313
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
314
+ elif directions.ndim == 3: # (H, W, 3)
315
+ assert c2w.ndim in [2, 3]
316
+ if c2w.ndim == 2: # (4, 4)
317
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
318
+ -1
319
+ ) # (H, W, 3)
320
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
321
+ elif c2w.ndim == 3: # (B, 4, 4)
322
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
323
+ -1
324
+ ) # (B, H, W, 3)
325
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
326
+ elif directions.ndim == 4: # (B, H, W, 3)
327
+ assert c2w.ndim == 3 # (B, 4, 4)
328
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
329
+ -1
330
+ ) # (B, H, W, 3)
331
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332
+
333
+ # add camera noise to avoid grid-like artifect
334
+ # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
335
+ if noise_scale > 0:
336
+ rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
337
+ rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
338
+
339
+ if normalize:
340
+ rays_d = F.normalize(rays_d, dim=-1)
341
+ if not keepdim:
342
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
343
+
344
+ return rays_o, rays_d
345
+
346
+
347
+ def get_spherical_cameras(
348
+ n_views: int,
349
+ elevation_deg: float,
350
+ camera_distance: float,
351
+ fovy_deg: float,
352
+ height: int,
353
+ width: int,
354
+ ):
355
+ azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
356
+ elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
357
+ camera_distances = torch.full_like(elevation_deg, camera_distance)
358
+
359
+ elevation = elevation_deg * math.pi / 180
360
+ azimuth = azimuth_deg * math.pi / 180
361
+
362
+ # convert spherical coordinates to cartesian coordinates
363
+ # right hand coordinate system, x back, y right, z up
364
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
365
+ camera_positions = torch.stack(
366
+ [
367
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
368
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
369
+ camera_distances * torch.sin(elevation),
370
+ ],
371
+ dim=-1,
372
+ )
373
+
374
+ # default scene center at origin
375
+ center = torch.zeros_like(camera_positions)
376
+ # default camera up direction as +z
377
+ up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
378
+
379
+ fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
380
+
381
+ lookat = F.normalize(center - camera_positions, dim=-1)
382
+ right = F.normalize(torch.cross(lookat, up), dim=-1)
383
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
384
+ c2w3x4 = torch.cat(
385
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
386
+ dim=-1,
387
+ )
388
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
389
+ c2w[:, 3, 3] = 1.0
390
+
391
+ # get directions by dividing directions_unit_focal by focal length
392
+ focal_length = 0.5 * height / torch.tan(0.5 * fovy)
393
+ directions_unit_focal = get_ray_directions(
394
+ H=height,
395
+ W=width,
396
+ focal=1.0,
397
+ )
398
+ directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
399
+ directions[:, :, :, :2] = (
400
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
401
+ )
402
+ # must use normalize=True to normalize directions here
403
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
404
+
405
+ return rays_o, rays_d
406
+
407
+
408
+ def remove_background(
409
+ image: PIL.Image.Image,
410
+ rembg_session: Any = None,
411
+ force: bool = False,
412
+ **rembg_kwargs,
413
+ ) -> PIL.Image.Image:
414
+ do_remove = True
415
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
416
+ do_remove = False
417
+ do_remove = do_remove or force
418
+ if do_remove:
419
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
420
+ return image
421
+
422
+
423
+ def resize_foreground(
424
+ image: PIL.Image.Image,
425
+ ratio: float,
426
+ ) -> PIL.Image.Image:
427
+ image = np.array(image)
428
+ assert image.shape[-1] == 4
429
+ alpha = np.where(image[..., 3] > 0)
430
+ y1, y2, x1, x2 = (
431
+ alpha[0].min(),
432
+ alpha[0].max(),
433
+ alpha[1].min(),
434
+ alpha[1].max(),
435
+ )
436
+ # crop the foreground
437
+ fg = image[y1:y2, x1:x2]
438
+ # pad to square
439
+ size = max(fg.shape[0], fg.shape[1])
440
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
441
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
442
+ new_image = np.pad(
443
+ fg,
444
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
445
+ mode="constant",
446
+ constant_values=((0, 0), (0, 0), (0, 0)),
447
+ )
448
+
449
+ # compute padding according to the ratio
450
+ new_size = int(new_image.shape[0] / ratio)
451
+ # pad to size, double side
452
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
453
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
454
+ new_image = np.pad(
455
+ new_image,
456
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
457
+ mode="constant",
458
+ constant_values=((0, 0), (0, 0), (0, 0)),
459
+ )
460
+ new_image = PIL.Image.fromarray(new_image)
461
+ return new_image
462
+
463
+
464
+ def save_video(
465
+ frames: List[PIL.Image.Image],
466
+ output_path: str,
467
+ fps: int = 30,
468
+ ):
469
+ # use imageio to save video
470
+ frames = [np.array(frame) for frame in frames]
471
+ writer = imageio.get_writer(output_path, fps=fps)
472
+ for frame in frames:
473
+ writer.append_data(frame)
474
+ writer.close()
475
+
476
+
477
+ _dir2vec = {
478
+ "+x": np.array([1, 0, 0]),
479
+ "+y": np.array([0, 1, 0]),
480
+ "+z": np.array([0, 0, 1]),
481
+ "-x": np.array([-1, 0, 0]),
482
+ "-y": np.array([0, -1, 0]),
483
+ "-z": np.array([0, 0, -1]),
484
+ }
485
+
486
+
487
+ def to_gradio_3d_orientation(vertices):
488
+ z_, x_ = _dir2vec["+y"], _dir2vec["-z"]
489
+ y_ = np.cross(z_, x_)
490
+ std2mesh = np.stack([x_, y_, z_], axis=0).T
491
+ vertices = np.dot(std2mesh, vertices.T).T
492
+ return vertices