boris commited on
Commit
b88a523
1 Parent(s): 31da1e5

chore: remove duplicate files

Browse files

Former-commit-id: c5805ff37276abff87e8ccb0ab756a8d1f3b0bf3

app/dalle_mini/__init__.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.0.1"
 
 
app/dalle_mini/dataset.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- An image-caption dataset dataloader.
3
- Luke Melas-Kyriazi, 2021
4
- """
5
- import warnings
6
- from typing import Optional, Callable
7
- from pathlib import Path
8
- import numpy as np
9
- import torch
10
- import pandas as pd
11
- from torch.utils.data import Dataset
12
- from torchvision.datasets.folder import default_loader
13
- from PIL import ImageFile
14
- from PIL.Image import DecompressionBombWarning
15
- ImageFile.LOAD_TRUNCATED_IMAGES = True
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
-
19
-
20
- class CaptionDataset(Dataset):
21
- """
22
- A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
- returns the raw text rather than tokens. This is done on purpose, because
24
- it's easy to tokenize a batch of text after loading it from this dataset.
25
- """
26
-
27
- def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
- image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
- include_captions: bool = True):
30
- """
31
- :param images_root: folder where images are stored
32
- :param captions_path: path to csv that maps image filenames to captions
33
- :param image_transform: image transform pipeline
34
- :param text_transform: image transform pipeline
35
- :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
- :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
- """
38
-
39
- # Base path for images
40
- self.images_root = Path(images_root)
41
-
42
- # Load captions as DataFrame
43
- self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
- self.captions['image_file'] = self.captions['image_file'].astype(str)
45
-
46
- # PyTorch transformation pipeline for the image (normalizing, etc.)
47
- self.text_transform = text_transform
48
- self.image_transform = image_transform
49
- self.image_transform_type = image_transform_type.lower()
50
- assert self.image_transform_type in ['torchvision', 'albumentations']
51
-
52
- # Total number of datapoints
53
- self.size = len(self.captions)
54
-
55
- # Return image+captions or just images
56
- self.include_captions = include_captions
57
-
58
- def verify_that_all_images_exist(self):
59
- for image_file in self.captions['image_file']:
60
- p = self.images_root / image_file
61
- if not p.is_file():
62
- print(f'file does not exist: {p}')
63
-
64
- def _get_raw_image(self, i):
65
- image_file = self.captions.iloc[i]['image_file']
66
- image_path = self.images_root / image_file
67
- image = default_loader(image_path)
68
- return image
69
-
70
- def _get_raw_text(self, i):
71
- return self.captions.iloc[i]['caption']
72
-
73
- def __getitem__(self, i):
74
- image = self._get_raw_image(i)
75
- caption = self._get_raw_text(i)
76
- if self.image_transform is not None:
77
- if self.image_transform_type == 'torchvision':
78
- image = self.image_transform(image)
79
- elif self.image_transform_type == 'albumentations':
80
- image = self.image_transform(image=np.array(image))['image']
81
- else:
82
- raise NotImplementedError(f"{self.image_transform_type=}")
83
- return {'image': image, 'text': caption} if self.include_captions else image
84
-
85
- def __len__(self):
86
- return self.size
87
-
88
-
89
- if __name__ == "__main__":
90
- import albumentations as A
91
- from albumentations.pytorch import ToTensorV2
92
- from transformers import AutoTokenizer
93
-
94
- # Paths
95
- images_root = './images'
96
- captions_path = './images-list-clean.tsv'
97
-
98
- # Create transforms
99
- tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
- def tokenize(text):
101
- return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
- image_transform = A.Compose([
103
- A.Resize(256, 256), A.CenterCrop(256, 256),
104
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
-
106
- # Create dataset
107
- dataset = CaptionDataset(
108
- images_root=images_root,
109
- captions_path=captions_path,
110
- image_transform=image_transform,
111
- text_transform=tokenize,
112
- image_transform_type='albumentations')
113
-
114
- # Create dataloader
115
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
- batch = next(iter(dataloader))
117
- print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
-
119
- # # (Optional) Check that all the images exist
120
- # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
- # dataset.verify_that_all_images_exist()
122
- # print('Done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/__init__.py DELETED
File without changes
app/dalle_mini/vqgan_jax/configuration_vqgan.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Tuple
2
-
3
- from transformers import PretrainedConfig
4
-
5
-
6
- class VQGANConfig(PretrainedConfig):
7
- def __init__(
8
- self,
9
- ch: int = 128,
10
- out_ch: int = 3,
11
- in_channels: int = 3,
12
- num_res_blocks: int = 2,
13
- resolution: int = 256,
14
- z_channels: int = 256,
15
- ch_mult: Tuple = (1, 1, 2, 2, 4),
16
- attn_resolutions: int = (16,),
17
- n_embed: int = 1024,
18
- embed_dim: int = 256,
19
- dropout: float = 0.0,
20
- double_z: bool = False,
21
- resamp_with_conv: bool = True,
22
- give_pre_end: bool = False,
23
- **kwargs,
24
- ):
25
- super().__init__(**kwargs)
26
- self.ch = ch
27
- self.out_ch = out_ch
28
- self.in_channels = in_channels
29
- self.num_res_blocks = num_res_blocks
30
- self.resolution = resolution
31
- self.z_channels = z_channels
32
- self.ch_mult = list(ch_mult)
33
- self.attn_resolutions = list(attn_resolutions)
34
- self.n_embed = n_embed
35
- self.embed_dim = embed_dim
36
- self.dropout = dropout
37
- self.double_z = double_z
38
- self.resamp_with_conv = resamp_with_conv
39
- self.give_pre_end = give_pre_end
40
- self.num_resolutions = len(ch_mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py DELETED
@@ -1,109 +0,0 @@
1
- import re
2
-
3
- import jax.numpy as jnp
4
- from flax.traverse_util import flatten_dict, unflatten_dict
5
-
6
- import torch
7
-
8
- from modeling_flax_vqgan import VQModel
9
- from configuration_vqgan import VQGANConfig
10
-
11
-
12
- regex = r"\w+[.]\d+"
13
-
14
-
15
- def rename_key(key):
16
- pats = re.findall(regex, key)
17
- for pat in pats:
18
- key = key.replace(pat, "_".join(pat.split(".")))
19
- return key
20
-
21
-
22
- # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
- # convert pytorch tensor to numpy
25
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
-
27
- random_flax_state_dict = flatten_dict(flax_model.params)
28
- flax_state_dict = {}
29
-
30
- remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
- flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
- )
33
- add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
- flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
- )
36
-
37
- # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
- for pt_key, pt_tensor in pt_state_dict.items():
39
- pt_tuple_key = tuple(pt_key.split("."))
40
-
41
- has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
- require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
-
44
- if remove_base_model_prefix and has_base_model_prefix:
45
- pt_tuple_key = pt_tuple_key[1:]
46
- elif add_base_model_prefix and require_base_model_prefix:
47
- pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
-
49
- # Correctly rename weight parameters
50
- if (
51
- "norm" in pt_key
52
- and (pt_tuple_key[-1] == "bias")
53
- and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
- ):
55
- pt_tensor = pt_tensor[None, None, None, :]
56
- elif (
57
- "norm" in pt_key
58
- and (pt_tuple_key[-1] == "bias")
59
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
- ):
61
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
- pt_tensor = pt_tensor[None, None, None, :]
63
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
- pt_tensor = pt_tensor[None, None, None, :]
66
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
- elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
- # conv layer
70
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
- elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
- # linear layer
74
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
- pt_tensor = pt_tensor.T
76
- elif pt_tuple_key[-1] == "gamma":
77
- pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
- elif pt_tuple_key[-1] == "beta":
79
- pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
-
81
- if pt_tuple_key in random_flax_state_dict:
82
- if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
- raise ValueError(
84
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
- f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
- )
87
-
88
- # also add unexpected weight so that warning is thrown
89
- flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
-
91
- return unflatten_dict(flax_state_dict)
92
-
93
-
94
- def convert_model(config_path, pt_state_dict_path, save_path):
95
- config = VQGANConfig.from_pretrained(config_path)
96
- model = VQModel(config)
97
-
98
- state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
- keys = list(state_dict.keys())
100
- for key in keys:
101
- if key.startswith("loss"):
102
- state_dict.pop(key)
103
- continue
104
- renamed_key = rename_key(key)
105
- state_dict[renamed_key] = state_dict.pop(key)
106
-
107
- state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
- model.params = unflatten_dict(state)
109
- model.save_pretrained(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/modeling_flax_vqgan.py DELETED
@@ -1,609 +0,0 @@
1
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
-
3
- from functools import partial
4
- from typing import Tuple
5
- import math
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import numpy as np
10
- import flax.linen as nn
11
- from flax.core.frozen_dict import FrozenDict
12
-
13
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
-
15
- from .configuration_vqgan import VQGANConfig
16
-
17
-
18
- class Upsample(nn.Module):
19
- in_channels: int
20
- with_conv: bool
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- if self.with_conv:
25
- self.conv = nn.Conv(
26
- self.in_channels,
27
- kernel_size=(3, 3),
28
- strides=(1, 1),
29
- padding=((1, 1), (1, 1)),
30
- dtype=self.dtype,
31
- )
32
-
33
- def __call__(self, hidden_states):
34
- batch, height, width, channels = hidden_states.shape
35
- hidden_states = jax.image.resize(
36
- hidden_states,
37
- shape=(batch, height * 2, width * 2, channels),
38
- method="nearest",
39
- )
40
- if self.with_conv:
41
- hidden_states = self.conv(hidden_states)
42
- return hidden_states
43
-
44
-
45
- class Downsample(nn.Module):
46
- in_channels: int
47
- with_conv: bool
48
- dtype: jnp.dtype = jnp.float32
49
-
50
- def setup(self):
51
- if self.with_conv:
52
- self.conv = nn.Conv(
53
- self.in_channels,
54
- kernel_size=(3, 3),
55
- strides=(2, 2),
56
- padding="VALID",
57
- dtype=self.dtype,
58
- )
59
-
60
- def __call__(self, hidden_states):
61
- if self.with_conv:
62
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
- hidden_states = self.conv(hidden_states)
65
- else:
66
- hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
- return hidden_states
68
-
69
-
70
- class ResnetBlock(nn.Module):
71
- in_channels: int
72
- out_channels: int = None
73
- use_conv_shortcut: bool = False
74
- temb_channels: int = 512
75
- dropout_prob: float = 0.0
76
- dtype: jnp.dtype = jnp.float32
77
-
78
- def setup(self):
79
- self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
-
81
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
- self.conv1 = nn.Conv(
83
- self.out_channels_,
84
- kernel_size=(3, 3),
85
- strides=(1, 1),
86
- padding=((1, 1), (1, 1)),
87
- dtype=self.dtype,
88
- )
89
-
90
- if self.temb_channels:
91
- self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
-
93
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
- self.dropout = nn.Dropout(self.dropout_prob)
95
- self.conv2 = nn.Conv(
96
- self.out_channels_,
97
- kernel_size=(3, 3),
98
- strides=(1, 1),
99
- padding=((1, 1), (1, 1)),
100
- dtype=self.dtype,
101
- )
102
-
103
- if self.in_channels != self.out_channels_:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = nn.Conv(
106
- self.out_channels_,
107
- kernel_size=(3, 3),
108
- strides=(1, 1),
109
- padding=((1, 1), (1, 1)),
110
- dtype=self.dtype,
111
- )
112
- else:
113
- self.nin_shortcut = nn.Conv(
114
- self.out_channels_,
115
- kernel_size=(1, 1),
116
- strides=(1, 1),
117
- padding="VALID",
118
- dtype=self.dtype,
119
- )
120
-
121
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
- residual = hidden_states
123
- hidden_states = self.norm1(hidden_states)
124
- hidden_states = nn.swish(hidden_states)
125
- hidden_states = self.conv1(hidden_states)
126
-
127
- if temb is not None:
128
- hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
-
130
- hidden_states = self.norm2(hidden_states)
131
- hidden_states = nn.swish(hidden_states)
132
- hidden_states = self.dropout(hidden_states, deterministic)
133
- hidden_states = self.conv2(hidden_states)
134
-
135
- if self.in_channels != self.out_channels_:
136
- if self.use_conv_shortcut:
137
- residual = self.conv_shortcut(residual)
138
- else:
139
- residual = self.nin_shortcut(residual)
140
-
141
- return hidden_states + residual
142
-
143
-
144
- class AttnBlock(nn.Module):
145
- in_channels: int
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- conv = partial(
150
- nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
- )
152
-
153
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
- self.q, self.k, self.v = conv(), conv(), conv()
155
- self.proj_out = conv()
156
-
157
- def __call__(self, hidden_states):
158
- residual = hidden_states
159
- hidden_states = self.norm(hidden_states)
160
-
161
- query = self.q(hidden_states)
162
- key = self.k(hidden_states)
163
- value = self.v(hidden_states)
164
-
165
- # compute attentions
166
- batch, height, width, channels = query.shape
167
- query = query.reshape((batch, height * width, channels))
168
- key = key.reshape((batch, height * width, channels))
169
- attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
- attn_weights = attn_weights * (int(channels) ** -0.5)
171
- attn_weights = nn.softmax(attn_weights, axis=2)
172
-
173
- ## attend to values
174
- value = value.reshape((batch, height * width, channels))
175
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
- hidden_states = hidden_states.reshape((batch, height, width, channels))
177
-
178
- hidden_states = self.proj_out(hidden_states)
179
- hidden_states = hidden_states + residual
180
- return hidden_states
181
-
182
-
183
- class UpsamplingBlock(nn.Module):
184
- config: VQGANConfig
185
- curr_res: int
186
- block_idx: int
187
- dtype: jnp.dtype = jnp.float32
188
-
189
- def setup(self):
190
- if self.block_idx == self.config.num_resolutions - 1:
191
- block_in = self.config.ch * self.config.ch_mult[-1]
192
- else:
193
- block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
-
195
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
- self.temb_ch = 0
197
-
198
- res_blocks = []
199
- attn_blocks = []
200
- for _ in range(self.config.num_res_blocks + 1):
201
- res_blocks.append(
202
- ResnetBlock(
203
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
- )
205
- )
206
- block_in = block_out
207
- if self.curr_res in self.config.attn_resolutions:
208
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
-
210
- self.block = res_blocks
211
- self.attn = attn_blocks
212
-
213
- self.upsample = None
214
- if self.block_idx != 0:
215
- self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
-
217
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
- for res_block in self.block:
219
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
- for attn_block in self.attn:
221
- hidden_states = attn_block(hidden_states)
222
-
223
- if self.upsample is not None:
224
- hidden_states = self.upsample(hidden_states)
225
-
226
- return hidden_states
227
-
228
-
229
- class DownsamplingBlock(nn.Module):
230
- config: VQGANConfig
231
- curr_res: int
232
- block_idx: int
233
- dtype: jnp.dtype = jnp.float32
234
-
235
- def setup(self):
236
- in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
- block_in = self.config.ch * in_ch_mult[self.block_idx]
238
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
- self.temb_ch = 0
240
-
241
- res_blocks = []
242
- attn_blocks = []
243
- for _ in range(self.config.num_res_blocks):
244
- res_blocks.append(
245
- ResnetBlock(
246
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
- )
248
- )
249
- block_in = block_out
250
- if self.curr_res in self.config.attn_resolutions:
251
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
-
253
- self.block = res_blocks
254
- self.attn = attn_blocks
255
-
256
- self.downsample = None
257
- if self.block_idx != self.config.num_resolutions - 1:
258
- self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
-
260
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
- for res_block in self.block:
262
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
- for attn_block in self.attn:
264
- hidden_states = attn_block(hidden_states)
265
-
266
- if self.downsample is not None:
267
- hidden_states = self.downsample(hidden_states)
268
-
269
- return hidden_states
270
-
271
-
272
- class MidBlock(nn.Module):
273
- in_channels: int
274
- temb_channels: int
275
- dropout: float
276
- dtype: jnp.dtype = jnp.float32
277
-
278
- def setup(self):
279
- self.block_1 = ResnetBlock(
280
- self.in_channels,
281
- self.in_channels,
282
- temb_channels=self.temb_channels,
283
- dropout_prob=self.dropout,
284
- dtype=self.dtype,
285
- )
286
- self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
- self.block_2 = ResnetBlock(
288
- self.in_channels,
289
- self.in_channels,
290
- temb_channels=self.temb_channels,
291
- dropout_prob=self.dropout,
292
- dtype=self.dtype,
293
- )
294
-
295
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
- hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
- hidden_states = self.attn_1(hidden_states)
298
- hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
- return hidden_states
300
-
301
-
302
- class Encoder(nn.Module):
303
- config: VQGANConfig
304
- dtype: jnp.dtype = jnp.float32
305
-
306
- def setup(self):
307
- self.temb_ch = 0
308
-
309
- # downsampling
310
- self.conv_in = nn.Conv(
311
- self.config.ch,
312
- kernel_size=(3, 3),
313
- strides=(1, 1),
314
- padding=((1, 1), (1, 1)),
315
- dtype=self.dtype,
316
- )
317
-
318
- curr_res = self.config.resolution
319
- downsample_blocks = []
320
- for i_level in range(self.config.num_resolutions):
321
- downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
-
323
- if i_level != self.config.num_resolutions - 1:
324
- curr_res = curr_res // 2
325
- self.down = downsample_blocks
326
-
327
- # middle
328
- mid_channels = self.config.ch * self.config.ch_mult[-1]
329
- self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
-
331
- # end
332
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
- self.conv_out = nn.Conv(
334
- 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
- kernel_size=(3, 3),
336
- strides=(1, 1),
337
- padding=((1, 1), (1, 1)),
338
- dtype=self.dtype,
339
- )
340
-
341
- def __call__(self, pixel_values, deterministic: bool = True):
342
- # timestep embedding
343
- temb = None
344
-
345
- # downsampling
346
- hidden_states = self.conv_in(pixel_values)
347
- for block in self.down:
348
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
-
350
- # middle
351
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
-
353
- # end
354
- hidden_states = self.norm_out(hidden_states)
355
- hidden_states = nn.swish(hidden_states)
356
- hidden_states = self.conv_out(hidden_states)
357
-
358
- return hidden_states
359
-
360
-
361
- class Decoder(nn.Module):
362
- config: VQGANConfig
363
- dtype: jnp.dtype = jnp.float32
364
-
365
- def setup(self):
366
- self.temb_ch = 0
367
-
368
- # compute in_ch_mult, block_in and curr_res at lowest res
369
- block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
- curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
- self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
-
374
- # z to block_in
375
- self.conv_in = nn.Conv(
376
- block_in,
377
- kernel_size=(3, 3),
378
- strides=(1, 1),
379
- padding=((1, 1), (1, 1)),
380
- dtype=self.dtype,
381
- )
382
-
383
- # middle
384
- self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
-
386
- # upsampling
387
- upsample_blocks = []
388
- for i_level in reversed(range(self.config.num_resolutions)):
389
- upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
- if i_level != 0:
391
- curr_res = curr_res * 2
392
- self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
-
394
- # end
395
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
- self.conv_out = nn.Conv(
397
- self.config.out_ch,
398
- kernel_size=(3, 3),
399
- strides=(1, 1),
400
- padding=((1, 1), (1, 1)),
401
- dtype=self.dtype,
402
- )
403
-
404
- def __call__(self, hidden_states, deterministic: bool = True):
405
- # timestep embedding
406
- temb = None
407
-
408
- # z to block_in
409
- hidden_states = self.conv_in(hidden_states)
410
-
411
- # middle
412
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
-
414
- # upsampling
415
- for block in reversed(self.up):
416
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
-
418
- # end
419
- if self.config.give_pre_end:
420
- return hidden_states
421
-
422
- hidden_states = self.norm_out(hidden_states)
423
- hidden_states = nn.swish(hidden_states)
424
- hidden_states = self.conv_out(hidden_states)
425
-
426
- return hidden_states
427
-
428
-
429
- class VectorQuantizer(nn.Module):
430
- """
431
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
- ____________________________________________
433
- Discretization bottleneck part of the VQ-VAE.
434
- Inputs:
435
- - n_e : number of embeddings
436
- - e_dim : dimension of embedding
437
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
- _____________________________________________
439
- """
440
-
441
- config: VQGANConfig
442
- dtype: jnp.dtype = jnp.float32
443
-
444
- def setup(self):
445
- self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
-
447
- def __call__(self, hidden_states):
448
- """
449
- Inputs the output of the encoder network z and maps it to a discrete
450
- one-hot vector that is the index of the closest embedding vector e_j
451
- z (continuous) -> z_q (discrete)
452
- z.shape = (batch, channel, height, width)
453
- quantization pipeline:
454
- 1. get encoder input (B,C,H,W)
455
- 2. flatten input to (B*H*W,C)
456
- """
457
- # flatten
458
- hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
-
460
- # dummy op to init the weights, so we can access them below
461
- self.embedding(jnp.ones((1, 1), dtype="i4"))
462
-
463
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
- emb_weights = self.variables["params"]["embedding"]["embedding"]
465
- distance = (
466
- jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
- + jnp.sum(emb_weights ** 2, axis=1)
468
- - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
- )
470
-
471
- # get quantized latent vectors
472
- min_encoding_indices = jnp.argmin(distance, axis=1)
473
- z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
-
475
- # reshape to (batch, num_tokens)
476
- min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
-
478
- # compute the codebook_loss (q_loss) outside the model
479
- # here we return the embeddings and indices
480
- return z_q, min_encoding_indices
481
-
482
- def get_codebook_entry(self, indices, shape=None):
483
- # indices are expected to be of shape (batch, num_tokens)
484
- # get quantized latent vectors
485
- batch, num_tokens = indices.shape
486
- z_q = self.embedding(indices)
487
- z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
- return z_q
489
-
490
-
491
- class VQModule(nn.Module):
492
- config: VQGANConfig
493
- dtype: jnp.dtype = jnp.float32
494
-
495
- def setup(self):
496
- self.encoder = Encoder(self.config, dtype=self.dtype)
497
- self.decoder = Decoder(self.config, dtype=self.dtype)
498
- self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
- self.quant_conv = nn.Conv(
500
- self.config.embed_dim,
501
- kernel_size=(1, 1),
502
- strides=(1, 1),
503
- padding="VALID",
504
- dtype=self.dtype,
505
- )
506
- self.post_quant_conv = nn.Conv(
507
- self.config.z_channels,
508
- kernel_size=(1, 1),
509
- strides=(1, 1),
510
- padding="VALID",
511
- dtype=self.dtype,
512
- )
513
-
514
- def encode(self, pixel_values, deterministic: bool = True):
515
- hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
- hidden_states = self.quant_conv(hidden_states)
517
- quant_states, indices = self.quantize(hidden_states)
518
- return quant_states, indices
519
-
520
- def decode(self, hidden_states, deterministic: bool = True):
521
- hidden_states = self.post_quant_conv(hidden_states)
522
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
- return hidden_states
524
-
525
- def decode_code(self, code_b):
526
- hidden_states = self.quantize.get_codebook_entry(code_b)
527
- hidden_states = self.decode(hidden_states)
528
- return hidden_states
529
-
530
- def __call__(self, pixel_values, deterministic: bool = True):
531
- quant_states, indices = self.encode(pixel_values, deterministic)
532
- hidden_states = self.decode(quant_states, deterministic)
533
- return hidden_states, indices
534
-
535
-
536
- class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
- """
538
- An abstract class to handle weights initialization and a simple interface
539
- for downloading and loading pretrained models.
540
- """
541
-
542
- config_class = VQGANConfig
543
- base_model_prefix = "model"
544
- module_class: nn.Module = None
545
-
546
- def __init__(
547
- self,
548
- config: VQGANConfig,
549
- input_shape: Tuple = (1, 256, 256, 3),
550
- seed: int = 0,
551
- dtype: jnp.dtype = jnp.float32,
552
- **kwargs,
553
- ):
554
- module = self.module_class(config=config, dtype=dtype, **kwargs)
555
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
-
557
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
- # init input tensors
559
- pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
- params_rng, dropout_rng = jax.random.split(rng)
561
- rngs = {"params": params_rng, "dropout": dropout_rng}
562
-
563
- return self.module.init(rngs, pixel_values)["params"]
564
-
565
- def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
- # Handle any PRNG if needed
567
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
-
569
- return self.module.apply(
570
- {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
- )
572
-
573
- def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
- # Handle any PRNG if needed
575
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
-
577
- return self.module.apply(
578
- {"params": params or self.params},
579
- jnp.array(hidden_states),
580
- not train,
581
- rngs=rngs,
582
- method=self.module.decode,
583
- )
584
-
585
- def decode_code(self, indices, params: dict = None):
586
- return self.module.apply(
587
- {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
- )
589
-
590
- def __call__(
591
- self,
592
- pixel_values,
593
- params: dict = None,
594
- dropout_rng: jax.random.PRNGKey = None,
595
- train: bool = False,
596
- ):
597
- # Handle any PRNG if needed
598
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
-
600
- return self.module.apply(
601
- {"params": params or self.params},
602
- jnp.array(pixel_values),
603
- not train,
604
- rngs=rngs,
605
- )
606
-
607
-
608
- class VQModel(VQGANPreTrainedModel):
609
- module_class = VQModule