jadechoghari commited on
Commit
687f75f
·
verified ·
1 Parent(s): 3ea616d

Create mv_unet.py

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +146 -0
unet/mv_unet.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffusers import ModelMixin, ConfigMixin
5
+ from einops import rearrange
6
+ from .mv_attention import SPADTransformer as SpatialTransformer
7
+ from .openaimodel import UNetModel, TimestepBlock
8
+
9
+ # we define the timestep_embedding
10
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
11
+ """
12
+ Create sinusoidal timestep embeddings.
13
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
14
+ These may be fractional.
15
+ :param dim: the dimension of the output.
16
+ :param max_period: controls the minimum frequency of the embeddings.
17
+ :return: an [N x dim] Tensor of positional embeddings.
18
+ """
19
+ if not repeat_only:
20
+ half = dim // 2
21
+ freqs = torch.exp(
22
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
23
+ ).to(device=timesteps.device)
24
+ args = timesteps[:, None].float() * freqs[None]
25
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
26
+ if dim % 2:
27
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
28
+ else:
29
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
30
+ return embedding
31
+
32
+ class SPADUnetModel(ModelMixin, ConfigMixin):
33
+ def __init__(self, image_size=32, in_channels=4, out_channels=4, model_channels=320,
34
+ attention_resolutions=(4, 2, 1), num_res_blocks=2, channel_mult=(1, 2, 4, 4),
35
+ num_heads=8, use_spatial_transformer=True, transformer_depth=1, context_dim=768,
36
+ use_checkpoint=False, legacy=False, **kwargs):
37
+ super().__init__()
38
+ self.image_size = image_size
39
+ self.in_channels = in_channels
40
+ self.out_channels = out_channels
41
+ self.model_channels = model_channels
42
+ self.attention_resolutions = attention_resolutions
43
+ self.num_res_blocks = num_res_blocks
44
+ self.channel_mult = channel_mult
45
+ self.num_heads = num_heads
46
+ self.use_spatial_transformer = use_spatial_transformer
47
+ self.transformer_depth = transformer_depth
48
+ self.context_dim = context_dim
49
+ self.use_checkpoint = use_checkpoint
50
+ self.legacy = legacy
51
+
52
+ # we initialize the unetmodel
53
+ self.unet = UNetModel(image_size, in_channels, out_channels, model_channels,
54
+ attention_resolutions, num_res_blocks, channel_mult,
55
+ num_heads=num_heads, context_dim=context_dim, **kwargs)
56
+
57
+ def encode(self, h, emb, context, blocks):
58
+ hs = []
59
+ n_objects, n_views = h.shape[:2]
60
+ for i, block in enumerate(blocks):
61
+ for j, layer in enumerate(block):
62
+ if isinstance(layer, SpatialTransformer):
63
+ h = layer(h, context)
64
+ elif isinstance(layer, TimestepBlock):
65
+ # squash first two dims (single pass)
66
+ h = rearrange(h, "n v c h w -> (n v) c h w")
67
+ emb = rearrange(emb, "n v c -> (n v) c")
68
+ # apply layer
69
+ h = layer(h, emb)
70
+ # unsquash first two dims
71
+ h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
72
+ emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
73
+ else:
74
+ # squash first two dims (single pass)
75
+ h = rearrange(h, "n v c h w -> (n v) c h w")
76
+ # apply layer
77
+ h = layer(h)
78
+ # unsquash first two dims
79
+ h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
80
+ hs.append(h)
81
+ return hs
82
+
83
+ def decode(self, h, hs, emb, context, xdtype, last=False, return_outputs=False):
84
+ ho = []
85
+ n_objects, n_views = h.shape[:2]
86
+ for i, block in enumerate(self.unet.output_blocks):
87
+ h = torch.cat([h, hs[-(i+1)]], dim=2)
88
+ for j, layer in enumerate(block):
89
+ if isinstance(layer, SpatialTransformer):
90
+ h = layer(h, context)
91
+ elif isinstance(layer, TimestepBlock):
92
+ # squash first two dims (single pass)
93
+ h = rearrange(h, "n v c h w -> (n v) c h w")
94
+ emb = rearrange(emb, "n v c -> (n v) c")
95
+ # apply layer
96
+ h = layer(h, emb)
97
+ # unsquash first two dims
98
+ h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
99
+ emb = rearrange(emb, "(n v) c -> n v c", n=n_objects, v=n_views)
100
+ else:
101
+ # squash first two dims (single pass)
102
+ h = rearrange(h, "n v c h w -> (n v) c h w")
103
+ # apply layer
104
+ h = layer(h)
105
+ # unsquash first two dims
106
+ h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
107
+ ho.append(h)
108
+
109
+ # process last layer
110
+ h = h.type(xdtype)
111
+ h = rearrange(h, "n v c h w -> (n v) c h w")
112
+ if last:
113
+ #changed code here to make compatible with diffusers unet
114
+ h = self.unet.out(h)
115
+ h = rearrange(h, "(n v) c h w -> n v c h w", n=n_objects, v=n_views)
116
+ ho.append(h)
117
+ return ho if return_outputs else h
118
+
119
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
120
+ n_objects, n_views = x.shape[:2]
121
+ timesteps = rearrange(timesteps, "n v -> (n v)")
122
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
123
+ time = self.unet.time_embed(t_emb)
124
+ time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
125
+
126
+ if len(context) == 2:
127
+ txt, cam = context
128
+ elif len(context) == 3:
129
+ txt, cam, epi_mask = context
130
+ txt = (txt, epi_mask)
131
+ else:
132
+ raise ValueError
133
+
134
+ if x.shape[2] > 4:
135
+ plucker, x = x[:, :, 4:], x[:, :, :4]
136
+ txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
137
+
138
+ time_cam = time + cam
139
+ del time, cam
140
+
141
+ h = x.type(self.dtype)
142
+ hs = self.encode(h, time_cam, txt, self.unet.input_blocks)
143
+ h = self.encode(hs[-1], time_cam, txt, [self.unet.middle_block])[0]
144
+ h = self.decode(h, hs, time_cam, txt, x.dtype, last=True)
145
+
146
+ return h