YashNagraj75 commited on
Commit
ec73463
·
1 Parent(s): e1d97a8

Add blocks for ddpm base unet (it's simpler)

Browse files
Files changed (1) hide show
  1. model_blocks/unet_base.py +34 -0
model_blocks/unet_base.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def get_time_embedding(time_steps, temb_dim):
10
+ r"""
11
+ Convert time steps tensor into an embedding using the
12
+ sinusoidal time embedding formula
13
+ :param time_steps: 1D tensor of length batch size
14
+ :param temb_dim: Dimension of the embedding
15
+ :return: BxD embedding representation of B time steps
16
+ """
17
+ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
18
+
19
+ # factor = 10000^(2i/d_model)
20
+ factor = 10000 ** (
21
+ torch.arange(
22
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device
23
+ )
24
+ / (temb_dim // 2)
25
+ )
26
+
27
+ # pos / factor
28
+ # timesteps B -> B, 1 -> B, temb_dim
29
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
30
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
31
+ return t_emb
32
+
33
+
34
+ class