boris commited on
Commit
9c0e5c9
2 Parent(s): 3f0364c 86ba774

Merge pull request #8 from pcuenca/main

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
dalle_mini/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.1"
dalle_mini/dataset.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An image-caption dataset dataloader.
3
+ Luke Melas-Kyriazi, 2021
4
+ """
5
+ import warnings
6
+ from typing import Optional, Callable
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import torch
10
+ import pandas as pd
11
+ from torch.utils.data import Dataset
12
+ from torchvision.datasets.folder import default_loader
13
+ from PIL import ImageFile
14
+ from PIL.Image import DecompressionBombWarning
15
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+ warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
+
19
+
20
+ class CaptionDataset(Dataset):
21
+ """
22
+ A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
+ returns the raw text rather than tokens. This is done on purpose, because
24
+ it's easy to tokenize a batch of text after loading it from this dataset.
25
+ """
26
+
27
+ def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
+ image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
+ include_captions: bool = True):
30
+ """
31
+ :param images_root: folder where images are stored
32
+ :param captions_path: path to csv that maps image filenames to captions
33
+ :param image_transform: image transform pipeline
34
+ :param text_transform: image transform pipeline
35
+ :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
+ :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
+ """
38
+
39
+ # Base path for images
40
+ self.images_root = Path(images_root)
41
+
42
+ # Load captions as DataFrame
43
+ self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
+ self.captions['image_file'] = self.captions['image_file'].astype(str)
45
+
46
+ # PyTorch transformation pipeline for the image (normalizing, etc.)
47
+ self.text_transform = text_transform
48
+ self.image_transform = image_transform
49
+ self.image_transform_type = image_transform_type.lower()
50
+ assert self.image_transform_type in ['torchvision', 'albumentations']
51
+
52
+ # Total number of datapoints
53
+ self.size = len(self.captions)
54
+
55
+ # Return image+captions or just images
56
+ self.include_captions = include_captions
57
+
58
+ def verify_that_all_images_exist(self):
59
+ for image_file in self.captions['image_file']:
60
+ p = self.images_root / image_file
61
+ if not p.is_file():
62
+ print(f'file does not exist: {p}')
63
+
64
+ def _get_raw_image(self, i):
65
+ image_file = self.captions.iloc[i]['image_file']
66
+ image_path = self.images_root / image_file
67
+ image = default_loader(image_path)
68
+ return image
69
+
70
+ def _get_raw_text(self, i):
71
+ return self.captions.iloc[i]['caption']
72
+
73
+ def __getitem__(self, i):
74
+ image = self._get_raw_image(i)
75
+ caption = self._get_raw_text(i)
76
+ if self.image_transform is not None:
77
+ if self.image_transform_type == 'torchvision':
78
+ image = self.image_transform(image)
79
+ elif self.image_transform_type == 'albumentations':
80
+ image = self.image_transform(image=np.array(image))['image']
81
+ else:
82
+ raise NotImplementedError(f"{self.image_transform_type=}")
83
+ return {'image': image, 'text': caption} if self.include_captions else image
84
+
85
+ def __len__(self):
86
+ return self.size
87
+
88
+
89
+ if __name__ == "__main__":
90
+ import albumentations as A
91
+ from albumentations.pytorch import ToTensorV2
92
+ from transformers import AutoTokenizer
93
+
94
+ # Paths
95
+ images_root = './images'
96
+ captions_path = './images-list-clean.tsv'
97
+
98
+ # Create transforms
99
+ tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
+ def tokenize(text):
101
+ return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
+ image_transform = A.Compose([
103
+ A.Resize(256, 256), A.CenterCrop(256, 256),
104
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
+
106
+ # Create dataset
107
+ dataset = CaptionDataset(
108
+ images_root=images_root,
109
+ captions_path=captions_path,
110
+ image_transform=image_transform,
111
+ text_transform=tokenize,
112
+ image_transform_type='albumentations')
113
+
114
+ # Create dataloader
115
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
+ batch = next(iter(dataloader))
117
+ print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
+
119
+ # # (Optional) Check that all the images exist
120
+ # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
+ # dataset.verify_that_all_images_exist()
122
+ # print('Done')
dalle_mini/vqgan_jax/__init__.py ADDED
File without changes
dalle_mini/vqgan_jax/configuration_vqgan.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class VQGANConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ ch: int = 128,
10
+ out_ch: int = 3,
11
+ in_channels: int = 3,
12
+ num_res_blocks: int = 2,
13
+ resolution: int = 256,
14
+ z_channels: int = 256,
15
+ ch_mult: Tuple = (1, 1, 2, 2, 4),
16
+ attn_resolutions: int = (16,),
17
+ n_embed: int = 1024,
18
+ embed_dim: int = 256,
19
+ dropout: float = 0.0,
20
+ double_z: bool = False,
21
+ resamp_with_conv: bool = True,
22
+ give_pre_end: bool = False,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.ch = ch
27
+ self.out_ch = out_ch
28
+ self.in_channels = in_channels
29
+ self.num_res_blocks = num_res_blocks
30
+ self.resolution = resolution
31
+ self.z_channels = z_channels
32
+ self.ch_mult = list(ch_mult)
33
+ self.attn_resolutions = list(attn_resolutions)
34
+ self.n_embed = n_embed
35
+ self.embed_dim = embed_dim
36
+ self.dropout = dropout
37
+ self.double_z = double_z
38
+ self.resamp_with_conv = resamp_with_conv
39
+ self.give_pre_end = give_pre_end
40
+ self.num_resolutions = len(ch_mult)
dalle_mini/vqgan_jax/convert_pt_model_to_jax.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import jax.numpy as jnp
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+
6
+ import torch
7
+
8
+ from modeling_flax_vqgan import VQModel
9
+ from configuration_vqgan import VQGANConfig
10
+
11
+
12
+ regex = r"\w+[.]\d+"
13
+
14
+
15
+ def rename_key(key):
16
+ pats = re.findall(regex, key)
17
+ for pat in pats:
18
+ key = key.replace(pat, "_".join(pat.split(".")))
19
+ return key
20
+
21
+
22
+ # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
+ # convert pytorch tensor to numpy
25
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
+
27
+ random_flax_state_dict = flatten_dict(flax_model.params)
28
+ flax_state_dict = {}
29
+
30
+ remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
+ flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
+ )
33
+ add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
+ flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
+ )
36
+
37
+ # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
+ for pt_key, pt_tensor in pt_state_dict.items():
39
+ pt_tuple_key = tuple(pt_key.split("."))
40
+
41
+ has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
+ require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
+
44
+ if remove_base_model_prefix and has_base_model_prefix:
45
+ pt_tuple_key = pt_tuple_key[1:]
46
+ elif add_base_model_prefix and require_base_model_prefix:
47
+ pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
+
49
+ # Correctly rename weight parameters
50
+ if (
51
+ "norm" in pt_key
52
+ and (pt_tuple_key[-1] == "bias")
53
+ and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
+ ):
55
+ pt_tensor = pt_tensor[None, None, None, :]
56
+ elif (
57
+ "norm" in pt_key
58
+ and (pt_tuple_key[-1] == "bias")
59
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
+ ):
61
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
+ pt_tensor = pt_tensor[None, None, None, :]
63
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
+ pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
+ pt_tensor = pt_tensor[None, None, None, :]
66
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
+ elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
+ # conv layer
70
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
+ elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
+ # linear layer
74
+ pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
+ pt_tensor = pt_tensor.T
76
+ elif pt_tuple_key[-1] == "gamma":
77
+ pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
+ elif pt_tuple_key[-1] == "beta":
79
+ pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
+
81
+ if pt_tuple_key in random_flax_state_dict:
82
+ if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
+ raise ValueError(
84
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
+ f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
+ )
87
+
88
+ # also add unexpected weight so that warning is thrown
89
+ flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
+
91
+ return unflatten_dict(flax_state_dict)
92
+
93
+
94
+ def convert_model(config_path, pt_state_dict_path, save_path):
95
+ config = VQGANConfig.from_pretrained(config_path)
96
+ model = VQModel(config)
97
+
98
+ state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
+ keys = list(state_dict.keys())
100
+ for key in keys:
101
+ if key.startswith("loss"):
102
+ state_dict.pop(key)
103
+ continue
104
+ renamed_key = rename_key(key)
105
+ state_dict[renamed_key] = state_dict.pop(key)
106
+
107
+ state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
+ model.params = unflatten_dict(state)
109
+ model.save_pretrained(save_path)
dalle_mini/vqgan_jax/modeling_flax_vqgan.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
+
3
+ from functools import partial
4
+ from typing import Tuple
5
+ import math
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ import flax.linen as nn
11
+ from flax.core.frozen_dict import FrozenDict
12
+
13
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
+
15
+ from .configuration_vqgan import VQGANConfig
16
+
17
+
18
+ class Upsample(nn.Module):
19
+ in_channels: int
20
+ with_conv: bool
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self):
24
+ if self.with_conv:
25
+ self.conv = nn.Conv(
26
+ self.in_channels,
27
+ kernel_size=(3, 3),
28
+ strides=(1, 1),
29
+ padding=((1, 1), (1, 1)),
30
+ dtype=self.dtype,
31
+ )
32
+
33
+ def __call__(self, hidden_states):
34
+ batch, height, width, channels = hidden_states.shape
35
+ hidden_states = jax.image.resize(
36
+ hidden_states,
37
+ shape=(batch, height * 2, width * 2, channels),
38
+ method="nearest",
39
+ )
40
+ if self.with_conv:
41
+ hidden_states = self.conv(hidden_states)
42
+ return hidden_states
43
+
44
+
45
+ class Downsample(nn.Module):
46
+ in_channels: int
47
+ with_conv: bool
48
+ dtype: jnp.dtype = jnp.float32
49
+
50
+ def setup(self):
51
+ if self.with_conv:
52
+ self.conv = nn.Conv(
53
+ self.in_channels,
54
+ kernel_size=(3, 3),
55
+ strides=(2, 2),
56
+ padding="VALID",
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ def __call__(self, hidden_states):
61
+ if self.with_conv:
62
+ pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
+ hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
+ hidden_states = self.conv(hidden_states)
65
+ else:
66
+ hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
+ return hidden_states
68
+
69
+
70
+ class ResnetBlock(nn.Module):
71
+ in_channels: int
72
+ out_channels: int = None
73
+ use_conv_shortcut: bool = False
74
+ temb_channels: int = 512
75
+ dropout_prob: float = 0.0
76
+ dtype: jnp.dtype = jnp.float32
77
+
78
+ def setup(self):
79
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
+
81
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
+ self.conv1 = nn.Conv(
83
+ self.out_channels_,
84
+ kernel_size=(3, 3),
85
+ strides=(1, 1),
86
+ padding=((1, 1), (1, 1)),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ if self.temb_channels:
91
+ self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
+
93
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
+ self.dropout = nn.Dropout(self.dropout_prob)
95
+ self.conv2 = nn.Conv(
96
+ self.out_channels_,
97
+ kernel_size=(3, 3),
98
+ strides=(1, 1),
99
+ padding=((1, 1), (1, 1)),
100
+ dtype=self.dtype,
101
+ )
102
+
103
+ if self.in_channels != self.out_channels_:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = nn.Conv(
106
+ self.out_channels_,
107
+ kernel_size=(3, 3),
108
+ strides=(1, 1),
109
+ padding=((1, 1), (1, 1)),
110
+ dtype=self.dtype,
111
+ )
112
+ else:
113
+ self.nin_shortcut = nn.Conv(
114
+ self.out_channels_,
115
+ kernel_size=(1, 1),
116
+ strides=(1, 1),
117
+ padding="VALID",
118
+ dtype=self.dtype,
119
+ )
120
+
121
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
+ residual = hidden_states
123
+ hidden_states = self.norm1(hidden_states)
124
+ hidden_states = nn.swish(hidden_states)
125
+ hidden_states = self.conv1(hidden_states)
126
+
127
+ if temb is not None:
128
+ hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
+
130
+ hidden_states = self.norm2(hidden_states)
131
+ hidden_states = nn.swish(hidden_states)
132
+ hidden_states = self.dropout(hidden_states, deterministic)
133
+ hidden_states = self.conv2(hidden_states)
134
+
135
+ if self.in_channels != self.out_channels_:
136
+ if self.use_conv_shortcut:
137
+ residual = self.conv_shortcut(residual)
138
+ else:
139
+ residual = self.nin_shortcut(residual)
140
+
141
+ return hidden_states + residual
142
+
143
+
144
+ class AttnBlock(nn.Module):
145
+ in_channels: int
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ conv = partial(
150
+ nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
+ )
152
+
153
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
+ self.q, self.k, self.v = conv(), conv(), conv()
155
+ self.proj_out = conv()
156
+
157
+ def __call__(self, hidden_states):
158
+ residual = hidden_states
159
+ hidden_states = self.norm(hidden_states)
160
+
161
+ query = self.q(hidden_states)
162
+ key = self.k(hidden_states)
163
+ value = self.v(hidden_states)
164
+
165
+ # compute attentions
166
+ batch, height, width, channels = query.shape
167
+ query = query.reshape((batch, height * width, channels))
168
+ key = key.reshape((batch, height * width, channels))
169
+ attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
+ attn_weights = attn_weights * (int(channels) ** -0.5)
171
+ attn_weights = nn.softmax(attn_weights, axis=2)
172
+
173
+ ## attend to values
174
+ value = value.reshape((batch, height * width, channels))
175
+ hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
+ hidden_states = hidden_states.reshape((batch, height, width, channels))
177
+
178
+ hidden_states = self.proj_out(hidden_states)
179
+ hidden_states = hidden_states + residual
180
+ return hidden_states
181
+
182
+
183
+ class UpsamplingBlock(nn.Module):
184
+ config: VQGANConfig
185
+ curr_res: int
186
+ block_idx: int
187
+ dtype: jnp.dtype = jnp.float32
188
+
189
+ def setup(self):
190
+ if self.block_idx == self.config.num_resolutions - 1:
191
+ block_in = self.config.ch * self.config.ch_mult[-1]
192
+ else:
193
+ block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
+
195
+ block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
+ self.temb_ch = 0
197
+
198
+ res_blocks = []
199
+ attn_blocks = []
200
+ for _ in range(self.config.num_res_blocks + 1):
201
+ res_blocks.append(
202
+ ResnetBlock(
203
+ block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
+ )
205
+ )
206
+ block_in = block_out
207
+ if self.curr_res in self.config.attn_resolutions:
208
+ attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
+
210
+ self.block = res_blocks
211
+ self.attn = attn_blocks
212
+
213
+ self.upsample = None
214
+ if self.block_idx != 0:
215
+ self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
+
217
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
+ for res_block in self.block:
219
+ hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
+ for attn_block in self.attn:
221
+ hidden_states = attn_block(hidden_states)
222
+
223
+ if self.upsample is not None:
224
+ hidden_states = self.upsample(hidden_states)
225
+
226
+ return hidden_states
227
+
228
+
229
+ class DownsamplingBlock(nn.Module):
230
+ config: VQGANConfig
231
+ curr_res: int
232
+ block_idx: int
233
+ dtype: jnp.dtype = jnp.float32
234
+
235
+ def setup(self):
236
+ in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
+ block_in = self.config.ch * in_ch_mult[self.block_idx]
238
+ block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
+ self.temb_ch = 0
240
+
241
+ res_blocks = []
242
+ attn_blocks = []
243
+ for _ in range(self.config.num_res_blocks):
244
+ res_blocks.append(
245
+ ResnetBlock(
246
+ block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
+ )
248
+ )
249
+ block_in = block_out
250
+ if self.curr_res in self.config.attn_resolutions:
251
+ attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
+
253
+ self.block = res_blocks
254
+ self.attn = attn_blocks
255
+
256
+ self.downsample = None
257
+ if self.block_idx != self.config.num_resolutions - 1:
258
+ self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
+
260
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
+ for res_block in self.block:
262
+ hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
+ for attn_block in self.attn:
264
+ hidden_states = attn_block(hidden_states)
265
+
266
+ if self.downsample is not None:
267
+ hidden_states = self.downsample(hidden_states)
268
+
269
+ return hidden_states
270
+
271
+
272
+ class MidBlock(nn.Module):
273
+ in_channels: int
274
+ temb_channels: int
275
+ dropout: float
276
+ dtype: jnp.dtype = jnp.float32
277
+
278
+ def setup(self):
279
+ self.block_1 = ResnetBlock(
280
+ self.in_channels,
281
+ self.in_channels,
282
+ temb_channels=self.temb_channels,
283
+ dropout_prob=self.dropout,
284
+ dtype=self.dtype,
285
+ )
286
+ self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
+ self.block_2 = ResnetBlock(
288
+ self.in_channels,
289
+ self.in_channels,
290
+ temb_channels=self.temb_channels,
291
+ dropout_prob=self.dropout,
292
+ dtype=self.dtype,
293
+ )
294
+
295
+ def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
+ hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
+ hidden_states = self.attn_1(hidden_states)
298
+ hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
+ return hidden_states
300
+
301
+
302
+ class Encoder(nn.Module):
303
+ config: VQGANConfig
304
+ dtype: jnp.dtype = jnp.float32
305
+
306
+ def setup(self):
307
+ self.temb_ch = 0
308
+
309
+ # downsampling
310
+ self.conv_in = nn.Conv(
311
+ self.config.ch,
312
+ kernel_size=(3, 3),
313
+ strides=(1, 1),
314
+ padding=((1, 1), (1, 1)),
315
+ dtype=self.dtype,
316
+ )
317
+
318
+ curr_res = self.config.resolution
319
+ downsample_blocks = []
320
+ for i_level in range(self.config.num_resolutions):
321
+ downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
+
323
+ if i_level != self.config.num_resolutions - 1:
324
+ curr_res = curr_res // 2
325
+ self.down = downsample_blocks
326
+
327
+ # middle
328
+ mid_channels = self.config.ch * self.config.ch_mult[-1]
329
+ self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
+
331
+ # end
332
+ self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
+ self.conv_out = nn.Conv(
334
+ 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
+ kernel_size=(3, 3),
336
+ strides=(1, 1),
337
+ padding=((1, 1), (1, 1)),
338
+ dtype=self.dtype,
339
+ )
340
+
341
+ def __call__(self, pixel_values, deterministic: bool = True):
342
+ # timestep embedding
343
+ temb = None
344
+
345
+ # downsampling
346
+ hidden_states = self.conv_in(pixel_values)
347
+ for block in self.down:
348
+ hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
+
350
+ # middle
351
+ hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
+
353
+ # end
354
+ hidden_states = self.norm_out(hidden_states)
355
+ hidden_states = nn.swish(hidden_states)
356
+ hidden_states = self.conv_out(hidden_states)
357
+
358
+ return hidden_states
359
+
360
+
361
+ class Decoder(nn.Module):
362
+ config: VQGANConfig
363
+ dtype: jnp.dtype = jnp.float32
364
+
365
+ def setup(self):
366
+ self.temb_ch = 0
367
+
368
+ # compute in_ch_mult, block_in and curr_res at lowest res
369
+ block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
+
374
+ # z to block_in
375
+ self.conv_in = nn.Conv(
376
+ block_in,
377
+ kernel_size=(3, 3),
378
+ strides=(1, 1),
379
+ padding=((1, 1), (1, 1)),
380
+ dtype=self.dtype,
381
+ )
382
+
383
+ # middle
384
+ self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
+
386
+ # upsampling
387
+ upsample_blocks = []
388
+ for i_level in reversed(range(self.config.num_resolutions)):
389
+ upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
+ if i_level != 0:
391
+ curr_res = curr_res * 2
392
+ self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
+
394
+ # end
395
+ self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
+ self.conv_out = nn.Conv(
397
+ self.config.out_ch,
398
+ kernel_size=(3, 3),
399
+ strides=(1, 1),
400
+ padding=((1, 1), (1, 1)),
401
+ dtype=self.dtype,
402
+ )
403
+
404
+ def __call__(self, hidden_states, deterministic: bool = True):
405
+ # timestep embedding
406
+ temb = None
407
+
408
+ # z to block_in
409
+ hidden_states = self.conv_in(hidden_states)
410
+
411
+ # middle
412
+ hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
+
414
+ # upsampling
415
+ for block in reversed(self.up):
416
+ hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
+
418
+ # end
419
+ if self.config.give_pre_end:
420
+ return hidden_states
421
+
422
+ hidden_states = self.norm_out(hidden_states)
423
+ hidden_states = nn.swish(hidden_states)
424
+ hidden_states = self.conv_out(hidden_states)
425
+
426
+ return hidden_states
427
+
428
+
429
+ class VectorQuantizer(nn.Module):
430
+ """
431
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
+ ____________________________________________
433
+ Discretization bottleneck part of the VQ-VAE.
434
+ Inputs:
435
+ - n_e : number of embeddings
436
+ - e_dim : dimension of embedding
437
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
+ _____________________________________________
439
+ """
440
+
441
+ config: VQGANConfig
442
+ dtype: jnp.dtype = jnp.float32
443
+
444
+ def setup(self):
445
+ self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
+
447
+ def __call__(self, hidden_states):
448
+ """
449
+ Inputs the output of the encoder network z and maps it to a discrete
450
+ one-hot vector that is the index of the closest embedding vector e_j
451
+ z (continuous) -> z_q (discrete)
452
+ z.shape = (batch, channel, height, width)
453
+ quantization pipeline:
454
+ 1. get encoder input (B,C,H,W)
455
+ 2. flatten input to (B*H*W,C)
456
+ """
457
+ # flatten
458
+ hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
+
460
+ # dummy op to init the weights, so we can access them below
461
+ self.embedding(jnp.ones((1, 1), dtype="i4"))
462
+
463
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
+ emb_weights = self.variables["params"]["embedding"]["embedding"]
465
+ distance = (
466
+ jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
+ + jnp.sum(emb_weights ** 2, axis=1)
468
+ - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
+ )
470
+
471
+ # get quantized latent vectors
472
+ min_encoding_indices = jnp.argmin(distance, axis=1)
473
+ z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
+
475
+ # reshape to (batch, num_tokens)
476
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
+
478
+ # compute the codebook_loss (q_loss) outside the model
479
+ # here we return the embeddings and indices
480
+ return z_q, min_encoding_indices
481
+
482
+ def get_codebook_entry(self, indices, shape=None):
483
+ # indices are expected to be of shape (batch, num_tokens)
484
+ # get quantized latent vectors
485
+ batch, num_tokens = indices.shape
486
+ z_q = self.embedding(indices)
487
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
+ return z_q
489
+
490
+
491
+ class VQModule(nn.Module):
492
+ config: VQGANConfig
493
+ dtype: jnp.dtype = jnp.float32
494
+
495
+ def setup(self):
496
+ self.encoder = Encoder(self.config, dtype=self.dtype)
497
+ self.decoder = Decoder(self.config, dtype=self.dtype)
498
+ self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
+ self.quant_conv = nn.Conv(
500
+ self.config.embed_dim,
501
+ kernel_size=(1, 1),
502
+ strides=(1, 1),
503
+ padding="VALID",
504
+ dtype=self.dtype,
505
+ )
506
+ self.post_quant_conv = nn.Conv(
507
+ self.config.z_channels,
508
+ kernel_size=(1, 1),
509
+ strides=(1, 1),
510
+ padding="VALID",
511
+ dtype=self.dtype,
512
+ )
513
+
514
+ def encode(self, pixel_values, deterministic: bool = True):
515
+ hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
+ hidden_states = self.quant_conv(hidden_states)
517
+ quant_states, indices = self.quantize(hidden_states)
518
+ return quant_states, indices
519
+
520
+ def decode(self, hidden_states, deterministic: bool = True):
521
+ hidden_states = self.post_quant_conv(hidden_states)
522
+ hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
+ return hidden_states
524
+
525
+ def decode_code(self, code_b):
526
+ hidden_states = self.quantize.get_codebook_entry(code_b)
527
+ hidden_states = self.decode(hidden_states)
528
+ return hidden_states
529
+
530
+ def __call__(self, pixel_values, deterministic: bool = True):
531
+ quant_states, indices = self.encode(pixel_values, deterministic)
532
+ hidden_states = self.decode(quant_states, deterministic)
533
+ return hidden_states, indices
534
+
535
+
536
+ class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
+ """
538
+ An abstract class to handle weights initialization and a simple interface
539
+ for downloading and loading pretrained models.
540
+ """
541
+
542
+ config_class = VQGANConfig
543
+ base_model_prefix = "model"
544
+ module_class: nn.Module = None
545
+
546
+ def __init__(
547
+ self,
548
+ config: VQGANConfig,
549
+ input_shape: Tuple = (1, 256, 256, 3),
550
+ seed: int = 0,
551
+ dtype: jnp.dtype = jnp.float32,
552
+ **kwargs,
553
+ ):
554
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
555
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
+
557
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
+ # init input tensors
559
+ pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
+ params_rng, dropout_rng = jax.random.split(rng)
561
+ rngs = {"params": params_rng, "dropout": dropout_rng}
562
+
563
+ return self.module.init(rngs, pixel_values)["params"]
564
+
565
+ def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
+ # Handle any PRNG if needed
567
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
+
569
+ return self.module.apply(
570
+ {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
+ )
572
+
573
+ def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
+ # Handle any PRNG if needed
575
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
+
577
+ return self.module.apply(
578
+ {"params": params or self.params},
579
+ jnp.array(hidden_states),
580
+ not train,
581
+ rngs=rngs,
582
+ method=self.module.decode,
583
+ )
584
+
585
+ def decode_code(self, indices, params: dict = None):
586
+ return self.module.apply(
587
+ {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
+ )
589
+
590
+ def __call__(
591
+ self,
592
+ pixel_values,
593
+ params: dict = None,
594
+ dropout_rng: jax.random.PRNGKey = None,
595
+ train: bool = False,
596
+ ):
597
+ # Handle any PRNG if needed
598
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
+
600
+ return self.module.apply(
601
+ {"params": params or self.params},
602
+ jnp.array(pixel_values),
603
+ not train,
604
+ rngs=rngs,
605
+ )
606
+
607
+
608
+ class VQModel(VQGANPreTrainedModel):
609
+ module_class = VQModule
encoding/vqgan-jax-encoding-with-captions.ipynb ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# vqgan-jax-encoding-with-captions"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "875c82b3",
14
+ "metadata": {},
15
+ "source": [
16
+ "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
17
+ "\n",
18
+ "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "3b59489e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import io\n",
29
+ "\n",
30
+ "import requests\n",
31
+ "from PIL import Image\n",
32
+ "import numpy as np\n",
33
+ "from tqdm import tqdm\n",
34
+ "\n",
35
+ "import torch\n",
36
+ "import torchvision.transforms as T\n",
37
+ "import torchvision.transforms.functional as TF\n",
38
+ "from torchvision.transforms import InterpolationMode\n",
39
+ "from torch.utils.data import Dataset, DataLoader\n",
40
+ "\n",
41
+ "import jax\n",
42
+ "from jax import pmap"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "511c3b9e",
48
+ "metadata": {},
49
+ "source": [
50
+ "## VQGAN-JAX model"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "bb408f6c",
56
+ "metadata": {},
57
+ "source": [
58
+ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 2,
64
+ "id": "2ca50dc7",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "id": "7b60da9a",
74
+ "metadata": {},
75
+ "source": [
76
+ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "29ce8b15",
83
+ "metadata": {},
84
+ "outputs": [
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "db406bdfc5d5428eaeae1631a04989dd",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "data": {
101
+ "application/vnd.jupyter.widget-view+json": {
102
+ "model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
103
+ "version_major": 2,
104
+ "version_minor": 0
105
+ },
106
+ "text/plain": [
107
+ "Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
108
+ ]
109
+ },
110
+ "metadata": {},
111
+ "output_type": "display_data"
112
+ },
113
+ {
114
+ "name": "stderr",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "INFO:absl:Starting the local TPU driver.\n",
118
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
119
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
120
+ ]
121
+ },
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "id": "c7c4c1e6",
137
+ "metadata": {},
138
+ "source": [
139
+ "## Dataset"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "id": "7014a7ce",
145
+ "metadata": {},
146
+ "source": [
147
+ "We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 4,
153
+ "id": "85832702",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "from dalle_mini.dataset import *"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 5,
163
+ "id": "81b19eca",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "cc12m_images = '/data/CC12M/images'\n",
168
+ "cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
169
+ "# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
170
+ "cc12m_output = '/data/CC12M/images-encoded.tsv'"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 6,
176
+ "id": "fecc9a00",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "image_size = 256\n",
181
+ "def image_transform(image):\n",
182
+ " s = min(image.size)\n",
183
+ " r = image_size / s\n",
184
+ " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
185
+ " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
186
+ " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
187
+ " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
188
+ " image = image.permute(0, 2, 3, 1).numpy()\n",
189
+ " return image"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 7,
195
+ "id": "4ce2211f",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "dataset = CaptionDataset(\n",
200
+ " images_root=cc12m_images,\n",
201
+ " captions_path=cc12m_list,\n",
202
+ " image_transform=image_transform,\n",
203
+ " image_transform_type='torchvision',\n",
204
+ " include_captions=False\n",
205
+ ")"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 8,
211
+ "id": "cc922704",
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "data": {
216
+ "text/plain": [
217
+ "8592141"
218
+ ]
219
+ },
220
+ "execution_count": 8,
221
+ "metadata": {},
222
+ "output_type": "execute_result"
223
+ }
224
+ ],
225
+ "source": [
226
+ "len(dataset)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "id": "62ad01c3",
232
+ "metadata": {},
233
+ "source": [
234
+ "## Encoding"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 9,
240
+ "id": "88f36d0b",
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "def encode(model, batch):\n",
245
+ "# print(\"jitting encode function\")\n",
246
+ " _, indices = model.encode(batch)\n",
247
+ " return indices"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 10,
253
+ "id": "1f35f0cb",
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "def superbatch_generator(dataloader, num_tpus):\n",
258
+ " iter_loader = iter(dataloader)\n",
259
+ " for batch in iter_loader:\n",
260
+ " superbatch = [batch.squeeze(1)]\n",
261
+ " try:\n",
262
+ " for b in range(num_tpus-1):\n",
263
+ " batch = next(iter_loader)\n",
264
+ " if batch is None:\n",
265
+ " break\n",
266
+ " # Skip incomplete last batch\n",
267
+ " if batch.shape[0] == dataloader.batch_size:\n",
268
+ " superbatch.append(batch.squeeze(1))\n",
269
+ " except StopIteration:\n",
270
+ " pass\n",
271
+ " superbatch = torch.stack(superbatch, axis=0)\n",
272
+ " yield superbatch"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 11,
278
+ "id": "2210705b",
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "import os\n",
283
+ "\n",
284
+ "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
285
+ " if os.path.isfile(output_tsv):\n",
286
+ " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
287
+ " return\n",
288
+ " \n",
289
+ " num_tpus = 8 \n",
290
+ " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
291
+ " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
292
+ " \n",
293
+ " p_encoder = pmap(lambda batch: encode(model, batch))\n",
294
+ "\n",
295
+ " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
296
+ " # We keep the file open to prevent excessive file seeks.\n",
297
+ " with open(output_tsv, \"w\") as file:\n",
298
+ " iterations = len(dataset) // (batch_size * num_tpus)\n",
299
+ " for n in tqdm(range(iterations)):\n",
300
+ " superbatch = next(superbatches)\n",
301
+ " encoded = p_encoder(superbatch.numpy())\n",
302
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
303
+ "\n",
304
+ " # Extract fields from the dataset internal `captions` property, and save to disk\n",
305
+ " start_index = n * batch_size * num_tpus\n",
306
+ " end_index = (n+1) * batch_size * num_tpus\n",
307
+ " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
308
+ " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
309
+ " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
310
+ " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
311
+ " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
312
+ " "
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "id": "7704863d",
319
+ "metadata": {},
320
+ "outputs": [
321
+ {
322
+ "name": "stderr",
323
+ "output_type": "stream",
324
+ "text": [
325
+ " 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
326
+ ]
327
+ }
328
+ ],
329
+ "source": [
330
+ "encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "8953dd84",
336
+ "metadata": {},
337
+ "source": [
338
+ "----"
339
+ ]
340
+ }
341
+ ],
342
+ "metadata": {
343
+ "kernelspec": {
344
+ "display_name": "Python 3 (ipykernel)",
345
+ "language": "python",
346
+ "name": "python3"
347
+ },
348
+ "language_info": {
349
+ "codemirror_mode": {
350
+ "name": "ipython",
351
+ "version": 3
352
+ },
353
+ "file_extension": ".py",
354
+ "mimetype": "text/x-python",
355
+ "name": "python",
356
+ "nbconvert_exporter": "python",
357
+ "pygments_lexer": "ipython3",
358
+ "version": "3.8.10"
359
+ }
360
+ },
361
+ "nbformat": 4,
362
+ "nbformat_minor": 5
363
+ }
encoding/vqgan-jax-encoding.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
model/data-pipeline.ipynb ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "bf8fb38a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Data Pipeline"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "9b83dcb9",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from dataclasses import dataclass, field\n",
19
+ "from pathlib import Path\n",
20
+ "\n",
21
+ "import datasets\n",
22
+ "from datasets import Dataset, load_dataset\n",
23
+ "import numpy as np\n",
24
+ "\n",
25
+ "from transformers import BartTokenizer\n",
26
+ "\n",
27
+ "from tqdm import tqdm\n",
28
+ "\n",
29
+ "import jax\n",
30
+ "import jax.numpy as jnp\n",
31
+ "\n",
32
+ "from flax.training.common_utils import shard"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "a661a89e",
38
+ "metadata": {},
39
+ "source": [
40
+ "File containing image paths, captions and VQGAN-encoded indices."
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "id": "0e84e889",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "7fdc640b",
56
+ "metadata": {},
57
+ "source": [
58
+ "TODO: generate train/test splits if necessary."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 3,
64
+ "id": "cc6789b4",
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "name": "stderr",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "Using custom data configuration default-91833df78e844785\n",
72
+ "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 4,
83
+ "id": "f3ed4919",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "DatasetDict({\n",
90
+ " train: Dataset({\n",
91
+ " features: ['image_file', 'caption', 'encoding'],\n",
92
+ " num_rows: 9999\n",
93
+ " })\n",
94
+ "})"
95
+ ]
96
+ },
97
+ "execution_count": 4,
98
+ "metadata": {},
99
+ "output_type": "execute_result"
100
+ }
101
+ ],
102
+ "source": [
103
+ "dataset"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 5,
109
+ "id": "a70c7354",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "data": {
114
+ "text/plain": [
115
+ "Dataset({\n",
116
+ " features: ['image_file', 'caption', 'encoding'],\n",
117
+ " num_rows: 9999\n",
118
+ "})"
119
+ ]
120
+ },
121
+ "execution_count": 5,
122
+ "metadata": {},
123
+ "output_type": "execute_result"
124
+ }
125
+ ],
126
+ "source": [
127
+ "dataset = dataset[\"train\"]\n",
128
+ "dataset"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "id": "a73454cf",
134
+ "metadata": {},
135
+ "source": [
136
+ "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "7c0fa992",
142
+ "metadata": {},
143
+ "source": [
144
+ "## Preprocessing"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "a0e36582",
150
+ "metadata": {},
151
+ "source": [
152
+ "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 6,
158
+ "id": "d46f6ac5",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
+ "max_length = 256 # Read from data_args.max_source_length\n",
164
+ "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
165
+ "image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 7,
171
+ "id": "4cac6643",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "def preprocess_function(examples):\n",
176
+ " inputs = examples[\"caption\"]\n",
177
+ "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
178
+ " model_inputs = tokenizer(\n",
179
+ " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
180
+ " )\n",
181
+ "\n",
182
+ " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
183
+ "\n",
184
+ " return model_inputs"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 8,
190
+ "id": "e6a4cb91",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "num_workers = 48 # We have 96 processors in the TPU\n",
195
+ "column_names = dataset.column_names\n",
196
+ "input_dataset = dataset.map(preprocess_function,\n",
197
+ " remove_columns=column_names,\n",
198
+ " batched=True,\n",
199
+ " num_proc=48\n",
200
+ ")"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 9,
206
+ "id": "a9b1b467",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
211
+ " \"\"\"\n",
212
+ " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
213
+ " Shuffle batches if `shuffle` is `True`.\n",
214
+ " \"\"\"\n",
215
+ " steps_per_epoch = len(dataset) // batch_size\n",
216
+ "\n",
217
+ " if shuffle:\n",
218
+ " batch_idx = jax.random.permutation(rng, len(dataset))\n",
219
+ " else:\n",
220
+ " batch_idx = jnp.arange(len(dataset))\n",
221
+ "\n",
222
+ " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
223
+ " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
224
+ "\n",
225
+ " for idx in batch_idx:\n",
226
+ " batch = dataset[idx] \n",
227
+ " batch = {k: jnp.array(v) for k, v in batch.items()}\n",
228
+ " batch = shard(batch)\n",
229
+ " yield batch"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 10,
235
+ "id": "0a628505",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stderr",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "INFO:absl:Starting the local TPU driver.\n",
243
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
244
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
250
+ "batch_size = 64 # Per device\n",
251
+ "super_batch_size = batch_size * jax.device_count()"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 11,
257
+ "id": "b3a5ce7d",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 12,
267
+ "id": "67aa8f9c",
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": [
271
+ "superbatch = next(iter(loader))"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 13,
277
+ "id": "7cd99402",
278
+ "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "data": {
282
+ "text/plain": [
283
+ "dict_keys(['attention_mask', 'input_ids', 'labels'])"
284
+ ]
285
+ },
286
+ "execution_count": 13,
287
+ "metadata": {},
288
+ "output_type": "execute_result"
289
+ }
290
+ ],
291
+ "source": [
292
+ "superbatch.keys()"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 14,
298
+ "id": "652a4a9e",
299
+ "metadata": {},
300
+ "outputs": [
301
+ {
302
+ "data": {
303
+ "text/plain": [
304
+ "8"
305
+ ]
306
+ },
307
+ "execution_count": 14,
308
+ "metadata": {},
309
+ "output_type": "execute_result"
310
+ }
311
+ ],
312
+ "source": [
313
+ "len(superbatch[\"labels\"])"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "execution_count": 15,
319
+ "id": "de7de4e8",
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "data": {
324
+ "text/plain": [
325
+ "(8, 64, 257)"
326
+ ]
327
+ },
328
+ "execution_count": 15,
329
+ "metadata": {},
330
+ "output_type": "execute_result"
331
+ }
332
+ ],
333
+ "source": [
334
+ "superbatch[\"labels\"].shape"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "id": "6800153b",
340
+ "metadata": {},
341
+ "source": [
342
+ "Any image sequence should begin with `image_bos`:"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": 16,
348
+ "id": "cfe23a71",
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "0fb899b4",
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": []
362
+ }
363
+ ],
364
+ "metadata": {
365
+ "kernelspec": {
366
+ "display_name": "Python 3 (ipykernel)",
367
+ "language": "python",
368
+ "name": "python3"
369
+ },
370
+ "language_info": {
371
+ "codemirror_mode": {
372
+ "name": "ipython",
373
+ "version": 3
374
+ },
375
+ "file_extension": ".py",
376
+ "mimetype": "text/x-python",
377
+ "name": "python",
378
+ "nbconvert_exporter": "python",
379
+ "pygments_lexer": "ipython3",
380
+ "version": "3.8.10"
381
+ }
382
+ },
383
+ "nbformat": 4,
384
+ "nbformat_minor": 5
385
+ }