boris commited on
Commit
07eca3c
1 Parent(s): 92ccf4c

chore: remove unused files

Browse files
dalle_mini/dataset.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- An image-caption dataset dataloader.
3
- Luke Melas-Kyriazi, 2021
4
- """
5
- import warnings
6
- from typing import Optional, Callable
7
- from pathlib import Path
8
- import numpy as np
9
- import torch
10
- import pandas as pd
11
- from torch.utils.data import Dataset
12
- from torchvision.datasets.folder import default_loader
13
- from PIL import ImageFile
14
- from PIL.Image import DecompressionBombWarning
15
- ImageFile.LOAD_TRUNCATED_IMAGES = True
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
-
19
-
20
- class CaptionDataset(Dataset):
21
- """
22
- A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
- returns the raw text rather than tokens. This is done on purpose, because
24
- it's easy to tokenize a batch of text after loading it from this dataset.
25
- """
26
-
27
- def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
- image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
- include_captions: bool = True):
30
- """
31
- :param images_root: folder where images are stored
32
- :param captions_path: path to csv that maps image filenames to captions
33
- :param image_transform: image transform pipeline
34
- :param text_transform: image transform pipeline
35
- :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
- :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
- """
38
-
39
- # Base path for images
40
- self.images_root = Path(images_root)
41
-
42
- # Load captions as DataFrame
43
- self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
- self.captions['image_file'] = self.captions['image_file'].astype(str)
45
-
46
- # PyTorch transformation pipeline for the image (normalizing, etc.)
47
- self.text_transform = text_transform
48
- self.image_transform = image_transform
49
- self.image_transform_type = image_transform_type.lower()
50
- assert self.image_transform_type in ['torchvision', 'albumentations']
51
-
52
- # Total number of datapoints
53
- self.size = len(self.captions)
54
-
55
- # Return image+captions or just images
56
- self.include_captions = include_captions
57
-
58
- def verify_that_all_images_exist(self):
59
- for image_file in self.captions['image_file']:
60
- p = self.images_root / image_file
61
- if not p.is_file():
62
- print(f'file does not exist: {p}')
63
-
64
- def _get_raw_image(self, i):
65
- image_file = self.captions.iloc[i]['image_file']
66
- image_path = self.images_root / image_file
67
- image = default_loader(image_path)
68
- return image
69
-
70
- def _get_raw_text(self, i):
71
- return self.captions.iloc[i]['caption']
72
-
73
- def __getitem__(self, i):
74
- image = self._get_raw_image(i)
75
- caption = self._get_raw_text(i)
76
- if self.image_transform is not None:
77
- if self.image_transform_type == 'torchvision':
78
- image = self.image_transform(image)
79
- elif self.image_transform_type == 'albumentations':
80
- image = self.image_transform(image=np.array(image))['image']
81
- else:
82
- raise NotImplementedError(f"{self.image_transform_type=}")
83
- return {'image': image, 'text': caption} if self.include_captions else image
84
-
85
- def __len__(self):
86
- return self.size
87
-
88
-
89
- if __name__ == "__main__":
90
- import albumentations as A
91
- from albumentations.pytorch import ToTensorV2
92
- from transformers import AutoTokenizer
93
-
94
- # Paths
95
- images_root = './images'
96
- captions_path = './images-list-clean.tsv'
97
-
98
- # Create transforms
99
- tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
- def tokenize(text):
101
- return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
- image_transform = A.Compose([
103
- A.Resize(256, 256), A.CenterCrop(256, 256),
104
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
-
106
- # Create dataset
107
- dataset = CaptionDataset(
108
- images_root=images_root,
109
- captions_path=captions_path,
110
- image_transform=image_transform,
111
- text_transform=tokenize,
112
- image_transform_type='albumentations')
113
-
114
- # Create dataloader
115
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
- batch = next(iter(dataloader))
117
- print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
-
119
- # # (Optional) Check that all the images exist
120
- # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
- # dataset.verify_that_all_images_exist()
122
- # print('Done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/README.md DELETED
@@ -1,122 +0,0 @@
1
- # Development Instructions for TPU
2
-
3
- ## Setup
4
-
5
- - Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
6
- - Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
7
- - Verify `gcloud config list`, in particular account, project & zone.
8
- - Create a TPU VM per the guide and connect to it.
9
-
10
- When needing a larger disk:
11
-
12
- - Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
13
- - Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
14
- - Format the partition as described in the guide.
15
- - Make sure to set up automatic remount of disk at restart.
16
-
17
- ## Connect VS Code
18
-
19
- - Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
20
- - Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
21
- - Add the same command as ssh host in VS Code.
22
- - Check config file
23
-
24
- ```
25
- Host INSTANCE_NAME
26
- HostName EXTERNAL_IP
27
- IdentityFile ~/.ssh/google_compute_engine
28
- ```
29
-
30
- ## Environment configuration
31
-
32
- ### Use virtual environments (optional)
33
-
34
- We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
35
-
36
- If you want to use `pyenv` and `pyenv-virtualenv`:
37
-
38
- - Installation
39
-
40
- - [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
41
- - Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
42
- - bash set-up:
43
-
44
- ```bash
45
- echo '\n'\
46
- '# pyenv setup \n'\
47
- 'export PYENV_ROOT="$HOME/.pyenv" \n'\
48
- 'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
49
- 'eval "$(pyenv init --path)" \n'\
50
- 'eval "$(pyenv init -)" \n'\
51
- 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
52
- ```
53
-
54
- - Usage
55
-
56
- - Install a python version: `pyenv install X.X.X`
57
- - Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
58
- - Activate: `pyenv activate dalle_env`
59
-
60
- Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
61
-
62
- ### Tools
63
-
64
- - Git
65
-
66
- - `git config --global user.email "name@domain.com"
67
- - `git config --global user.name "First Last"
68
-
69
- - Github CLI
70
-
71
- - See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
72
- - `gh auth login`
73
-
74
- - Direnv
75
-
76
- - Install direnv: `sudo apt-get update && sudo apt-get install direnv`
77
- - bash set-up:
78
-
79
- ```bash
80
- echo -e '\n'\
81
- '# direnv setup \n'\
82
- 'eval "$(direnv hook bash)" \n' >> ~/.bashrc
83
- ```
84
-
85
- ### Set up repo
86
-
87
- - Clone repo: `gh repo clone borisdayma/dalle-mini`
88
- - If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
89
-
90
- ## Environment
91
-
92
- - Install the following (use it later to update our dev requirements.txt)
93
-
94
- ```
95
- requests
96
- pillow
97
- jupyterlab
98
- ipywidgets
99
-
100
- -e ../datasets[streaming]
101
- -e ../transformers
102
- -e ../webdataset
103
-
104
- # JAX
105
- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
106
- jax[tpu]>=0.2.16
107
- flax
108
- ```
109
-
110
- - `transformers-cli login`
111
-
112
- ---
113
-
114
- - set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
115
-
116
- ## Working with datasets or models
117
-
118
- - Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
119
- - Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
120
- - Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
121
- - Track specific extentions: `git lfs track "*.ext"`
122
- - See files tracked with LFS with `git lfs ls-files`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/CC12M_downloader.py DELETED
@@ -1,91 +0,0 @@
1
- # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
2
-
3
- #%%
4
- import sys
5
- import os
6
- from datetime import datetime
7
- import pandas as pd
8
- import contexttimer
9
- from urllib.request import urlopen
10
- import requests
11
- from PIL import Image
12
- import torch
13
- from torchvision.transforms import functional as TF
14
- from multiprocessing import Pool
15
- from tqdm import tqdm
16
- import logging
17
-
18
- # Setup
19
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
20
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
21
-
22
-
23
- # # For downloading SVG images (I can't get this to work)
24
- # from io import BytesIO
25
- # import cairosvg
26
-
27
- #%%
28
- # Load data
29
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
30
- with contexttimer.Timer(prefix="Loading from tsv"):
31
- df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
32
-
33
- url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
34
- print(f'Loaded {len(url_to_idx_map)} urls')
35
-
36
- #%%
37
- df.head()
38
-
39
- #%%
40
-
41
- # Note: it seems that there are no SVG images
42
- df.sample(10000)[1].str.contains('.svg').sum()
43
-
44
- #%%
45
- # Resize function
46
- def resize(img):
47
- max_size_of_short_side = 512
48
- if min(img.size) > max_size_of_short_side:
49
- img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
50
- return img
51
-
52
- base_dir = os.path.join(os.getcwd(), 'images')
53
-
54
- def process(item):
55
- url, image_id = item
56
- try:
57
- base_url = os.path.basename(url) # extract base url
58
- stem, ext = os.path.splitext(base_url) # split into stem and extension
59
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
60
- filepath = os.path.join(base_dir, filename) # concat to get filepath
61
- if not os.path.isfile(filepath):
62
- # if filepath.endswith('.svg'):
63
- # raise NotImplementedError()
64
- # image_bytes = BytesIO() # create a bytestream
65
- # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
66
- # else:
67
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
68
- image = Image.open(req).convert('RGB')
69
- if min(image.size) > 512:
70
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
71
- # image = resize(image) # resize PIL image
72
- image.save(filepath) # save PIL image
73
- except Exception as e:
74
- logging.info(" ".join(repr(e).splitlines()))
75
- logging.error(url)
76
-
77
- #%%
78
- #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
79
- # process(item)
80
- # if i > 100:
81
- # break
82
-
83
- # Use multiprocessing for speed
84
- list_of_items = list(url_to_idx_map.items())
85
- print(len(list_of_items))
86
- list_of_items = list_of_items[10_000_000:]
87
- print(len(list_of_items))
88
- with Pool(128) as p:
89
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
90
- print('DONE')
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/CC3M_downloader.py DELETED
@@ -1,62 +0,0 @@
1
- '''
2
- This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
3
- Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
4
- Find them here- [https://github.com/google-research-datasets/conceptual-captions]
5
- '''
6
-
7
- import sys
8
- import os
9
- from datetime import datetime
10
- import pandas as pd
11
- import contexttimer
12
- from urllib.request import urlopen
13
- import requests
14
- from PIL import Image
15
- import torch
16
- from torchvision.transforms import functional as TF
17
- from multiprocessing import Pool
18
- from tqdm import tqdm
19
- import logging
20
- import sys
21
-
22
- # Setup
23
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
24
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
25
-
26
- if len(sys.argv) != 3:
27
- print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
28
- exit(1)
29
-
30
- # Load data
31
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
32
- with contexttimer.Timer(prefix="Loading from tsv"):
33
- df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
34
-
35
- url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
36
- print(f'Loaded {len(url_to_idx_map)} urls')
37
-
38
- base_dir = os.path.join(os.getcwd(), sys.argv[2])
39
-
40
- def process(item):
41
- url, image_id = item
42
- try:
43
- base_url = os.path.basename(url) # extract base url
44
- stem, ext = os.path.splitext(base_url) # split into stem and extension
45
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
46
- filepath = os.path.join(base_dir, filename) # concat to get filepath
47
- if not os.path.isfile(filepath):
48
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
49
- image = Image.open(req).convert('RGB')
50
- if min(image.size) > 512:
51
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
52
- image.save(filepath) # save PIL image
53
- except Exception as e:
54
- logging.info(" ".join(repr(e).splitlines()))
55
- logging.error(url)
56
-
57
- list_of_items = list(url_to_idx_map.items())
58
- print(len(list_of_items))
59
-
60
- with Pool(128) as p:
61
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
62
- print('DONE')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/README.md DELETED
@@ -1,3 +0,0 @@
1
- # Data
2
-
3
- Utility scripts for downloading CC3M and CC12M.
 
 
 
 
dev/encoding/vqgan-jax-encoding-streaming.ipynb DELETED
@@ -1,562 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "ba7b31e6",
14
- "metadata": {},
15
- "source": [
16
- "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
17
- "\n",
18
- "This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": null,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "import os\n",
40
- "\n",
41
- "import jax\n",
42
- "from jax import pmap"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "c7c4c1e6",
48
- "metadata": {},
49
- "source": [
50
- "## Dataset and Parameters"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": null,
56
- "id": "d45a289e",
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "import datasets\n",
61
- "from datasets import Dataset, load_dataset"
62
- ]
63
- },
64
- {
65
- "cell_type": "markdown",
66
- "id": "f26e4f18",
67
- "metadata": {},
68
- "source": [
69
- "We'll use the `validation` set for testing. Adjust accordingly."
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": null,
75
- "id": "28893c3e",
76
- "metadata": {},
77
- "outputs": [],
78
- "source": [
79
- "dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": null,
85
- "id": "33861477",
86
- "metadata": {},
87
- "outputs": [],
88
- "source": [
89
- "from pathlib import Path\n",
90
- "\n",
91
- "yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
92
- "yfcc100m_output = yfcc100m/'encoded' # Output directory for encoded files"
93
- ]
94
- },
95
- {
96
- "cell_type": "code",
97
- "execution_count": null,
98
- "id": "6e7b71c4",
99
- "metadata": {},
100
- "outputs": [],
101
- "source": [
102
- "batch_size = 128 # Per device\n",
103
- "num_workers = 16 # Unused in streaming mode"
104
- ]
105
- },
106
- {
107
- "cell_type": "markdown",
108
- "id": "0793c26a",
109
- "metadata": {},
110
- "source": [
111
- "### Data preparation"
112
- ]
113
- },
114
- {
115
- "cell_type": "markdown",
116
- "id": "86415769",
117
- "metadata": {},
118
- "source": [
119
- "* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
120
- "* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
121
- "\n",
122
- "These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
123
- ]
124
- },
125
- {
126
- "cell_type": "markdown",
127
- "id": "0fdf1851",
128
- "metadata": {},
129
- "source": [
130
- "This helper function is used to decode images from the bytes retrieved in `streaming` mode."
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "id": "5bbca804",
137
- "metadata": {},
138
- "outputs": [],
139
- "source": [
140
- "from PIL import Image\n",
141
- "import io\n",
142
- "\n",
143
- "def get_image(byte_stream):\n",
144
- " image = Image.open(io.BytesIO(byte_stream))\n",
145
- " return image.convert('RGB')"
146
- ]
147
- },
148
- {
149
- "cell_type": "markdown",
150
- "id": "b435290b",
151
- "metadata": {},
152
- "source": [
153
- "Image processing"
154
- ]
155
- },
156
- {
157
- "cell_type": "code",
158
- "execution_count": null,
159
- "id": "7e73dfa3",
160
- "metadata": {},
161
- "outputs": [],
162
- "source": [
163
- "def center_crop(image, max_size=256):\n",
164
- " # Note: we allow upscaling too. We should exclude small images. \n",
165
- " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
166
- " image = TF.center_crop(image, output_size=2 * [max_size])\n",
167
- " return image\n",
168
- "\n",
169
- "preprocess_image = T.Compose([\n",
170
- " get_image,\n",
171
- " center_crop,\n",
172
- " T.ToTensor(),\n",
173
- " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
174
- "])"
175
- ]
176
- },
177
- {
178
- "cell_type": "markdown",
179
- "id": "1e3ac8de",
180
- "metadata": {},
181
- "source": [
182
- "Caption preparation"
183
- ]
184
- },
185
- {
186
- "cell_type": "code",
187
- "execution_count": null,
188
- "id": "aadb4d23",
189
- "metadata": {},
190
- "outputs": [],
191
- "source": [
192
- "import string\n",
193
- "\n",
194
- "def create_caption(title, description):\n",
195
- " title = title.strip()\n",
196
- " description = description.strip()\n",
197
- " if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
198
- " return f'{title} {description}'"
199
- ]
200
- },
201
- {
202
- "cell_type": "markdown",
203
- "id": "3c4522b9",
204
- "metadata": {},
205
- "source": [
206
- "And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
207
- ]
208
- },
209
- {
210
- "cell_type": "code",
211
- "execution_count": null,
212
- "id": "2566ff68",
213
- "metadata": {},
214
- "outputs": [],
215
- "source": [
216
- "def prepare_item(item):\n",
217
- " return {\n",
218
- " 'key': item['key'],\n",
219
- " 'caption': create_caption(item['title_clean'], item['description_clean']),\n",
220
- " 'image': preprocess_image(item['img'])\n",
221
- " }"
222
- ]
223
- },
224
- {
225
- "cell_type": "markdown",
226
- "id": "e519e475",
227
- "metadata": {},
228
- "source": [
229
- "Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": null,
235
- "id": "10d7750e",
236
- "metadata": {},
237
- "outputs": [],
238
- "source": [
239
- "prepared_dataset = dataset.map(prepare_item, batched=False)"
240
- ]
241
- },
242
- {
243
- "cell_type": "code",
244
- "execution_count": null,
245
- "id": "a8595539",
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "%%time\n",
250
- "item = next(iter(prepared_dataset))"
251
- ]
252
- },
253
- {
254
- "cell_type": "code",
255
- "execution_count": null,
256
- "id": "04a6eeb4",
257
- "metadata": {},
258
- "outputs": [],
259
- "source": [
260
- "assert(list(item.keys()) == ['key', 'caption', 'image'])"
261
- ]
262
- },
263
- {
264
- "cell_type": "code",
265
- "execution_count": null,
266
- "id": "40d3115f",
267
- "metadata": {},
268
- "outputs": [],
269
- "source": [
270
- "item['image'].shape"
271
- ]
272
- },
273
- {
274
- "cell_type": "code",
275
- "execution_count": null,
276
- "id": "dd844e1c",
277
- "metadata": {},
278
- "outputs": [],
279
- "source": [
280
- "T.ToPILImage()(item['image'].permute(2, 0, 1))"
281
- ]
282
- },
283
- {
284
- "cell_type": "markdown",
285
- "id": "44d50a51",
286
- "metadata": {},
287
- "source": [
288
- "### Torch DataLoader"
289
- ]
290
- },
291
- {
292
- "cell_type": "markdown",
293
- "id": "17a4bbc6",
294
- "metadata": {},
295
- "source": [
296
- "We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
297
- "\n",
298
- "We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
299
- ]
300
- },
301
- {
302
- "cell_type": "code",
303
- "execution_count": null,
304
- "id": "e1c08b7e",
305
- "metadata": {},
306
- "outputs": [],
307
- "source": [
308
- "import torch\n",
309
- "from torch.utils.data import DataLoader"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": null,
315
- "id": "6a296677",
316
- "metadata": {},
317
- "outputs": [],
318
- "source": [
319
- "torch_dataset = prepared_dataset.with_format(\"torch\")"
320
- ]
321
- },
322
- {
323
- "cell_type": "markdown",
324
- "id": "29ab13bc",
325
- "metadata": {},
326
- "source": [
327
- "**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
328
- ]
329
- },
330
- {
331
- "cell_type": "code",
332
- "execution_count": null,
333
- "id": "e2df5e13",
334
- "metadata": {},
335
- "outputs": [],
336
- "source": [
337
- "dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
338
- ]
339
- },
340
- {
341
- "cell_type": "code",
342
- "execution_count": null,
343
- "id": "c15e3783",
344
- "metadata": {},
345
- "outputs": [],
346
- "source": [
347
- "batch = next(iter(dataloader))"
348
- ]
349
- },
350
- {
351
- "cell_type": "code",
352
- "execution_count": null,
353
- "id": "71d027fe",
354
- "metadata": {},
355
- "outputs": [],
356
- "source": [
357
- "batch['image'].shape"
358
- ]
359
- },
360
- {
361
- "cell_type": "markdown",
362
- "id": "a354472b",
363
- "metadata": {},
364
- "source": [
365
- "## VQGAN-JAX model"
366
- ]
367
- },
368
- {
369
- "cell_type": "code",
370
- "execution_count": null,
371
- "id": "2fcf01d7",
372
- "metadata": {},
373
- "outputs": [],
374
- "source": [
375
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
376
- ]
377
- },
378
- {
379
- "cell_type": "markdown",
380
- "id": "9daa636d",
381
- "metadata": {},
382
- "source": [
383
- "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
384
- ]
385
- },
386
- {
387
- "cell_type": "code",
388
- "execution_count": null,
389
- "id": "47a8b818",
390
- "metadata": {
391
- "scrolled": true
392
- },
393
- "outputs": [],
394
- "source": [
395
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
396
- ]
397
- },
398
- {
399
- "cell_type": "markdown",
400
- "id": "62ad01c3",
401
- "metadata": {},
402
- "source": [
403
- "## Encoding"
404
- ]
405
- },
406
- {
407
- "cell_type": "markdown",
408
- "id": "20357f74",
409
- "metadata": {},
410
- "source": [
411
- "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": null,
417
- "id": "6686b004",
418
- "metadata": {},
419
- "outputs": [],
420
- "source": [
421
- "from flax.training.common_utils import shard\n",
422
- "from functools import partial"
423
- ]
424
- },
425
- {
426
- "cell_type": "code",
427
- "execution_count": null,
428
- "id": "322a4619",
429
- "metadata": {},
430
- "outputs": [],
431
- "source": [
432
- "@partial(jax.pmap, axis_name=\"batch\")\n",
433
- "def encode(batch):\n",
434
- " # Not sure if we should `replicate` params, does not seem to have any effect\n",
435
- " _, indices = model.encode(batch)\n",
436
- " return indices"
437
- ]
438
- },
439
- {
440
- "cell_type": "markdown",
441
- "id": "14375a41",
442
- "metadata": {},
443
- "source": [
444
- "### Encoding loop"
445
- ]
446
- },
447
- {
448
- "cell_type": "code",
449
- "execution_count": null,
450
- "id": "ff6c10d4",
451
- "metadata": {},
452
- "outputs": [],
453
- "source": [
454
- "import os\n",
455
- "import pandas as pd\n",
456
- "\n",
457
- "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
458
- " output_dir.mkdir(parents=True, exist_ok=True)\n",
459
- " \n",
460
- " # Saving strategy:\n",
461
- " # - Create a new file every so often to prevent excessive file seeking.\n",
462
- " # - Save each batch after processing.\n",
463
- " # - Keep the file open until we are done with it.\n",
464
- " file = None \n",
465
- " for n, batch in enumerate(tqdm(iter(dataloader))):\n",
466
- " if (n % save_every == 0):\n",
467
- " if file is not None:\n",
468
- " file.close()\n",
469
- " split_num = n // save_every\n",
470
- " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
471
- "\n",
472
- " images = batch[\"image\"].numpy()\n",
473
- " images = shard(images.squeeze())\n",
474
- " encoded = encode(images)\n",
475
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
476
- "\n",
477
- " keys = batch[\"key\"]\n",
478
- " captions = batch[\"caption\"]\n",
479
- "\n",
480
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
481
- " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
482
- " batch_df.to_json(file, orient='records', lines=True)"
483
- ]
484
- },
485
- {
486
- "cell_type": "markdown",
487
- "id": "09ff75a3",
488
- "metadata": {},
489
- "source": [
490
- "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
491
- ]
492
- },
493
- {
494
- "cell_type": "code",
495
- "execution_count": null,
496
- "id": "96222bb4",
497
- "metadata": {},
498
- "outputs": [],
499
- "source": [
500
- "save_every = 318"
501
- ]
502
- },
503
- {
504
- "cell_type": "code",
505
- "execution_count": null,
506
- "id": "7704863d",
507
- "metadata": {},
508
- "outputs": [
509
- {
510
- "name": "stderr",
511
- "output_type": "stream",
512
- "text": [
513
- "28it [01:17, 1.60s/it]"
514
- ]
515
- }
516
- ],
517
- "source": [
518
- "encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
519
- ]
520
- },
521
- {
522
- "cell_type": "markdown",
523
- "id": "e266a70a",
524
- "metadata": {},
525
- "source": [
526
- "This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
527
- ]
528
- },
529
- {
530
- "cell_type": "markdown",
531
- "id": "8953dd84",
532
- "metadata": {},
533
- "source": [
534
- "----"
535
- ]
536
- }
537
- ],
538
- "metadata": {
539
- "interpreter": {
540
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
541
- },
542
- "kernelspec": {
543
- "display_name": "Python 3 (ipykernel)",
544
- "language": "python",
545
- "name": "python3"
546
- },
547
- "language_info": {
548
- "codemirror_mode": {
549
- "name": "ipython",
550
- "version": 3
551
- },
552
- "file_extension": ".py",
553
- "mimetype": "text/x-python",
554
- "name": "python",
555
- "nbconvert_exporter": "python",
556
- "pygments_lexer": "ipython3",
557
- "version": "3.8.10"
558
- }
559
- },
560
- "nbformat": 4,
561
- "nbformat_minor": 5
562
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding-with-captions.ipynb DELETED
@@ -1,355 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-with-captions"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "875c82b3",
14
- "metadata": {},
15
- "source": [
16
- "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
17
- "\n",
18
- "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 1,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "\n",
41
- "import jax\n",
42
- "from jax import pmap"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "511c3b9e",
48
- "metadata": {},
49
- "source": [
50
- "## VQGAN-JAX model"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": 2,
56
- "id": "2ca50dc7",
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
61
- ]
62
- },
63
- {
64
- "cell_type": "markdown",
65
- "id": "7b60da9a",
66
- "metadata": {},
67
- "source": [
68
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": 3,
74
- "id": "29ce8b15",
75
- "metadata": {},
76
- "outputs": [
77
- {
78
- "data": {
79
- "application/vnd.jupyter.widget-view+json": {
80
- "model_id": "db406bdfc5d5428eaeae1631a04989dd",
81
- "version_major": 2,
82
- "version_minor": 0
83
- },
84
- "text/plain": [
85
- "Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
86
- ]
87
- },
88
- "metadata": {},
89
- "output_type": "display_data"
90
- },
91
- {
92
- "data": {
93
- "application/vnd.jupyter.widget-view+json": {
94
- "model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
95
- "version_major": 2,
96
- "version_minor": 0
97
- },
98
- "text/plain": [
99
- "Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
100
- ]
101
- },
102
- "metadata": {},
103
- "output_type": "display_data"
104
- },
105
- {
106
- "name": "stderr",
107
- "output_type": "stream",
108
- "text": [
109
- "INFO:absl:Starting the local TPU driver.\n",
110
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
111
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
112
- ]
113
- },
114
- {
115
- "name": "stdout",
116
- "output_type": "stream",
117
- "text": [
118
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
119
- ]
120
- }
121
- ],
122
- "source": [
123
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
124
- ]
125
- },
126
- {
127
- "cell_type": "markdown",
128
- "id": "c7c4c1e6",
129
- "metadata": {},
130
- "source": [
131
- "## Dataset"
132
- ]
133
- },
134
- {
135
- "cell_type": "markdown",
136
- "id": "7014a7ce",
137
- "metadata": {},
138
- "source": [
139
- "We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": 4,
145
- "id": "85832702",
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "from dalle_mini.dataset import *"
150
- ]
151
- },
152
- {
153
- "cell_type": "code",
154
- "execution_count": 5,
155
- "id": "81b19eca",
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "cc12m_images = '/data/CC12M/images'\n",
160
- "cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
161
- "# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
162
- "cc12m_output = '/data/CC12M/images-encoded.tsv'"
163
- ]
164
- },
165
- {
166
- "cell_type": "code",
167
- "execution_count": 6,
168
- "id": "fecc9a00",
169
- "metadata": {},
170
- "outputs": [],
171
- "source": [
172
- "image_size = 256\n",
173
- "def image_transform(image):\n",
174
- " s = min(image.size)\n",
175
- " r = image_size / s\n",
176
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
177
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
178
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
179
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
180
- " image = image.permute(0, 2, 3, 1).numpy()\n",
181
- " return image"
182
- ]
183
- },
184
- {
185
- "cell_type": "code",
186
- "execution_count": 7,
187
- "id": "4ce2211f",
188
- "metadata": {},
189
- "outputs": [],
190
- "source": [
191
- "dataset = CaptionDataset(\n",
192
- " images_root=cc12m_images,\n",
193
- " captions_path=cc12m_list,\n",
194
- " image_transform=image_transform,\n",
195
- " image_transform_type='torchvision',\n",
196
- " include_captions=False\n",
197
- ")"
198
- ]
199
- },
200
- {
201
- "cell_type": "code",
202
- "execution_count": 8,
203
- "id": "cc922704",
204
- "metadata": {},
205
- "outputs": [
206
- {
207
- "data": {
208
- "text/plain": [
209
- "8592141"
210
- ]
211
- },
212
- "execution_count": 8,
213
- "metadata": {},
214
- "output_type": "execute_result"
215
- }
216
- ],
217
- "source": [
218
- "len(dataset)"
219
- ]
220
- },
221
- {
222
- "cell_type": "markdown",
223
- "id": "62ad01c3",
224
- "metadata": {},
225
- "source": [
226
- "## Encoding"
227
- ]
228
- },
229
- {
230
- "cell_type": "code",
231
- "execution_count": 9,
232
- "id": "88f36d0b",
233
- "metadata": {},
234
- "outputs": [],
235
- "source": [
236
- "def encode(model, batch):\n",
237
- "# print(\"jitting encode function\")\n",
238
- " _, indices = model.encode(batch)\n",
239
- " return indices"
240
- ]
241
- },
242
- {
243
- "cell_type": "code",
244
- "execution_count": 10,
245
- "id": "1f35f0cb",
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "def superbatch_generator(dataloader, num_tpus):\n",
250
- " iter_loader = iter(dataloader)\n",
251
- " for batch in iter_loader:\n",
252
- " superbatch = [batch.squeeze(1)]\n",
253
- " try:\n",
254
- " for b in range(num_tpus-1):\n",
255
- " batch = next(iter_loader)\n",
256
- " if batch is None:\n",
257
- " break\n",
258
- " # Skip incomplete last batch\n",
259
- " if batch.shape[0] == dataloader.batch_size:\n",
260
- " superbatch.append(batch.squeeze(1))\n",
261
- " except StopIteration:\n",
262
- " pass\n",
263
- " superbatch = torch.stack(superbatch, axis=0)\n",
264
- " yield superbatch"
265
- ]
266
- },
267
- {
268
- "cell_type": "code",
269
- "execution_count": 11,
270
- "id": "2210705b",
271
- "metadata": {},
272
- "outputs": [],
273
- "source": [
274
- "import os\n",
275
- "\n",
276
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
277
- " if os.path.isfile(output_tsv):\n",
278
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
279
- " return\n",
280
- " \n",
281
- " num_tpus = 8 \n",
282
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
283
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
284
- " \n",
285
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
286
- "\n",
287
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
288
- " # We keep the file open to prevent excessive file seeks.\n",
289
- " with open(output_tsv, \"w\") as file:\n",
290
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
291
- " for n in tqdm(range(iterations)):\n",
292
- " superbatch = next(superbatches)\n",
293
- " encoded = p_encoder(superbatch.numpy())\n",
294
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
295
- "\n",
296
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
297
- " start_index = n * batch_size * num_tpus\n",
298
- " end_index = (n+1) * batch_size * num_tpus\n",
299
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
300
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
301
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
302
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
303
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
304
- " "
305
- ]
306
- },
307
- {
308
- "cell_type": "code",
309
- "execution_count": null,
310
- "id": "7704863d",
311
- "metadata": {},
312
- "outputs": [
313
- {
314
- "name": "stderr",
315
- "output_type": "stream",
316
- "text": [
317
- " 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
318
- ]
319
- }
320
- ],
321
- "source": [
322
- "encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
323
- ]
324
- },
325
- {
326
- "cell_type": "markdown",
327
- "id": "8953dd84",
328
- "metadata": {},
329
- "source": [
330
- "----"
331
- ]
332
- }
333
- ],
334
- "metadata": {
335
- "kernelspec": {
336
- "display_name": "Python 3 (ipykernel)",
337
- "language": "python",
338
- "name": "python3"
339
- },
340
- "language_info": {
341
- "codemirror_mode": {
342
- "name": "ipython",
343
- "version": 3
344
- },
345
- "file_extension": ".py",
346
- "mimetype": "text/x-python",
347
- "name": "python",
348
- "nbconvert_exporter": "python",
349
- "pygments_lexer": "ipython3",
350
- "version": "3.8.10"
351
- }
352
- },
353
- "nbformat": 4,
354
- "nbformat_minor": 5
355
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding-yfcc100m.ipynb DELETED
@@ -1,1129 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-yfcc100m"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "ba7b31e6",
14
- "metadata": {},
15
- "source": [
16
- "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
17
- "\n",
18
- "This dataset was prepared by @borisdayma in Json lines format."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 92,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "from torchvision.datasets.folder import default_loader\n",
41
- "import os\n",
42
- "\n",
43
- "import jax\n",
44
- "from jax import pmap"
45
- ]
46
- },
47
- {
48
- "cell_type": "markdown",
49
- "id": "511c3b9e",
50
- "metadata": {},
51
- "source": [
52
- "## VQGAN-JAX model"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 93,
58
- "id": "2ca50dc7",
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
63
- ]
64
- },
65
- {
66
- "cell_type": "markdown",
67
- "id": "7b60da9a",
68
- "metadata": {},
69
- "source": [
70
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": 167,
76
- "id": "29ce8b15",
77
- "metadata": {},
78
- "outputs": [
79
- {
80
- "name": "stdout",
81
- "output_type": "stream",
82
- "text": [
83
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
84
- ]
85
- }
86
- ],
87
- "source": [
88
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
89
- ]
90
- },
91
- {
92
- "cell_type": "markdown",
93
- "id": "c7c4c1e6",
94
- "metadata": {},
95
- "source": [
96
- "## Dataset"
97
- ]
98
- },
99
- {
100
- "cell_type": "code",
101
- "execution_count": 94,
102
- "id": "33861477",
103
- "metadata": {},
104
- "outputs": [],
105
- "source": [
106
- "import pandas as pd\n",
107
- "from pathlib import Path"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": 134,
113
- "id": "81b19eca",
114
- "metadata": {},
115
- "outputs": [],
116
- "source": [
117
- "yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')\n",
118
- "# Images are 'sharded' from the following directory\n",
119
- "yfcc100m_images = yfcc100m/'data'/'data'/'images'\n",
120
- "yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n",
121
- "yfcc100m_output = yfcc100m/'metadata_encoded.tsv'"
122
- ]
123
- },
124
- {
125
- "cell_type": "markdown",
126
- "id": "1c58bb4a",
127
- "metadata": {},
128
- "source": [
129
- "### Cleanup"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "1a14ae3d",
135
- "metadata": {},
136
- "source": [
137
- "We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas."
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": 96,
143
- "id": "7811648c",
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "import datasets\n",
148
- "from datasets import Dataset, load_dataset"
149
- ]
150
- },
151
- {
152
- "cell_type": "code",
153
- "execution_count": 10,
154
- "id": "4811a230",
155
- "metadata": {},
156
- "outputs": [
157
- {
158
- "name": "stderr",
159
- "output_type": "stream",
160
- "text": [
161
- "tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @ 0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
162
- "tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @ 0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
163
- "tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
164
- "tcmalloc: large alloc 5019099136 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
165
- "tcmalloc: large alloc 5019811840 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
166
- "tcmalloc: large alloc 5024571392 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
167
- "tcmalloc: large alloc 5021097984 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
168
- "tcmalloc: large alloc 5022818304 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
169
- "tcmalloc: large alloc 5020794880 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
170
- "tcmalloc: large alloc 5019451392 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
171
- "tcmalloc: large alloc 5020565504 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
172
- "tcmalloc: large alloc 5012561920 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
173
- "tcmalloc: large alloc 5021835264 bytes == 0x5f6cba000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
174
- "tcmalloc: large alloc 5017436160 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n"
175
- ]
176
- }
177
- ],
178
- "source": [
179
- "# The metadata is too bog to load into memory at once, so chopping it into chunks\n",
180
- "chunk_size=1000000\n",
181
- "batch_no=1\n",
182
- "for chunk in pd.read_json(yfcc100m_metadata, orient=\"records\", lines=True,chunksize=chunk_size):\n",
183
- " chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep=\"\\t\", index=False)\n",
184
- " batch_no+=1"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 25,
190
- "id": "46b2f083",
191
- "metadata": {},
192
- "outputs": [
193
- {
194
- "data": {
195
- "text/html": [
196
- "<div>\n",
197
- "<style scoped>\n",
198
- " .dataframe tbody tr th:only-of-type {\n",
199
- " vertical-align: middle;\n",
200
- " }\n",
201
- "\n",
202
- " .dataframe tbody tr th {\n",
203
- " vertical-align: top;\n",
204
- " }\n",
205
- "\n",
206
- " .dataframe thead th {\n",
207
- " text-align: right;\n",
208
- " }\n",
209
- "</style>\n",
210
- "<table border=\"1\" class=\"dataframe\">\n",
211
- " <thead>\n",
212
- " <tr style=\"text-align: right;\">\n",
213
- " <th></th>\n",
214
- " <th>photoid</th>\n",
215
- " <th>uid</th>\n",
216
- " <th>unickname</th>\n",
217
- " <th>datetaken</th>\n",
218
- " <th>dateuploaded</th>\n",
219
- " <th>capturedevice</th>\n",
220
- " <th>title</th>\n",
221
- " <th>description</th>\n",
222
- " <th>usertags</th>\n",
223
- " <th>machinetags</th>\n",
224
- " <th>...</th>\n",
225
- " <th>licenseurl</th>\n",
226
- " <th>serverid</th>\n",
227
- " <th>farmid</th>\n",
228
- " <th>secret</th>\n",
229
- " <th>secretoriginal</th>\n",
230
- " <th>ext</th>\n",
231
- " <th>marker</th>\n",
232
- " <th>key</th>\n",
233
- " <th>title_clean</th>\n",
234
- " <th>description_clean</th>\n",
235
- " </tr>\n",
236
- " </thead>\n",
237
- " <tbody>\n",
238
- " <tr>\n",
239
- " <th>0</th>\n",
240
- " <td>137943</td>\n",
241
- " <td>48600072071@N01</td>\n",
242
- " <td>doctor+paradox</td>\n",
243
- " <td>2004-08-01 18:13:06.0</td>\n",
244
- " <td>1091409186</td>\n",
245
- " <td>NaN</td>\n",
246
- " <td>A+Picture+Share%21</td>\n",
247
- " <td>Antenna</td>\n",
248
- " <td>cameraphone,cayugaheights,green,hydrant,ithaca...</td>\n",
249
- " <td>NaN</td>\n",
250
- " <td>...</td>\n",
251
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
252
- " <td>1</td>\n",
253
- " <td>1</td>\n",
254
- " <td>1650c7cdc6</td>\n",
255
- " <td>1650c7cdc6</td>\n",
256
- " <td>jpg</td>\n",
257
- " <td>0</td>\n",
258
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
259
- " <td>A Picture Share!</td>\n",
260
- " <td>Antenna</td>\n",
261
- " </tr>\n",
262
- " <tr>\n",
263
- " <th>1</th>\n",
264
- " <td>1246361</td>\n",
265
- " <td>44124324682@N01</td>\n",
266
- " <td>mharrsch</td>\n",
267
- " <td>2004-11-03 23:04:02.0</td>\n",
268
- " <td>1099523042</td>\n",
269
- " <td>NaN</td>\n",
270
- " <td>An+ornate+Roman+urn</td>\n",
271
- " <td>Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...</td>\n",
272
- " <td>ancient,baltimore,burial,death,empire,funeral,...</td>\n",
273
- " <td>NaN</td>\n",
274
- " <td>...</td>\n",
275
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
276
- " <td>1</td>\n",
277
- " <td>1</td>\n",
278
- " <td>cf37054610</td>\n",
279
- " <td>cf37054610</td>\n",
280
- " <td>jpg</td>\n",
281
- " <td>0</td>\n",
282
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
283
- " <td>An ornate Roman urn</td>\n",
284
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
285
- " </tr>\n",
286
- " <tr>\n",
287
- " <th>2</th>\n",
288
- " <td>1251599</td>\n",
289
- " <td>51035803024@N01</td>\n",
290
- " <td>bmitd67</td>\n",
291
- " <td>2004-10-30 17:09:32.0</td>\n",
292
- " <td>1099538888</td>\n",
293
- " <td>Canon+PowerShot+S30</td>\n",
294
- " <td>Jai+%26+Tara+on+the+Cumberland</td>\n",
295
- " <td>Another+trip+for+the+happy+couple.</td>\n",
296
- " <td>blue+heron,cumberland+river,jai,tara,tennessee</td>\n",
297
- " <td>NaN</td>\n",
298
- " <td>...</td>\n",
299
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
300
- " <td>1</td>\n",
301
- " <td>1</td>\n",
302
- " <td>4a4234e32c</td>\n",
303
- " <td>4a4234e32c</td>\n",
304
- " <td>jpg</td>\n",
305
- " <td>0</td>\n",
306
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
307
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
308
- " <td>Another trip for the happy couple.</td>\n",
309
- " </tr>\n",
310
- " <tr>\n",
311
- " <th>3</th>\n",
312
- " <td>2348587</td>\n",
313
- " <td>73621375@N00</td>\n",
314
- " <td>Thom+Watson</td>\n",
315
- " <td>2004-12-18 21:08:09.0</td>\n",
316
- " <td>1103497228</td>\n",
317
- " <td>SONY+DSC-W1</td>\n",
318
- " <td>Castle+gate+-+%22lite-brited%22</td>\n",
319
- " <td>Taken+at+the+Miracle+of+Lights+display+in+Cent...</td>\n",
320
- " <td>bullrunpark,castle,centreville,christmas,decor...</td>\n",
321
- " <td>NaN</td>\n",
322
- " <td>...</td>\n",
323
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
324
- " <td>2</td>\n",
325
- " <td>1</td>\n",
326
- " <td>7162c974c3</td>\n",
327
- " <td>7162c974c3</td>\n",
328
- " <td>jpg</td>\n",
329
- " <td>0</td>\n",
330
- " <td>d29ce96395848478b1e8396e44899</td>\n",
331
- " <td>Castle gate - \"lite-brited\"</td>\n",
332
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
333
- " </tr>\n",
334
- " <tr>\n",
335
- " <th>4</th>\n",
336
- " <td>3516047</td>\n",
337
- " <td>48600072071@N01</td>\n",
338
- " <td>doctor+paradox</td>\n",
339
- " <td>2005-01-18 16:44:18.0</td>\n",
340
- " <td>1106084658</td>\n",
341
- " <td>NaN</td>\n",
342
- " <td>A+Picture+Share%21</td>\n",
343
- " <td>Tabular</td>\n",
344
- " <td>cameraphone,moblog,unfound</td>\n",
345
- " <td>NaN</td>\n",
346
- " <td>...</td>\n",
347
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
348
- " <td>3</td>\n",
349
- " <td>1</td>\n",
350
- " <td>663e0d8b3d</td>\n",
351
- " <td>663e0d8b3d</td>\n",
352
- " <td>jpg</td>\n",
353
- " <td>0</td>\n",
354
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
355
- " <td>A Picture Share!</td>\n",
356
- " <td>Tabular</td>\n",
357
- " </tr>\n",
358
- " <tr>\n",
359
- " <th>...</th>\n",
360
- " <td>...</td>\n",
361
- " <td>...</td>\n",
362
- " <td>...</td>\n",
363
- " <td>...</td>\n",
364
- " <td>...</td>\n",
365
- " <td>...</td>\n",
366
- " <td>...</td>\n",
367
- " <td>...</td>\n",
368
- " <td>...</td>\n",
369
- " <td>...</td>\n",
370
- " <td>...</td>\n",
371
- " <td>...</td>\n",
372
- " <td>...</td>\n",
373
- " <td>...</td>\n",
374
- " <td>...</td>\n",
375
- " <td>...</td>\n",
376
- " <td>...</td>\n",
377
- " <td>...</td>\n",
378
- " <td>...</td>\n",
379
- " <td>...</td>\n",
380
- " <td>...</td>\n",
381
- " </tr>\n",
382
- " <tr>\n",
383
- " <th>999995</th>\n",
384
- " <td>4648651054</td>\n",
385
- " <td>24511045@N04</td>\n",
386
- " <td>mtfrazier</td>\n",
387
- " <td>2010-05-02 15:47:45.0</td>\n",
388
- " <td>1275083371</td>\n",
389
- " <td>Canon+EOS+50D</td>\n",
390
- " <td>U.S.+Navy+Blue+Angels%3A+2010</td>\n",
391
- " <td>2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri</td>\n",
392
- " <td>NaN</td>\n",
393
- " <td>NaN</td>\n",
394
- " <td>...</td>\n",
395
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
396
- " <td>4072</td>\n",
397
- " <td>5</td>\n",
398
- " <td>2d12d73fb0</td>\n",
399
- " <td>dd5856ea42</td>\n",
400
- " <td>jpg</td>\n",
401
- " <td>0</td>\n",
402
- " <td>60fa2911cb81eb25b356e9fee978aef</td>\n",
403
- " <td>U.S. Navy Blue Angels: 2010</td>\n",
404
- " <td>2 May 2010 Sunday St. Joseph, Missouri</td>\n",
405
- " </tr>\n",
406
- " <tr>\n",
407
- " <th>999996</th>\n",
408
- " <td>4652130996</td>\n",
409
- " <td>21963865@N04</td>\n",
410
- " <td>GRAB1.0</td>\n",
411
- " <td>2010-05-29 19:23:10.0</td>\n",
412
- " <td>1275200833</td>\n",
413
- " <td>SONY+DSLR-A230</td>\n",
414
- " <td>Attempts+on+Her+Life</td>\n",
415
- " <td>BAPA+1+production+of+Martin+Crimp%27s+Attempts...</td>\n",
416
- " <td>NaN</td>\n",
417
- " <td>NaN</td>\n",
418
- " <td>...</td>\n",
419
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
420
- " <td>4003</td>\n",
421
- " <td>5</td>\n",
422
- " <td>8889121579</td>\n",
423
- " <td>2f46599456</td>\n",
424
- " <td>jpg</td>\n",
425
- " <td>0</td>\n",
426
- " <td>60f5ef5ce4c2d24566226abebd67d4</td>\n",
427
- " <td>Attempts on Her Life</td>\n",
428
- " <td>BAPA 1 production of Martin Crimp's Attempts o...</td>\n",
429
- " </tr>\n",
430
- " <tr>\n",
431
- " <th>999997</th>\n",
432
- " <td>4652568339</td>\n",
433
- " <td>64025277@N00</td>\n",
434
- " <td>1Sock</td>\n",
435
- " <td>2010-05-13 15:38:37.0</td>\n",
436
- " <td>1275234267</td>\n",
437
- " <td>Canon+EOS+DIGITAL+REBEL+XT</td>\n",
438
- " <td>Carlsbad+Caverns+3</td>\n",
439
- " <td>%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...</td>\n",
440
- " <td>carlsbad,carlsbad+caverns,cave,faa,new+mexico,...</td>\n",
441
- " <td>NaN</td>\n",
442
- " <td>...</td>\n",
443
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
444
- " <td>4010</td>\n",
445
- " <td>5</td>\n",
446
- " <td>0a1808a69e</td>\n",
447
- " <td>cf6d348e3d</td>\n",
448
- " <td>jpg</td>\n",
449
- " <td>0</td>\n",
450
- " <td>60f029482d1d1028fda5281daf498f</td>\n",
451
- " <td>Carlsbad Caverns 3</td>\n",
452
- " <td>♥♥♥♥♥♥♥ Interested in purchasing this photogra...</td>\n",
453
- " </tr>\n",
454
- " <tr>\n",
455
- " <th>999998</th>\n",
456
- " <td>4653110895</td>\n",
457
- " <td>20483509@N00</td>\n",
458
- " <td>subberculture</td>\n",
459
- " <td>2010-05-30 15:37:05.0</td>\n",
460
- " <td>1275245596</td>\n",
461
- " <td>Canon+DIGITAL+IXUS+40</td>\n",
462
- " <td>Want</td>\n",
463
- " <td>Isn%27t+that+gorgeous%3F</td>\n",
464
- " <td>2010,edinburgh+museum,may,phonebox,wood</td>\n",
465
- " <td>NaN</td>\n",
466
- " <td>...</td>\n",
467
- " <td>http://creativecommons.org/licenses/by-sa/2.0/</td>\n",
468
- " <td>4066</td>\n",
469
- " <td>5</td>\n",
470
- " <td>77c3b3a254</td>\n",
471
- " <td>c4697e1511</td>\n",
472
- " <td>jpg</td>\n",
473
- " <td>0</td>\n",
474
- " <td>60f72775f433cf8de3efaeb431866153</td>\n",
475
- " <td>Want</td>\n",
476
- " <td>Isn't that gorgeous?</td>\n",
477
- " </tr>\n",
478
- " <tr>\n",
479
- " <th>999999</th>\n",
480
- " <td>4655503987</td>\n",
481
- " <td>8457193@N07</td>\n",
482
- " <td>zackojones</td>\n",
483
- " <td>2010-05-30 15:34:58.0</td>\n",
484
- " <td>1275310230</td>\n",
485
- " <td>Canon+EOS+7D</td>\n",
486
- " <td>Summertime</td>\n",
487
- " <td>You+gotta+love+it%21</td>\n",
488
- " <td>georgia,savannah,united+states,us</td>\n",
489
- " <td>NaN</td>\n",
490
- " <td>...</td>\n",
491
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
492
- " <td>4043</td>\n",
493
- " <td>5</td>\n",
494
- " <td>caff543bfe</td>\n",
495
- " <td>f60952ac4d</td>\n",
496
- " <td>jpg</td>\n",
497
- " <td>0</td>\n",
498
- " <td>60f687e11b913bce461e9525d8047e0</td>\n",
499
- " <td>Summertime</td>\n",
500
- " <td>You gotta love it!</td>\n",
501
- " </tr>\n",
502
- " </tbody>\n",
503
- "</table>\n",
504
- "<p>1000000 rows × 26 columns</p>\n",
505
- "</div>"
506
- ],
507
- "text/plain": [
508
- " photoid uid unickname datetaken \\\n",
509
- "0 137943 48600072071@N01 doctor+paradox 2004-08-01 18:13:06.0 \n",
510
- "1 1246361 44124324682@N01 mharrsch 2004-11-03 23:04:02.0 \n",
511
- "2 1251599 51035803024@N01 bmitd67 2004-10-30 17:09:32.0 \n",
512
- "3 2348587 73621375@N00 Thom+Watson 2004-12-18 21:08:09.0 \n",
513
- "4 3516047 48600072071@N01 doctor+paradox 2005-01-18 16:44:18.0 \n",
514
- "... ... ... ... ... \n",
515
- "999995 4648651054 24511045@N04 mtfrazier 2010-05-02 15:47:45.0 \n",
516
- "999996 4652130996 21963865@N04 GRAB1.0 2010-05-29 19:23:10.0 \n",
517
- "999997 4652568339 64025277@N00 1Sock 2010-05-13 15:38:37.0 \n",
518
- "999998 4653110895 20483509@N00 subberculture 2010-05-30 15:37:05.0 \n",
519
- "999999 4655503987 8457193@N07 zackojones 2010-05-30 15:34:58.0 \n",
520
- "\n",
521
- " dateuploaded capturedevice \\\n",
522
- "0 1091409186 NaN \n",
523
- "1 1099523042 NaN \n",
524
- "2 1099538888 Canon+PowerShot+S30 \n",
525
- "3 1103497228 SONY+DSC-W1 \n",
526
- "4 1106084658 NaN \n",
527
- "... ... ... \n",
528
- "999995 1275083371 Canon+EOS+50D \n",
529
- "999996 1275200833 SONY+DSLR-A230 \n",
530
- "999997 1275234267 Canon+EOS+DIGITAL+REBEL+XT \n",
531
- "999998 1275245596 Canon+DIGITAL+IXUS+40 \n",
532
- "999999 1275310230 Canon+EOS+7D \n",
533
- "\n",
534
- " title \\\n",
535
- "0 A+Picture+Share%21 \n",
536
- "1 An+ornate+Roman+urn \n",
537
- "2 Jai+%26+Tara+on+the+Cumberland \n",
538
- "3 Castle+gate+-+%22lite-brited%22 \n",
539
- "4 A+Picture+Share%21 \n",
540
- "... ... \n",
541
- "999995 U.S.+Navy+Blue+Angels%3A+2010 \n",
542
- "999996 Attempts+on+Her+Life \n",
543
- "999997 Carlsbad+Caverns+3 \n",
544
- "999998 Want \n",
545
- "999999 Summertime \n",
546
- "\n",
547
- " description \\\n",
548
- "0 Antenna \n",
549
- "1 Photographed+at+the+%3Ca+href%3D%22http%3A%2F%... \n",
550
- "2 Another+trip+for+the+happy+couple. \n",
551
- "3 Taken+at+the+Miracle+of+Lights+display+in+Cent... \n",
552
- "4 Tabular \n",
553
- "... ... \n",
554
- "999995 2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri \n",
555
- "999996 BAPA+1+production+of+Martin+Crimp%27s+Attempts... \n",
556
- "999997 %E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%... \n",
557
- "999998 Isn%27t+that+gorgeous%3F \n",
558
- "999999 You+gotta+love+it%21 \n",
559
- "\n",
560
- " usertags machinetags ... \\\n",
561
- "0 cameraphone,cayugaheights,green,hydrant,ithaca... NaN ... \n",
562
- "1 ancient,baltimore,burial,death,empire,funeral,... NaN ... \n",
563
- "2 blue+heron,cumberland+river,jai,tara,tennessee NaN ... \n",
564
- "3 bullrunpark,castle,centreville,christmas,decor... NaN ... \n",
565
- "4 cameraphone,moblog,unfound NaN ... \n",
566
- "... ... ... ... \n",
567
- "999995 NaN NaN ... \n",
568
- "999996 NaN NaN ... \n",
569
- "999997 carlsbad,carlsbad+caverns,cave,faa,new+mexico,... NaN ... \n",
570
- "999998 2010,edinburgh+museum,may,phonebox,wood NaN ... \n",
571
- "999999 georgia,savannah,united+states,us NaN ... \n",
572
- "\n",
573
- " licenseurl serverid farmid \\\n",
574
- "0 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
575
- "1 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
576
- "2 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
577
- "3 http://creativecommons.org/licenses/by-nc-sa/2.0/ 2 1 \n",
578
- "4 http://creativecommons.org/licenses/by-nc-sa/2.0/ 3 1 \n",
579
- "... ... ... ... \n",
580
- "999995 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4072 5 \n",
581
- "999996 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4003 5 \n",
582
- "999997 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4010 5 \n",
583
- "999998 http://creativecommons.org/licenses/by-sa/2.0/ 4066 5 \n",
584
- "999999 http://creativecommons.org/licenses/by-nc-sa/2.0/ 4043 5 \n",
585
- "\n",
586
- " secret secretoriginal ext marker \\\n",
587
- "0 1650c7cdc6 1650c7cdc6 jpg 0 \n",
588
- "1 cf37054610 cf37054610 jpg 0 \n",
589
- "2 4a4234e32c 4a4234e32c jpg 0 \n",
590
- "3 7162c974c3 7162c974c3 jpg 0 \n",
591
- "4 663e0d8b3d 663e0d8b3d jpg 0 \n",
592
- "... ... ... ... ... \n",
593
- "999995 2d12d73fb0 dd5856ea42 jpg 0 \n",
594
- "999996 8889121579 2f46599456 jpg 0 \n",
595
- "999997 0a1808a69e cf6d348e3d jpg 0 \n",
596
- "999998 77c3b3a254 c4697e1511 jpg 0 \n",
597
- "999999 caff543bfe f60952ac4d jpg 0 \n",
598
- "\n",
599
- " key title_clean \\\n",
600
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
601
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
602
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
603
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
604
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
605
- "... ... ... \n",
606
- "999995 60fa2911cb81eb25b356e9fee978aef U.S. Navy Blue Angels: 2010 \n",
607
- "999996 60f5ef5ce4c2d24566226abebd67d4 Attempts on Her Life \n",
608
- "999997 60f029482d1d1028fda5281daf498f Carlsbad Caverns 3 \n",
609
- "999998 60f72775f433cf8de3efaeb431866153 Want \n",
610
- "999999 60f687e11b913bce461e9525d8047e0 Summertime \n",
611
- "\n",
612
- " description_clean \n",
613
- "0 Antenna \n",
614
- "1 Photographed at the Walters Art Museum, Baltim... \n",
615
- "2 Another trip for the happy couple. \n",
616
- "3 Taken at the Miracle of Lights display in Cent... \n",
617
- "4 Tabular \n",
618
- "... ... \n",
619
- "999995 2 May 2010 Sunday St. Joseph, Missouri \n",
620
- "999996 BAPA 1 production of Martin Crimp's Attempts o... \n",
621
- "999997 ♥♥♥♥♥♥♥ Interested in purchasing this photogra... \n",
622
- "999998 Isn't that gorgeous? \n",
623
- "999999 You gotta love it! \n",
624
- "\n",
625
- "[1000000 rows x 26 columns]"
626
- ]
627
- },
628
- "execution_count": 25,
629
- "metadata": {},
630
- "output_type": "execute_result"
631
- }
632
- ],
633
- "source": [
634
- "# looking up at a chunk\n",
635
- "pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")"
636
- ]
637
- },
638
- {
639
- "cell_type": "code",
640
- "execution_count": 98,
641
- "id": "c51c5597",
642
- "metadata": {},
643
- "outputs": [
644
- {
645
- "data": {
646
- "text/html": [
647
- "<div>\n",
648
- "<style scoped>\n",
649
- " .dataframe tbody tr th:only-of-type {\n",
650
- " vertical-align: middle;\n",
651
- " }\n",
652
- "\n",
653
- " .dataframe tbody tr th {\n",
654
- " vertical-align: top;\n",
655
- " }\n",
656
- "\n",
657
- " .dataframe thead th {\n",
658
- " text-align: right;\n",
659
- " }\n",
660
- "</style>\n",
661
- "<table border=\"1\" class=\"dataframe\">\n",
662
- " <thead>\n",
663
- " <tr style=\"text-align: right;\">\n",
664
- " <th></th>\n",
665
- " <th>key</th>\n",
666
- " <th>title_clean</th>\n",
667
- " <th>description_clean</th>\n",
668
- " <th>ext</th>\n",
669
- " </tr>\n",
670
- " </thead>\n",
671
- " <tbody>\n",
672
- " <tr>\n",
673
- " <th>0</th>\n",
674
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
675
- " <td>A Picture Share!</td>\n",
676
- " <td>Antenna</td>\n",
677
- " <td>jpg</td>\n",
678
- " </tr>\n",
679
- " <tr>\n",
680
- " <th>1</th>\n",
681
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
682
- " <td>An ornate Roman urn</td>\n",
683
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
684
- " <td>jpg</td>\n",
685
- " </tr>\n",
686
- " <tr>\n",
687
- " <th>2</th>\n",
688
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
689
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
690
- " <td>Another trip for the happy couple.</td>\n",
691
- " <td>jpg</td>\n",
692
- " </tr>\n",
693
- " <tr>\n",
694
- " <th>3</th>\n",
695
- " <td>d29ce96395848478b1e8396e44899</td>\n",
696
- " <td>Castle gate - \"lite-brited\"</td>\n",
697
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
698
- " <td>jpg</td>\n",
699
- " </tr>\n",
700
- " <tr>\n",
701
- " <th>4</th>\n",
702
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
703
- " <td>A Picture Share!</td>\n",
704
- " <td>Tabular</td>\n",
705
- " <td>jpg</td>\n",
706
- " </tr>\n",
707
- " </tbody>\n",
708
- "</table>\n",
709
- "</div>"
710
- ],
711
- "text/plain": [
712
- " key title_clean \\\n",
713
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
714
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
715
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
716
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
717
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
718
- "\n",
719
- " description_clean ext \n",
720
- "0 Antenna jpg \n",
721
- "1 Photographed at the Walters Art Museum, Baltim... jpg \n",
722
- "2 Another trip for the happy couple. jpg \n",
723
- "3 Taken at the Miracle of Lights display in Cent... jpg \n",
724
- "4 Tabular jpg "
725
- ]
726
- },
727
- "execution_count": 98,
728
- "metadata": {},
729
- "output_type": "execute_result"
730
- }
731
- ],
732
- "source": [
733
- "# Looking at a chunk with only the relevant columns that we need\n",
734
- "df = pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
735
- "df.head()"
736
- ]
737
- },
738
- {
739
- "cell_type": "markdown",
740
- "id": "cc1668f8",
741
- "metadata": {},
742
- "source": [
743
- "### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df"
744
- ]
745
- },
746
- {
747
- "cell_type": "code",
748
- "execution_count": null,
749
- "id": "abbcccf3",
750
- "metadata": {},
751
- "outputs": [],
752
- "source": [
753
- "# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for\n",
754
- "def image_exists(item):\n",
755
- " name, _, _, ext, _ = item\n",
756
- " root=str(yfcc100m_images)\n",
757
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".\"+ext)\n",
758
- " if image_path.exists():\n",
759
- " return True\n",
760
- " else:\n",
761
- " return None"
762
- ]
763
- },
764
- {
765
- "cell_type": "code",
766
- "execution_count": 86,
767
- "id": "44fa86ab",
768
- "metadata": {},
769
- "outputs": [],
770
- "source": [
771
- "# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.\n",
772
- "global_df = pd.DataFrame()\n",
773
- "chunks_dir = \"./chunks\"\n",
774
- "for filename in os.listdir(chunks_dir):\n",
775
- " df = pd.read_csv(f\"./chunks/{str(filename)}\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
776
- " df['caption'] = df[\"title_clean\"]+\". \"+df['description_clean']\n",
777
- " df['is_exist'] = df.apply(image_exists, axis=1)\n",
778
- " df = df.dropna()[[\"key\", \"caption\"]]\n",
779
- " df.columns = ['image_file', 'caption']\n",
780
- " global_df = global_df.append(df, ignore_index=True)"
781
- ]
782
- },
783
- {
784
- "cell_type": "code",
785
- "execution_count": 89,
786
- "id": "45024fdc",
787
- "metadata": {},
788
- "outputs": [],
789
- "source": [
790
- "# saving the tsv to disk\n",
791
- "global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep=\"\\t\", index=False)"
792
- ]
793
- },
794
- {
795
- "cell_type": "code",
796
- "execution_count": 101,
797
- "id": "dca4eb73",
798
- "metadata": {},
799
- "outputs": [],
800
- "source": [
801
- "# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )\n",
802
- "\n",
803
- "dataset = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")"
804
- ]
805
- },
806
- {
807
- "cell_type": "code",
808
- "execution_count": 153,
809
- "id": "a511264a",
810
- "metadata": {},
811
- "outputs": [],
812
- "source": [
813
- "\"\"\"\n",
814
- "Luke Melas-Kyriazi's dataset.py's modified version for YFCC\n",
815
- "\"\"\"\n",
816
- "import warnings\n",
817
- "from typing import Optional, Callable\n",
818
- "from pathlib import Path\n",
819
- "import numpy as np\n",
820
- "import torch\n",
821
- "import pandas as pd\n",
822
- "from torch.utils.data import Dataset\n",
823
- "from torchvision.datasets.folder import default_loader\n",
824
- "from PIL import ImageFile\n",
825
- "from PIL.Image import DecompressionBombWarning\n",
826
- "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
827
- "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
828
- "warnings.filterwarnings(\"ignore\", category=DecompressionBombWarning)\n",
829
- "\n",
830
- "\n",
831
- "class CaptionDataset(Dataset):\n",
832
- " \"\"\"\n",
833
- " A PyTorch Dataset class for (image, texts) tasks. Note that this dataset \n",
834
- " returns the raw text rather than tokens. This is done on purpose, because\n",
835
- " it's easy to tokenize a batch of text after loading it from this dataset.\n",
836
- " \"\"\"\n",
837
- "\n",
838
- " def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, \n",
839
- " image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',\n",
840
- " include_captions: bool = True):\n",
841
- " \"\"\"\n",
842
- " :param images_root: folder where images are stored\n",
843
- " :param captions_path: path to csv that maps image filenames to captions\n",
844
- " :param image_transform: image transform pipeline\n",
845
- " :param text_transform: image transform pipeline\n",
846
- " :param image_transform_type: image transform type, either `torchvision` or `albumentations`\n",
847
- " :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.\n",
848
- " \"\"\"\n",
849
- "\n",
850
- " # Base path for images\n",
851
- " self.images_root = Path(images_root)\n",
852
- "\n",
853
- " # Load captions as DataFrame\n",
854
- " self.captions = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")\n",
855
- " self.captions['image_file'] = self.captions['image_file'].astype(str)\n",
856
- "\n",
857
- " # PyTorch transformation pipeline for the image (normalizing, etc.)\n",
858
- " self.text_transform = text_transform\n",
859
- " self.image_transform = image_transform\n",
860
- " self.image_transform_type = image_transform_type.lower()\n",
861
- " assert self.image_transform_type in ['torchvision', 'albumentations']\n",
862
- "\n",
863
- " # Total number of datapoints\n",
864
- " self.size = len(self.captions)\n",
865
- "\n",
866
- " # Return image+captions or just images\n",
867
- " self.include_captions = include_captions\n",
868
- " \n",
869
- " def image_exists(item):\n",
870
- " name, caption = item\n",
871
- " root=str(self.images_root)\n",
872
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
873
- "\n",
874
- " return image_path.exists()\n",
875
- "\n",
876
- " def verify_that_all_images_exist(self):\n",
877
- " for image_file in self.captions['image_file']:\n",
878
- " if not image_exists:\n",
879
- " print(f'file does not exist: {p}')\n",
880
- "\n",
881
- " def _get_raw_image(self, i):\n",
882
- " name = self.captions.iloc[i]['image_file']\n",
883
- " image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
884
- " image = default_loader(image_path)\n",
885
- " return image\n",
886
- "\n",
887
- " def _get_raw_text(self, i):\n",
888
- " return self.captions.iloc[i]['caption']\n",
889
- "\n",
890
- " def __getitem__(self, i):\n",
891
- " image = self._get_raw_image(i)\n",
892
- " caption = self._get_raw_text(i)\n",
893
- " if self.image_transform is not None:\n",
894
- " if self.image_transform_type == 'torchvision':\n",
895
- " image = self.image_transform(image)\n",
896
- " elif self.image_transform_type == 'albumentations':\n",
897
- " image = self.image_transform(image=np.array(image))['image']\n",
898
- " else:\n",
899
- " raise NotImplementedError(f\"{self.image_transform_type=}\")\n",
900
- " return {'image': image, 'text': caption} if self.include_captions else image\n",
901
- "\n",
902
- " def __len__(self):\n",
903
- " return self.size\n",
904
- "\n",
905
- "\n",
906
- "if __name__ == \"__main__\":\n",
907
- " import albumentations as A\n",
908
- " from albumentations.pytorch import ToTensorV2\n",
909
- " from transformers import AutoTokenizer\n",
910
- " \n",
911
- "\n",
912
- " images_root = \"/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images\"\n",
913
- " captions_path = './YFCC_subset_clean.tsv'\n",
914
- " image_size = 256\n",
915
- " \n",
916
- " # Create transforms\n",
917
- " def image_transform(image):\n",
918
- " s = min(image.size)\n",
919
- " r = image_size / s\n",
920
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
921
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
922
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
923
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
924
- " image = image.permute(0, 2, 3, 1).numpy()\n",
925
- " return image\n",
926
- " \n",
927
- " # Create dataset\n",
928
- " dataset = CaptionDataset(\n",
929
- " images_root=images_root,\n",
930
- " captions_path=captions_path,\n",
931
- " image_transform=image_transform,\n",
932
- " image_transform_type='torchvision',\n",
933
- " include_captions=False\n",
934
- " )"
935
- ]
936
- },
937
- {
938
- "cell_type": "code",
939
- "execution_count": 155,
940
- "id": "cc922704",
941
- "metadata": {},
942
- "outputs": [
943
- {
944
- "data": {
945
- "text/plain": [
946
- "2483316"
947
- ]
948
- },
949
- "execution_count": 155,
950
- "metadata": {},
951
- "output_type": "execute_result"
952
- }
953
- ],
954
- "source": [
955
- "len(dataset)"
956
- ]
957
- },
958
- {
959
- "cell_type": "code",
960
- "execution_count": 156,
961
- "id": "6e47ba46",
962
- "metadata": {},
963
- "outputs": [],
964
- "source": [
965
- "dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
966
- ]
967
- },
968
- {
969
- "cell_type": "code",
970
- "execution_count": 1,
971
- "id": "c8a130eb",
972
- "metadata": {},
973
- "outputs": [],
974
- "source": [
975
- "# looking at a batch\n",
976
- "next(iter(dataloader))"
977
- ]
978
- },
979
- {
980
- "cell_type": "code",
981
- "execution_count": null,
982
- "id": "c192fd44",
983
- "metadata": {},
984
- "outputs": [],
985
- "source": [
986
- "# import matplotlib.pyplot as plt\n",
987
- "# for tensor_image, _ in dataloader:\n",
988
- "# print(tensor_image)\n",
989
- "# plt.imshow(tensor_image.permute(1, 2, 0))\n",
990
- "# break"
991
- ]
992
- },
993
- {
994
- "cell_type": "markdown",
995
- "id": "62ad01c3",
996
- "metadata": {},
997
- "source": [
998
- "## Encoding"
999
- ]
1000
- },
1001
- {
1002
- "cell_type": "code",
1003
- "execution_count": 158,
1004
- "id": "88f36d0b",
1005
- "metadata": {},
1006
- "outputs": [],
1007
- "source": [
1008
- "def encode(model, batch):\n",
1009
- "# print(\"jitting encode function\")\n",
1010
- " _, indices = model.encode(batch)\n",
1011
- " return indices"
1012
- ]
1013
- },
1014
- {
1015
- "cell_type": "code",
1016
- "execution_count": 160,
1017
- "id": "1f35f0cb",
1018
- "metadata": {},
1019
- "outputs": [],
1020
- "source": [
1021
- "def superbatch_generator(dataloader, num_tpus):\n",
1022
- " iter_loader = iter(dataloader)\n",
1023
- " for batch in iter_loader:\n",
1024
- " superbatch = [batch.squeeze(1)]\n",
1025
- " try:\n",
1026
- " for b in range(num_tpus-1):\n",
1027
- " batch = next(iter_loader)\n",
1028
- " if batch is None:\n",
1029
- " break\n",
1030
- " # Skip incomplete last batch\n",
1031
- " if batch.shape[0] == dataloader.batch_size:\n",
1032
- " superbatch.append(batch.squeeze(1))\n",
1033
- " except StopIteration:\n",
1034
- " pass\n",
1035
- " superbatch = torch.stack(superbatch, axis=0)\n",
1036
- " yield superbatch"
1037
- ]
1038
- },
1039
- {
1040
- "cell_type": "code",
1041
- "execution_count": 170,
1042
- "id": "2210705b",
1043
- "metadata": {},
1044
- "outputs": [],
1045
- "source": [
1046
- "import os\n",
1047
- "\n",
1048
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
1049
- " if os.path.isfile(output_tsv):\n",
1050
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
1051
- " return\n",
1052
- " \n",
1053
- " num_tpus = 8 \n",
1054
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
1055
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
1056
- " \n",
1057
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
1058
- "\n",
1059
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
1060
- " # We keep the file open to prevent excessive file seeks.\n",
1061
- " with open(output_tsv, \"w\") as file:\n",
1062
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
1063
- " for n in tqdm(range(iterations)):\n",
1064
- " superbatch = next(superbatches)\n",
1065
- " encoded = p_encoder(superbatch.numpy())\n",
1066
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
1067
- "\n",
1068
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
1069
- " start_index = n * batch_size * num_tpus\n",
1070
- " end_index = (n+1) * batch_size * num_tpus\n",
1071
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
1072
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
1073
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
1074
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
1075
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)"
1076
- ]
1077
- },
1078
- {
1079
- "cell_type": "code",
1080
- "execution_count": 171,
1081
- "id": "7704863d",
1082
- "metadata": {},
1083
- "outputs": [
1084
- {
1085
- "name": "stderr",
1086
- "output_type": "stream",
1087
- "text": [
1088
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00, 1.83s/it]\n"
1089
- ]
1090
- }
1091
- ],
1092
- "source": [
1093
- "encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)"
1094
- ]
1095
- },
1096
- {
1097
- "cell_type": "markdown",
1098
- "id": "8953dd84",
1099
- "metadata": {},
1100
- "source": [
1101
- "----"
1102
- ]
1103
- }
1104
- ],
1105
- "metadata": {
1106
- "interpreter": {
1107
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1108
- },
1109
- "kernelspec": {
1110
- "display_name": "Python 3 (ipykernel)",
1111
- "language": "python",
1112
- "name": "python3"
1113
- },
1114
- "language_info": {
1115
- "codemirror_mode": {
1116
- "name": "ipython",
1117
- "version": 3
1118
- },
1119
- "file_extension": ".py",
1120
- "mimetype": "text/x-python",
1121
- "name": "python",
1122
- "nbconvert_exporter": "python",
1123
- "pygments_lexer": "ipython3",
1124
- "version": "3.8.10"
1125
- }
1126
- },
1127
- "nbformat": 4,
1128
- "nbformat_minor": 5
1129
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
dev/environment.yaml DELETED
@@ -1,10 +0,0 @@
1
- name: dalle
2
- channels:
3
- - defaults
4
- dependencies:
5
- - python=3.9.5
6
- - pip=21.1.3
7
- - ipython=7.22.0
8
- - cudatoolkit
9
- - pip:
10
- - -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
dev/requirements.txt DELETED
@@ -1,14 +0,0 @@
1
- requests
2
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- jax[tpu]>=0.2.16
4
- transformers
5
- datasets
6
- flax
7
- jupyter
8
- wandb
9
- nltk
10
- optax
11
- git+https://github.com/patil-suraj/vqgan-jax.git@610d842dd33c739325a944102ed33acc07692dd5
12
-
13
- # Inference
14
- ftfy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/seq2seq/do_big_run.sh DELETED
@@ -1,21 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --dataset_repo_or_path dalle-mini/encoded \
3
- --train_file **/train/*/*.jsonl \
4
- --validation_file **/valid/*/*.jsonl \
5
- --len_train 129847128 \
6
- --len_eval 157312 \
7
- --eval_steps 1000 \
8
- --streaming \
9
- --normalize_text \
10
- --output_dir output \
11
- --per_device_train_batch_size 56 \
12
- --per_device_eval_batch_size 56 \
13
- --preprocessing_num_workers 80 \
14
- --warmup_steps 5000 \
15
- --gradient_accumulation_steps 8 \
16
- --do_train \
17
- --do_eval \
18
- --adafactor \
19
- --num_train_epochs 6 \
20
- --log_model \
21
- --learning_rate 0.005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/seq2seq/do_small_run.sh DELETED
@@ -1,19 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --dataset_repo_or_path dalle-mini/encoded \
3
- --train_file **/train/CC3M/*.jsonl \
4
- --validation_file **/valid/*/*.jsonl \
5
- --len_train 129847128 \
6
- --len_eval 157312 \
7
- --streaming \
8
- --output_dir output \
9
- --per_device_train_batch_size 16 \
10
- --per_device_eval_batch_size 16 \
11
- --preprocessing_num_workers 80 \
12
- --warmup_steps 125 \
13
- --gradient_accumulation_steps 8 \
14
- --do_train \
15
- --do_eval \
16
- --adafactor \
17
- --num_train_epochs 1 \
18
- --max_train_samples 10000 \
19
- --learning_rate 0.005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/vqgan/JAX_VQGAN_f16_16384_Reconstruction.ipynb DELETED
The diff for this file is too large to render. See raw diff