JustinLin610 commited on
Commit
dd78d66
1 Parent(s): 9eb2477

remove unnecessary files

Browse files
models/clip/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .clip import *
 
 
models/clip/clip.py DELETED
@@ -1,229 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from typing import Any, Union, List
6
- from pkg_resources import packaging
7
-
8
- import torch
9
- from PIL import Image
10
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
- from tqdm import tqdm
12
-
13
- from .model import build_model
14
- from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
-
16
- try:
17
- from torchvision.transforms import InterpolationMode
18
- BICUBIC = InterpolationMode.BICUBIC
19
- except ImportError:
20
- BICUBIC = Image.BICUBIC
21
-
22
-
23
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
- warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
-
26
-
27
- __all__ = ["available_models", "load", "tokenize"]
28
- _tokenizer = _Tokenizer()
29
-
30
- _MODELS = {
31
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
- "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36
- "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37
- }
38
-
39
-
40
- def _download(url: str, root: str):
41
- os.makedirs(root, exist_ok=True)
42
- filename = os.path.basename(url)
43
-
44
- expected_sha256 = url.split("/")[-2]
45
- download_target = os.path.join(root, filename)
46
-
47
- if os.path.exists(download_target) and not os.path.isfile(download_target):
48
- raise RuntimeError(f"{download_target} exists and is not a regular file")
49
-
50
- if os.path.isfile(download_target):
51
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
52
- return download_target
53
- else:
54
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
55
-
56
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
57
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
58
- while True:
59
- buffer = source.read(8192)
60
- if not buffer:
61
- break
62
-
63
- output.write(buffer)
64
- loop.update(len(buffer))
65
-
66
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
67
- raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
68
-
69
- return download_target
70
-
71
-
72
- def _convert_image_to_rgb(image):
73
- return image.convert("RGB")
74
-
75
-
76
- def _transform(n_px):
77
- return Compose([
78
- Resize(n_px, interpolation=BICUBIC),
79
- CenterCrop(n_px),
80
- _convert_image_to_rgb,
81
- ToTensor(),
82
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
83
- ])
84
-
85
-
86
- def available_models() -> List[str]:
87
- """Returns the names of available CLIP models"""
88
- return list(_MODELS.keys())
89
-
90
-
91
- def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
92
- """Load a CLIP model
93
-
94
- Parameters
95
- ----------
96
- name : str
97
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
98
-
99
- device : Union[str, torch.device]
100
- The device to put the loaded model
101
-
102
- jit : bool
103
- Whether to load the optimized JIT model or more hackable non-JIT model (default).
104
-
105
- download_root: str
106
- path to download the model files; by default, it uses "~/.cache/clip"
107
-
108
- Returns
109
- -------
110
- model : torch.nn.Module
111
- The CLIP model
112
-
113
- preprocess : Callable[[PIL.Image], torch.Tensor]
114
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
115
- """
116
- if name in _MODELS:
117
- model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
118
- elif os.path.isfile(name):
119
- model_path = name
120
- else:
121
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
122
-
123
- try:
124
- # loading JIT archive
125
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
126
- state_dict = None
127
- except RuntimeError:
128
- # loading saved state dict
129
- if jit:
130
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
131
- jit = False
132
- state_dict = torch.load(model_path, map_location="cpu")
133
-
134
- if not jit:
135
- model = build_model(state_dict or model.state_dict()).to(device)
136
- if str(device) == "cpu":
137
- model.float()
138
- return model, _transform(model.visual.input_resolution)
139
-
140
- # patch the device names
141
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
142
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
143
-
144
- def patch_device(module):
145
- try:
146
- graphs = [module.graph] if hasattr(module, "graph") else []
147
- except RuntimeError:
148
- graphs = []
149
-
150
- if hasattr(module, "forward1"):
151
- graphs.append(module.forward1.graph)
152
-
153
- for graph in graphs:
154
- for node in graph.findAllNodes("prim::Constant"):
155
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
156
- node.copyAttributes(device_node)
157
-
158
- model.apply(patch_device)
159
- patch_device(model.encode_image)
160
- patch_device(model.encode_text)
161
-
162
- # patch dtype to float32 on CPU
163
- if str(device) == "cpu":
164
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
165
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
166
- float_node = float_input.node()
167
-
168
- def patch_float(module):
169
- try:
170
- graphs = [module.graph] if hasattr(module, "graph") else []
171
- except RuntimeError:
172
- graphs = []
173
-
174
- if hasattr(module, "forward1"):
175
- graphs.append(module.forward1.graph)
176
-
177
- for graph in graphs:
178
- for node in graph.findAllNodes("aten::to"):
179
- inputs = list(node.inputs())
180
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
181
- if inputs[i].node()["value"] == 5:
182
- inputs[i].node().copyAttributes(float_node)
183
-
184
- model.apply(patch_float)
185
- patch_float(model.encode_image)
186
- patch_float(model.encode_text)
187
-
188
- model.float()
189
-
190
- return model, _transform(model.input_resolution.item())
191
-
192
-
193
- def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
194
- """
195
- Returns the tokenized representation of given input string(s)
196
-
197
- Parameters
198
- ----------
199
- texts : Union[str, List[str]]
200
- An input string or a list of input strings to tokenize
201
-
202
- context_length : int
203
- The context length to use; all CLIP models use 77 as the context length
204
-
205
- truncate: bool
206
- Whether to truncate the text in case its encoding is longer than the context length
207
-
208
- Returns
209
- -------
210
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
211
- """
212
- if isinstance(texts, str):
213
- texts = [texts]
214
-
215
- sot_token = _tokenizer.encoder["<|startoftext|>"]
216
- eot_token = _tokenizer.encoder["<|endoftext|>"]
217
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
218
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
219
-
220
- for i, tokens in enumerate(all_tokens):
221
- if len(tokens) > context_length:
222
- if truncate:
223
- tokens = tokens[:context_length]
224
- tokens[-1] = eot_token
225
- else:
226
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
227
- result[i, :len(tokens)] = torch.tensor(tokens)
228
-
229
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/clip/model.py DELETED
@@ -1,437 +0,0 @@
1
- from collections import OrderedDict
2
- from typing import Tuple, Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import nn
8
-
9
-
10
- class Bottleneck(nn.Module):
11
- expansion = 4
12
-
13
- def __init__(self, inplanes, planes, stride=1):
14
- super().__init__()
15
-
16
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
- self.bn1 = nn.BatchNorm2d(planes)
19
-
20
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
- self.bn2 = nn.BatchNorm2d(planes)
22
-
23
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
-
25
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
-
28
- self.relu = nn.ReLU(inplace=True)
29
- self.downsample = None
30
- self.stride = stride
31
-
32
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
- self.downsample = nn.Sequential(OrderedDict([
35
- ("-1", nn.AvgPool2d(stride)),
36
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
- ("1", nn.BatchNorm2d(planes * self.expansion))
38
- ]))
39
-
40
- def forward(self, x: torch.Tensor):
41
- identity = x
42
-
43
- out = self.relu(self.bn1(self.conv1(x)))
44
- out = self.relu(self.bn2(self.conv2(out)))
45
- out = self.avgpool(out)
46
- out = self.bn3(self.conv3(out))
47
-
48
- if self.downsample is not None:
49
- identity = self.downsample(x)
50
-
51
- out += identity
52
- out = self.relu(out)
53
- return out
54
-
55
-
56
- class AttentionPool2d(nn.Module):
57
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
- super().__init__()
59
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
- self.k_proj = nn.Linear(embed_dim, embed_dim)
61
- self.q_proj = nn.Linear(embed_dim, embed_dim)
62
- self.v_proj = nn.Linear(embed_dim, embed_dim)
63
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
- self.num_heads = num_heads
65
-
66
- def forward(self, x):
67
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
- x, _ = F.multi_head_attention_forward(
71
- query=x, key=x, value=x,
72
- embed_dim_to_check=x.shape[-1],
73
- num_heads=self.num_heads,
74
- q_proj_weight=self.q_proj.weight,
75
- k_proj_weight=self.k_proj.weight,
76
- v_proj_weight=self.v_proj.weight,
77
- in_proj_weight=None,
78
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
- bias_k=None,
80
- bias_v=None,
81
- add_zero_attn=False,
82
- dropout_p=0,
83
- out_proj_weight=self.c_proj.weight,
84
- out_proj_bias=self.c_proj.bias,
85
- use_separate_proj_weight=True,
86
- training=self.training,
87
- need_weights=False
88
- )
89
-
90
- return x[0]
91
-
92
-
93
- class ModifiedResNet(nn.Module):
94
- """
95
- A ResNet class that is similar to torchvision's but contains the following changes:
96
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
- - The final pooling layer is a QKV attention instead of an average pool
99
- """
100
-
101
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
- super().__init__()
103
- self.output_dim = output_dim
104
- self.input_resolution = input_resolution
105
-
106
- # the 3-layer stem
107
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
- self.bn1 = nn.BatchNorm2d(width // 2)
109
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
- self.bn2 = nn.BatchNorm2d(width // 2)
111
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
- self.bn3 = nn.BatchNorm2d(width)
113
- self.avgpool = nn.AvgPool2d(2)
114
- self.relu = nn.ReLU(inplace=True)
115
-
116
- # residual layers
117
- self._inplanes = width # this is a *mutable* variable used during construction
118
- self.layer1 = self._make_layer(width, layers[0])
119
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
-
123
- embed_dim = width * 32 # the ResNet feature dimension
124
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
-
126
- def _make_layer(self, planes, blocks, stride=1):
127
- layers = [Bottleneck(self._inplanes, planes, stride)]
128
-
129
- self._inplanes = planes * Bottleneck.expansion
130
- for _ in range(1, blocks):
131
- layers.append(Bottleneck(self._inplanes, planes))
132
-
133
- return nn.Sequential(*layers)
134
-
135
- def forward(self, x):
136
- def stem(x):
137
- for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
- x = self.relu(bn(conv(x)))
139
- x = self.avgpool(x)
140
- return x
141
-
142
- x = x.type(self.conv1.weight.dtype)
143
- x = stem(x)
144
- x = self.layer1(x)
145
- x = self.layer2(x)
146
- x = self.layer3(x)
147
- x = self.layer4(x)
148
- x = self.attnpool(x)
149
-
150
- return x
151
-
152
-
153
- class LayerNorm(nn.LayerNorm):
154
- """Subclass torch's LayerNorm to handle fp16."""
155
-
156
- def forward(self, x: torch.Tensor):
157
- orig_type = x.dtype
158
- ret = super().forward(x.type(torch.float32))
159
- return ret.type(orig_type)
160
-
161
-
162
- class QuickGELU(nn.Module):
163
- def forward(self, x: torch.Tensor):
164
- return x * torch.sigmoid(1.702 * x)
165
-
166
-
167
- class ResidualAttentionBlock(nn.Module):
168
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
- super().__init__()
170
-
171
- self.attn = nn.MultiheadAttention(d_model, n_head)
172
- self.ln_1 = LayerNorm(d_model)
173
- self.mlp = nn.Sequential(OrderedDict([
174
- ("c_fc", nn.Linear(d_model, d_model * 4)),
175
- ("gelu", QuickGELU()),
176
- ("c_proj", nn.Linear(d_model * 4, d_model))
177
- ]))
178
- self.ln_2 = LayerNorm(d_model)
179
- self.attn_mask = attn_mask
180
-
181
- def attention(self, x: torch.Tensor):
182
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
-
185
- def forward(self, x: torch.Tensor):
186
- x = x + self.attention(self.ln_1(x))
187
- x = x + self.mlp(self.ln_2(x))
188
- return x
189
-
190
-
191
- class Transformer(nn.Module):
192
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
- super().__init__()
194
- self.width = width
195
- self.layers = layers
196
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
-
198
- def forward(self, x: torch.Tensor):
199
- return self.resblocks(x)
200
-
201
-
202
- class VisionTransformer(nn.Module):
203
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
- super().__init__()
205
- self.input_resolution = input_resolution
206
- self.output_dim = output_dim
207
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
-
209
- scale = width ** -0.5
210
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
- self.ln_pre = LayerNorm(width)
213
-
214
- self.transformer = Transformer(width, layers, heads)
215
-
216
- self.ln_post = LayerNorm(width)
217
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
-
219
- def forward(self, x: torch.Tensor):
220
- x = self.conv1(x) # shape = [*, width, grid, grid]
221
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
- x = torch.cat(
224
- [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
225
- x], dim=1) # shape = [*, grid ** 2 + 1, width]
226
- x = x + self.positional_embedding.to(x.dtype)
227
- x = self.ln_pre(x)
228
-
229
- x = x.permute(1, 0, 2) # NLD -> LND
230
- x = self.transformer(x)
231
- x = x.permute(1, 0, 2) # LND -> NLD
232
-
233
- x = self.ln_post(x[:, 0, :])
234
-
235
- if self.proj is not None:
236
- x = x @ self.proj
237
-
238
- return x
239
-
240
-
241
- class CLIP(nn.Module):
242
- def __init__(self,
243
- embed_dim: int,
244
- # vision
245
- image_resolution: int,
246
- vision_layers: Union[Tuple[int, int, int, int], int],
247
- vision_width: int,
248
- vision_patch_size: int,
249
- # text
250
- context_length: int,
251
- vocab_size: int,
252
- transformer_width: int,
253
- transformer_heads: int,
254
- transformer_layers: int
255
- ):
256
- super().__init__()
257
-
258
- self.context_length = context_length
259
- self.input_resolution = image_resolution
260
-
261
- if isinstance(vision_layers, (tuple, list)):
262
- vision_heads = vision_width * 32 // 64
263
- self.visual = ModifiedResNet(
264
- layers=vision_layers,
265
- output_dim=embed_dim,
266
- heads=vision_heads,
267
- input_resolution=image_resolution,
268
- width=vision_width
269
- )
270
- else:
271
- vision_heads = vision_width // 64
272
- self.visual = VisionTransformer(
273
- input_resolution=image_resolution,
274
- patch_size=vision_patch_size,
275
- width=vision_width,
276
- layers=vision_layers,
277
- heads=vision_heads,
278
- output_dim=embed_dim
279
- )
280
-
281
- self.transformer = Transformer(
282
- width=transformer_width,
283
- layers=transformer_layers,
284
- heads=transformer_heads,
285
- attn_mask=self.build_attention_mask()
286
- )
287
-
288
- self.vocab_size = vocab_size
289
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
290
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
291
- self.ln_final = LayerNorm(transformer_width)
292
-
293
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
294
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
295
-
296
- self.initialize_parameters()
297
-
298
- def initialize_parameters(self):
299
- nn.init.normal_(self.token_embedding.weight, std=0.02)
300
- nn.init.normal_(self.positional_embedding, std=0.01)
301
-
302
- if isinstance(self.visual, ModifiedResNet):
303
- if self.visual.attnpool is not None:
304
- std = self.visual.attnpool.c_proj.in_features ** -0.5
305
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
306
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
307
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
308
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
309
-
310
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
311
- for name, param in resnet_block.named_parameters():
312
- if name.endswith("bn3.weight"):
313
- nn.init.zeros_(param)
314
-
315
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
316
- attn_std = self.transformer.width ** -0.5
317
- fc_std = (2 * self.transformer.width) ** -0.5
318
- for block in self.transformer.resblocks:
319
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
320
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
321
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
322
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
323
-
324
- if self.text_projection is not None:
325
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
326
-
327
- def build_attention_mask(self):
328
- # lazily create causal attention mask, with full attention between the vision tokens
329
- # pytorch uses additive attention mask; fill with -inf
330
- mask = torch.empty(self.context_length, self.context_length)
331
- mask.fill_(float("-inf"))
332
- mask.triu_(1) # zero out the lower diagonal
333
- return mask
334
-
335
- @property
336
- def dtype(self):
337
- return self.visual.conv1.weight.dtype
338
-
339
- def encode_image(self, image):
340
- return self.visual(image.type(self.dtype))
341
-
342
- def encode_text(self, text):
343
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
344
-
345
- x = x + self.positional_embedding.type(self.dtype)
346
- x = x.permute(1, 0, 2) # NLD -> LND
347
- x = self.transformer(x)
348
- x = x.permute(1, 0, 2) # LND -> NLD
349
- x = self.ln_final(x).type(self.dtype)
350
-
351
- # x.shape = [batch_size, n_ctx, transformer.width]
352
- # take features from the eot embedding (eot_token is the highest number in each sequence)
353
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
354
-
355
- return x
356
-
357
- def forward(self, image, text):
358
- image_features = self.encode_image(image)
359
- text_features = self.encode_text(text)
360
-
361
- # normalized features
362
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
363
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
364
-
365
- # cosine similarity as logits
366
- logit_scale = self.logit_scale.exp()
367
- logits_per_image = logit_scale * image_features @ text_features.t()
368
- logits_per_text = logits_per_image.t()
369
-
370
- # shape = [global_batch_size, global_batch_size]
371
- return logits_per_image, logits_per_text
372
-
373
-
374
- def convert_weights(model: nn.Module):
375
- """Convert applicable model parameters to fp16"""
376
-
377
- def _convert_weights_to_fp16(l):
378
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
379
- l.weight.data = l.weight.data.half()
380
- if l.bias is not None:
381
- l.bias.data = l.bias.data.half()
382
-
383
- if isinstance(l, nn.MultiheadAttention):
384
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
385
- tensor = getattr(l, attr)
386
- if tensor is not None:
387
- tensor.data = tensor.data.half()
388
-
389
- for name in ["text_projection", "proj"]:
390
- if hasattr(l, name):
391
- attr = getattr(l, name)
392
- if attr is not None:
393
- attr.data = attr.data.half()
394
-
395
- model.apply(_convert_weights_to_fp16)
396
-
397
-
398
- def build_model(state_dict: dict):
399
- vit = "visual.proj" in state_dict
400
-
401
- if vit:
402
- vision_width = state_dict["visual.conv1.weight"].shape[0]
403
- vision_layers = len(
404
- [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
- image_resolution = vision_patch_size * grid_size
408
- else:
409
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
410
- [1, 2, 3, 4]]
411
- vision_layers = tuple(counts)
412
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
413
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
414
- vision_patch_size = None
415
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
416
- image_resolution = output_width * 32
417
-
418
- embed_dim = state_dict["text_projection"].shape[1]
419
- context_length = state_dict["positional_embedding"].shape[0]
420
- vocab_size = state_dict["token_embedding.weight"].shape[0]
421
- transformer_width = state_dict["ln_final.weight"].shape[0]
422
- transformer_heads = transformer_width // 64
423
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
424
-
425
- model = CLIP(
426
- embed_dim,
427
- image_resolution, vision_layers, vision_width, vision_patch_size,
428
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
429
- )
430
-
431
- for key in ["input_resolution", "context_length", "vocab_size"]:
432
- if key in state_dict:
433
- del state_dict[key]
434
-
435
- convert_weights(model)
436
- model.load_state_dict(state_dict)
437
- return model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/clip/simple_tokenizer.py DELETED
@@ -1,132 +0,0 @@
1
- import gzip
2
- import html
3
- import os
4
- from functools import lru_cache
5
-
6
- import ftfy
7
- import regex as re
8
-
9
-
10
- @lru_cache()
11
- def default_bpe():
12
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
-
14
-
15
- @lru_cache()
16
- def bytes_to_unicode():
17
- """
18
- Returns list of utf-8 byte and a corresponding list of unicode strings.
19
- The reversible bpe codes work on unicode strings.
20
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
- This is a signficant percentage of your normal, say, 32K bpe vocab.
23
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
- And avoids mapping to whitespace/control characters the bpe code barfs on.
25
- """
26
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
- cs = bs[:]
28
- n = 0
29
- for b in range(2**8):
30
- if b not in bs:
31
- bs.append(b)
32
- cs.append(2**8+n)
33
- n += 1
34
- cs = [chr(n) for n in cs]
35
- return dict(zip(bs, cs))
36
-
37
-
38
- def get_pairs(word):
39
- """Return set of symbol pairs in a word.
40
- Word is represented as tuple of symbols (symbols being variable-length strings).
41
- """
42
- pairs = set()
43
- prev_char = word[0]
44
- for char in word[1:]:
45
- pairs.add((prev_char, char))
46
- prev_char = char
47
- return pairs
48
-
49
-
50
- def basic_clean(text):
51
- text = ftfy.fix_text(text)
52
- text = html.unescape(html.unescape(text))
53
- return text.strip()
54
-
55
-
56
- def whitespace_clean(text):
57
- text = re.sub(r'\s+', ' ', text)
58
- text = text.strip()
59
- return text
60
-
61
-
62
- class SimpleTokenizer(object):
63
- def __init__(self, bpe_path: str = default_bpe()):
64
- self.byte_encoder = bytes_to_unicode()
65
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
- merges = merges[1:49152-256-2+1]
68
- merges = [tuple(merge.split()) for merge in merges]
69
- vocab = list(bytes_to_unicode().values())
70
- vocab = vocab + [v+'</w>' for v in vocab]
71
- for merge in merges:
72
- vocab.append(''.join(merge))
73
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
- self.encoder = dict(zip(vocab, range(len(vocab))))
75
- self.decoder = {v: k for k, v in self.encoder.items()}
76
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
-
80
- def bpe(self, token):
81
- if token in self.cache:
82
- return self.cache[token]
83
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
- pairs = get_pairs(word)
85
-
86
- if not pairs:
87
- return token+'</w>'
88
-
89
- while True:
90
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
- if bigram not in self.bpe_ranks:
92
- break
93
- first, second = bigram
94
- new_word = []
95
- i = 0
96
- while i < len(word):
97
- try:
98
- j = word.index(first, i)
99
- new_word.extend(word[i:j])
100
- i = j
101
- except:
102
- new_word.extend(word[i:])
103
- break
104
-
105
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
- new_word.append(first+second)
107
- i += 2
108
- else:
109
- new_word.append(word[i])
110
- i += 1
111
- new_word = tuple(new_word)
112
- word = new_word
113
- if len(word) == 1:
114
- break
115
- else:
116
- pairs = get_pairs(word)
117
- word = ' '.join(word)
118
- self.cache[token] = word
119
- return word
120
-
121
- def encode(self, text):
122
- bpe_tokens = []
123
- text = whitespace_clean(basic_clean(text)).lower()
124
- for token in re.findall(self.pat, text):
125
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
- return bpe_tokens
128
-
129
- def decode(self, tokens):
130
- text = ''.join([self.decoder[token] for token in tokens])
131
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/lr_scheduler.py DELETED
@@ -1,39 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import numpy as np
7
-
8
-
9
- class LambdaWarmUpCosineScheduler:
10
- """
11
- note: use with a base_lr of 1.0
12
- """
13
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
14
- self.lr_warm_up_steps = warm_up_steps
15
- self.lr_start = lr_start
16
- self.lr_min = lr_min
17
- self.lr_max = lr_max
18
- self.lr_max_decay_steps = max_decay_steps
19
- self.last_lr = 0.
20
- self.verbosity_interval = verbosity_interval
21
-
22
- def schedule(self, n):
23
- if self.verbosity_interval > 0:
24
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
25
- if n < self.lr_warm_up_steps:
26
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
27
- self.last_lr = lr
28
- return lr
29
- else:
30
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
31
- t = min(t, 1.0)
32
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
33
- 1 + np.cos(t * np.pi))
34
- self.last_lr = lr
35
- return lr
36
-
37
- def __call__(self, n):
38
- return self.schedule(n)
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/models/vqgan.py DELETED
@@ -1,262 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import pytorch_lightning as pl
4
-
5
- from models.taming.util import instantiate_from_config
6
-
7
- from models.taming.modules.diffusionmodules.model import Encoder, Decoder
8
- from models.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
- from models.taming.modules.vqvae.quantize import GumbelQuantize
10
-
11
- class VQModel(pl.LightningModule):
12
- def __init__(self,
13
- ddconfig,
14
- lossconfig,
15
- n_embed,
16
- embed_dim,
17
- ckpt_path=None,
18
- ignore_keys=[],
19
- image_key="image",
20
- colorize_nlabels=None,
21
- monitor=None,
22
- remap=None,
23
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
24
- ):
25
- super().__init__()
26
- self.image_key = image_key
27
- self.encoder = Encoder(**ddconfig)
28
- self.decoder = Decoder(**ddconfig)
29
- self.loss = instantiate_from_config(lossconfig)
30
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
31
- remap=remap, sane_index_shape=sane_index_shape)
32
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
33
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
34
- if ckpt_path is not None:
35
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
36
- self.image_key = image_key
37
- if colorize_nlabels is not None:
38
- assert type(colorize_nlabels)==int
39
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
40
- if monitor is not None:
41
- self.monitor = monitor
42
-
43
- def init_from_ckpt(self, path, ignore_keys=list()):
44
- sd = torch.load(path, map_location="cpu")["state_dict"]
45
- keys = list(sd.keys())
46
- for k in keys:
47
- for ik in ignore_keys:
48
- if k.startswith(ik):
49
- print("Deleting key {} from state_dict.".format(k))
50
- del sd[k]
51
- self.load_state_dict(sd, strict=False)
52
- print(f"Restored from {path}")
53
-
54
- def encode(self, x):
55
- h = self.encoder(x)
56
- h = self.quant_conv(h)
57
- quant, emb_loss, info = self.quantize(h)
58
- return quant, emb_loss, info
59
-
60
- def decode(self, quant):
61
- quant = self.post_quant_conv(quant)
62
- dec = self.decoder(quant)
63
- return dec
64
-
65
- def decode_code(self, code_b):
66
- quant_b = self.quantize.embed_code(code_b)
67
- dec = self.decode(quant_b)
68
- return dec
69
-
70
- def forward(self, input):
71
- quant, diff, _ = self.encode(input)
72
- dec = self.decode(quant)
73
- return dec, diff
74
-
75
- def get_input(self, batch, k):
76
- x = batch[k]
77
- if len(x.shape) == 3:
78
- x = x[..., None]
79
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80
- return x.float()
81
-
82
- def training_step(self, batch, batch_idx, optimizer_idx):
83
- x = self.get_input(batch, self.image_key)
84
- xrec, qloss = self(x)
85
-
86
- if optimizer_idx == 0:
87
- # autoencode
88
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
89
- last_layer=self.get_last_layer(), split="train")
90
-
91
- self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
92
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
93
- return aeloss
94
-
95
- if optimizer_idx == 1:
96
- # discriminator
97
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
98
- last_layer=self.get_last_layer(), split="train")
99
- self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
100
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
101
- return discloss
102
-
103
- def validation_step(self, batch, batch_idx):
104
- x = self.get_input(batch, self.image_key)
105
- xrec, qloss = self(x)
106
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
107
- last_layer=self.get_last_layer(), split="val")
108
-
109
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
110
- last_layer=self.get_last_layer(), split="val")
111
- rec_loss = log_dict_ae["val/rec_loss"]
112
- self.log("val/rec_loss", rec_loss,
113
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
114
- self.log("val/aeloss", aeloss,
115
- prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
116
- self.log_dict(log_dict_ae)
117
- self.log_dict(log_dict_disc)
118
- return self.log_dict
119
-
120
- def configure_optimizers(self):
121
- lr = self.learning_rate
122
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
123
- list(self.decoder.parameters())+
124
- list(self.quantize.parameters())+
125
- list(self.quant_conv.parameters())+
126
- list(self.post_quant_conv.parameters()),
127
- lr=lr, betas=(0.5, 0.9))
128
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
129
- lr=lr, betas=(0.5, 0.9))
130
- return [opt_ae, opt_disc], []
131
-
132
- def get_last_layer(self):
133
- return self.decoder.conv_out.weight
134
-
135
- def log_images(self, batch, **kwargs):
136
- log = dict()
137
- x = self.get_input(batch, self.image_key)
138
- x = x.to(self.device)
139
- xrec, _ = self(x)
140
- if x.shape[1] > 3:
141
- # colorize with random projection
142
- assert xrec.shape[1] > 3
143
- x = self.to_rgb(x)
144
- xrec = self.to_rgb(xrec)
145
- log["inputs"] = x
146
- log["reconstructions"] = xrec
147
- return log
148
-
149
- def to_rgb(self, x):
150
- assert self.image_key == "segmentation"
151
- if not hasattr(self, "colorize"):
152
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
153
- x = F.conv2d(x, weight=self.colorize)
154
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
155
- return x
156
-
157
-
158
- class GumbelVQ(VQModel):
159
- def __init__(self,
160
- ddconfig,
161
- lossconfig,
162
- n_embed,
163
- embed_dim,
164
- temperature_scheduler_config,
165
- ckpt_path=None,
166
- ignore_keys=[],
167
- image_key="image",
168
- colorize_nlabels=None,
169
- monitor=None,
170
- kl_weight=1e-8,
171
- remap=None,
172
- ):
173
-
174
- z_channels = ddconfig["z_channels"]
175
- super().__init__(ddconfig,
176
- lossconfig,
177
- n_embed,
178
- embed_dim,
179
- ckpt_path=None,
180
- ignore_keys=ignore_keys,
181
- image_key=image_key,
182
- colorize_nlabels=colorize_nlabels,
183
- monitor=monitor,
184
- )
185
-
186
- self.loss.n_classes = n_embed
187
- self.vocab_size = n_embed
188
-
189
- self.quantize = GumbelQuantize(z_channels, embed_dim,
190
- n_embed=n_embed,
191
- kl_weight=kl_weight, temp_init=1.0,
192
- remap=remap)
193
-
194
- self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
195
-
196
- if ckpt_path is not None:
197
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
198
-
199
- def temperature_scheduling(self):
200
- self.quantize.temperature = self.temperature_scheduler(self.global_step)
201
-
202
- def encode_to_prequant(self, x):
203
- h = self.encoder(x)
204
- h = self.quant_conv(h)
205
- return h
206
-
207
- def decode_code(self, code_b):
208
- quant_b = self.quantize.get_codebook_entry(code_b.view(-1), list(code_b.size())+[self.quantize.embedding_dim])
209
- dec = self.decode(quant_b)
210
- return dec
211
-
212
- def training_step(self, batch, batch_idx, optimizer_idx):
213
- self.temperature_scheduling()
214
- x = self.get_input(batch, self.image_key)
215
- xrec, qloss = self(x)
216
-
217
- if optimizer_idx == 0:
218
- # autoencode
219
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
220
- last_layer=self.get_last_layer(), split="train")
221
-
222
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
223
- self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
224
- return aeloss
225
-
226
- if optimizer_idx == 1:
227
- # discriminator
228
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
229
- last_layer=self.get_last_layer(), split="train")
230
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
231
- return discloss
232
-
233
- def validation_step(self, batch, batch_idx):
234
- x = self.get_input(batch, self.image_key)
235
- xrec, qloss = self(x, return_pred_indices=True)
236
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
237
- last_layer=self.get_last_layer(), split="val")
238
-
239
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
240
- last_layer=self.get_last_layer(), split="val")
241
- rec_loss = log_dict_ae["val/rec_loss"]
242
- self.log("val/rec_loss", rec_loss,
243
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
244
- self.log("val/aeloss", aeloss,
245
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
246
- self.log_dict(log_dict_ae)
247
- self.log_dict(log_dict_disc)
248
- return self.log_dict
249
-
250
- def log_images(self, batch, **kwargs):
251
- log = dict()
252
- x = self.get_input(batch, self.image_key)
253
- x = x.to(self.device)
254
- # encode
255
- h = self.encoder(x)
256
- h = self.quant_conv(h)
257
- quant, _, _ = self.quantize(h)
258
- # decode
259
- x_rec = self.decode(quant)
260
- log["inputs"] = x
261
- log["reconstructions"] = x_rec
262
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/diffusionmodules/model.py DELETED
@@ -1,776 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
-
7
-
8
- def get_timestep_embedding(timesteps, embedding_dim):
9
- """
10
- This matches the implementation in Denoising Diffusion Probabilistic Models:
11
- From Fairseq.
12
- Build sinusoidal embeddings.
13
- This matches the implementation in tensor2tensor, but differs slightly
14
- from the description in Section 3.5 of "Attention Is All You Need".
15
- """
16
- assert len(timesteps.shape) == 1
17
-
18
- half_dim = embedding_dim // 2
19
- emb = math.log(10000) / (half_dim - 1)
20
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
- emb = emb.to(device=timesteps.device)
22
- emb = timesteps.float()[:, None] * emb[None, :]
23
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
- if embedding_dim % 2 == 1: # zero pad
25
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
- return emb
27
-
28
-
29
- def nonlinearity(x):
30
- # swish
31
- return x*torch.sigmoid(x)
32
-
33
-
34
- def Normalize(in_channels):
35
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
-
37
-
38
- class Upsample(nn.Module):
39
- def __init__(self, in_channels, with_conv):
40
- super().__init__()
41
- self.with_conv = with_conv
42
- if self.with_conv:
43
- self.conv = torch.nn.Conv2d(in_channels,
44
- in_channels,
45
- kernel_size=3,
46
- stride=1,
47
- padding=1)
48
-
49
- def forward(self, x):
50
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
- if self.with_conv:
52
- x = self.conv(x)
53
- return x
54
-
55
-
56
- class Downsample(nn.Module):
57
- def __init__(self, in_channels, with_conv):
58
- super().__init__()
59
- self.with_conv = with_conv
60
- if self.with_conv:
61
- # no asymmetric padding in torch conv, must do it ourselves
62
- self.conv = torch.nn.Conv2d(in_channels,
63
- in_channels,
64
- kernel_size=3,
65
- stride=2,
66
- padding=0)
67
-
68
- def forward(self, x):
69
- if self.with_conv:
70
- pad = (0,1,0,1)
71
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
- x = self.conv(x)
73
- else:
74
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
- return x
76
-
77
-
78
- class ResnetBlock(nn.Module):
79
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
- dropout, temb_channels=512):
81
- super().__init__()
82
- self.in_channels = in_channels
83
- out_channels = in_channels if out_channels is None else out_channels
84
- self.out_channels = out_channels
85
- self.use_conv_shortcut = conv_shortcut
86
-
87
- self.norm1 = Normalize(in_channels)
88
- self.conv1 = torch.nn.Conv2d(in_channels,
89
- out_channels,
90
- kernel_size=3,
91
- stride=1,
92
- padding=1)
93
- if temb_channels > 0:
94
- self.temb_proj = torch.nn.Linear(temb_channels,
95
- out_channels)
96
- self.norm2 = Normalize(out_channels)
97
- self.dropout = torch.nn.Dropout(dropout)
98
- self.conv2 = torch.nn.Conv2d(out_channels,
99
- out_channels,
100
- kernel_size=3,
101
- stride=1,
102
- padding=1)
103
- if self.in_channels != self.out_channels:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
- out_channels,
107
- kernel_size=3,
108
- stride=1,
109
- padding=1)
110
- else:
111
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
- out_channels,
113
- kernel_size=1,
114
- stride=1,
115
- padding=0)
116
-
117
- def forward(self, x, temb):
118
- h = x
119
- h = self.norm1(h)
120
- h = nonlinearity(h)
121
- h = self.conv1(h)
122
-
123
- if temb is not None:
124
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
-
126
- h = self.norm2(h)
127
- h = nonlinearity(h)
128
- h = self.dropout(h)
129
- h = self.conv2(h)
130
-
131
- if self.in_channels != self.out_channels:
132
- if self.use_conv_shortcut:
133
- x = self.conv_shortcut(x)
134
- else:
135
- x = self.nin_shortcut(x)
136
-
137
- return x+h
138
-
139
-
140
- class AttnBlock(nn.Module):
141
- def __init__(self, in_channels):
142
- super().__init__()
143
- self.in_channels = in_channels
144
-
145
- self.norm = Normalize(in_channels)
146
- self.q = torch.nn.Conv2d(in_channels,
147
- in_channels,
148
- kernel_size=1,
149
- stride=1,
150
- padding=0)
151
- self.k = torch.nn.Conv2d(in_channels,
152
- in_channels,
153
- kernel_size=1,
154
- stride=1,
155
- padding=0)
156
- self.v = torch.nn.Conv2d(in_channels,
157
- in_channels,
158
- kernel_size=1,
159
- stride=1,
160
- padding=0)
161
- self.proj_out = torch.nn.Conv2d(in_channels,
162
- in_channels,
163
- kernel_size=1,
164
- stride=1,
165
- padding=0)
166
-
167
-
168
- def forward(self, x):
169
- h_ = x
170
- h_ = self.norm(h_)
171
- q = self.q(h_)
172
- k = self.k(h_)
173
- v = self.v(h_)
174
-
175
- # compute attention
176
- b,c,h,w = q.shape
177
- q = q.reshape(b,c,h*w)
178
- q = q.permute(0,2,1) # b,hw,c
179
- k = k.reshape(b,c,h*w) # b,c,hw
180
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
- w_ = w_ * (int(c)**(-0.5))
182
- w_ = torch.nn.functional.softmax(w_, dim=2)
183
-
184
- # attend to values
185
- v = v.reshape(b,c,h*w)
186
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
- h_ = h_.reshape(b,c,h,w)
189
-
190
- h_ = self.proj_out(h_)
191
-
192
- return x+h_
193
-
194
-
195
- class Model(nn.Module):
196
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
- resolution, use_timestep=True):
199
- super().__init__()
200
- self.ch = ch
201
- self.temb_ch = self.ch*4
202
- self.num_resolutions = len(ch_mult)
203
- self.num_res_blocks = num_res_blocks
204
- self.resolution = resolution
205
- self.in_channels = in_channels
206
-
207
- self.use_timestep = use_timestep
208
- if self.use_timestep:
209
- # timestep embedding
210
- self.temb = nn.Module()
211
- self.temb.dense = nn.ModuleList([
212
- torch.nn.Linear(self.ch,
213
- self.temb_ch),
214
- torch.nn.Linear(self.temb_ch,
215
- self.temb_ch),
216
- ])
217
-
218
- # downsampling
219
- self.conv_in = torch.nn.Conv2d(in_channels,
220
- self.ch,
221
- kernel_size=3,
222
- stride=1,
223
- padding=1)
224
-
225
- curr_res = resolution
226
- in_ch_mult = (1,)+tuple(ch_mult)
227
- self.down = nn.ModuleList()
228
- for i_level in range(self.num_resolutions):
229
- block = nn.ModuleList()
230
- attn = nn.ModuleList()
231
- block_in = ch*in_ch_mult[i_level]
232
- block_out = ch*ch_mult[i_level]
233
- for i_block in range(self.num_res_blocks):
234
- block.append(ResnetBlock(in_channels=block_in,
235
- out_channels=block_out,
236
- temb_channels=self.temb_ch,
237
- dropout=dropout))
238
- block_in = block_out
239
- if curr_res in attn_resolutions:
240
- attn.append(AttnBlock(block_in))
241
- down = nn.Module()
242
- down.block = block
243
- down.attn = attn
244
- if i_level != self.num_resolutions-1:
245
- down.downsample = Downsample(block_in, resamp_with_conv)
246
- curr_res = curr_res // 2
247
- self.down.append(down)
248
-
249
- # middle
250
- self.mid = nn.Module()
251
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
- out_channels=block_in,
253
- temb_channels=self.temb_ch,
254
- dropout=dropout)
255
- self.mid.attn_1 = AttnBlock(block_in)
256
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
- out_channels=block_in,
258
- temb_channels=self.temb_ch,
259
- dropout=dropout)
260
-
261
- # upsampling
262
- self.up = nn.ModuleList()
263
- for i_level in reversed(range(self.num_resolutions)):
264
- block = nn.ModuleList()
265
- attn = nn.ModuleList()
266
- block_out = ch*ch_mult[i_level]
267
- skip_in = ch*ch_mult[i_level]
268
- for i_block in range(self.num_res_blocks+1):
269
- if i_block == self.num_res_blocks:
270
- skip_in = ch*in_ch_mult[i_level]
271
- block.append(ResnetBlock(in_channels=block_in+skip_in,
272
- out_channels=block_out,
273
- temb_channels=self.temb_ch,
274
- dropout=dropout))
275
- block_in = block_out
276
- if curr_res in attn_resolutions:
277
- attn.append(AttnBlock(block_in))
278
- up = nn.Module()
279
- up.block = block
280
- up.attn = attn
281
- if i_level != 0:
282
- up.upsample = Upsample(block_in, resamp_with_conv)
283
- curr_res = curr_res * 2
284
- self.up.insert(0, up) # prepend to get consistent order
285
-
286
- # end
287
- self.norm_out = Normalize(block_in)
288
- self.conv_out = torch.nn.Conv2d(block_in,
289
- out_ch,
290
- kernel_size=3,
291
- stride=1,
292
- padding=1)
293
-
294
-
295
- def forward(self, x, t=None):
296
- #assert x.shape[2] == x.shape[3] == self.resolution
297
-
298
- if self.use_timestep:
299
- # timestep embedding
300
- assert t is not None
301
- temb = get_timestep_embedding(t, self.ch)
302
- temb = self.temb.dense[0](temb)
303
- temb = nonlinearity(temb)
304
- temb = self.temb.dense[1](temb)
305
- else:
306
- temb = None
307
-
308
- # downsampling
309
- hs = [self.conv_in(x)]
310
- for i_level in range(self.num_resolutions):
311
- for i_block in range(self.num_res_blocks):
312
- h = self.down[i_level].block[i_block](hs[-1], temb)
313
- if len(self.down[i_level].attn) > 0:
314
- h = self.down[i_level].attn[i_block](h)
315
- hs.append(h)
316
- if i_level != self.num_resolutions-1:
317
- hs.append(self.down[i_level].downsample(hs[-1]))
318
-
319
- # middle
320
- h = hs[-1]
321
- h = self.mid.block_1(h, temb)
322
- h = self.mid.attn_1(h)
323
- h = self.mid.block_2(h, temb)
324
-
325
- # upsampling
326
- for i_level in reversed(range(self.num_resolutions)):
327
- for i_block in range(self.num_res_blocks+1):
328
- h = self.up[i_level].block[i_block](
329
- torch.cat([h, hs.pop()], dim=1), temb)
330
- if len(self.up[i_level].attn) > 0:
331
- h = self.up[i_level].attn[i_block](h)
332
- if i_level != 0:
333
- h = self.up[i_level].upsample(h)
334
-
335
- # end
336
- h = self.norm_out(h)
337
- h = nonlinearity(h)
338
- h = self.conv_out(h)
339
- return h
340
-
341
-
342
- class Encoder(nn.Module):
343
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
- resolution, z_channels, double_z=True, **ignore_kwargs):
346
- super().__init__()
347
- self.ch = ch
348
- self.temb_ch = 0
349
- self.num_resolutions = len(ch_mult)
350
- self.num_res_blocks = num_res_blocks
351
- self.resolution = resolution
352
- self.in_channels = in_channels
353
-
354
- # downsampling
355
- self.conv_in = torch.nn.Conv2d(in_channels,
356
- self.ch,
357
- kernel_size=3,
358
- stride=1,
359
- padding=1)
360
-
361
- curr_res = resolution
362
- in_ch_mult = (1,)+tuple(ch_mult)
363
- self.down = nn.ModuleList()
364
- for i_level in range(self.num_resolutions):
365
- block = nn.ModuleList()
366
- attn = nn.ModuleList()
367
- block_in = ch*in_ch_mult[i_level]
368
- block_out = ch*ch_mult[i_level]
369
- for i_block in range(self.num_res_blocks):
370
- block.append(ResnetBlock(in_channels=block_in,
371
- out_channels=block_out,
372
- temb_channels=self.temb_ch,
373
- dropout=dropout))
374
- block_in = block_out
375
- if curr_res in attn_resolutions:
376
- attn.append(AttnBlock(block_in))
377
- down = nn.Module()
378
- down.block = block
379
- down.attn = attn
380
- if i_level != self.num_resolutions-1:
381
- down.downsample = Downsample(block_in, resamp_with_conv)
382
- curr_res = curr_res // 2
383
- self.down.append(down)
384
-
385
- # middle
386
- self.mid = nn.Module()
387
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
- out_channels=block_in,
389
- temb_channels=self.temb_ch,
390
- dropout=dropout)
391
- self.mid.attn_1 = AttnBlock(block_in)
392
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
- out_channels=block_in,
394
- temb_channels=self.temb_ch,
395
- dropout=dropout)
396
-
397
- # end
398
- self.norm_out = Normalize(block_in)
399
- self.conv_out = torch.nn.Conv2d(block_in,
400
- 2*z_channels if double_z else z_channels,
401
- kernel_size=3,
402
- stride=1,
403
- padding=1)
404
-
405
-
406
- def forward(self, x):
407
- #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
-
409
- # timestep embedding
410
- temb = None
411
-
412
- # downsampling
413
- hs = [self.conv_in(x)]
414
- for i_level in range(self.num_resolutions):
415
- for i_block in range(self.num_res_blocks):
416
- h = self.down[i_level].block[i_block](hs[-1], temb)
417
- if len(self.down[i_level].attn) > 0:
418
- h = self.down[i_level].attn[i_block](h)
419
- hs.append(h)
420
- if i_level != self.num_resolutions-1:
421
- hs.append(self.down[i_level].downsample(hs[-1]))
422
-
423
- # middle
424
- h = hs[-1]
425
- h = self.mid.block_1(h, temb)
426
- h = self.mid.attn_1(h)
427
- h = self.mid.block_2(h, temb)
428
-
429
- # end
430
- h = self.norm_out(h)
431
- h = nonlinearity(h)
432
- h = self.conv_out(h)
433
- return h
434
-
435
-
436
- class Decoder(nn.Module):
437
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
- resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
- super().__init__()
441
- self.ch = ch
442
- self.temb_ch = 0
443
- self.num_resolutions = len(ch_mult)
444
- self.num_res_blocks = num_res_blocks
445
- self.resolution = resolution
446
- self.in_channels = in_channels
447
- self.give_pre_end = give_pre_end
448
-
449
- # compute in_ch_mult, block_in and curr_res at lowest res
450
- in_ch_mult = (1,)+tuple(ch_mult)
451
- block_in = ch*ch_mult[self.num_resolutions-1]
452
- curr_res = resolution // 2**(self.num_resolutions-1)
453
- self.z_shape = (1,z_channels,curr_res,curr_res)
454
- print("Working with z of shape {} = {} dimensions.".format(
455
- self.z_shape, np.prod(self.z_shape)))
456
-
457
- # z to block_in
458
- self.conv_in = torch.nn.Conv2d(z_channels,
459
- block_in,
460
- kernel_size=3,
461
- stride=1,
462
- padding=1)
463
-
464
- # middle
465
- self.mid = nn.Module()
466
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
- out_channels=block_in,
468
- temb_channels=self.temb_ch,
469
- dropout=dropout)
470
- self.mid.attn_1 = AttnBlock(block_in)
471
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
- out_channels=block_in,
473
- temb_channels=self.temb_ch,
474
- dropout=dropout)
475
-
476
- # upsampling
477
- self.up = nn.ModuleList()
478
- for i_level in reversed(range(self.num_resolutions)):
479
- block = nn.ModuleList()
480
- attn = nn.ModuleList()
481
- block_out = ch*ch_mult[i_level]
482
- for i_block in range(self.num_res_blocks+1):
483
- block.append(ResnetBlock(in_channels=block_in,
484
- out_channels=block_out,
485
- temb_channels=self.temb_ch,
486
- dropout=dropout))
487
- block_in = block_out
488
- if curr_res in attn_resolutions:
489
- attn.append(AttnBlock(block_in))
490
- up = nn.Module()
491
- up.block = block
492
- up.attn = attn
493
- if i_level != 0:
494
- up.upsample = Upsample(block_in, resamp_with_conv)
495
- curr_res = curr_res * 2
496
- self.up.insert(0, up) # prepend to get consistent order
497
-
498
- # end
499
- self.norm_out = Normalize(block_in)
500
- self.conv_out = torch.nn.Conv2d(block_in,
501
- out_ch,
502
- kernel_size=3,
503
- stride=1,
504
- padding=1)
505
-
506
- def forward(self, z):
507
- #assert z.shape[1:] == self.z_shape[1:]
508
- self.last_z_shape = z.shape
509
-
510
- # timestep embedding
511
- temb = None
512
-
513
- # z to block_in
514
- h = self.conv_in(z)
515
-
516
- # middle
517
- h = self.mid.block_1(h, temb)
518
- h = self.mid.attn_1(h)
519
- h = self.mid.block_2(h, temb)
520
-
521
- # upsampling
522
- for i_level in reversed(range(self.num_resolutions)):
523
- for i_block in range(self.num_res_blocks+1):
524
- h = self.up[i_level].block[i_block](h, temb)
525
- if len(self.up[i_level].attn) > 0:
526
- h = self.up[i_level].attn[i_block](h)
527
- if i_level != 0:
528
- h = self.up[i_level].upsample(h)
529
-
530
- # end
531
- if self.give_pre_end:
532
- return h
533
-
534
- h = self.norm_out(h)
535
- h = nonlinearity(h)
536
- h = self.conv_out(h)
537
- return h
538
-
539
-
540
- class VUNet(nn.Module):
541
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
- in_channels, c_channels,
544
- resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
- super().__init__()
546
- self.ch = ch
547
- self.temb_ch = self.ch*4
548
- self.num_resolutions = len(ch_mult)
549
- self.num_res_blocks = num_res_blocks
550
- self.resolution = resolution
551
-
552
- self.use_timestep = use_timestep
553
- if self.use_timestep:
554
- # timestep embedding
555
- self.temb = nn.Module()
556
- self.temb.dense = nn.ModuleList([
557
- torch.nn.Linear(self.ch,
558
- self.temb_ch),
559
- torch.nn.Linear(self.temb_ch,
560
- self.temb_ch),
561
- ])
562
-
563
- # downsampling
564
- self.conv_in = torch.nn.Conv2d(c_channels,
565
- self.ch,
566
- kernel_size=3,
567
- stride=1,
568
- padding=1)
569
-
570
- curr_res = resolution
571
- in_ch_mult = (1,)+tuple(ch_mult)
572
- self.down = nn.ModuleList()
573
- for i_level in range(self.num_resolutions):
574
- block = nn.ModuleList()
575
- attn = nn.ModuleList()
576
- block_in = ch*in_ch_mult[i_level]
577
- block_out = ch*ch_mult[i_level]
578
- for i_block in range(self.num_res_blocks):
579
- block.append(ResnetBlock(in_channels=block_in,
580
- out_channels=block_out,
581
- temb_channels=self.temb_ch,
582
- dropout=dropout))
583
- block_in = block_out
584
- if curr_res in attn_resolutions:
585
- attn.append(AttnBlock(block_in))
586
- down = nn.Module()
587
- down.block = block
588
- down.attn = attn
589
- if i_level != self.num_resolutions-1:
590
- down.downsample = Downsample(block_in, resamp_with_conv)
591
- curr_res = curr_res // 2
592
- self.down.append(down)
593
-
594
- self.z_in = torch.nn.Conv2d(z_channels,
595
- block_in,
596
- kernel_size=1,
597
- stride=1,
598
- padding=0)
599
- # middle
600
- self.mid = nn.Module()
601
- self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
- out_channels=block_in,
603
- temb_channels=self.temb_ch,
604
- dropout=dropout)
605
- self.mid.attn_1 = AttnBlock(block_in)
606
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
- out_channels=block_in,
608
- temb_channels=self.temb_ch,
609
- dropout=dropout)
610
-
611
- # upsampling
612
- self.up = nn.ModuleList()
613
- for i_level in reversed(range(self.num_resolutions)):
614
- block = nn.ModuleList()
615
- attn = nn.ModuleList()
616
- block_out = ch*ch_mult[i_level]
617
- skip_in = ch*ch_mult[i_level]
618
- for i_block in range(self.num_res_blocks+1):
619
- if i_block == self.num_res_blocks:
620
- skip_in = ch*in_ch_mult[i_level]
621
- block.append(ResnetBlock(in_channels=block_in+skip_in,
622
- out_channels=block_out,
623
- temb_channels=self.temb_ch,
624
- dropout=dropout))
625
- block_in = block_out
626
- if curr_res in attn_resolutions:
627
- attn.append(AttnBlock(block_in))
628
- up = nn.Module()
629
- up.block = block
630
- up.attn = attn
631
- if i_level != 0:
632
- up.upsample = Upsample(block_in, resamp_with_conv)
633
- curr_res = curr_res * 2
634
- self.up.insert(0, up) # prepend to get consistent order
635
-
636
- # end
637
- self.norm_out = Normalize(block_in)
638
- self.conv_out = torch.nn.Conv2d(block_in,
639
- out_ch,
640
- kernel_size=3,
641
- stride=1,
642
- padding=1)
643
-
644
-
645
- def forward(self, x, z):
646
- #assert x.shape[2] == x.shape[3] == self.resolution
647
-
648
- if self.use_timestep:
649
- # timestep embedding
650
- assert t is not None
651
- temb = get_timestep_embedding(t, self.ch)
652
- temb = self.temb.dense[0](temb)
653
- temb = nonlinearity(temb)
654
- temb = self.temb.dense[1](temb)
655
- else:
656
- temb = None
657
-
658
- # downsampling
659
- hs = [self.conv_in(x)]
660
- for i_level in range(self.num_resolutions):
661
- for i_block in range(self.num_res_blocks):
662
- h = self.down[i_level].block[i_block](hs[-1], temb)
663
- if len(self.down[i_level].attn) > 0:
664
- h = self.down[i_level].attn[i_block](h)
665
- hs.append(h)
666
- if i_level != self.num_resolutions-1:
667
- hs.append(self.down[i_level].downsample(hs[-1]))
668
-
669
- # middle
670
- h = hs[-1]
671
- z = self.z_in(z)
672
- h = torch.cat((h,z),dim=1)
673
- h = self.mid.block_1(h, temb)
674
- h = self.mid.attn_1(h)
675
- h = self.mid.block_2(h, temb)
676
-
677
- # upsampling
678
- for i_level in reversed(range(self.num_resolutions)):
679
- for i_block in range(self.num_res_blocks+1):
680
- h = self.up[i_level].block[i_block](
681
- torch.cat([h, hs.pop()], dim=1), temb)
682
- if len(self.up[i_level].attn) > 0:
683
- h = self.up[i_level].attn[i_block](h)
684
- if i_level != 0:
685
- h = self.up[i_level].upsample(h)
686
-
687
- # end
688
- h = self.norm_out(h)
689
- h = nonlinearity(h)
690
- h = self.conv_out(h)
691
- return h
692
-
693
-
694
- class SimpleDecoder(nn.Module):
695
- def __init__(self, in_channels, out_channels, *args, **kwargs):
696
- super().__init__()
697
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
- ResnetBlock(in_channels=in_channels,
699
- out_channels=2 * in_channels,
700
- temb_channels=0, dropout=0.0),
701
- ResnetBlock(in_channels=2 * in_channels,
702
- out_channels=4 * in_channels,
703
- temb_channels=0, dropout=0.0),
704
- ResnetBlock(in_channels=4 * in_channels,
705
- out_channels=2 * in_channels,
706
- temb_channels=0, dropout=0.0),
707
- nn.Conv2d(2*in_channels, in_channels, 1),
708
- Upsample(in_channels, with_conv=True)])
709
- # end
710
- self.norm_out = Normalize(in_channels)
711
- self.conv_out = torch.nn.Conv2d(in_channels,
712
- out_channels,
713
- kernel_size=3,
714
- stride=1,
715
- padding=1)
716
-
717
- def forward(self, x):
718
- for i, layer in enumerate(self.model):
719
- if i in [1,2,3]:
720
- x = layer(x, None)
721
- else:
722
- x = layer(x)
723
-
724
- h = self.norm_out(x)
725
- h = nonlinearity(h)
726
- x = self.conv_out(h)
727
- return x
728
-
729
-
730
- class UpsampleDecoder(nn.Module):
731
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
- ch_mult=(2,2), dropout=0.0):
733
- super().__init__()
734
- # upsampling
735
- self.temb_ch = 0
736
- self.num_resolutions = len(ch_mult)
737
- self.num_res_blocks = num_res_blocks
738
- block_in = in_channels
739
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
- self.res_blocks = nn.ModuleList()
741
- self.upsample_blocks = nn.ModuleList()
742
- for i_level in range(self.num_resolutions):
743
- res_block = []
744
- block_out = ch * ch_mult[i_level]
745
- for i_block in range(self.num_res_blocks + 1):
746
- res_block.append(ResnetBlock(in_channels=block_in,
747
- out_channels=block_out,
748
- temb_channels=self.temb_ch,
749
- dropout=dropout))
750
- block_in = block_out
751
- self.res_blocks.append(nn.ModuleList(res_block))
752
- if i_level != self.num_resolutions - 1:
753
- self.upsample_blocks.append(Upsample(block_in, True))
754
- curr_res = curr_res * 2
755
-
756
- # end
757
- self.norm_out = Normalize(block_in)
758
- self.conv_out = torch.nn.Conv2d(block_in,
759
- out_channels,
760
- kernel_size=3,
761
- stride=1,
762
- padding=1)
763
-
764
- def forward(self, x):
765
- # upsampling
766
- h = x
767
- for k, i_level in enumerate(range(self.num_resolutions)):
768
- for i_block in range(self.num_res_blocks + 1):
769
- h = self.res_blocks[i_level][i_block](h, None)
770
- if i_level != self.num_resolutions - 1:
771
- h = self.upsample_blocks[k](h)
772
- h = self.norm_out(h)
773
- h = nonlinearity(h)
774
- h = self.conv_out(h)
775
- return h
776
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/discriminator/model.py DELETED
@@ -1,67 +0,0 @@
1
- import functools
2
- import torch.nn as nn
3
-
4
-
5
- from models.taming.modules.util import ActNorm
6
-
7
-
8
- def weights_init(m):
9
- classname = m.__class__.__name__
10
- if classname.find('Conv') != -1:
11
- nn.init.normal_(m.weight.data, 0.0, 0.02)
12
- elif classname.find('BatchNorm') != -1:
13
- nn.init.normal_(m.weight.data, 1.0, 0.02)
14
- nn.init.constant_(m.bias.data, 0)
15
-
16
-
17
- class NLayerDiscriminator(nn.Module):
18
- """Defines a PatchGAN discriminator as in Pix2Pix
19
- --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
- """
21
- def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
- """Construct a PatchGAN discriminator
23
- Parameters:
24
- input_nc (int) -- the number of channels in input images
25
- ndf (int) -- the number of filters in the last conv layer
26
- n_layers (int) -- the number of conv layers in the discriminator
27
- norm_layer -- normalization layer
28
- """
29
- super(NLayerDiscriminator, self).__init__()
30
- if not use_actnorm:
31
- norm_layer = nn.BatchNorm2d
32
- else:
33
- norm_layer = ActNorm
34
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
- use_bias = norm_layer.func != nn.BatchNorm2d
36
- else:
37
- use_bias = norm_layer != nn.BatchNorm2d
38
-
39
- kw = 4
40
- padw = 1
41
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
- nf_mult = 1
43
- nf_mult_prev = 1
44
- for n in range(1, n_layers): # gradually increase the number of filters
45
- nf_mult_prev = nf_mult
46
- nf_mult = min(2 ** n, 8)
47
- sequence += [
48
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
- norm_layer(ndf * nf_mult),
50
- nn.LeakyReLU(0.2, True)
51
- ]
52
-
53
- nf_mult_prev = nf_mult
54
- nf_mult = min(2 ** n_layers, 8)
55
- sequence += [
56
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
- norm_layer(ndf * nf_mult),
58
- nn.LeakyReLU(0.2, True)
59
- ]
60
-
61
- sequence += [
62
- nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
- self.main = nn.Sequential(*sequence)
64
-
65
- def forward(self, input):
66
- """Standard forward."""
67
- return self.main(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/losses/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from models.taming.modules.losses.vqperceptual import DummyLoss
2
-
 
 
 
models/taming/modules/losses/lpips.py DELETED
@@ -1,123 +0,0 @@
1
- """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torchvision import models
6
- from collections import namedtuple
7
-
8
- from models.taming.util import get_ckpt_path
9
-
10
-
11
- class LPIPS(nn.Module):
12
- # Learned perceptual metric
13
- def __init__(self, use_dropout=True):
14
- super().__init__()
15
- self.scaling_layer = ScalingLayer()
16
- self.chns = [64, 128, 256, 512, 512] # vg16 features
17
- self.net = vgg16(pretrained=True, requires_grad=False)
18
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
- self.load_from_pretrained()
24
- for param in self.parameters():
25
- param.requires_grad = False
26
-
27
- def load_from_pretrained(self, name="vgg_lpips"):
28
- ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
- self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
- print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
-
32
- @classmethod
33
- def from_pretrained(cls, name="vgg_lpips"):
34
- if name != "vgg_lpips":
35
- raise NotImplementedError
36
- model = cls()
37
- ckpt = get_ckpt_path(name)
38
- model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
- return model
40
-
41
- def forward(self, input, target):
42
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
- feats0, feats1, diffs = {}, {}, {}
45
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
- for kk in range(len(self.chns)):
47
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
-
50
- res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
- val = res[0]
52
- for l in range(1, len(self.chns)):
53
- val += res[l]
54
- return val
55
-
56
-
57
- class ScalingLayer(nn.Module):
58
- def __init__(self):
59
- super(ScalingLayer, self).__init__()
60
- self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
- self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
-
63
- def forward(self, inp):
64
- return (inp - self.shift) / self.scale
65
-
66
-
67
- class NetLinLayer(nn.Module):
68
- """ A single linear layer which does a 1x1 conv """
69
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
- super(NetLinLayer, self).__init__()
71
- layers = [nn.Dropout(), ] if (use_dropout) else []
72
- layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
- self.model = nn.Sequential(*layers)
74
-
75
-
76
- class vgg16(torch.nn.Module):
77
- def __init__(self, requires_grad=False, pretrained=True):
78
- super(vgg16, self).__init__()
79
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
- self.slice1 = torch.nn.Sequential()
81
- self.slice2 = torch.nn.Sequential()
82
- self.slice3 = torch.nn.Sequential()
83
- self.slice4 = torch.nn.Sequential()
84
- self.slice5 = torch.nn.Sequential()
85
- self.N_slices = 5
86
- for x in range(4):
87
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
- for x in range(4, 9):
89
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
- for x in range(9, 16):
91
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
- for x in range(16, 23):
93
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
- for x in range(23, 30):
95
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
- if not requires_grad:
97
- for param in self.parameters():
98
- param.requires_grad = False
99
-
100
- def forward(self, X):
101
- h = self.slice1(X)
102
- h_relu1_2 = h
103
- h = self.slice2(h)
104
- h_relu2_2 = h
105
- h = self.slice3(h)
106
- h_relu3_3 = h
107
- h = self.slice4(h)
108
- h_relu4_3 = h
109
- h = self.slice5(h)
110
- h_relu5_3 = h
111
- vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
- return out
114
-
115
-
116
- def normalize_tensor(x,eps=1e-10):
117
- norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
- return x/(norm_factor+eps)
119
-
120
-
121
- def spatial_average(x, keepdim=True):
122
- return x.mean([2,3],keepdim=keepdim)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/losses/segmentation.py DELETED
@@ -1,22 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
-
4
-
5
- class BCELoss(nn.Module):
6
- def forward(self, prediction, target):
7
- loss = F.binary_cross_entropy_with_logits(prediction,target)
8
- return loss, {}
9
-
10
-
11
- class BCELossWithQuant(nn.Module):
12
- def __init__(self, codebook_weight=1.):
13
- super().__init__()
14
- self.codebook_weight = codebook_weight
15
-
16
- def forward(self, qloss, target, prediction, split):
17
- bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
- loss = bce_loss + self.codebook_weight*qloss
19
- return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
- "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
- "{}/quant_loss".format(split): qloss.detach().mean()
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/losses/vqperceptual.py DELETED
@@ -1,136 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from models.taming.modules.losses.lpips import LPIPS
6
- from models.taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
-
8
-
9
- class DummyLoss(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
-
13
-
14
- def adopt_weight(weight, global_step, threshold=0, value=0.):
15
- if global_step < threshold:
16
- weight = value
17
- return weight
18
-
19
-
20
- def hinge_d_loss(logits_real, logits_fake):
21
- loss_real = torch.mean(F.relu(1. - logits_real))
22
- loss_fake = torch.mean(F.relu(1. + logits_fake))
23
- d_loss = 0.5 * (loss_real + loss_fake)
24
- return d_loss
25
-
26
-
27
- def vanilla_d_loss(logits_real, logits_fake):
28
- d_loss = 0.5 * (
29
- torch.mean(torch.nn.functional.softplus(-logits_real)) +
30
- torch.mean(torch.nn.functional.softplus(logits_fake)))
31
- return d_loss
32
-
33
-
34
- class VQLPIPSWithDiscriminator(nn.Module):
35
- def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36
- disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37
- perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38
- disc_ndf=64, disc_loss="hinge"):
39
- super().__init__()
40
- assert disc_loss in ["hinge", "vanilla"]
41
- self.codebook_weight = codebook_weight
42
- self.pixel_weight = pixelloss_weight
43
- self.perceptual_loss = LPIPS().eval()
44
- self.perceptual_weight = perceptual_weight
45
-
46
- self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47
- n_layers=disc_num_layers,
48
- use_actnorm=use_actnorm,
49
- ndf=disc_ndf
50
- ).apply(weights_init)
51
- self.discriminator_iter_start = disc_start
52
- if disc_loss == "hinge":
53
- self.disc_loss = hinge_d_loss
54
- elif disc_loss == "vanilla":
55
- self.disc_loss = vanilla_d_loss
56
- else:
57
- raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58
- print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59
- self.disc_factor = disc_factor
60
- self.discriminator_weight = disc_weight
61
- self.disc_conditional = disc_conditional
62
-
63
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64
- if last_layer is not None:
65
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67
- else:
68
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70
-
71
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73
- d_weight = d_weight * self.discriminator_weight
74
- return d_weight
75
-
76
- def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77
- global_step, last_layer=None, cond=None, split="train"):
78
- rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79
- if self.perceptual_weight > 0:
80
- p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81
- rec_loss = rec_loss + self.perceptual_weight * p_loss
82
- else:
83
- p_loss = torch.tensor([0.0])
84
-
85
- nll_loss = rec_loss
86
- #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87
- nll_loss = torch.mean(nll_loss)
88
-
89
- # now the GAN part
90
- if optimizer_idx == 0:
91
- # generator update
92
- if cond is None:
93
- assert not self.disc_conditional
94
- logits_fake = self.discriminator(reconstructions.contiguous())
95
- else:
96
- assert self.disc_conditional
97
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98
- g_loss = -torch.mean(logits_fake)
99
-
100
- try:
101
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102
- except RuntimeError:
103
- assert not self.training
104
- d_weight = torch.tensor(0.0)
105
-
106
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107
- loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108
-
109
- log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110
- "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
112
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
113
- "{}/p_loss".format(split): p_loss.detach().mean(),
114
- "{}/d_weight".format(split): d_weight.detach(),
115
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
116
- "{}/g_loss".format(split): g_loss.detach().mean(),
117
- }
118
- return loss, log
119
-
120
- if optimizer_idx == 1:
121
- # second pass for discriminator update
122
- if cond is None:
123
- logits_real = self.discriminator(inputs.contiguous().detach())
124
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
125
- else:
126
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128
-
129
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131
-
132
- log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133
- "{}/logits_real".format(split): logits_real.detach().mean(),
134
- "{}/logits_fake".format(split): logits_fake.detach().mean()
135
- }
136
- return d_loss, log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/misc/coord.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
-
3
- class CoordStage(object):
4
- def __init__(self, n_embed, down_factor):
5
- self.n_embed = n_embed
6
- self.down_factor = down_factor
7
-
8
- def eval(self):
9
- return self
10
-
11
- def encode(self, c):
12
- """fake vqmodel interface"""
13
- assert 0.0 <= c.min() and c.max() <= 1.0
14
- b,ch,h,w = c.shape
15
- assert ch == 1
16
-
17
- c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
- mode="area")
19
- c = c.clamp(0.0, 1.0)
20
- c = self.n_embed*c
21
- c_quant = c.round()
22
- c_ind = c_quant.to(dtype=torch.long)
23
-
24
- info = None, None, c_ind
25
- return c_quant, None, info
26
-
27
- def decode(self, c):
28
- c = c/self.n_embed
29
- c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
- mode="nearest")
31
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/util.py DELETED
@@ -1,130 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- def count_params(model):
6
- total_params = sum(p.numel() for p in model.parameters())
7
- return total_params
8
-
9
-
10
- class ActNorm(nn.Module):
11
- def __init__(self, num_features, logdet=False, affine=True,
12
- allow_reverse_init=False):
13
- assert affine
14
- super().__init__()
15
- self.logdet = logdet
16
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
- self.allow_reverse_init = allow_reverse_init
19
-
20
- self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
-
22
- def initialize(self, input):
23
- with torch.no_grad():
24
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
- mean = (
26
- flatten.mean(1)
27
- .unsqueeze(1)
28
- .unsqueeze(2)
29
- .unsqueeze(3)
30
- .permute(1, 0, 2, 3)
31
- )
32
- std = (
33
- flatten.std(1)
34
- .unsqueeze(1)
35
- .unsqueeze(2)
36
- .unsqueeze(3)
37
- .permute(1, 0, 2, 3)
38
- )
39
-
40
- self.loc.data.copy_(-mean)
41
- self.scale.data.copy_(1 / (std + 1e-6))
42
-
43
- def forward(self, input, reverse=False):
44
- if reverse:
45
- return self.reverse(input)
46
- if len(input.shape) == 2:
47
- input = input[:,:,None,None]
48
- squeeze = True
49
- else:
50
- squeeze = False
51
-
52
- _, _, height, width = input.shape
53
-
54
- if self.training and self.initialized.item() == 0:
55
- self.initialize(input)
56
- self.initialized.fill_(1)
57
-
58
- h = self.scale * (input + self.loc)
59
-
60
- if squeeze:
61
- h = h.squeeze(-1).squeeze(-1)
62
-
63
- if self.logdet:
64
- log_abs = torch.log(torch.abs(self.scale))
65
- logdet = height*width*torch.sum(log_abs)
66
- logdet = logdet * torch.ones(input.shape[0]).to(input)
67
- return h, logdet
68
-
69
- return h
70
-
71
- def reverse(self, output):
72
- if self.training and self.initialized.item() == 0:
73
- if not self.allow_reverse_init:
74
- raise RuntimeError(
75
- "Initializing ActNorm in reverse direction is "
76
- "disabled by default. Use allow_reverse_init=True to enable."
77
- )
78
- else:
79
- self.initialize(output)
80
- self.initialized.fill_(1)
81
-
82
- if len(output.shape) == 2:
83
- output = output[:,:,None,None]
84
- squeeze = True
85
- else:
86
- squeeze = False
87
-
88
- h = output / self.scale - self.loc
89
-
90
- if squeeze:
91
- h = h.squeeze(-1).squeeze(-1)
92
- return h
93
-
94
-
95
- class AbstractEncoder(nn.Module):
96
- def __init__(self):
97
- super().__init__()
98
-
99
- def encode(self, *args, **kwargs):
100
- raise NotImplementedError
101
-
102
-
103
- class Labelator(AbstractEncoder):
104
- """Net2Net Interface for Class-Conditional Model"""
105
- def __init__(self, n_classes, quantize_interface=True):
106
- super().__init__()
107
- self.n_classes = n_classes
108
- self.quantize_interface = quantize_interface
109
-
110
- def encode(self, c):
111
- c = c[:,None]
112
- if self.quantize_interface:
113
- return c, None, [None, None, c.long()]
114
- return c
115
-
116
-
117
- class SOSProvider(AbstractEncoder):
118
- # for unconditional training
119
- def __init__(self, sos_token, quantize_interface=True):
120
- super().__init__()
121
- self.sos_token = sos_token
122
- self.quantize_interface = quantize_interface
123
-
124
- def encode(self, x):
125
- # get batch size from data and replicate sos_token
126
- c = torch.ones(x.shape[0], 1)*self.sos_token
127
- c = c.long().to(x.device)
128
- if self.quantize_interface:
129
- return c, None, [None, None, c]
130
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/modules/vqvae/quantize.py DELETED
@@ -1,445 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from torch import einsum
6
- from einops import rearrange
7
-
8
-
9
- class VectorQuantizer(nn.Module):
10
- """
11
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
- ____________________________________________
13
- Discretization bottleneck part of the VQ-VAE.
14
- Inputs:
15
- - n_e : number of embeddings
16
- - e_dim : dimension of embedding
17
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
- _____________________________________________
19
- """
20
-
21
- # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
- # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
- # used wherever VectorQuantizer has been used before and is additionally
24
- # more efficient.
25
- def __init__(self, n_e, e_dim, beta):
26
- super(VectorQuantizer, self).__init__()
27
- self.n_e = n_e
28
- self.e_dim = e_dim
29
- self.beta = beta
30
-
31
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
-
34
- def forward(self, z):
35
- """
36
- Inputs the output of the encoder network z and maps it to a discrete
37
- one-hot vector that is the index of the closest embedding vector e_j
38
- z (continuous) -> z_q (discrete)
39
- z.shape = (batch, channel, height, width)
40
- quantization pipeline:
41
- 1. get encoder input (B,C,H,W)
42
- 2. flatten input to (B*H*W,C)
43
- """
44
- # reshape z -> (batch, height, width, channel) and flatten
45
- z = z.permute(0, 2, 3, 1).contiguous()
46
- z_flattened = z.view(-1, self.e_dim)
47
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
-
49
- d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
- torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
- torch.matmul(z_flattened, self.embedding.weight.t())
52
-
53
- ## could possible replace this here
54
- # #\start...
55
- # find closest encodings
56
- min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
-
58
- min_encodings = torch.zeros(
59
- min_encoding_indices.shape[0], self.n_e).to(z)
60
- min_encodings.scatter_(1, min_encoding_indices, 1)
61
-
62
- # dtype min encodings: torch.float32
63
- # min_encodings shape: torch.Size([2048, 512])
64
- # min_encoding_indices.shape: torch.Size([2048, 1])
65
-
66
- # get quantized latent vectors
67
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
- #.........\end
69
-
70
- # with:
71
- # .........\start
72
- #min_encoding_indices = torch.argmin(d, dim=1)
73
- #z_q = self.embedding(min_encoding_indices)
74
- # ......\end......... (TODO)
75
-
76
- # compute loss for embedding
77
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
- torch.mean((z_q - z.detach()) ** 2)
79
-
80
- # preserve gradients
81
- z_q = z + (z_q - z).detach()
82
-
83
- # perplexity
84
- e_mean = torch.mean(min_encodings, dim=0)
85
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
-
87
- # reshape back to match original input shape
88
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
-
90
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
-
92
- def get_codebook_entry(self, indices, shape):
93
- # shape specifying (batch, height, width, channel)
94
- # TODO: check for more easy handling with nn.Embedding
95
- min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
- min_encodings.scatter_(1, indices[:,None], 1)
97
-
98
- # get quantized latent vectors
99
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
-
101
- if shape is not None:
102
- z_q = z_q.view(shape)
103
-
104
- # reshape back to match original input shape
105
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
-
107
- return z_q
108
-
109
-
110
- class GumbelQuantize(nn.Module):
111
- """
112
- credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
- Gumbel Softmax trick quantizer
114
- Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
- https://arxiv.org/abs/1611.01144
116
- """
117
- def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
- kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
- remap=None, unknown_index="random"):
120
- super().__init__()
121
-
122
- self.embedding_dim = embedding_dim
123
- self.n_embed = n_embed
124
-
125
- self.straight_through = straight_through
126
- self.temperature = temp_init
127
- self.kl_weight = kl_weight
128
-
129
- self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
- self.embed = nn.Embedding(n_embed, embedding_dim)
131
-
132
- self.use_vqinterface = use_vqinterface
133
-
134
- self.remap = remap
135
- if self.remap is not None:
136
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
- self.re_embed = self.used.shape[0]
138
- self.unknown_index = unknown_index # "random" or "extra" or integer
139
- if self.unknown_index == "extra":
140
- self.unknown_index = self.re_embed
141
- self.re_embed = self.re_embed+1
142
- print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
- f"Using {self.unknown_index} for unknown indices.")
144
- else:
145
- self.re_embed = n_embed
146
-
147
- def remap_to_used(self, inds):
148
- ishape = inds.shape
149
- assert len(ishape)>1
150
- inds = inds.reshape(ishape[0],-1)
151
- used = self.used.to(inds)
152
- match = (inds[:,:,None]==used[None,None,...]).long()
153
- new = match.argmax(-1)
154
- unknown = match.sum(2)<1
155
- if self.unknown_index == "random":
156
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
- else:
158
- new[unknown] = self.unknown_index
159
- return new.reshape(ishape)
160
-
161
- def unmap_to_all(self, inds):
162
- ishape = inds.shape
163
- assert len(ishape)>1
164
- inds = inds.reshape(ishape[0],-1)
165
- used = self.used.to(inds)
166
- if self.re_embed > self.used.shape[0]: # extra token
167
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
- return back.reshape(ishape)
170
-
171
- def forward(self, z, temp=None, return_logits=False):
172
- # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
- hard = self.straight_through if self.training else True
174
- temp = self.temperature if temp is None else temp
175
-
176
- logits = self.proj(z)
177
- if self.remap is not None:
178
- # continue only with used logits
179
- full_zeros = torch.zeros_like(logits)
180
- logits = logits[:,self.used,...]
181
-
182
- soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
- if self.remap is not None:
184
- # go back to all entries but unused set to zero
185
- full_zeros[:,self.used,...] = soft_one_hot
186
- soft_one_hot = full_zeros
187
- z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
-
189
- # + kl divergence to the prior loss
190
- qy = F.softmax(logits, dim=1)
191
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
-
193
- ind = soft_one_hot.argmax(dim=1)
194
- if self.remap is not None:
195
- ind = self.remap_to_used(ind)
196
- if self.use_vqinterface:
197
- if return_logits:
198
- return z_q, diff, (None, None, ind), logits
199
- return z_q, diff, (None, None, ind)
200
- return z_q, diff, ind
201
-
202
- def get_codebook_entry(self, indices, shape):
203
- b, h, w, c = shape
204
- assert b*h*w == indices.shape[0]
205
- indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
- if self.remap is not None:
207
- indices = self.unmap_to_all(indices)
208
- one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
- z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
- return z_q
211
-
212
-
213
- class VectorQuantizer2(nn.Module):
214
- """
215
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
- avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
- """
218
- # NOTE: due to a bug the beta term was applied to the wrong term. for
219
- # backwards compatibility we use the buggy version by default, but you can
220
- # specify legacy=False to fix it.
221
- def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
- sane_index_shape=False, legacy=True):
223
- super().__init__()
224
- self.n_e = n_e
225
- self.e_dim = e_dim
226
- self.beta = beta
227
- self.legacy = legacy
228
-
229
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
-
232
- self.remap = remap
233
- if self.remap is not None:
234
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
- self.re_embed = self.used.shape[0]
236
- self.unknown_index = unknown_index # "random" or "extra" or integer
237
- if self.unknown_index == "extra":
238
- self.unknown_index = self.re_embed
239
- self.re_embed = self.re_embed+1
240
- print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
- f"Using {self.unknown_index} for unknown indices.")
242
- else:
243
- self.re_embed = n_e
244
-
245
- self.sane_index_shape = sane_index_shape
246
-
247
- def remap_to_used(self, inds):
248
- ishape = inds.shape
249
- assert len(ishape)>1
250
- inds = inds.reshape(ishape[0],-1)
251
- used = self.used.to(inds)
252
- match = (inds[:,:,None]==used[None,None,...]).long()
253
- new = match.argmax(-1)
254
- unknown = match.sum(2)<1
255
- if self.unknown_index == "random":
256
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
- else:
258
- new[unknown] = self.unknown_index
259
- return new.reshape(ishape)
260
-
261
- def unmap_to_all(self, inds):
262
- ishape = inds.shape
263
- assert len(ishape)>1
264
- inds = inds.reshape(ishape[0],-1)
265
- used = self.used.to(inds)
266
- if self.re_embed > self.used.shape[0]: # extra token
267
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
- return back.reshape(ishape)
270
-
271
- def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
- assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
- assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
- assert return_logits==False, "Only for interface compatible with Gumbel"
275
- # reshape z -> (batch, height, width, channel) and flatten
276
- z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
- z_flattened = z.view(-1, self.e_dim)
278
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
-
280
- d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
- torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
- torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
-
284
- min_encoding_indices = torch.argmin(d, dim=1)
285
- z_q = self.embedding(min_encoding_indices).view(z.shape)
286
- perplexity = None
287
- min_encodings = None
288
-
289
- # compute loss for embedding
290
- if not self.legacy:
291
- loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
- torch.mean((z_q - z.detach()) ** 2)
293
- else:
294
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
- torch.mean((z_q - z.detach()) ** 2)
296
-
297
- # preserve gradients
298
- z_q = z + (z_q - z).detach()
299
-
300
- # reshape back to match original input shape
301
- z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
-
303
- if self.remap is not None:
304
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
- min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
-
308
- if self.sane_index_shape:
309
- min_encoding_indices = min_encoding_indices.reshape(
310
- z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
-
312
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
-
314
- def get_codebook_entry(self, indices, shape):
315
- # shape specifying (batch, height, width, channel)
316
- if self.remap is not None:
317
- indices = indices.reshape(shape[0],-1) # add batch axis
318
- indices = self.unmap_to_all(indices)
319
- indices = indices.reshape(-1) # flatten again
320
-
321
- # get quantized latent vectors
322
- z_q = self.embedding(indices)
323
-
324
- if shape is not None:
325
- z_q = z_q.view(shape)
326
- # reshape back to match original input shape
327
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
-
329
- return z_q
330
-
331
- class EmbeddingEMA(nn.Module):
332
- def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
- super().__init__()
334
- self.decay = decay
335
- self.eps = eps
336
- weight = torch.randn(num_tokens, codebook_dim)
337
- self.weight = nn.Parameter(weight, requires_grad = False)
338
- self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
- self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
- self.update = True
341
-
342
- def forward(self, embed_id):
343
- return F.embedding(embed_id, self.weight)
344
-
345
- def cluster_size_ema_update(self, new_cluster_size):
346
- self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
-
348
- def embed_avg_ema_update(self, new_embed_avg):
349
- self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
-
351
- def weight_update(self, num_tokens):
352
- n = self.cluster_size.sum()
353
- smoothed_cluster_size = (
354
- (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
- )
356
- #normalize embedding average with smoothed cluster size
357
- embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
- self.weight.data.copy_(embed_normalized)
359
-
360
-
361
- class EMAVectorQuantizer(nn.Module):
362
- def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
- remap=None, unknown_index="random"):
364
- super().__init__()
365
- self.codebook_dim = codebook_dim
366
- self.num_tokens = num_tokens
367
- self.beta = beta
368
- self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
-
370
- self.remap = remap
371
- if self.remap is not None:
372
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
- self.re_embed = self.used.shape[0]
374
- self.unknown_index = unknown_index # "random" or "extra" or integer
375
- if self.unknown_index == "extra":
376
- self.unknown_index = self.re_embed
377
- self.re_embed = self.re_embed+1
378
- print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
- f"Using {self.unknown_index} for unknown indices.")
380
- else:
381
- self.re_embed = n_embed
382
-
383
- def remap_to_used(self, inds):
384
- ishape = inds.shape
385
- assert len(ishape)>1
386
- inds = inds.reshape(ishape[0],-1)
387
- used = self.used.to(inds)
388
- match = (inds[:,:,None]==used[None,None,...]).long()
389
- new = match.argmax(-1)
390
- unknown = match.sum(2)<1
391
- if self.unknown_index == "random":
392
- new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
- else:
394
- new[unknown] = self.unknown_index
395
- return new.reshape(ishape)
396
-
397
- def unmap_to_all(self, inds):
398
- ishape = inds.shape
399
- assert len(ishape)>1
400
- inds = inds.reshape(ishape[0],-1)
401
- used = self.used.to(inds)
402
- if self.re_embed > self.used.shape[0]: # extra token
403
- inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
- back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
- return back.reshape(ishape)
406
-
407
- def forward(self, z):
408
- # reshape z -> (batch, height, width, channel) and flatten
409
- #z, 'b c h w -> b h w c'
410
- z = rearrange(z, 'b c h w -> b h w c')
411
- z_flattened = z.reshape(-1, self.codebook_dim)
412
-
413
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
- d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
- self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
- torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
-
418
-
419
- encoding_indices = torch.argmin(d, dim=1)
420
-
421
- z_q = self.embedding(encoding_indices).view(z.shape)
422
- encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
- avg_probs = torch.mean(encodings, dim=0)
424
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
-
426
- if self.training and self.embedding.update:
427
- #EMA cluster size
428
- encodings_sum = encodings.sum(0)
429
- self.embedding.cluster_size_ema_update(encodings_sum)
430
- #EMA embedding average
431
- embed_sum = encodings.transpose(0,1) @ z_flattened
432
- self.embedding.embed_avg_ema_update(embed_sum)
433
- #normalize embed_avg and update weight
434
- self.embedding.weight_update(self.num_tokens)
435
-
436
- # compute loss for embedding
437
- loss = self.beta * F.mse_loss(z_q.detach(), z)
438
-
439
- # preserve gradients
440
- z_q = z + (z_q - z).detach()
441
-
442
- # reshape back to match original input shape
443
- #z_q, 'b h w c -> b c h w'
444
- z_q = rearrange(z_q, 'b h w c -> b c h w')
445
- return z_q, loss, (perplexity, encodings, encoding_indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/taming/util.py DELETED
@@ -1,172 +0,0 @@
1
- import os, hashlib
2
- import requests
3
- from tqdm import tqdm
4
- import importlib
5
-
6
- URL_MAP = {
7
- "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
8
- }
9
-
10
- CKPT_MAP = {
11
- "vgg_lpips": "vgg.pth"
12
- }
13
-
14
- MD5_MAP = {
15
- "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
16
- }
17
-
18
-
19
- def get_obj_from_str(string, reload=False):
20
- module, cls = string.rsplit(".", 1)
21
- if reload:
22
- module_imp = importlib.import_module(module)
23
- importlib.reload(module_imp)
24
- return getattr(importlib.import_module(module, package=None), cls)
25
-
26
-
27
- def instantiate_from_config(config):
28
- if not "target" in config:
29
- raise KeyError("Expected key `target` to instantiate.")
30
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
31
-
32
-
33
- def download(url, local_path, chunk_size=1024):
34
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
35
- with requests.get(url, stream=True) as r:
36
- total_size = int(r.headers.get("content-length", 0))
37
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
38
- with open(local_path, "wb") as f:
39
- for data in r.iter_content(chunk_size=chunk_size):
40
- if data:
41
- f.write(data)
42
- pbar.update(chunk_size)
43
-
44
-
45
- def md5_hash(path):
46
- with open(path, "rb") as f:
47
- content = f.read()
48
- return hashlib.md5(content).hexdigest()
49
-
50
-
51
- def get_ckpt_path(name, root, check=False):
52
- assert name in URL_MAP
53
- path = os.path.join(root, CKPT_MAP[name])
54
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
55
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
56
- download(URL_MAP[name], path)
57
- md5 = md5_hash(path)
58
- assert md5 == MD5_MAP[name], md5
59
- return path
60
-
61
-
62
- class KeyNotFoundError(Exception):
63
- def __init__(self, cause, keys=None, visited=None):
64
- self.cause = cause
65
- self.keys = keys
66
- self.visited = visited
67
- messages = list()
68
- if keys is not None:
69
- messages.append("Key not found: {}".format(keys))
70
- if visited is not None:
71
- messages.append("Visited: {}".format(visited))
72
- messages.append("Cause:\n{}".format(cause))
73
- message = "\n".join(messages)
74
- super().__init__(message)
75
-
76
-
77
- def retrieve(
78
- list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
79
- ):
80
- """Given a nested list or dict return the desired value at key expanding
81
- callable nodes if necessary and :attr:`expand` is ``True``. The expansion
82
- is done in-place.
83
-
84
- Parameters
85
- ----------
86
- list_or_dict : list or dict
87
- Possibly nested list or dictionary.
88
- key : str
89
- key/to/value, path like string describing all keys necessary to
90
- consider to get to the desired value. List indices can also be
91
- passed here.
92
- splitval : str
93
- String that defines the delimiter between keys of the
94
- different depth levels in `key`.
95
- default : obj
96
- Value returned if :attr:`key` is not found.
97
- expand : bool
98
- Whether to expand callable nodes on the path or not.
99
-
100
- Returns
101
- -------
102
- The desired value or if :attr:`default` is not ``None`` and the
103
- :attr:`key` is not found returns ``default``.
104
-
105
- Raises
106
- ------
107
- Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
108
- ``None``.
109
- """
110
-
111
- keys = key.split(splitval)
112
-
113
- success = True
114
- try:
115
- visited = []
116
- parent = None
117
- last_key = None
118
- for key in keys:
119
- if callable(list_or_dict):
120
- if not expand:
121
- raise KeyNotFoundError(
122
- ValueError(
123
- "Trying to get past callable node with expand=False."
124
- ),
125
- keys=keys,
126
- visited=visited,
127
- )
128
- list_or_dict = list_or_dict()
129
- parent[last_key] = list_or_dict
130
-
131
- last_key = key
132
- parent = list_or_dict
133
-
134
- try:
135
- if isinstance(list_or_dict, dict):
136
- list_or_dict = list_or_dict[key]
137
- else:
138
- list_or_dict = list_or_dict[int(key)]
139
- except (KeyError, IndexError, ValueError) as e:
140
- raise KeyNotFoundError(e, keys=keys, visited=visited)
141
-
142
- visited += [key]
143
- # final expansion of retrieved value
144
- if expand and callable(list_or_dict):
145
- list_or_dict = list_or_dict()
146
- parent[last_key] = list_or_dict
147
- except KeyNotFoundError as e:
148
- if default is None:
149
- raise e
150
- else:
151
- list_or_dict = default
152
- success = False
153
-
154
- if not pass_success:
155
- return list_or_dict
156
- else:
157
- return list_or_dict, success
158
-
159
-
160
- if __name__ == "__main__":
161
- config = {"keya": "a",
162
- "keyb": "b",
163
- "keyc":
164
- {"cc1": 1,
165
- "cc2": 2,
166
- }
167
- }
168
- from omegaconf import OmegaConf
169
-
170
- config = OmegaConf.create(config)
171
- print(config)
172
- retrieve(config, "keya")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/eval_utils.py CHANGED
@@ -63,7 +63,7 @@ def eval_ocr(task, generator, models, sample, **kwargs):
63
 
64
 
65
  def eval_step(task, generator, models, sample, **kwargs):
66
- if task.cfg._name == "ocr":
67
  return eval_ocr(task, generator, models, sample, **kwargs)
68
  else:
69
  raise NotImplementedError
 
63
 
64
 
65
  def eval_step(task, generator, models, sample, **kwargs):
66
+ if task.cfg._name == "ocr":
67
  return eval_ocr(task, generator, models, sample, **kwargs)
68
  else:
69
  raise NotImplementedError