keithhon commited on
Commit
60eb46a
1 Parent(s): bceda6b

Upload dalle/models/stage1/vqgan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/models/stage1/vqgan.py +93 -0
dalle/models/stage1/vqgan.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Tuple, Optional
9
+ from einops import rearrange
10
+ from omegaconf import OmegaConf
11
+ from .layers import Encoder, Decoder
12
+
13
+
14
+ class VectorQuantizer(nn.Module):
15
+ """
16
+ Simplified VectorQuantizer in the original VQGAN repository
17
+ by removing unncessary modules for sampling
18
+ """
19
+ def __init__(self, dim: int, n_embed: int, beta: float) -> None:
20
+ super().__init__()
21
+ self.n_embed = n_embed
22
+ self.dim = dim
23
+ self.beta = beta
24
+
25
+ self.embedding = nn.Embedding(self.n_embed, self.dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
27
+
28
+ def forward(self,
29
+ z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
30
+ z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
31
+ z_flattened = z.view(-1, self.dim)
32
+
33
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
34
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
35
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
36
+
37
+ min_encoding_indices = torch.argmin(d, dim=1)
38
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
39
+ return z_q, min_encoding_indices
40
+
41
+ def get_codebook_entry(self,
42
+ indices: torch.LongTensor,
43
+ shape: Optional[List[int]] = None) -> torch.FloatTensor:
44
+ z_q = self.embedding(indices)
45
+ if shape is not None:
46
+ z_q = z_q.view(shape)
47
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
48
+ return z_q
49
+
50
+
51
+ class VQGAN(nn.Module):
52
+ def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
53
+ super().__init__()
54
+ self.encoder = Encoder(**hparams)
55
+ self.decoder = Decoder(**hparams)
56
+ self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
57
+ self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
58
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
59
+ self.latent_dim = hparams.attn_resolutions[0]
60
+
61
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
62
+ quant = self.encode(x)
63
+ dec = self.decode(quant)
64
+ return dec
65
+
66
+ def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
67
+ h = self.encoder(x)
68
+ h = self.quant_conv(h)
69
+ quant = self.quantize(h)[0]
70
+ quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
71
+ return quant
72
+
73
+ def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
74
+ quant = self.post_quant_conv(quant)
75
+ dec = self.decoder(quant)
76
+ return dec
77
+
78
+ def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
79
+ quant = self.quantize.get_codebook_entry(code)
80
+ quant = quant.permute(0, 3, 1, 2)
81
+ dec = self.decode(quant)
82
+ return dec
83
+
84
+ def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
85
+ h = self.encoder(x)
86
+ h = self.quant_conv(h)
87
+ codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
88
+ return codes
89
+
90
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
91
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
92
+ self.load_state_dict(ckpt, strict=strict)
93
+ print(f'{path} successfully restored..')