boris commited on
Commit
cbfa520
2 Parent(s): db7d521 741bf32

Merge pull request #112 from borisdayma/’cleanup’

Browse files
.github/workflows/style.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ lint:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - uses: psf/black@stable
15
+ - uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.9
18
+ - name: Install requirements
19
+ run: pip install ".[dev]"
20
+ - uses: jamescurtin/isort-action@master
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: style
2
+
3
+ style:
4
+ black .
5
+ isort .
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🥑
4
  colorFrom: yellow
5
  colorTo: green
6
  sdk: streamlit
7
- app_file: app/app.py
8
- pinned: false
9
  ---
10
 
11
  # DALL·E Mini
@@ -28,7 +28,9 @@ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini
28
 
29
  ### Dependencies Installation
30
 
31
- For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
 
 
32
 
33
  ### Training of VQGAN
34
 
@@ -42,15 +44,15 @@ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
42
 
43
  ### Training of Seq2Seq
44
 
45
- Refer to [`dev/seq2seq`](dev/seq2seq) folder.
46
 
47
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
48
 
49
  ### Inference Pipeline
50
 
51
- To generate sample predictions and understand the inference pipeline step by step, refer to [`dev/inference/inference_pipeline.ipynb`](dev/inference/inference_pipeline.ipynb).
52
 
53
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
54
 
55
  ## FAQ
56
 
 
4
  colorFrom: yellow
5
  colorTo: green
6
  sdk: streamlit
7
+ app_file: app/streamlit/app.py
8
+ pinned: True
9
  ---
10
 
11
  # DALL·E Mini
 
28
 
29
  ### Dependencies Installation
30
 
31
+ For inference only, use `pip install git+https://github.com/borisdayma/dalle-mini.git`.
32
+
33
+ For development, clone the repo and use `pip install -e ".[dev]"`.
34
 
35
  ### Training of VQGAN
36
 
 
44
 
45
  ### Training of Seq2Seq
46
 
47
+ Use [`tools/train/train.py`](tools/train/train.py).
48
 
49
  You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
50
 
51
  ### Inference Pipeline
52
 
53
+ To generate sample predictions and understand the inference pipeline step by step, refer to [`tools/inference/inference_pipeline.ipynb`](tools/inference/inference_pipeline.ipynb).
54
 
55
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)
56
 
57
  ## FAQ
58
 
app/gradio/app_gradio.py CHANGED
@@ -2,51 +2,62 @@
2
  # coding: utf-8
3
 
4
  # Uncomment to run on cpu
5
- #import os
6
- #os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
 
8
  import random
9
 
 
10
  import jax
11
- import flax.linen as nn
12
- from flax.training.common_utils import shard
13
- from flax.jax_utils import replicate, unreplicate
14
-
15
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
16
-
17
- from PIL import Image
18
  import numpy as np
19
- import matplotlib.pyplot as plt
 
 
20
 
 
 
21
  from vqgan_jax.modeling_flax_vqgan import VQModel
 
22
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
23
 
24
- # ## CLIP Scoring
25
- from transformers import CLIPProcessor, FlaxCLIPModel
26
 
27
- import gradio as gr
 
28
 
29
- from dalle_mini.helpers import captioned_strip
 
 
 
 
30
 
31
 
32
- DALLE_REPO = 'flax-community/dalle-mini'
33
- DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
 
 
 
 
34
 
35
- VQGAN_REPO = 'flax-community/vqgan_f16_16384'
36
- VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
 
 
 
 
 
37
 
38
- tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
39
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
40
- vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
41
 
42
  def custom_to_pil(x):
43
- x = np.clip(x, 0., 1.)
44
- x = (255*x).astype(np.uint8)
45
  x = Image.fromarray(x)
46
  if not x.mode == "RGB":
47
  x = x.convert("RGB")
48
  return x
49
 
 
50
  def generate(input, rng, params):
51
  return model.generate(
52
  **input,
@@ -59,9 +70,11 @@ def generate(input, rng, params):
59
  params=params,
60
  )
61
 
 
62
  def get_images(indices, params):
63
  return vqgan.decode_code(indices, params=params)
64
 
 
65
  p_generate = jax.pmap(generate, "batch")
66
  p_get_images = jax.pmap(get_images, "batch")
67
 
@@ -73,9 +86,16 @@ print("Initialize FlaxCLIPModel")
73
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
74
  print("Initialize CLIPProcessor")
75
 
 
76
  def hallucinate(prompt, num_images=64):
77
  prompt = [prompt] * jax.device_count()
78
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
 
 
 
 
 
 
79
  inputs = shard(inputs)
80
 
81
  all_images = []
@@ -92,6 +112,7 @@ def hallucinate(prompt, num_images=64):
92
  all_images.append(custom_to_pil(image))
93
  return all_images
94
 
 
95
  def clip_top_k(prompt, images, k=8):
96
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
97
  outputs = clip(**inputs)
@@ -99,24 +120,29 @@ def clip_top_k(prompt, images, k=8):
99
  scores = np.array(logits[0]).argsort()[-k:][::-1]
100
  return [images[score] for score in scores]
101
 
 
102
  def compose_predictions(images, caption=None):
103
  increased_h = 0 if caption is None else 48
104
  w, h = images[0].size[0], images[0].size[1]
105
- img = Image.new("RGB", (len(images)*w, h + increased_h))
106
  for i, img_ in enumerate(images):
107
- img.paste(img_, (i*w, increased_h))
108
 
109
  if caption is not None:
110
  draw = ImageDraw.Draw(img)
111
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
112
- draw.text((20, 3), caption, (255,255,255), font=font)
 
 
113
  return img
114
 
 
115
  def top_k_predictions(prompt, num_candidates=32, k=8):
116
  images = hallucinate(prompt, num_images=num_candidates)
117
  images = clip_top_k(prompt, images, k=k)
118
  return images
119
 
 
120
  def run_inference(prompt, num_images=32, num_preds=8):
121
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
122
  predictions = captioned_strip(images)
@@ -125,23 +151,28 @@ def run_inference(prompt, num_images=32, num_preds=8):
125
  """
126
  return (output_title, predictions)
127
 
 
128
  outputs = [
129
- gr.outputs.HTML(label=""), # To be used as title
130
- gr.outputs.Image(label=''),
131
  ]
132
 
133
  description = """
134
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
135
  """
136
- gr.Interface(run_inference,
137
- inputs=[gr.inputs.Textbox(label='What do you want to see?')],
138
- outputs=outputs,
139
- title='DALL·E mini',
 
140
  description=description,
141
  article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
142
- layout='vertical',
143
- theme='huggingface',
144
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
 
 
 
145
  allow_flagging=False,
146
  live=False,
147
  # server_port=8999
 
2
  # coding: utf-8
3
 
4
  # Uncomment to run on cpu
5
+ # import os
6
+ # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
 
8
  import random
9
 
10
+ import gradio as gr
11
  import jax
 
 
 
 
 
 
 
12
  import numpy as np
13
+ from flax.jax_utils import replicate
14
+ from flax.training.common_utils import shard
15
+ from PIL import Image, ImageDraw, ImageFont
16
 
17
+ # ## CLIP Scoring
18
+ from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
19
  from vqgan_jax.modeling_flax_vqgan import VQModel
20
+
21
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
 
23
+ DALLE_REPO = "flax-community/dalle-mini"
24
+ DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
25
 
26
+ VQGAN_REPO = "flax-community/vqgan_f16_16384"
27
+ VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
28
 
29
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
30
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
31
+ DALLE_REPO, revision=DALLE_COMMIT_ID
32
+ )
33
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
34
 
35
 
36
+ def captioned_strip(images, caption=None, rows=1):
37
+ increased_h = 0 if caption is None else 48
38
+ w, h = images[0].size[0], images[0].size[1]
39
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
40
+ for i, img_ in enumerate(images):
41
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
42
 
43
+ if caption is not None:
44
+ draw = ImageDraw.Draw(img)
45
+ font = ImageFont.truetype(
46
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
47
+ )
48
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
49
+ return img
50
 
 
 
 
51
 
52
  def custom_to_pil(x):
53
+ x = np.clip(x, 0.0, 1.0)
54
+ x = (255 * x).astype(np.uint8)
55
  x = Image.fromarray(x)
56
  if not x.mode == "RGB":
57
  x = x.convert("RGB")
58
  return x
59
 
60
+
61
  def generate(input, rng, params):
62
  return model.generate(
63
  **input,
 
70
  params=params,
71
  )
72
 
73
+
74
  def get_images(indices, params):
75
  return vqgan.decode_code(indices, params=params)
76
 
77
+
78
  p_generate = jax.pmap(generate, "batch")
79
  p_get_images = jax.pmap(get_images, "batch")
80
 
 
86
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
87
  print("Initialize CLIPProcessor")
88
 
89
+
90
  def hallucinate(prompt, num_images=64):
91
  prompt = [prompt] * jax.device_count()
92
+ inputs = tokenizer(
93
+ prompt,
94
+ return_tensors="jax",
95
+ padding="max_length",
96
+ truncation=True,
97
+ max_length=128,
98
+ ).data
99
  inputs = shard(inputs)
100
 
101
  all_images = []
 
112
  all_images.append(custom_to_pil(image))
113
  return all_images
114
 
115
+
116
  def clip_top_k(prompt, images, k=8):
117
  inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
118
  outputs = clip(**inputs)
 
120
  scores = np.array(logits[0]).argsort()[-k:][::-1]
121
  return [images[score] for score in scores]
122
 
123
+
124
  def compose_predictions(images, caption=None):
125
  increased_h = 0 if caption is None else 48
126
  w, h = images[0].size[0], images[0].size[1]
127
+ img = Image.new("RGB", (len(images) * w, h + increased_h))
128
  for i, img_ in enumerate(images):
129
+ img.paste(img_, (i * w, increased_h))
130
 
131
  if caption is not None:
132
  draw = ImageDraw.Draw(img)
133
+ font = ImageFont.truetype(
134
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
135
+ )
136
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
137
  return img
138
 
139
+
140
  def top_k_predictions(prompt, num_candidates=32, k=8):
141
  images = hallucinate(prompt, num_images=num_candidates)
142
  images = clip_top_k(prompt, images, k=k)
143
  return images
144
 
145
+
146
  def run_inference(prompt, num_images=32, num_preds=8):
147
  images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
148
  predictions = captioned_strip(images)
 
151
  """
152
  return (output_title, predictions)
153
 
154
+
155
  outputs = [
156
+ gr.outputs.HTML(label=""), # To be used as title
157
+ gr.outputs.Image(label=""),
158
  ]
159
 
160
  description = """
161
  DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
162
  """
163
+ gr.Interface(
164
+ run_inference,
165
+ inputs=[gr.inputs.Textbox(label="What do you want to see?")],
166
+ outputs=outputs,
167
+ title="DALL·E mini",
168
  description=description,
169
  article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
170
+ layout="vertical",
171
+ theme="huggingface",
172
+ examples=[
173
+ ["an armchair in the shape of an avocado"],
174
+ ["snowy mountains by the sea"],
175
+ ],
176
  allow_flagging=False,
177
  live=False,
178
  # server_port=8999
app/gradio/app_gradio_ngrok.py DELETED
@@ -1,89 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import requests
5
- from PIL import Image
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- from io import BytesIO
9
- import base64
10
- import os
11
-
12
- import gradio as gr
13
-
14
- from dalle_mini.helpers import captioned_strip
15
-
16
-
17
- backend_url = os.environ["BACKEND_SERVER"]
18
-
19
-
20
- class ServiceError(Exception):
21
- def __init__(self, status_code):
22
- self.status_code = status_code
23
-
24
- def get_images_from_ngrok(prompt):
25
- r = requests.post(
26
- backend_url,
27
- json={"prompt": prompt}
28
- )
29
- if r.status_code == 200:
30
- images = r.json()["images"]
31
- images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
32
- return images
33
- else:
34
- raise ServiceError(r.status_code)
35
-
36
- def run_inference(prompt):
37
- try:
38
- images = get_images_from_ngrok(prompt)
39
- predictions = captioned_strip(images)
40
- output_title = f"""
41
- <p style="font-size:22px; font-style:bold">Best predictions</p>
42
- <p>We asked our model to generate 128 candidates for your prompt:</p>
43
-
44
- <pre>
45
-
46
- <b>{prompt}</b>
47
- </pre>
48
- <p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
49
- similarity of the text and the image representations.</p>
50
-
51
- <p>This is the result:</p>
52
- """
53
-
54
- output_description = """
55
- <p>Read our <a style="color:blue;" href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">full report</a> for more details on how this works.<p>
56
- <p style='text-align: center'>Created with <a style="color:blue;" href="https://github.com/borisdayma/dalle-mini">DALL·E mini</a></p>
57
- """
58
-
59
- except ServiceError:
60
- output_title = f"""
61
- Sorry, there was an error retrieving the images. Please, try again later or <a href="mailto:pcuenca-dalle@guenever.net">contact us here</a>.
62
- """
63
- predictions = None
64
- output_description = ""
65
-
66
- return (output_title, predictions, output_description)
67
-
68
- outputs = [
69
- gr.outputs.HTML(label=""), # To be used as title
70
- gr.outputs.Image(label=''),
71
- gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
72
- ]
73
-
74
- description = """
75
- Welcome to DALL·E-mini, a text-to-image generation model.
76
- """
77
- gr.Interface(run_inference,
78
- inputs=[gr.inputs.Textbox(label='Prompt')],
79
- outputs=outputs,
80
- title='DALL·E mini',
81
- description=description,
82
- article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
83
- layout='vertical',
84
- theme='huggingface',
85
- examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
86
- allow_flagging=False,
87
- live=False,
88
- # server_name="0.0.0.0", # Bind to all interfaces
89
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/{app.py → streamlit/app.py} RENAMED
@@ -1,9 +1,10 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- from dalle_mini.backend import ServiceError, get_images_from_backend
5
  import streamlit as st
6
 
 
 
7
  st.sidebar.markdown(
8
  """
9
  <style>
@@ -50,7 +51,7 @@ if prompt != "":
50
  <div class="st-b7">
51
  <div class="css-whx05o e13vu3m50">
52
  <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
53
- <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
54
  Generating predictions for: <b>{prompt}</b>
55
  </div>
56
  </div>
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
4
  import streamlit as st
5
 
6
+ from .backend import ServiceError, get_images_from_backend
7
+
8
  st.sidebar.markdown(
9
  """
10
  <style>
 
51
  <div class="st-b7">
52
  <div class="css-whx05o e13vu3m50">
53
  <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
54
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
55
  Generating predictions for: <b>{prompt}</b>
56
  </div>
57
  </div>
{dalle_mini → app/streamlit}/backend.py RENAMED
@@ -1,17 +1,17 @@
1
- import requests
2
- from io import BytesIO
3
  import base64
 
 
 
4
  from PIL import Image
5
 
 
6
  class ServiceError(Exception):
7
  def __init__(self, status_code):
8
  self.status_code = status_code
9
 
 
10
  def get_images_from_backend(prompt, backend_url):
11
- r = requests.post(
12
- backend_url,
13
- json={"prompt": prompt}
14
- )
15
  if r.status_code == 200:
16
  images = r.json()["images"]
17
  images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
 
 
 
1
  import base64
2
+ from io import BytesIO
3
+
4
+ import requests
5
  from PIL import Image
6
 
7
+
8
  class ServiceError(Exception):
9
  def __init__(self, status_code):
10
  self.status_code = status_code
11
 
12
+
13
  def get_images_from_backend(prompt, backend_url):
14
+ r = requests.post(backend_url, json={"prompt": prompt})
 
 
 
15
  if r.status_code == 200:
16
  images = r.json()["images"]
17
  images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
app/{img → streamlit/img}/loading.gif RENAMED
File without changes
dalle_mini/data.py CHANGED
@@ -1,10 +1,12 @@
1
  from dataclasses import dataclass, field
2
- from datasets import load_dataset, Dataset
3
  from functools import partial
4
- import numpy as np
5
  import jax
6
  import jax.numpy as jnp
 
 
7
  from flax.training.common_utils import shard
 
8
  from .text import TextNormalizer
9
 
10
 
 
1
  from dataclasses import dataclass, field
 
2
  from functools import partial
3
+
4
  import jax
5
  import jax.numpy as jnp
6
+ import numpy as np
7
+ from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
9
+
10
  from .text import TextNormalizer
11
 
12
 
dalle_mini/dataset.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- An image-caption dataset dataloader.
3
- Luke Melas-Kyriazi, 2021
4
- """
5
- import warnings
6
- from typing import Optional, Callable
7
- from pathlib import Path
8
- import numpy as np
9
- import torch
10
- import pandas as pd
11
- from torch.utils.data import Dataset
12
- from torchvision.datasets.folder import default_loader
13
- from PIL import ImageFile
14
- from PIL.Image import DecompressionBombWarning
15
- ImageFile.LOAD_TRUNCATED_IMAGES = True
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
-
19
-
20
- class CaptionDataset(Dataset):
21
- """
22
- A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
- returns the raw text rather than tokens. This is done on purpose, because
24
- it's easy to tokenize a batch of text after loading it from this dataset.
25
- """
26
-
27
- def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
- image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
- include_captions: bool = True):
30
- """
31
- :param images_root: folder where images are stored
32
- :param captions_path: path to csv that maps image filenames to captions
33
- :param image_transform: image transform pipeline
34
- :param text_transform: image transform pipeline
35
- :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
- :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
- """
38
-
39
- # Base path for images
40
- self.images_root = Path(images_root)
41
-
42
- # Load captions as DataFrame
43
- self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
- self.captions['image_file'] = self.captions['image_file'].astype(str)
45
-
46
- # PyTorch transformation pipeline for the image (normalizing, etc.)
47
- self.text_transform = text_transform
48
- self.image_transform = image_transform
49
- self.image_transform_type = image_transform_type.lower()
50
- assert self.image_transform_type in ['torchvision', 'albumentations']
51
-
52
- # Total number of datapoints
53
- self.size = len(self.captions)
54
-
55
- # Return image+captions or just images
56
- self.include_captions = include_captions
57
-
58
- def verify_that_all_images_exist(self):
59
- for image_file in self.captions['image_file']:
60
- p = self.images_root / image_file
61
- if not p.is_file():
62
- print(f'file does not exist: {p}')
63
-
64
- def _get_raw_image(self, i):
65
- image_file = self.captions.iloc[i]['image_file']
66
- image_path = self.images_root / image_file
67
- image = default_loader(image_path)
68
- return image
69
-
70
- def _get_raw_text(self, i):
71
- return self.captions.iloc[i]['caption']
72
-
73
- def __getitem__(self, i):
74
- image = self._get_raw_image(i)
75
- caption = self._get_raw_text(i)
76
- if self.image_transform is not None:
77
- if self.image_transform_type == 'torchvision':
78
- image = self.image_transform(image)
79
- elif self.image_transform_type == 'albumentations':
80
- image = self.image_transform(image=np.array(image))['image']
81
- else:
82
- raise NotImplementedError(f"{self.image_transform_type=}")
83
- return {'image': image, 'text': caption} if self.include_captions else image
84
-
85
- def __len__(self):
86
- return self.size
87
-
88
-
89
- if __name__ == "__main__":
90
- import albumentations as A
91
- from albumentations.pytorch import ToTensorV2
92
- from transformers import AutoTokenizer
93
-
94
- # Paths
95
- images_root = './images'
96
- captions_path = './images-list-clean.tsv'
97
-
98
- # Create transforms
99
- tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
- def tokenize(text):
101
- return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
- image_transform = A.Compose([
103
- A.Resize(256, 256), A.CenterCrop(256, 256),
104
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
-
106
- # Create dataset
107
- dataset = CaptionDataset(
108
- images_root=images_root,
109
- captions_path=captions_path,
110
- image_transform=image_transform,
111
- text_transform=tokenize,
112
- image_transform_type='albumentations')
113
-
114
- # Create dataloader
115
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
- batch = next(iter(dataloader))
117
- print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
-
119
- # # (Optional) Check that all the images exist
120
- # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
- # dataset.verify_that_all_images_exist()
122
- # print('Done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/helpers.py DELETED
@@ -1,14 +0,0 @@
1
- from PIL import Image, ImageDraw, ImageFont
2
-
3
- def captioned_strip(images, caption=None, rows=1):
4
- increased_h = 0 if caption is None else 48
5
- w, h = images[0].size[0], images[0].size[1]
6
- img = Image.new("RGB", (len(images)*w//rows, h*rows + increased_h))
7
- for i, img_ in enumerate(images):
8
- img.paste(img_, (i//rows*w, increased_h + (i % rows) * h))
9
-
10
- if caption is not None:
11
- draw = ImageDraw.Draw(img)
12
- font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
13
- draw.text((20, 3), caption, (255,255,255), font=font)
14
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/model.py CHANGED
@@ -1,16 +1,14 @@
1
- import jax
2
  import flax.linen as nn
3
-
 
4
  from transformers.models.bart.modeling_flax_bart import (
5
- FlaxBartModule,
6
- FlaxBartForConditionalGenerationModule,
7
- FlaxBartForConditionalGeneration,
8
- FlaxBartEncoder,
9
  FlaxBartDecoder,
 
 
 
 
10
  )
11
 
12
- from transformers import BartConfig
13
-
14
 
15
  class CustomFlaxBartModule(FlaxBartModule):
16
  def setup(self):
@@ -46,6 +44,11 @@ class CustomFlaxBartForConditionalGenerationModule(
46
  FlaxBartForConditionalGenerationModule
47
  ):
48
  def setup(self):
 
 
 
 
 
49
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
50
  self.lm_head = nn.Dense(
51
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
 
 
1
  import flax.linen as nn
2
+ import jax
3
+ from transformers import BartConfig
4
  from transformers.models.bart.modeling_flax_bart import (
 
 
 
 
5
  FlaxBartDecoder,
6
+ FlaxBartEncoder,
7
+ FlaxBartForConditionalGeneration,
8
+ FlaxBartForConditionalGenerationModule,
9
+ FlaxBartModule,
10
  )
11
 
 
 
12
 
13
  class CustomFlaxBartModule(FlaxBartModule):
14
  def setup(self):
 
44
  FlaxBartForConditionalGenerationModule
45
  ):
46
  def setup(self):
47
+ # set default config
48
+ self.config.normalize_text = getattr(self.config, "normalize_text", False)
49
+ self.config.image_length = getattr(self.config, "image_length", 256)
50
+ self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
51
+
52
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
53
  self.lm_head = nn.Dense(
54
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
dalle_mini/text.py CHANGED
@@ -2,13 +2,15 @@
2
  Utilities for processing text.
3
  """
4
 
 
 
 
 
5
  from pathlib import Path
6
- from unidecode import unidecode
7
 
8
- import re, math, random, html
9
  import ftfy
10
-
11
  from huggingface_hub import hf_hub_download
 
12
 
13
  # based on wiki word occurence
14
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
 
2
  Utilities for processing text.
3
  """
4
 
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
  from pathlib import Path
 
10
 
 
11
  import ftfy
 
12
  from huggingface_hub import hf_hub_download
13
+ from unidecode import unidecode
14
 
15
  # based on wiki word occurence
16
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
dev/README.md DELETED
@@ -1,122 +0,0 @@
1
- # Development Instructions for TPU
2
-
3
- ## Setup
4
-
5
- - Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
6
- - Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
7
- - Verify `gcloud config list`, in particular account, project & zone.
8
- - Create a TPU VM per the guide and connect to it.
9
-
10
- When needing a larger disk:
11
-
12
- - Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
13
- - Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
14
- - Format the partition as described in the guide.
15
- - Make sure to set up automatic remount of disk at restart.
16
-
17
- ## Connect VS Code
18
-
19
- - Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
20
- - Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
21
- - Add the same command as ssh host in VS Code.
22
- - Check config file
23
-
24
- ```
25
- Host INSTANCE_NAME
26
- HostName EXTERNAL_IP
27
- IdentityFile ~/.ssh/google_compute_engine
28
- ```
29
-
30
- ## Environment configuration
31
-
32
- ### Use virtual environments (optional)
33
-
34
- We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
35
-
36
- If you want to use `pyenv` and `pyenv-virtualenv`:
37
-
38
- - Installation
39
-
40
- - [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
41
- - Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
42
- - bash set-up:
43
-
44
- ```bash
45
- echo '\n'\
46
- '# pyenv setup \n'\
47
- 'export PYENV_ROOT="$HOME/.pyenv" \n'\
48
- 'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
49
- 'eval "$(pyenv init --path)" \n'\
50
- 'eval "$(pyenv init -)" \n'\
51
- 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
52
- ```
53
-
54
- - Usage
55
-
56
- - Install a python version: `pyenv install X.X.X`
57
- - Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
58
- - Activate: `pyenv activate dalle_env`
59
-
60
- Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
61
-
62
- ### Tools
63
-
64
- - Git
65
-
66
- - `git config --global user.email "name@domain.com"
67
- - `git config --global user.name "First Last"
68
-
69
- - Github CLI
70
-
71
- - See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
72
- - `gh auth login`
73
-
74
- - Direnv
75
-
76
- - Install direnv: `sudo apt-get update && sudo apt-get install direnv`
77
- - bash set-up:
78
-
79
- ```bash
80
- echo -e '\n'\
81
- '# direnv setup \n'\
82
- 'eval "$(direnv hook bash)" \n' >> ~/.bashrc
83
- ```
84
-
85
- ### Set up repo
86
-
87
- - Clone repo: `gh repo clone borisdayma/dalle-mini`
88
- - If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
89
-
90
- ## Environment
91
-
92
- - Install the following (use it later to update our dev requirements.txt)
93
-
94
- ```
95
- requests
96
- pillow
97
- jupyterlab
98
- ipywidgets
99
-
100
- -e ../datasets[streaming]
101
- -e ../transformers
102
- -e ../webdataset
103
-
104
- # JAX
105
- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
106
- jax[tpu]>=0.2.16
107
- flax
108
- ```
109
-
110
- - `transformers-cli login`
111
-
112
- ---
113
-
114
- - set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
115
-
116
- ## Working with datasets or models
117
-
118
- - Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
119
- - Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
120
- - Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
121
- - Track specific extentions: `git lfs track "*.ext"`
122
- - See files tracked with LFS with `git lfs ls-files`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/CC12M_downloader.py DELETED
@@ -1,91 +0,0 @@
1
- # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
2
-
3
- #%%
4
- import sys
5
- import os
6
- from datetime import datetime
7
- import pandas as pd
8
- import contexttimer
9
- from urllib.request import urlopen
10
- import requests
11
- from PIL import Image
12
- import torch
13
- from torchvision.transforms import functional as TF
14
- from multiprocessing import Pool
15
- from tqdm import tqdm
16
- import logging
17
-
18
- # Setup
19
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
20
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
21
-
22
-
23
- # # For downloading SVG images (I can't get this to work)
24
- # from io import BytesIO
25
- # import cairosvg
26
-
27
- #%%
28
- # Load data
29
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
30
- with contexttimer.Timer(prefix="Loading from tsv"):
31
- df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
32
-
33
- url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
34
- print(f'Loaded {len(url_to_idx_map)} urls')
35
-
36
- #%%
37
- df.head()
38
-
39
- #%%
40
-
41
- # Note: it seems that there are no SVG images
42
- df.sample(10000)[1].str.contains('.svg').sum()
43
-
44
- #%%
45
- # Resize function
46
- def resize(img):
47
- max_size_of_short_side = 512
48
- if min(img.size) > max_size_of_short_side:
49
- img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
50
- return img
51
-
52
- base_dir = os.path.join(os.getcwd(), 'images')
53
-
54
- def process(item):
55
- url, image_id = item
56
- try:
57
- base_url = os.path.basename(url) # extract base url
58
- stem, ext = os.path.splitext(base_url) # split into stem and extension
59
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
60
- filepath = os.path.join(base_dir, filename) # concat to get filepath
61
- if not os.path.isfile(filepath):
62
- # if filepath.endswith('.svg'):
63
- # raise NotImplementedError()
64
- # image_bytes = BytesIO() # create a bytestream
65
- # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
66
- # else:
67
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
68
- image = Image.open(req).convert('RGB')
69
- if min(image.size) > 512:
70
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
71
- # image = resize(image) # resize PIL image
72
- image.save(filepath) # save PIL image
73
- except Exception as e:
74
- logging.info(" ".join(repr(e).splitlines()))
75
- logging.error(url)
76
-
77
- #%%
78
- #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
79
- # process(item)
80
- # if i > 100:
81
- # break
82
-
83
- # Use multiprocessing for speed
84
- list_of_items = list(url_to_idx_map.items())
85
- print(len(list_of_items))
86
- list_of_items = list_of_items[10_000_000:]
87
- print(len(list_of_items))
88
- with Pool(128) as p:
89
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
90
- print('DONE')
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/CC3M_downloader.py DELETED
@@ -1,62 +0,0 @@
1
- '''
2
- This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
3
- Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
4
- Find them here- [https://github.com/google-research-datasets/conceptual-captions]
5
- '''
6
-
7
- import sys
8
- import os
9
- from datetime import datetime
10
- import pandas as pd
11
- import contexttimer
12
- from urllib.request import urlopen
13
- import requests
14
- from PIL import Image
15
- import torch
16
- from torchvision.transforms import functional as TF
17
- from multiprocessing import Pool
18
- from tqdm import tqdm
19
- import logging
20
- import sys
21
-
22
- # Setup
23
- logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
24
- requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
25
-
26
- if len(sys.argv) != 3:
27
- print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
28
- exit(1)
29
-
30
- # Load data
31
- print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
32
- with contexttimer.Timer(prefix="Loading from tsv"):
33
- df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
34
-
35
- url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
36
- print(f'Loaded {len(url_to_idx_map)} urls')
37
-
38
- base_dir = os.path.join(os.getcwd(), sys.argv[2])
39
-
40
- def process(item):
41
- url, image_id = item
42
- try:
43
- base_url = os.path.basename(url) # extract base url
44
- stem, ext = os.path.splitext(base_url) # split into stem and extension
45
- filename = f'{image_id:08d}---{stem}.jpg' # create filename
46
- filepath = os.path.join(base_dir, filename) # concat to get filepath
47
- if not os.path.isfile(filepath):
48
- req = requests.get(url, stream=True, timeout=1, verify=False).raw
49
- image = Image.open(req).convert('RGB')
50
- if min(image.size) > 512:
51
- image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
52
- image.save(filepath) # save PIL image
53
- except Exception as e:
54
- logging.info(" ".join(repr(e).splitlines()))
55
- logging.error(url)
56
-
57
- list_of_items = list(url_to_idx_map.items())
58
- print(len(list_of_items))
59
-
60
- with Pool(128) as p:
61
- r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
62
- print('DONE')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/data/README.md DELETED
@@ -1,3 +0,0 @@
1
- # Data
2
-
3
- Utility scripts for downloading CC3M and CC12M.
 
 
 
 
dev/encoding/vqgan-jax-encoding-streaming.ipynb DELETED
@@ -1,562 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "ba7b31e6",
14
- "metadata": {},
15
- "source": [
16
- "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
17
- "\n",
18
- "This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": null,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "import os\n",
40
- "\n",
41
- "import jax\n",
42
- "from jax import pmap"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "c7c4c1e6",
48
- "metadata": {},
49
- "source": [
50
- "## Dataset and Parameters"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": null,
56
- "id": "d45a289e",
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "import datasets\n",
61
- "from datasets import Dataset, load_dataset"
62
- ]
63
- },
64
- {
65
- "cell_type": "markdown",
66
- "id": "f26e4f18",
67
- "metadata": {},
68
- "source": [
69
- "We'll use the `validation` set for testing. Adjust accordingly."
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": null,
75
- "id": "28893c3e",
76
- "metadata": {},
77
- "outputs": [],
78
- "source": [
79
- "dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": null,
85
- "id": "33861477",
86
- "metadata": {},
87
- "outputs": [],
88
- "source": [
89
- "from pathlib import Path\n",
90
- "\n",
91
- "yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
92
- "yfcc100m_output = yfcc100m/'encoded' # Output directory for encoded files"
93
- ]
94
- },
95
- {
96
- "cell_type": "code",
97
- "execution_count": null,
98
- "id": "6e7b71c4",
99
- "metadata": {},
100
- "outputs": [],
101
- "source": [
102
- "batch_size = 128 # Per device\n",
103
- "num_workers = 16 # Unused in streaming mode"
104
- ]
105
- },
106
- {
107
- "cell_type": "markdown",
108
- "id": "0793c26a",
109
- "metadata": {},
110
- "source": [
111
- "### Data preparation"
112
- ]
113
- },
114
- {
115
- "cell_type": "markdown",
116
- "id": "86415769",
117
- "metadata": {},
118
- "source": [
119
- "* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
120
- "* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
121
- "\n",
122
- "These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
123
- ]
124
- },
125
- {
126
- "cell_type": "markdown",
127
- "id": "0fdf1851",
128
- "metadata": {},
129
- "source": [
130
- "This helper function is used to decode images from the bytes retrieved in `streaming` mode."
131
- ]
132
- },
133
- {
134
- "cell_type": "code",
135
- "execution_count": null,
136
- "id": "5bbca804",
137
- "metadata": {},
138
- "outputs": [],
139
- "source": [
140
- "from PIL import Image\n",
141
- "import io\n",
142
- "\n",
143
- "def get_image(byte_stream):\n",
144
- " image = Image.open(io.BytesIO(byte_stream))\n",
145
- " return image.convert('RGB')"
146
- ]
147
- },
148
- {
149
- "cell_type": "markdown",
150
- "id": "b435290b",
151
- "metadata": {},
152
- "source": [
153
- "Image processing"
154
- ]
155
- },
156
- {
157
- "cell_type": "code",
158
- "execution_count": null,
159
- "id": "7e73dfa3",
160
- "metadata": {},
161
- "outputs": [],
162
- "source": [
163
- "def center_crop(image, max_size=256):\n",
164
- " # Note: we allow upscaling too. We should exclude small images. \n",
165
- " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
166
- " image = TF.center_crop(image, output_size=2 * [max_size])\n",
167
- " return image\n",
168
- "\n",
169
- "preprocess_image = T.Compose([\n",
170
- " get_image,\n",
171
- " center_crop,\n",
172
- " T.ToTensor(),\n",
173
- " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
174
- "])"
175
- ]
176
- },
177
- {
178
- "cell_type": "markdown",
179
- "id": "1e3ac8de",
180
- "metadata": {},
181
- "source": [
182
- "Caption preparation"
183
- ]
184
- },
185
- {
186
- "cell_type": "code",
187
- "execution_count": null,
188
- "id": "aadb4d23",
189
- "metadata": {},
190
- "outputs": [],
191
- "source": [
192
- "import string\n",
193
- "\n",
194
- "def create_caption(title, description):\n",
195
- " title = title.strip()\n",
196
- " description = description.strip()\n",
197
- " if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
198
- " return f'{title} {description}'"
199
- ]
200
- },
201
- {
202
- "cell_type": "markdown",
203
- "id": "3c4522b9",
204
- "metadata": {},
205
- "source": [
206
- "And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
207
- ]
208
- },
209
- {
210
- "cell_type": "code",
211
- "execution_count": null,
212
- "id": "2566ff68",
213
- "metadata": {},
214
- "outputs": [],
215
- "source": [
216
- "def prepare_item(item):\n",
217
- " return {\n",
218
- " 'key': item['key'],\n",
219
- " 'caption': create_caption(item['title_clean'], item['description_clean']),\n",
220
- " 'image': preprocess_image(item['img'])\n",
221
- " }"
222
- ]
223
- },
224
- {
225
- "cell_type": "markdown",
226
- "id": "e519e475",
227
- "metadata": {},
228
- "source": [
229
- "Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": null,
235
- "id": "10d7750e",
236
- "metadata": {},
237
- "outputs": [],
238
- "source": [
239
- "prepared_dataset = dataset.map(prepare_item, batched=False)"
240
- ]
241
- },
242
- {
243
- "cell_type": "code",
244
- "execution_count": null,
245
- "id": "a8595539",
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "%%time\n",
250
- "item = next(iter(prepared_dataset))"
251
- ]
252
- },
253
- {
254
- "cell_type": "code",
255
- "execution_count": null,
256
- "id": "04a6eeb4",
257
- "metadata": {},
258
- "outputs": [],
259
- "source": [
260
- "assert(list(item.keys()) == ['key', 'caption', 'image'])"
261
- ]
262
- },
263
- {
264
- "cell_type": "code",
265
- "execution_count": null,
266
- "id": "40d3115f",
267
- "metadata": {},
268
- "outputs": [],
269
- "source": [
270
- "item['image'].shape"
271
- ]
272
- },
273
- {
274
- "cell_type": "code",
275
- "execution_count": null,
276
- "id": "dd844e1c",
277
- "metadata": {},
278
- "outputs": [],
279
- "source": [
280
- "T.ToPILImage()(item['image'].permute(2, 0, 1))"
281
- ]
282
- },
283
- {
284
- "cell_type": "markdown",
285
- "id": "44d50a51",
286
- "metadata": {},
287
- "source": [
288
- "### Torch DataLoader"
289
- ]
290
- },
291
- {
292
- "cell_type": "markdown",
293
- "id": "17a4bbc6",
294
- "metadata": {},
295
- "source": [
296
- "We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
297
- "\n",
298
- "We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
299
- ]
300
- },
301
- {
302
- "cell_type": "code",
303
- "execution_count": null,
304
- "id": "e1c08b7e",
305
- "metadata": {},
306
- "outputs": [],
307
- "source": [
308
- "import torch\n",
309
- "from torch.utils.data import DataLoader"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": null,
315
- "id": "6a296677",
316
- "metadata": {},
317
- "outputs": [],
318
- "source": [
319
- "torch_dataset = prepared_dataset.with_format(\"torch\")"
320
- ]
321
- },
322
- {
323
- "cell_type": "markdown",
324
- "id": "29ab13bc",
325
- "metadata": {},
326
- "source": [
327
- "**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
328
- ]
329
- },
330
- {
331
- "cell_type": "code",
332
- "execution_count": null,
333
- "id": "e2df5e13",
334
- "metadata": {},
335
- "outputs": [],
336
- "source": [
337
- "dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
338
- ]
339
- },
340
- {
341
- "cell_type": "code",
342
- "execution_count": null,
343
- "id": "c15e3783",
344
- "metadata": {},
345
- "outputs": [],
346
- "source": [
347
- "batch = next(iter(dataloader))"
348
- ]
349
- },
350
- {
351
- "cell_type": "code",
352
- "execution_count": null,
353
- "id": "71d027fe",
354
- "metadata": {},
355
- "outputs": [],
356
- "source": [
357
- "batch['image'].shape"
358
- ]
359
- },
360
- {
361
- "cell_type": "markdown",
362
- "id": "a354472b",
363
- "metadata": {},
364
- "source": [
365
- "## VQGAN-JAX model"
366
- ]
367
- },
368
- {
369
- "cell_type": "code",
370
- "execution_count": null,
371
- "id": "2fcf01d7",
372
- "metadata": {},
373
- "outputs": [],
374
- "source": [
375
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
376
- ]
377
- },
378
- {
379
- "cell_type": "markdown",
380
- "id": "9daa636d",
381
- "metadata": {},
382
- "source": [
383
- "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
384
- ]
385
- },
386
- {
387
- "cell_type": "code",
388
- "execution_count": null,
389
- "id": "47a8b818",
390
- "metadata": {
391
- "scrolled": true
392
- },
393
- "outputs": [],
394
- "source": [
395
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
396
- ]
397
- },
398
- {
399
- "cell_type": "markdown",
400
- "id": "62ad01c3",
401
- "metadata": {},
402
- "source": [
403
- "## Encoding"
404
- ]
405
- },
406
- {
407
- "cell_type": "markdown",
408
- "id": "20357f74",
409
- "metadata": {},
410
- "source": [
411
- "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": null,
417
- "id": "6686b004",
418
- "metadata": {},
419
- "outputs": [],
420
- "source": [
421
- "from flax.training.common_utils import shard\n",
422
- "from functools import partial"
423
- ]
424
- },
425
- {
426
- "cell_type": "code",
427
- "execution_count": null,
428
- "id": "322a4619",
429
- "metadata": {},
430
- "outputs": [],
431
- "source": [
432
- "@partial(jax.pmap, axis_name=\"batch\")\n",
433
- "def encode(batch):\n",
434
- " # Not sure if we should `replicate` params, does not seem to have any effect\n",
435
- " _, indices = model.encode(batch)\n",
436
- " return indices"
437
- ]
438
- },
439
- {
440
- "cell_type": "markdown",
441
- "id": "14375a41",
442
- "metadata": {},
443
- "source": [
444
- "### Encoding loop"
445
- ]
446
- },
447
- {
448
- "cell_type": "code",
449
- "execution_count": null,
450
- "id": "ff6c10d4",
451
- "metadata": {},
452
- "outputs": [],
453
- "source": [
454
- "import os\n",
455
- "import pandas as pd\n",
456
- "\n",
457
- "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
458
- " output_dir.mkdir(parents=True, exist_ok=True)\n",
459
- " \n",
460
- " # Saving strategy:\n",
461
- " # - Create a new file every so often to prevent excessive file seeking.\n",
462
- " # - Save each batch after processing.\n",
463
- " # - Keep the file open until we are done with it.\n",
464
- " file = None \n",
465
- " for n, batch in enumerate(tqdm(iter(dataloader))):\n",
466
- " if (n % save_every == 0):\n",
467
- " if file is not None:\n",
468
- " file.close()\n",
469
- " split_num = n // save_every\n",
470
- " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
471
- "\n",
472
- " images = batch[\"image\"].numpy()\n",
473
- " images = shard(images.squeeze())\n",
474
- " encoded = encode(images)\n",
475
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
476
- "\n",
477
- " keys = batch[\"key\"]\n",
478
- " captions = batch[\"caption\"]\n",
479
- "\n",
480
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
481
- " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
482
- " batch_df.to_json(file, orient='records', lines=True)"
483
- ]
484
- },
485
- {
486
- "cell_type": "markdown",
487
- "id": "09ff75a3",
488
- "metadata": {},
489
- "source": [
490
- "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
491
- ]
492
- },
493
- {
494
- "cell_type": "code",
495
- "execution_count": null,
496
- "id": "96222bb4",
497
- "metadata": {},
498
- "outputs": [],
499
- "source": [
500
- "save_every = 318"
501
- ]
502
- },
503
- {
504
- "cell_type": "code",
505
- "execution_count": null,
506
- "id": "7704863d",
507
- "metadata": {},
508
- "outputs": [
509
- {
510
- "name": "stderr",
511
- "output_type": "stream",
512
- "text": [
513
- "28it [01:17, 1.60s/it]"
514
- ]
515
- }
516
- ],
517
- "source": [
518
- "encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
519
- ]
520
- },
521
- {
522
- "cell_type": "markdown",
523
- "id": "e266a70a",
524
- "metadata": {},
525
- "source": [
526
- "This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
527
- ]
528
- },
529
- {
530
- "cell_type": "markdown",
531
- "id": "8953dd84",
532
- "metadata": {},
533
- "source": [
534
- "----"
535
- ]
536
- }
537
- ],
538
- "metadata": {
539
- "interpreter": {
540
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
541
- },
542
- "kernelspec": {
543
- "display_name": "Python 3 (ipykernel)",
544
- "language": "python",
545
- "name": "python3"
546
- },
547
- "language_info": {
548
- "codemirror_mode": {
549
- "name": "ipython",
550
- "version": 3
551
- },
552
- "file_extension": ".py",
553
- "mimetype": "text/x-python",
554
- "name": "python",
555
- "nbconvert_exporter": "python",
556
- "pygments_lexer": "ipython3",
557
- "version": "3.8.10"
558
- }
559
- },
560
- "nbformat": 4,
561
- "nbformat_minor": 5
562
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding-with-captions.ipynb DELETED
@@ -1,355 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-with-captions"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "875c82b3",
14
- "metadata": {},
15
- "source": [
16
- "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
17
- "\n",
18
- "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 1,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "\n",
41
- "import jax\n",
42
- "from jax import pmap"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "511c3b9e",
48
- "metadata": {},
49
- "source": [
50
- "## VQGAN-JAX model"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": 2,
56
- "id": "2ca50dc7",
57
- "metadata": {},
58
- "outputs": [],
59
- "source": [
60
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
61
- ]
62
- },
63
- {
64
- "cell_type": "markdown",
65
- "id": "7b60da9a",
66
- "metadata": {},
67
- "source": [
68
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": 3,
74
- "id": "29ce8b15",
75
- "metadata": {},
76
- "outputs": [
77
- {
78
- "data": {
79
- "application/vnd.jupyter.widget-view+json": {
80
- "model_id": "db406bdfc5d5428eaeae1631a04989dd",
81
- "version_major": 2,
82
- "version_minor": 0
83
- },
84
- "text/plain": [
85
- "Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
86
- ]
87
- },
88
- "metadata": {},
89
- "output_type": "display_data"
90
- },
91
- {
92
- "data": {
93
- "application/vnd.jupyter.widget-view+json": {
94
- "model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
95
- "version_major": 2,
96
- "version_minor": 0
97
- },
98
- "text/plain": [
99
- "Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
100
- ]
101
- },
102
- "metadata": {},
103
- "output_type": "display_data"
104
- },
105
- {
106
- "name": "stderr",
107
- "output_type": "stream",
108
- "text": [
109
- "INFO:absl:Starting the local TPU driver.\n",
110
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
111
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
112
- ]
113
- },
114
- {
115
- "name": "stdout",
116
- "output_type": "stream",
117
- "text": [
118
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
119
- ]
120
- }
121
- ],
122
- "source": [
123
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
124
- ]
125
- },
126
- {
127
- "cell_type": "markdown",
128
- "id": "c7c4c1e6",
129
- "metadata": {},
130
- "source": [
131
- "## Dataset"
132
- ]
133
- },
134
- {
135
- "cell_type": "markdown",
136
- "id": "7014a7ce",
137
- "metadata": {},
138
- "source": [
139
- "We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": 4,
145
- "id": "85832702",
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "from dalle_mini.dataset import *"
150
- ]
151
- },
152
- {
153
- "cell_type": "code",
154
- "execution_count": 5,
155
- "id": "81b19eca",
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "cc12m_images = '/data/CC12M/images'\n",
160
- "cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
161
- "# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
162
- "cc12m_output = '/data/CC12M/images-encoded.tsv'"
163
- ]
164
- },
165
- {
166
- "cell_type": "code",
167
- "execution_count": 6,
168
- "id": "fecc9a00",
169
- "metadata": {},
170
- "outputs": [],
171
- "source": [
172
- "image_size = 256\n",
173
- "def image_transform(image):\n",
174
- " s = min(image.size)\n",
175
- " r = image_size / s\n",
176
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
177
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
178
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
179
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
180
- " image = image.permute(0, 2, 3, 1).numpy()\n",
181
- " return image"
182
- ]
183
- },
184
- {
185
- "cell_type": "code",
186
- "execution_count": 7,
187
- "id": "4ce2211f",
188
- "metadata": {},
189
- "outputs": [],
190
- "source": [
191
- "dataset = CaptionDataset(\n",
192
- " images_root=cc12m_images,\n",
193
- " captions_path=cc12m_list,\n",
194
- " image_transform=image_transform,\n",
195
- " image_transform_type='torchvision',\n",
196
- " include_captions=False\n",
197
- ")"
198
- ]
199
- },
200
- {
201
- "cell_type": "code",
202
- "execution_count": 8,
203
- "id": "cc922704",
204
- "metadata": {},
205
- "outputs": [
206
- {
207
- "data": {
208
- "text/plain": [
209
- "8592141"
210
- ]
211
- },
212
- "execution_count": 8,
213
- "metadata": {},
214
- "output_type": "execute_result"
215
- }
216
- ],
217
- "source": [
218
- "len(dataset)"
219
- ]
220
- },
221
- {
222
- "cell_type": "markdown",
223
- "id": "62ad01c3",
224
- "metadata": {},
225
- "source": [
226
- "## Encoding"
227
- ]
228
- },
229
- {
230
- "cell_type": "code",
231
- "execution_count": 9,
232
- "id": "88f36d0b",
233
- "metadata": {},
234
- "outputs": [],
235
- "source": [
236
- "def encode(model, batch):\n",
237
- "# print(\"jitting encode function\")\n",
238
- " _, indices = model.encode(batch)\n",
239
- " return indices"
240
- ]
241
- },
242
- {
243
- "cell_type": "code",
244
- "execution_count": 10,
245
- "id": "1f35f0cb",
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "def superbatch_generator(dataloader, num_tpus):\n",
250
- " iter_loader = iter(dataloader)\n",
251
- " for batch in iter_loader:\n",
252
- " superbatch = [batch.squeeze(1)]\n",
253
- " try:\n",
254
- " for b in range(num_tpus-1):\n",
255
- " batch = next(iter_loader)\n",
256
- " if batch is None:\n",
257
- " break\n",
258
- " # Skip incomplete last batch\n",
259
- " if batch.shape[0] == dataloader.batch_size:\n",
260
- " superbatch.append(batch.squeeze(1))\n",
261
- " except StopIteration:\n",
262
- " pass\n",
263
- " superbatch = torch.stack(superbatch, axis=0)\n",
264
- " yield superbatch"
265
- ]
266
- },
267
- {
268
- "cell_type": "code",
269
- "execution_count": 11,
270
- "id": "2210705b",
271
- "metadata": {},
272
- "outputs": [],
273
- "source": [
274
- "import os\n",
275
- "\n",
276
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
277
- " if os.path.isfile(output_tsv):\n",
278
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
279
- " return\n",
280
- " \n",
281
- " num_tpus = 8 \n",
282
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
283
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
284
- " \n",
285
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
286
- "\n",
287
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
288
- " # We keep the file open to prevent excessive file seeks.\n",
289
- " with open(output_tsv, \"w\") as file:\n",
290
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
291
- " for n in tqdm(range(iterations)):\n",
292
- " superbatch = next(superbatches)\n",
293
- " encoded = p_encoder(superbatch.numpy())\n",
294
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
295
- "\n",
296
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
297
- " start_index = n * batch_size * num_tpus\n",
298
- " end_index = (n+1) * batch_size * num_tpus\n",
299
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
300
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
301
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
302
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
303
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
304
- " "
305
- ]
306
- },
307
- {
308
- "cell_type": "code",
309
- "execution_count": null,
310
- "id": "7704863d",
311
- "metadata": {},
312
- "outputs": [
313
- {
314
- "name": "stderr",
315
- "output_type": "stream",
316
- "text": [
317
- " 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
318
- ]
319
- }
320
- ],
321
- "source": [
322
- "encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
323
- ]
324
- },
325
- {
326
- "cell_type": "markdown",
327
- "id": "8953dd84",
328
- "metadata": {},
329
- "source": [
330
- "----"
331
- ]
332
- }
333
- ],
334
- "metadata": {
335
- "kernelspec": {
336
- "display_name": "Python 3 (ipykernel)",
337
- "language": "python",
338
- "name": "python3"
339
- },
340
- "language_info": {
341
- "codemirror_mode": {
342
- "name": "ipython",
343
- "version": 3
344
- },
345
- "file_extension": ".py",
346
- "mimetype": "text/x-python",
347
- "name": "python",
348
- "nbconvert_exporter": "python",
349
- "pygments_lexer": "ipython3",
350
- "version": "3.8.10"
351
- }
352
- },
353
- "nbformat": 4,
354
- "nbformat_minor": 5
355
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding-yfcc100m.ipynb DELETED
@@ -1,1129 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-yfcc100m"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "ba7b31e6",
14
- "metadata": {},
15
- "source": [
16
- "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
17
- "\n",
18
- "This dataset was prepared by @borisdayma in Json lines format."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 92,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "from torchvision.datasets.folder import default_loader\n",
41
- "import os\n",
42
- "\n",
43
- "import jax\n",
44
- "from jax import pmap"
45
- ]
46
- },
47
- {
48
- "cell_type": "markdown",
49
- "id": "511c3b9e",
50
- "metadata": {},
51
- "source": [
52
- "## VQGAN-JAX model"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 93,
58
- "id": "2ca50dc7",
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
63
- ]
64
- },
65
- {
66
- "cell_type": "markdown",
67
- "id": "7b60da9a",
68
- "metadata": {},
69
- "source": [
70
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": 167,
76
- "id": "29ce8b15",
77
- "metadata": {},
78
- "outputs": [
79
- {
80
- "name": "stdout",
81
- "output_type": "stream",
82
- "text": [
83
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
84
- ]
85
- }
86
- ],
87
- "source": [
88
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
89
- ]
90
- },
91
- {
92
- "cell_type": "markdown",
93
- "id": "c7c4c1e6",
94
- "metadata": {},
95
- "source": [
96
- "## Dataset"
97
- ]
98
- },
99
- {
100
- "cell_type": "code",
101
- "execution_count": 94,
102
- "id": "33861477",
103
- "metadata": {},
104
- "outputs": [],
105
- "source": [
106
- "import pandas as pd\n",
107
- "from pathlib import Path"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": 134,
113
- "id": "81b19eca",
114
- "metadata": {},
115
- "outputs": [],
116
- "source": [
117
- "yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')\n",
118
- "# Images are 'sharded' from the following directory\n",
119
- "yfcc100m_images = yfcc100m/'data'/'data'/'images'\n",
120
- "yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n",
121
- "yfcc100m_output = yfcc100m/'metadata_encoded.tsv'"
122
- ]
123
- },
124
- {
125
- "cell_type": "markdown",
126
- "id": "1c58bb4a",
127
- "metadata": {},
128
- "source": [
129
- "### Cleanup"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "1a14ae3d",
135
- "metadata": {},
136
- "source": [
137
- "We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas."
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": 96,
143
- "id": "7811648c",
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "import datasets\n",
148
- "from datasets import Dataset, load_dataset"
149
- ]
150
- },
151
- {
152
- "cell_type": "code",
153
- "execution_count": 10,
154
- "id": "4811a230",
155
- "metadata": {},
156
- "outputs": [
157
- {
158
- "name": "stderr",
159
- "output_type": "stream",
160
- "text": [
161
- "tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @ 0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
162
- "tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @ 0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
163
- "tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
164
- "tcmalloc: large alloc 5019099136 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
165
- "tcmalloc: large alloc 5019811840 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
166
- "tcmalloc: large alloc 5024571392 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
167
- "tcmalloc: large alloc 5021097984 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
168
- "tcmalloc: large alloc 5022818304 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
169
- "tcmalloc: large alloc 5020794880 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
170
- "tcmalloc: large alloc 5019451392 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
171
- "tcmalloc: large alloc 5020565504 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
172
- "tcmalloc: large alloc 5012561920 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
173
- "tcmalloc: large alloc 5021835264 bytes == 0x5f6cba000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
174
- "tcmalloc: large alloc 5017436160 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n"
175
- ]
176
- }
177
- ],
178
- "source": [
179
- "# The metadata is too bog to load into memory at once, so chopping it into chunks\n",
180
- "chunk_size=1000000\n",
181
- "batch_no=1\n",
182
- "for chunk in pd.read_json(yfcc100m_metadata, orient=\"records\", lines=True,chunksize=chunk_size):\n",
183
- " chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep=\"\\t\", index=False)\n",
184
- " batch_no+=1"
185
- ]
186
- },
187
- {
188
- "cell_type": "code",
189
- "execution_count": 25,
190
- "id": "46b2f083",
191
- "metadata": {},
192
- "outputs": [
193
- {
194
- "data": {
195
- "text/html": [
196
- "<div>\n",
197
- "<style scoped>\n",
198
- " .dataframe tbody tr th:only-of-type {\n",
199
- " vertical-align: middle;\n",
200
- " }\n",
201
- "\n",
202
- " .dataframe tbody tr th {\n",
203
- " vertical-align: top;\n",
204
- " }\n",
205
- "\n",
206
- " .dataframe thead th {\n",
207
- " text-align: right;\n",
208
- " }\n",
209
- "</style>\n",
210
- "<table border=\"1\" class=\"dataframe\">\n",
211
- " <thead>\n",
212
- " <tr style=\"text-align: right;\">\n",
213
- " <th></th>\n",
214
- " <th>photoid</th>\n",
215
- " <th>uid</th>\n",
216
- " <th>unickname</th>\n",
217
- " <th>datetaken</th>\n",
218
- " <th>dateuploaded</th>\n",
219
- " <th>capturedevice</th>\n",
220
- " <th>title</th>\n",
221
- " <th>description</th>\n",
222
- " <th>usertags</th>\n",
223
- " <th>machinetags</th>\n",
224
- " <th>...</th>\n",
225
- " <th>licenseurl</th>\n",
226
- " <th>serverid</th>\n",
227
- " <th>farmid</th>\n",
228
- " <th>secret</th>\n",
229
- " <th>secretoriginal</th>\n",
230
- " <th>ext</th>\n",
231
- " <th>marker</th>\n",
232
- " <th>key</th>\n",
233
- " <th>title_clean</th>\n",
234
- " <th>description_clean</th>\n",
235
- " </tr>\n",
236
- " </thead>\n",
237
- " <tbody>\n",
238
- " <tr>\n",
239
- " <th>0</th>\n",
240
- " <td>137943</td>\n",
241
- " <td>48600072071@N01</td>\n",
242
- " <td>doctor+paradox</td>\n",
243
- " <td>2004-08-01 18:13:06.0</td>\n",
244
- " <td>1091409186</td>\n",
245
- " <td>NaN</td>\n",
246
- " <td>A+Picture+Share%21</td>\n",
247
- " <td>Antenna</td>\n",
248
- " <td>cameraphone,cayugaheights,green,hydrant,ithaca...</td>\n",
249
- " <td>NaN</td>\n",
250
- " <td>...</td>\n",
251
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
252
- " <td>1</td>\n",
253
- " <td>1</td>\n",
254
- " <td>1650c7cdc6</td>\n",
255
- " <td>1650c7cdc6</td>\n",
256
- " <td>jpg</td>\n",
257
- " <td>0</td>\n",
258
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
259
- " <td>A Picture Share!</td>\n",
260
- " <td>Antenna</td>\n",
261
- " </tr>\n",
262
- " <tr>\n",
263
- " <th>1</th>\n",
264
- " <td>1246361</td>\n",
265
- " <td>44124324682@N01</td>\n",
266
- " <td>mharrsch</td>\n",
267
- " <td>2004-11-03 23:04:02.0</td>\n",
268
- " <td>1099523042</td>\n",
269
- " <td>NaN</td>\n",
270
- " <td>An+ornate+Roman+urn</td>\n",
271
- " <td>Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...</td>\n",
272
- " <td>ancient,baltimore,burial,death,empire,funeral,...</td>\n",
273
- " <td>NaN</td>\n",
274
- " <td>...</td>\n",
275
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
276
- " <td>1</td>\n",
277
- " <td>1</td>\n",
278
- " <td>cf37054610</td>\n",
279
- " <td>cf37054610</td>\n",
280
- " <td>jpg</td>\n",
281
- " <td>0</td>\n",
282
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
283
- " <td>An ornate Roman urn</td>\n",
284
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
285
- " </tr>\n",
286
- " <tr>\n",
287
- " <th>2</th>\n",
288
- " <td>1251599</td>\n",
289
- " <td>51035803024@N01</td>\n",
290
- " <td>bmitd67</td>\n",
291
- " <td>2004-10-30 17:09:32.0</td>\n",
292
- " <td>1099538888</td>\n",
293
- " <td>Canon+PowerShot+S30</td>\n",
294
- " <td>Jai+%26+Tara+on+the+Cumberland</td>\n",
295
- " <td>Another+trip+for+the+happy+couple.</td>\n",
296
- " <td>blue+heron,cumberland+river,jai,tara,tennessee</td>\n",
297
- " <td>NaN</td>\n",
298
- " <td>...</td>\n",
299
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
300
- " <td>1</td>\n",
301
- " <td>1</td>\n",
302
- " <td>4a4234e32c</td>\n",
303
- " <td>4a4234e32c</td>\n",
304
- " <td>jpg</td>\n",
305
- " <td>0</td>\n",
306
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
307
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
308
- " <td>Another trip for the happy couple.</td>\n",
309
- " </tr>\n",
310
- " <tr>\n",
311
- " <th>3</th>\n",
312
- " <td>2348587</td>\n",
313
- " <td>73621375@N00</td>\n",
314
- " <td>Thom+Watson</td>\n",
315
- " <td>2004-12-18 21:08:09.0</td>\n",
316
- " <td>1103497228</td>\n",
317
- " <td>SONY+DSC-W1</td>\n",
318
- " <td>Castle+gate+-+%22lite-brited%22</td>\n",
319
- " <td>Taken+at+the+Miracle+of+Lights+display+in+Cent...</td>\n",
320
- " <td>bullrunpark,castle,centreville,christmas,decor...</td>\n",
321
- " <td>NaN</td>\n",
322
- " <td>...</td>\n",
323
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
324
- " <td>2</td>\n",
325
- " <td>1</td>\n",
326
- " <td>7162c974c3</td>\n",
327
- " <td>7162c974c3</td>\n",
328
- " <td>jpg</td>\n",
329
- " <td>0</td>\n",
330
- " <td>d29ce96395848478b1e8396e44899</td>\n",
331
- " <td>Castle gate - \"lite-brited\"</td>\n",
332
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
333
- " </tr>\n",
334
- " <tr>\n",
335
- " <th>4</th>\n",
336
- " <td>3516047</td>\n",
337
- " <td>48600072071@N01</td>\n",
338
- " <td>doctor+paradox</td>\n",
339
- " <td>2005-01-18 16:44:18.0</td>\n",
340
- " <td>1106084658</td>\n",
341
- " <td>NaN</td>\n",
342
- " <td>A+Picture+Share%21</td>\n",
343
- " <td>Tabular</td>\n",
344
- " <td>cameraphone,moblog,unfound</td>\n",
345
- " <td>NaN</td>\n",
346
- " <td>...</td>\n",
347
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
348
- " <td>3</td>\n",
349
- " <td>1</td>\n",
350
- " <td>663e0d8b3d</td>\n",
351
- " <td>663e0d8b3d</td>\n",
352
- " <td>jpg</td>\n",
353
- " <td>0</td>\n",
354
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
355
- " <td>A Picture Share!</td>\n",
356
- " <td>Tabular</td>\n",
357
- " </tr>\n",
358
- " <tr>\n",
359
- " <th>...</th>\n",
360
- " <td>...</td>\n",
361
- " <td>...</td>\n",
362
- " <td>...</td>\n",
363
- " <td>...</td>\n",
364
- " <td>...</td>\n",
365
- " <td>...</td>\n",
366
- " <td>...</td>\n",
367
- " <td>...</td>\n",
368
- " <td>...</td>\n",
369
- " <td>...</td>\n",
370
- " <td>...</td>\n",
371
- " <td>...</td>\n",
372
- " <td>...</td>\n",
373
- " <td>...</td>\n",
374
- " <td>...</td>\n",
375
- " <td>...</td>\n",
376
- " <td>...</td>\n",
377
- " <td>...</td>\n",
378
- " <td>...</td>\n",
379
- " <td>...</td>\n",
380
- " <td>...</td>\n",
381
- " </tr>\n",
382
- " <tr>\n",
383
- " <th>999995</th>\n",
384
- " <td>4648651054</td>\n",
385
- " <td>24511045@N04</td>\n",
386
- " <td>mtfrazier</td>\n",
387
- " <td>2010-05-02 15:47:45.0</td>\n",
388
- " <td>1275083371</td>\n",
389
- " <td>Canon+EOS+50D</td>\n",
390
- " <td>U.S.+Navy+Blue+Angels%3A+2010</td>\n",
391
- " <td>2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri</td>\n",
392
- " <td>NaN</td>\n",
393
- " <td>NaN</td>\n",
394
- " <td>...</td>\n",
395
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
396
- " <td>4072</td>\n",
397
- " <td>5</td>\n",
398
- " <td>2d12d73fb0</td>\n",
399
- " <td>dd5856ea42</td>\n",
400
- " <td>jpg</td>\n",
401
- " <td>0</td>\n",
402
- " <td>60fa2911cb81eb25b356e9fee978aef</td>\n",
403
- " <td>U.S. Navy Blue Angels: 2010</td>\n",
404
- " <td>2 May 2010 Sunday St. Joseph, Missouri</td>\n",
405
- " </tr>\n",
406
- " <tr>\n",
407
- " <th>999996</th>\n",
408
- " <td>4652130996</td>\n",
409
- " <td>21963865@N04</td>\n",
410
- " <td>GRAB1.0</td>\n",
411
- " <td>2010-05-29 19:23:10.0</td>\n",
412
- " <td>1275200833</td>\n",
413
- " <td>SONY+DSLR-A230</td>\n",
414
- " <td>Attempts+on+Her+Life</td>\n",
415
- " <td>BAPA+1+production+of+Martin+Crimp%27s+Attempts...</td>\n",
416
- " <td>NaN</td>\n",
417
- " <td>NaN</td>\n",
418
- " <td>...</td>\n",
419
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
420
- " <td>4003</td>\n",
421
- " <td>5</td>\n",
422
- " <td>8889121579</td>\n",
423
- " <td>2f46599456</td>\n",
424
- " <td>jpg</td>\n",
425
- " <td>0</td>\n",
426
- " <td>60f5ef5ce4c2d24566226abebd67d4</td>\n",
427
- " <td>Attempts on Her Life</td>\n",
428
- " <td>BAPA 1 production of Martin Crimp's Attempts o...</td>\n",
429
- " </tr>\n",
430
- " <tr>\n",
431
- " <th>999997</th>\n",
432
- " <td>4652568339</td>\n",
433
- " <td>64025277@N00</td>\n",
434
- " <td>1Sock</td>\n",
435
- " <td>2010-05-13 15:38:37.0</td>\n",
436
- " <td>1275234267</td>\n",
437
- " <td>Canon+EOS+DIGITAL+REBEL+XT</td>\n",
438
- " <td>Carlsbad+Caverns+3</td>\n",
439
- " <td>%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...</td>\n",
440
- " <td>carlsbad,carlsbad+caverns,cave,faa,new+mexico,...</td>\n",
441
- " <td>NaN</td>\n",
442
- " <td>...</td>\n",
443
- " <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
444
- " <td>4010</td>\n",
445
- " <td>5</td>\n",
446
- " <td>0a1808a69e</td>\n",
447
- " <td>cf6d348e3d</td>\n",
448
- " <td>jpg</td>\n",
449
- " <td>0</td>\n",
450
- " <td>60f029482d1d1028fda5281daf498f</td>\n",
451
- " <td>Carlsbad Caverns 3</td>\n",
452
- " <td>♥♥♥♥♥♥♥ Interested in purchasing this photogra...</td>\n",
453
- " </tr>\n",
454
- " <tr>\n",
455
- " <th>999998</th>\n",
456
- " <td>4653110895</td>\n",
457
- " <td>20483509@N00</td>\n",
458
- " <td>subberculture</td>\n",
459
- " <td>2010-05-30 15:37:05.0</td>\n",
460
- " <td>1275245596</td>\n",
461
- " <td>Canon+DIGITAL+IXUS+40</td>\n",
462
- " <td>Want</td>\n",
463
- " <td>Isn%27t+that+gorgeous%3F</td>\n",
464
- " <td>2010,edinburgh+museum,may,phonebox,wood</td>\n",
465
- " <td>NaN</td>\n",
466
- " <td>...</td>\n",
467
- " <td>http://creativecommons.org/licenses/by-sa/2.0/</td>\n",
468
- " <td>4066</td>\n",
469
- " <td>5</td>\n",
470
- " <td>77c3b3a254</td>\n",
471
- " <td>c4697e1511</td>\n",
472
- " <td>jpg</td>\n",
473
- " <td>0</td>\n",
474
- " <td>60f72775f433cf8de3efaeb431866153</td>\n",
475
- " <td>Want</td>\n",
476
- " <td>Isn't that gorgeous?</td>\n",
477
- " </tr>\n",
478
- " <tr>\n",
479
- " <th>999999</th>\n",
480
- " <td>4655503987</td>\n",
481
- " <td>8457193@N07</td>\n",
482
- " <td>zackojones</td>\n",
483
- " <td>2010-05-30 15:34:58.0</td>\n",
484
- " <td>1275310230</td>\n",
485
- " <td>Canon+EOS+7D</td>\n",
486
- " <td>Summertime</td>\n",
487
- " <td>You+gotta+love+it%21</td>\n",
488
- " <td>georgia,savannah,united+states,us</td>\n",
489
- " <td>NaN</td>\n",
490
- " <td>...</td>\n",
491
- " <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
492
- " <td>4043</td>\n",
493
- " <td>5</td>\n",
494
- " <td>caff543bfe</td>\n",
495
- " <td>f60952ac4d</td>\n",
496
- " <td>jpg</td>\n",
497
- " <td>0</td>\n",
498
- " <td>60f687e11b913bce461e9525d8047e0</td>\n",
499
- " <td>Summertime</td>\n",
500
- " <td>You gotta love it!</td>\n",
501
- " </tr>\n",
502
- " </tbody>\n",
503
- "</table>\n",
504
- "<p>1000000 rows × 26 columns</p>\n",
505
- "</div>"
506
- ],
507
- "text/plain": [
508
- " photoid uid unickname datetaken \\\n",
509
- "0 137943 48600072071@N01 doctor+paradox 2004-08-01 18:13:06.0 \n",
510
- "1 1246361 44124324682@N01 mharrsch 2004-11-03 23:04:02.0 \n",
511
- "2 1251599 51035803024@N01 bmitd67 2004-10-30 17:09:32.0 \n",
512
- "3 2348587 73621375@N00 Thom+Watson 2004-12-18 21:08:09.0 \n",
513
- "4 3516047 48600072071@N01 doctor+paradox 2005-01-18 16:44:18.0 \n",
514
- "... ... ... ... ... \n",
515
- "999995 4648651054 24511045@N04 mtfrazier 2010-05-02 15:47:45.0 \n",
516
- "999996 4652130996 21963865@N04 GRAB1.0 2010-05-29 19:23:10.0 \n",
517
- "999997 4652568339 64025277@N00 1Sock 2010-05-13 15:38:37.0 \n",
518
- "999998 4653110895 20483509@N00 subberculture 2010-05-30 15:37:05.0 \n",
519
- "999999 4655503987 8457193@N07 zackojones 2010-05-30 15:34:58.0 \n",
520
- "\n",
521
- " dateuploaded capturedevice \\\n",
522
- "0 1091409186 NaN \n",
523
- "1 1099523042 NaN \n",
524
- "2 1099538888 Canon+PowerShot+S30 \n",
525
- "3 1103497228 SONY+DSC-W1 \n",
526
- "4 1106084658 NaN \n",
527
- "... ... ... \n",
528
- "999995 1275083371 Canon+EOS+50D \n",
529
- "999996 1275200833 SONY+DSLR-A230 \n",
530
- "999997 1275234267 Canon+EOS+DIGITAL+REBEL+XT \n",
531
- "999998 1275245596 Canon+DIGITAL+IXUS+40 \n",
532
- "999999 1275310230 Canon+EOS+7D \n",
533
- "\n",
534
- " title \\\n",
535
- "0 A+Picture+Share%21 \n",
536
- "1 An+ornate+Roman+urn \n",
537
- "2 Jai+%26+Tara+on+the+Cumberland \n",
538
- "3 Castle+gate+-+%22lite-brited%22 \n",
539
- "4 A+Picture+Share%21 \n",
540
- "... ... \n",
541
- "999995 U.S.+Navy+Blue+Angels%3A+2010 \n",
542
- "999996 Attempts+on+Her+Life \n",
543
- "999997 Carlsbad+Caverns+3 \n",
544
- "999998 Want \n",
545
- "999999 Summertime \n",
546
- "\n",
547
- " description \\\n",
548
- "0 Antenna \n",
549
- "1 Photographed+at+the+%3Ca+href%3D%22http%3A%2F%... \n",
550
- "2 Another+trip+for+the+happy+couple. \n",
551
- "3 Taken+at+the+Miracle+of+Lights+display+in+Cent... \n",
552
- "4 Tabular \n",
553
- "... ... \n",
554
- "999995 2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri \n",
555
- "999996 BAPA+1+production+of+Martin+Crimp%27s+Attempts... \n",
556
- "999997 %E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%... \n",
557
- "999998 Isn%27t+that+gorgeous%3F \n",
558
- "999999 You+gotta+love+it%21 \n",
559
- "\n",
560
- " usertags machinetags ... \\\n",
561
- "0 cameraphone,cayugaheights,green,hydrant,ithaca... NaN ... \n",
562
- "1 ancient,baltimore,burial,death,empire,funeral,... NaN ... \n",
563
- "2 blue+heron,cumberland+river,jai,tara,tennessee NaN ... \n",
564
- "3 bullrunpark,castle,centreville,christmas,decor... NaN ... \n",
565
- "4 cameraphone,moblog,unfound NaN ... \n",
566
- "... ... ... ... \n",
567
- "999995 NaN NaN ... \n",
568
- "999996 NaN NaN ... \n",
569
- "999997 carlsbad,carlsbad+caverns,cave,faa,new+mexico,... NaN ... \n",
570
- "999998 2010,edinburgh+museum,may,phonebox,wood NaN ... \n",
571
- "999999 georgia,savannah,united+states,us NaN ... \n",
572
- "\n",
573
- " licenseurl serverid farmid \\\n",
574
- "0 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
575
- "1 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
576
- "2 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
577
- "3 http://creativecommons.org/licenses/by-nc-sa/2.0/ 2 1 \n",
578
- "4 http://creativecommons.org/licenses/by-nc-sa/2.0/ 3 1 \n",
579
- "... ... ... ... \n",
580
- "999995 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4072 5 \n",
581
- "999996 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4003 5 \n",
582
- "999997 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4010 5 \n",
583
- "999998 http://creativecommons.org/licenses/by-sa/2.0/ 4066 5 \n",
584
- "999999 http://creativecommons.org/licenses/by-nc-sa/2.0/ 4043 5 \n",
585
- "\n",
586
- " secret secretoriginal ext marker \\\n",
587
- "0 1650c7cdc6 1650c7cdc6 jpg 0 \n",
588
- "1 cf37054610 cf37054610 jpg 0 \n",
589
- "2 4a4234e32c 4a4234e32c jpg 0 \n",
590
- "3 7162c974c3 7162c974c3 jpg 0 \n",
591
- "4 663e0d8b3d 663e0d8b3d jpg 0 \n",
592
- "... ... ... ... ... \n",
593
- "999995 2d12d73fb0 dd5856ea42 jpg 0 \n",
594
- "999996 8889121579 2f46599456 jpg 0 \n",
595
- "999997 0a1808a69e cf6d348e3d jpg 0 \n",
596
- "999998 77c3b3a254 c4697e1511 jpg 0 \n",
597
- "999999 caff543bfe f60952ac4d jpg 0 \n",
598
- "\n",
599
- " key title_clean \\\n",
600
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
601
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
602
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
603
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
604
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
605
- "... ... ... \n",
606
- "999995 60fa2911cb81eb25b356e9fee978aef U.S. Navy Blue Angels: 2010 \n",
607
- "999996 60f5ef5ce4c2d24566226abebd67d4 Attempts on Her Life \n",
608
- "999997 60f029482d1d1028fda5281daf498f Carlsbad Caverns 3 \n",
609
- "999998 60f72775f433cf8de3efaeb431866153 Want \n",
610
- "999999 60f687e11b913bce461e9525d8047e0 Summertime \n",
611
- "\n",
612
- " description_clean \n",
613
- "0 Antenna \n",
614
- "1 Photographed at the Walters Art Museum, Baltim... \n",
615
- "2 Another trip for the happy couple. \n",
616
- "3 Taken at the Miracle of Lights display in Cent... \n",
617
- "4 Tabular \n",
618
- "... ... \n",
619
- "999995 2 May 2010 Sunday St. Joseph, Missouri \n",
620
- "999996 BAPA 1 production of Martin Crimp's Attempts o... \n",
621
- "999997 ♥♥♥♥♥♥♥ Interested in purchasing this photogra... \n",
622
- "999998 Isn't that gorgeous? \n",
623
- "999999 You gotta love it! \n",
624
- "\n",
625
- "[1000000 rows x 26 columns]"
626
- ]
627
- },
628
- "execution_count": 25,
629
- "metadata": {},
630
- "output_type": "execute_result"
631
- }
632
- ],
633
- "source": [
634
- "# looking up at a chunk\n",
635
- "pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")"
636
- ]
637
- },
638
- {
639
- "cell_type": "code",
640
- "execution_count": 98,
641
- "id": "c51c5597",
642
- "metadata": {},
643
- "outputs": [
644
- {
645
- "data": {
646
- "text/html": [
647
- "<div>\n",
648
- "<style scoped>\n",
649
- " .dataframe tbody tr th:only-of-type {\n",
650
- " vertical-align: middle;\n",
651
- " }\n",
652
- "\n",
653
- " .dataframe tbody tr th {\n",
654
- " vertical-align: top;\n",
655
- " }\n",
656
- "\n",
657
- " .dataframe thead th {\n",
658
- " text-align: right;\n",
659
- " }\n",
660
- "</style>\n",
661
- "<table border=\"1\" class=\"dataframe\">\n",
662
- " <thead>\n",
663
- " <tr style=\"text-align: right;\">\n",
664
- " <th></th>\n",
665
- " <th>key</th>\n",
666
- " <th>title_clean</th>\n",
667
- " <th>description_clean</th>\n",
668
- " <th>ext</th>\n",
669
- " </tr>\n",
670
- " </thead>\n",
671
- " <tbody>\n",
672
- " <tr>\n",
673
- " <th>0</th>\n",
674
- " <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
675
- " <td>A Picture Share!</td>\n",
676
- " <td>Antenna</td>\n",
677
- " <td>jpg</td>\n",
678
- " </tr>\n",
679
- " <tr>\n",
680
- " <th>1</th>\n",
681
- " <td>d29f01b149167d683f9ddde464bb3db</td>\n",
682
- " <td>An ornate Roman urn</td>\n",
683
- " <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
684
- " <td>jpg</td>\n",
685
- " </tr>\n",
686
- " <tr>\n",
687
- " <th>2</th>\n",
688
- " <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
689
- " <td>Jai &amp; Tara on the Cumberland</td>\n",
690
- " <td>Another trip for the happy couple.</td>\n",
691
- " <td>jpg</td>\n",
692
- " </tr>\n",
693
- " <tr>\n",
694
- " <th>3</th>\n",
695
- " <td>d29ce96395848478b1e8396e44899</td>\n",
696
- " <td>Castle gate - \"lite-brited\"</td>\n",
697
- " <td>Taken at the Miracle of Lights display in Cent...</td>\n",
698
- " <td>jpg</td>\n",
699
- " </tr>\n",
700
- " <tr>\n",
701
- " <th>4</th>\n",
702
- " <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
703
- " <td>A Picture Share!</td>\n",
704
- " <td>Tabular</td>\n",
705
- " <td>jpg</td>\n",
706
- " </tr>\n",
707
- " </tbody>\n",
708
- "</table>\n",
709
- "</div>"
710
- ],
711
- "text/plain": [
712
- " key title_clean \\\n",
713
- "0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
714
- "1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
715
- "2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
716
- "3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
717
- "4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
718
- "\n",
719
- " description_clean ext \n",
720
- "0 Antenna jpg \n",
721
- "1 Photographed at the Walters Art Museum, Baltim... jpg \n",
722
- "2 Another trip for the happy couple. jpg \n",
723
- "3 Taken at the Miracle of Lights display in Cent... jpg \n",
724
- "4 Tabular jpg "
725
- ]
726
- },
727
- "execution_count": 98,
728
- "metadata": {},
729
- "output_type": "execute_result"
730
- }
731
- ],
732
- "source": [
733
- "# Looking at a chunk with only the relevant columns that we need\n",
734
- "df = pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
735
- "df.head()"
736
- ]
737
- },
738
- {
739
- "cell_type": "markdown",
740
- "id": "cc1668f8",
741
- "metadata": {},
742
- "source": [
743
- "### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df"
744
- ]
745
- },
746
- {
747
- "cell_type": "code",
748
- "execution_count": null,
749
- "id": "abbcccf3",
750
- "metadata": {},
751
- "outputs": [],
752
- "source": [
753
- "# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for\n",
754
- "def image_exists(item):\n",
755
- " name, _, _, ext, _ = item\n",
756
- " root=str(yfcc100m_images)\n",
757
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".\"+ext)\n",
758
- " if image_path.exists():\n",
759
- " return True\n",
760
- " else:\n",
761
- " return None"
762
- ]
763
- },
764
- {
765
- "cell_type": "code",
766
- "execution_count": 86,
767
- "id": "44fa86ab",
768
- "metadata": {},
769
- "outputs": [],
770
- "source": [
771
- "# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.\n",
772
- "global_df = pd.DataFrame()\n",
773
- "chunks_dir = \"./chunks\"\n",
774
- "for filename in os.listdir(chunks_dir):\n",
775
- " df = pd.read_csv(f\"./chunks/{str(filename)}\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
776
- " df['caption'] = df[\"title_clean\"]+\". \"+df['description_clean']\n",
777
- " df['is_exist'] = df.apply(image_exists, axis=1)\n",
778
- " df = df.dropna()[[\"key\", \"caption\"]]\n",
779
- " df.columns = ['image_file', 'caption']\n",
780
- " global_df = global_df.append(df, ignore_index=True)"
781
- ]
782
- },
783
- {
784
- "cell_type": "code",
785
- "execution_count": 89,
786
- "id": "45024fdc",
787
- "metadata": {},
788
- "outputs": [],
789
- "source": [
790
- "# saving the tsv to disk\n",
791
- "global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep=\"\\t\", index=False)"
792
- ]
793
- },
794
- {
795
- "cell_type": "code",
796
- "execution_count": 101,
797
- "id": "dca4eb73",
798
- "metadata": {},
799
- "outputs": [],
800
- "source": [
801
- "# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )\n",
802
- "\n",
803
- "dataset = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")"
804
- ]
805
- },
806
- {
807
- "cell_type": "code",
808
- "execution_count": 153,
809
- "id": "a511264a",
810
- "metadata": {},
811
- "outputs": [],
812
- "source": [
813
- "\"\"\"\n",
814
- "Luke Melas-Kyriazi's dataset.py's modified version for YFCC\n",
815
- "\"\"\"\n",
816
- "import warnings\n",
817
- "from typing import Optional, Callable\n",
818
- "from pathlib import Path\n",
819
- "import numpy as np\n",
820
- "import torch\n",
821
- "import pandas as pd\n",
822
- "from torch.utils.data import Dataset\n",
823
- "from torchvision.datasets.folder import default_loader\n",
824
- "from PIL import ImageFile\n",
825
- "from PIL.Image import DecompressionBombWarning\n",
826
- "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
827
- "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
828
- "warnings.filterwarnings(\"ignore\", category=DecompressionBombWarning)\n",
829
- "\n",
830
- "\n",
831
- "class CaptionDataset(Dataset):\n",
832
- " \"\"\"\n",
833
- " A PyTorch Dataset class for (image, texts) tasks. Note that this dataset \n",
834
- " returns the raw text rather than tokens. This is done on purpose, because\n",
835
- " it's easy to tokenize a batch of text after loading it from this dataset.\n",
836
- " \"\"\"\n",
837
- "\n",
838
- " def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, \n",
839
- " image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',\n",
840
- " include_captions: bool = True):\n",
841
- " \"\"\"\n",
842
- " :param images_root: folder where images are stored\n",
843
- " :param captions_path: path to csv that maps image filenames to captions\n",
844
- " :param image_transform: image transform pipeline\n",
845
- " :param text_transform: image transform pipeline\n",
846
- " :param image_transform_type: image transform type, either `torchvision` or `albumentations`\n",
847
- " :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.\n",
848
- " \"\"\"\n",
849
- "\n",
850
- " # Base path for images\n",
851
- " self.images_root = Path(images_root)\n",
852
- "\n",
853
- " # Load captions as DataFrame\n",
854
- " self.captions = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")\n",
855
- " self.captions['image_file'] = self.captions['image_file'].astype(str)\n",
856
- "\n",
857
- " # PyTorch transformation pipeline for the image (normalizing, etc.)\n",
858
- " self.text_transform = text_transform\n",
859
- " self.image_transform = image_transform\n",
860
- " self.image_transform_type = image_transform_type.lower()\n",
861
- " assert self.image_transform_type in ['torchvision', 'albumentations']\n",
862
- "\n",
863
- " # Total number of datapoints\n",
864
- " self.size = len(self.captions)\n",
865
- "\n",
866
- " # Return image+captions or just images\n",
867
- " self.include_captions = include_captions\n",
868
- " \n",
869
- " def image_exists(item):\n",
870
- " name, caption = item\n",
871
- " root=str(self.images_root)\n",
872
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
873
- "\n",
874
- " return image_path.exists()\n",
875
- "\n",
876
- " def verify_that_all_images_exist(self):\n",
877
- " for image_file in self.captions['image_file']:\n",
878
- " if not image_exists:\n",
879
- " print(f'file does not exist: {p}')\n",
880
- "\n",
881
- " def _get_raw_image(self, i):\n",
882
- " name = self.captions.iloc[i]['image_file']\n",
883
- " image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
884
- " image = default_loader(image_path)\n",
885
- " return image\n",
886
- "\n",
887
- " def _get_raw_text(self, i):\n",
888
- " return self.captions.iloc[i]['caption']\n",
889
- "\n",
890
- " def __getitem__(self, i):\n",
891
- " image = self._get_raw_image(i)\n",
892
- " caption = self._get_raw_text(i)\n",
893
- " if self.image_transform is not None:\n",
894
- " if self.image_transform_type == 'torchvision':\n",
895
- " image = self.image_transform(image)\n",
896
- " elif self.image_transform_type == 'albumentations':\n",
897
- " image = self.image_transform(image=np.array(image))['image']\n",
898
- " else:\n",
899
- " raise NotImplementedError(f\"{self.image_transform_type=}\")\n",
900
- " return {'image': image, 'text': caption} if self.include_captions else image\n",
901
- "\n",
902
- " def __len__(self):\n",
903
- " return self.size\n",
904
- "\n",
905
- "\n",
906
- "if __name__ == \"__main__\":\n",
907
- " import albumentations as A\n",
908
- " from albumentations.pytorch import ToTensorV2\n",
909
- " from transformers import AutoTokenizer\n",
910
- " \n",
911
- "\n",
912
- " images_root = \"/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images\"\n",
913
- " captions_path = './YFCC_subset_clean.tsv'\n",
914
- " image_size = 256\n",
915
- " \n",
916
- " # Create transforms\n",
917
- " def image_transform(image):\n",
918
- " s = min(image.size)\n",
919
- " r = image_size / s\n",
920
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
921
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
922
- " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
923
- " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
924
- " image = image.permute(0, 2, 3, 1).numpy()\n",
925
- " return image\n",
926
- " \n",
927
- " # Create dataset\n",
928
- " dataset = CaptionDataset(\n",
929
- " images_root=images_root,\n",
930
- " captions_path=captions_path,\n",
931
- " image_transform=image_transform,\n",
932
- " image_transform_type='torchvision',\n",
933
- " include_captions=False\n",
934
- " )"
935
- ]
936
- },
937
- {
938
- "cell_type": "code",
939
- "execution_count": 155,
940
- "id": "cc922704",
941
- "metadata": {},
942
- "outputs": [
943
- {
944
- "data": {
945
- "text/plain": [
946
- "2483316"
947
- ]
948
- },
949
- "execution_count": 155,
950
- "metadata": {},
951
- "output_type": "execute_result"
952
- }
953
- ],
954
- "source": [
955
- "len(dataset)"
956
- ]
957
- },
958
- {
959
- "cell_type": "code",
960
- "execution_count": 156,
961
- "id": "6e47ba46",
962
- "metadata": {},
963
- "outputs": [],
964
- "source": [
965
- "dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
966
- ]
967
- },
968
- {
969
- "cell_type": "code",
970
- "execution_count": 1,
971
- "id": "c8a130eb",
972
- "metadata": {},
973
- "outputs": [],
974
- "source": [
975
- "# looking at a batch\n",
976
- "next(iter(dataloader))"
977
- ]
978
- },
979
- {
980
- "cell_type": "code",
981
- "execution_count": null,
982
- "id": "c192fd44",
983
- "metadata": {},
984
- "outputs": [],
985
- "source": [
986
- "# import matplotlib.pyplot as plt\n",
987
- "# for tensor_image, _ in dataloader:\n",
988
- "# print(tensor_image)\n",
989
- "# plt.imshow(tensor_image.permute(1, 2, 0))\n",
990
- "# break"
991
- ]
992
- },
993
- {
994
- "cell_type": "markdown",
995
- "id": "62ad01c3",
996
- "metadata": {},
997
- "source": [
998
- "## Encoding"
999
- ]
1000
- },
1001
- {
1002
- "cell_type": "code",
1003
- "execution_count": 158,
1004
- "id": "88f36d0b",
1005
- "metadata": {},
1006
- "outputs": [],
1007
- "source": [
1008
- "def encode(model, batch):\n",
1009
- "# print(\"jitting encode function\")\n",
1010
- " _, indices = model.encode(batch)\n",
1011
- " return indices"
1012
- ]
1013
- },
1014
- {
1015
- "cell_type": "code",
1016
- "execution_count": 160,
1017
- "id": "1f35f0cb",
1018
- "metadata": {},
1019
- "outputs": [],
1020
- "source": [
1021
- "def superbatch_generator(dataloader, num_tpus):\n",
1022
- " iter_loader = iter(dataloader)\n",
1023
- " for batch in iter_loader:\n",
1024
- " superbatch = [batch.squeeze(1)]\n",
1025
- " try:\n",
1026
- " for b in range(num_tpus-1):\n",
1027
- " batch = next(iter_loader)\n",
1028
- " if batch is None:\n",
1029
- " break\n",
1030
- " # Skip incomplete last batch\n",
1031
- " if batch.shape[0] == dataloader.batch_size:\n",
1032
- " superbatch.append(batch.squeeze(1))\n",
1033
- " except StopIteration:\n",
1034
- " pass\n",
1035
- " superbatch = torch.stack(superbatch, axis=0)\n",
1036
- " yield superbatch"
1037
- ]
1038
- },
1039
- {
1040
- "cell_type": "code",
1041
- "execution_count": 170,
1042
- "id": "2210705b",
1043
- "metadata": {},
1044
- "outputs": [],
1045
- "source": [
1046
- "import os\n",
1047
- "\n",
1048
- "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
1049
- " if os.path.isfile(output_tsv):\n",
1050
- " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
1051
- " return\n",
1052
- " \n",
1053
- " num_tpus = 8 \n",
1054
- " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
1055
- " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
1056
- " \n",
1057
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
1058
- "\n",
1059
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
1060
- " # We keep the file open to prevent excessive file seeks.\n",
1061
- " with open(output_tsv, \"w\") as file:\n",
1062
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
1063
- " for n in tqdm(range(iterations)):\n",
1064
- " superbatch = next(superbatches)\n",
1065
- " encoded = p_encoder(superbatch.numpy())\n",
1066
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
1067
- "\n",
1068
- " # Extract fields from the dataset internal `captions` property, and save to disk\n",
1069
- " start_index = n * batch_size * num_tpus\n",
1070
- " end_index = (n+1) * batch_size * num_tpus\n",
1071
- " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
1072
- " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
1073
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
1074
- " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
1075
- " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)"
1076
- ]
1077
- },
1078
- {
1079
- "cell_type": "code",
1080
- "execution_count": 171,
1081
- "id": "7704863d",
1082
- "metadata": {},
1083
- "outputs": [
1084
- {
1085
- "name": "stderr",
1086
- "output_type": "stream",
1087
- "text": [
1088
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00, 1.83s/it]\n"
1089
- ]
1090
- }
1091
- ],
1092
- "source": [
1093
- "encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)"
1094
- ]
1095
- },
1096
- {
1097
- "cell_type": "markdown",
1098
- "id": "8953dd84",
1099
- "metadata": {},
1100
- "source": [
1101
- "----"
1102
- ]
1103
- }
1104
- ],
1105
- "metadata": {
1106
- "interpreter": {
1107
- "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
1108
- },
1109
- "kernelspec": {
1110
- "display_name": "Python 3 (ipykernel)",
1111
- "language": "python",
1112
- "name": "python3"
1113
- },
1114
- "language_info": {
1115
- "codemirror_mode": {
1116
- "name": "ipython",
1117
- "version": 3
1118
- },
1119
- "file_extension": ".py",
1120
- "mimetype": "text/x-python",
1121
- "name": "python",
1122
- "nbconvert_exporter": "python",
1123
- "pygments_lexer": "ipython3",
1124
- "version": "3.8.10"
1125
- }
1126
- },
1127
- "nbformat": 4,
1128
- "nbformat_minor": 5
1129
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/encoding/vqgan-jax-encoding.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
dev/environment.yaml DELETED
@@ -1,10 +0,0 @@
1
- name: dalle
2
- channels:
3
- - defaults
4
- dependencies:
5
- - python=3.9.5
6
- - pip=21.1.3
7
- - ipython=7.22.0
8
- - cudatoolkit
9
- - pip:
10
- - -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
dev/requirements.txt DELETED
@@ -1,14 +0,0 @@
1
- requests
2
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- jax[tpu]>=0.2.16
4
- transformers
5
- datasets
6
- flax
7
- jupyter
8
- wandb
9
- nltk
10
- optax
11
- git+https://github.com/patil-suraj/vqgan-jax.git@610d842dd33c739325a944102ed33acc07692dd5
12
-
13
- # Inference
14
- ftfy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/seq2seq/do_big_run.sh DELETED
@@ -1,21 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --dataset_repo_or_path dalle-mini/encoded \
3
- --train_file **/train/*/*.jsonl \
4
- --validation_file **/valid/*/*.jsonl \
5
- --len_train 129847128 \
6
- --len_eval 157312 \
7
- --eval_steps 1000 \
8
- --streaming \
9
- --normalize_text \
10
- --output_dir output \
11
- --per_device_train_batch_size 56 \
12
- --per_device_eval_batch_size 56 \
13
- --preprocessing_num_workers 80 \
14
- --warmup_steps 5000 \
15
- --gradient_accumulation_steps 8 \
16
- --do_train \
17
- --do_eval \
18
- --adafactor \
19
- --num_train_epochs 6 \
20
- --log_model \
21
- --learning_rate 0.005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/seq2seq/do_small_run.sh DELETED
@@ -1,19 +0,0 @@
1
- python run_seq2seq_flax.py \
2
- --dataset_repo_or_path dalle-mini/encoded \
3
- --train_file **/train/CC3M/*.jsonl \
4
- --validation_file **/valid/*/*.jsonl \
5
- --len_train 129847128 \
6
- --len_eval 157312 \
7
- --streaming \
8
- --output_dir output \
9
- --per_device_train_batch_size 16 \
10
- --per_device_eval_batch_size 16 \
11
- --preprocessing_num_workers 80 \
12
- --warmup_steps 125 \
13
- --gradient_accumulation_steps 8 \
14
- --do_train \
15
- --do_eval \
16
- --adafactor \
17
- --num_train_epochs 1 \
18
- --max_train_samples 10000 \
19
- --learning_rate 0.005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/vqgan/JAX_VQGAN_f16_16384_Reconstruction.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
setup.cfg CHANGED
@@ -16,3 +16,11 @@ install_requires =
16
  ftfy
17
  jax
18
  flax
 
 
 
 
 
 
 
 
 
16
  ftfy
17
  jax
18
  flax
19
+
20
+ [options.extras_require]
21
+ dev =
22
+ tqdm
23
+ wandb
24
+ optax
25
+ black[jupyter]
26
+ isort
setup.py CHANGED
@@ -1,4 +1,4 @@
1
  from setuptools import setup
2
 
3
  if __name__ == "__main__":
4
- setup()
 
1
  from setuptools import setup
2
 
3
  if __name__ == "__main__":
4
+ setup()
dev/encoding/vqgan-jax-encoding-webdataset.ipynb → tools/dataset/encode_dataset.ipynb RENAMED
@@ -5,7 +5,7 @@
5
  "id": "d0b72877",
6
  "metadata": {},
7
  "source": [
8
- "# VQGAN JAX Encoding for `webdataset`"
9
  ]
10
  },
11
  {
@@ -15,7 +15,11 @@
15
  "source": [
16
  "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
  "\n",
18
- "This example uses a small subset of YFCC100M we created for testing, but it should be easy to adapt to any other image/caption dataset in the `webdataset` format."
 
 
 
 
19
  ]
20
  },
21
  {
@@ -25,19 +29,15 @@
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
28
- "import numpy as np\n",
29
- "from tqdm import tqdm\n",
30
  "\n",
31
- "import torch\n",
32
  "import torchvision.transforms as T\n",
33
- "import torchvision.transforms.functional as TF\n",
34
- "from torchvision.transforms import InterpolationMode\n",
35
- "import math\n",
36
  "\n",
37
  "import webdataset as wds\n",
38
  "\n",
39
  "import jax\n",
40
- "from jax import pmap"
 
41
  ]
42
  },
43
  {
@@ -45,184 +45,110 @@
45
  "id": "c7c4c1e6",
46
  "metadata": {},
47
  "source": [
48
- "## Dataset and Parameters"
49
- ]
50
- },
51
- {
52
- "cell_type": "markdown",
53
- "id": "9822850f",
54
- "metadata": {},
55
- "source": [
56
- "The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": null,
62
  "id": "1265dbfe",
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
66
- "shards = 'https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset/resolve/main/data/shard-{0000..0008}.tar'\n",
67
- "length = 8320"
68
- ]
69
- },
70
- {
71
- "cell_type": "markdown",
72
- "id": "7e38fa14",
73
- "metadata": {},
74
- "source": [
75
- "If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
76
- ]
77
- },
78
- {
79
- "cell_type": "code",
80
- "execution_count": null,
81
- "id": "4c8c5960",
82
- "metadata": {},
83
- "outputs": [],
84
- "source": [
85
- "# Enable curl retries to try to work around temporary network / server errors.\n",
86
- "# This shouldn't be necessary when using reliable servers.\n",
87
- "# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "id": "13c6631b",
94
- "metadata": {},
95
- "outputs": [],
96
- "source": [
97
- "from pathlib import Path\n",
98
  "\n",
99
- "# Output directory for encoded files\n",
100
- "encoded_output = Path.home()/'data'/'wds'/'encoded'\n",
 
 
101
  "\n",
102
- "batch_size = 128 # Per device\n",
103
- "num_workers = 8 # For parallel processing"
 
 
 
104
  ]
105
  },
106
  {
107
  "cell_type": "code",
108
- "execution_count": null,
109
- "id": "3435fb85",
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
114
- "batches = math.ceil(length / bs)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ]
116
  },
117
  {
118
  "cell_type": "markdown",
119
- "id": "88598e4b",
120
  "metadata": {},
121
  "source": [
122
- "Image processing"
123
- ]
124
- },
125
- {
126
- "cell_type": "code",
127
- "execution_count": null,
128
- "id": "669b35df",
129
- "metadata": {},
130
- "outputs": [],
131
- "source": [
132
- "def center_crop(image, max_size=256):\n",
133
- " # Note: we allow upscaling too. We should exclude small images. \n",
134
- " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
135
- " image = TF.center_crop(image, output_size=2 * [max_size])\n",
136
- " return image\n",
137
- "\n",
138
- "preprocess_image = T.Compose([\n",
139
- " center_crop,\n",
140
- " T.ToTensor(),\n",
141
- " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
142
- "])"
143
  ]
144
  },
145
  {
146
  "cell_type": "markdown",
147
- "id": "a185e90c",
148
  "metadata": {},
149
  "source": [
150
- "Caption preparation.\n",
151
- "\n",
152
- "Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
153
- "If we want to keep other fields inside `json`, we can add `caption` as a new field."
154
  ]
155
  },
156
  {
157
  "cell_type": "code",
158
  "execution_count": null,
159
- "id": "423ee10e",
160
  "metadata": {},
161
  "outputs": [],
162
  "source": [
163
- "def create_caption(item):\n",
164
- " title = item['title_clean'].strip()\n",
165
- " description = item['description_clean'].strip()\n",
166
- " if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
167
- " return f'{title} {description}'"
 
168
  ]
169
  },
170
  {
171
  "cell_type": "markdown",
172
- "id": "8d3a95db",
173
  "metadata": {},
174
  "source": [
175
- "When an error occurs (a download is disconnected, an image cannot be decoded, etc) the process stops with an exception. We can use one of the exception handlers provided by the `webdataset` library, such as `wds.warn_and_continue` or `wds.ignore_and_continue` to ignore the offending entry and keep iterating.\n",
176
- "\n",
177
- "**IMPORTANT WARNING:** Do not use error handlers to ignore exceptions until you have tested that your processing pipeline works fine. Otherwise, the process will continue trying to find a valid entry, and it will consume your whole dataset without doing any work.\n",
178
- "\n",
179
- "We can also create our custom exception handler as demonstrated here:"
180
  ]
181
  },
182
  {
183
- "cell_type": "code",
184
- "execution_count": null,
185
- "id": "369d9719",
186
- "metadata": {},
187
- "outputs": [],
188
- "source": [
189
- "# UNUSED - Log exceptions to a file\n",
190
- "def ignore_and_log(exn):\n",
191
- " with open('errors.txt', 'a') as f:\n",
192
- " f.write(f'{repr(exn)}\\n')\n",
193
- " return True"
194
- ]
195
- },
196
- {
197
- "cell_type": "code",
198
- "execution_count": null,
199
- "id": "27de1414",
200
- "metadata": {},
201
- "outputs": [],
202
- "source": [
203
- "# Or simply use `wds.ignore_and_continue`\n",
204
- "exception_handler = wds.warn_and_continue"
205
- ]
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "id": "5149b6d5",
211
  "metadata": {},
212
- "outputs": [],
213
  "source": [
214
- "dataset = wds.WebDataset(shards,\n",
215
- " length=batches, # Hint so `len` is implemented\n",
216
- " shardshuffle=False, # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
217
- " handler=exception_handler, # Ignore read errors instead of failing.\n",
218
- ")\n",
219
- "\n",
220
- "dataset = (dataset \n",
221
- " .decode('pil') # decode image with PIL\n",
222
- "# .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler) # Process fields with functions defined above\n",
223
- " .map_dict(jpg=preprocess_image, json=create_caption) # Process fields with functions defined above\n",
224
- " .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
225
- " .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
226
  ]
227
  },
228
  {
@@ -235,7 +161,7 @@
235
  "outputs": [],
236
  "source": [
237
  "%%time\n",
238
- "keys, images, captions = next(iter(dataset))"
239
  ]
240
  },
241
  {
@@ -251,54 +177,50 @@
251
  {
252
  "cell_type": "code",
253
  "execution_count": null,
254
- "id": "c24693c0",
255
  "metadata": {},
256
  "outputs": [],
257
  "source": [
258
- "T.ToPILImage()(images[0].permute(2, 0, 1))"
259
- ]
260
- },
261
- {
262
- "cell_type": "markdown",
263
- "id": "44d50a51",
264
- "metadata": {},
265
- "source": [
266
- "### Torch DataLoader"
267
  ]
268
  },
269
  {
270
  "cell_type": "code",
271
  "execution_count": null,
272
- "id": "e2df5e13",
273
  "metadata": {},
274
  "outputs": [],
275
  "source": [
276
- "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
277
  ]
278
  },
279
  {
280
  "cell_type": "markdown",
281
- "id": "a354472b",
282
  "metadata": {},
283
  "source": [
284
- "## VQGAN-JAX model"
285
  ]
286
  },
287
  {
288
  "cell_type": "code",
289
  "execution_count": null,
290
- "id": "2fcf01d7",
291
  "metadata": {},
292
  "outputs": [],
293
  "source": [
294
- "from vqgan_jax.modeling_flax_vqgan import VQModel"
 
 
295
  ]
296
  },
297
  {
298
  "cell_type": "markdown",
299
- "id": "9daa636d",
300
  "metadata": {},
301
  "source": [
 
 
302
  "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
303
  ]
304
  },
@@ -311,7 +233,11 @@
311
  },
312
  "outputs": [],
313
  "source": [
314
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
 
 
 
 
315
  ]
316
  },
317
  {
@@ -327,18 +253,7 @@
327
  "id": "20357f74",
328
  "metadata": {},
329
  "source": [
330
- "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
331
- ]
332
- },
333
- {
334
- "cell_type": "code",
335
- "execution_count": null,
336
- "id": "6686b004",
337
- "metadata": {},
338
- "outputs": [],
339
- "source": [
340
- "from flax.training.common_utils import shard\n",
341
- "from functools import partial"
342
  ]
343
  },
344
  {
@@ -348,21 +263,17 @@
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
 
 
 
 
351
  "@partial(jax.pmap, axis_name=\"batch\")\n",
352
- "def encode(batch):\n",
353
  " # Not sure if we should `replicate` params, does not seem to have any effect\n",
354
- " _, indices = model.encode(batch)\n",
355
  " return indices"
356
  ]
357
  },
358
- {
359
- "cell_type": "markdown",
360
- "id": "14375a41",
361
- "metadata": {},
362
- "source": [
363
- "### Encoding loop"
364
- ]
365
- },
366
  {
367
  "cell_type": "code",
368
  "execution_count": null,
@@ -370,49 +281,48 @@
370
  "metadata": {},
371
  "outputs": [],
372
  "source": [
373
- "import os\n",
374
  "import pandas as pd\n",
375
  "\n",
376
- "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
377
- " output_dir.mkdir(parents=True, exist_ok=True)\n",
378
  "\n",
379
- " # Saving strategy:\n",
380
- " # - Create a new file every so often to prevent excessive file seeking.\n",
381
- " # - Save each batch after processing.\n",
382
- " # - Keep the file open until we are done with it.\n",
383
- " file = None \n",
384
- " for n, (keys, images, captions) in enumerate(tqdm(dataloader)):\n",
385
- " if (n % save_every == 0):\n",
386
- " if file is not None:\n",
387
- " file.close()\n",
388
- " split_num = n // save_every\n",
389
- " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
390
- "\n",
391
- " images = shard(images.numpy().squeeze())\n",
392
- " encoded = encode(images)\n",
 
 
 
 
393
  " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
 
 
394
  "\n",
395
- " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
396
- " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
397
- " batch_df.to_json(file, orient='records', lines=True)"
398
- ]
399
- },
400
- {
401
- "cell_type": "markdown",
402
- "id": "09ff75a3",
403
- "metadata": {},
404
- "source": [
405
- "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
406
- ]
407
- },
408
- {
409
- "cell_type": "code",
410
- "execution_count": null,
411
- "id": "96222bb4",
412
- "metadata": {},
413
- "outputs": [],
414
- "source": [
415
- "save_every = 318"
416
  ]
417
  },
418
  {
@@ -422,7 +332,7 @@
422
  "metadata": {},
423
  "outputs": [],
424
  "source": [
425
- "encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
426
  ]
427
  },
428
  {
@@ -453,7 +363,7 @@
453
  "name": "python",
454
  "nbconvert_exporter": "python",
455
  "pygments_lexer": "ipython3",
456
- "version": "3.8.10"
457
  }
458
  },
459
  "nbformat": 4,
 
5
  "id": "d0b72877",
6
  "metadata": {},
7
  "source": [
8
+ "# Pre-encoding a dataset for DALLE·mini"
9
  ]
10
  },
11
  {
 
15
  "source": [
16
  "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
  "\n",
18
+ "Adapt it to your own dataset and image encoder.\n",
19
+ "\n",
20
+ "At the end you should have a dataset of pairs:\n",
21
+ "* a caption defined as a string\n",
22
+ "* an encoded image defined as a list of int."
23
  ]
24
  },
25
  {
 
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
32
+ "from tqdm.notebook import tqdm\n",
 
33
  "\n",
 
34
  "import torchvision.transforms as T\n",
 
 
 
35
  "\n",
36
  "import webdataset as wds\n",
37
  "\n",
38
  "import jax\n",
39
+ "import braceexpand\n",
40
+ "from pathlib import Path"
41
  ]
42
  },
43
  {
 
45
  "id": "c7c4c1e6",
46
  "metadata": {},
47
  "source": [
48
+ "## Configuration Parameters"
 
 
 
 
 
 
 
 
49
  ]
50
  },
51
  {
52
  "cell_type": "code",
53
+ "execution_count": 3,
54
  "id": "1265dbfe",
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
58
+ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
59
+ "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  "\n",
61
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
62
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
63
+ " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
64
+ ")\n",
65
  "\n",
66
+ "# good defaults for a TPU v3-8\n",
67
+ "batch_size = 128 # Per device\n",
68
+ "num_workers = 8 # For parallel processing\n",
69
+ "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
70
+ "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
71
  ]
72
  },
73
  {
74
  "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "['XXX/shard-0000.tar',\n",
83
+ " 'XXX/shard-0001.tar',\n",
84
+ " 'XXX/shard-0002.tar',\n",
85
+ " 'XXX/shard-0003.tar',\n",
86
+ " 'XXX/shard-0004.tar',\n",
87
+ " 'XXX/shard-0005.tar',\n",
88
+ " 'XXX/shard-0006.tar',\n",
89
+ " 'XXX/shard-0007.tar',\n",
90
+ " 'XXX/shard-0008.tar']"
91
+ ]
92
+ },
93
+ "execution_count": 5,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "shards = list(\n",
100
+ " braceexpand.braceexpand(shards)\n",
101
+ ") # better display for tqdm with known length"
102
  ]
103
  },
104
  {
105
  "cell_type": "markdown",
106
+ "id": "75dba8e2",
107
  "metadata": {},
108
  "source": [
109
+ "## Load data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  ]
111
  },
112
  {
113
  "cell_type": "markdown",
114
+ "id": "a1e8fb95",
115
  "metadata": {},
116
  "source": [
117
+ "We load data using `webdataset`."
 
 
 
118
  ]
119
  },
120
  {
121
  "cell_type": "code",
122
  "execution_count": null,
123
+ "id": "9ef5de9e",
124
  "metadata": {},
125
  "outputs": [],
126
  "source": [
127
+ "ds = (\n",
128
+ " wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
129
+ " .decode(\"rgb\", handler=wds.warn_and_continue)\n",
130
+ " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
131
+ " .batched(total_bs) # load in batch per worker (faster)\n",
132
+ ")"
133
  ]
134
  },
135
  {
136
  "cell_type": "markdown",
137
+ "id": "90981824",
138
  "metadata": {},
139
  "source": [
140
+ "Note:\n",
141
+ "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
142
+ "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
143
+ "* you can also filter out some items using `select`."
 
144
  ]
145
  },
146
  {
147
+ "cell_type": "markdown",
148
+ "id": "129c377d",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  "metadata": {},
 
150
  "source": [
151
+ "We can now inspect our data."
 
 
 
 
 
 
 
 
 
 
 
152
  ]
153
  },
154
  {
 
161
  "outputs": [],
162
  "source": [
163
  "%%time\n",
164
+ "images, captions = next(iter(ds))"
165
  ]
166
  },
167
  {
 
177
  {
178
  "cell_type": "code",
179
  "execution_count": null,
180
+ "id": "5acfc4d8",
181
  "metadata": {},
182
  "outputs": [],
183
  "source": [
184
+ "captions[:10]"
 
 
 
 
 
 
 
 
185
  ]
186
  },
187
  {
188
  "cell_type": "code",
189
  "execution_count": null,
190
+ "id": "c24693c0",
191
  "metadata": {},
192
  "outputs": [],
193
  "source": [
194
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
195
  ]
196
  },
197
  {
198
  "cell_type": "markdown",
199
+ "id": "3059ffb1",
200
  "metadata": {},
201
  "source": [
202
+ "Finally we create our dataloader."
203
  ]
204
  },
205
  {
206
  "cell_type": "code",
207
  "execution_count": null,
208
+ "id": "c227c551",
209
  "metadata": {},
210
  "outputs": [],
211
  "source": [
212
+ "dl = (\n",
213
+ " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
214
+ ") # avoid partial batch at the end of each worker"
215
  ]
216
  },
217
  {
218
  "cell_type": "markdown",
219
+ "id": "a354472b",
220
  "metadata": {},
221
  "source": [
222
+ "## Image encoder\n",
223
+ "\n",
224
  "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
225
  ]
226
  },
 
233
  },
234
  "outputs": [],
235
  "source": [
236
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
237
+ "from flax.jax_utils import replicate\n",
238
+ "\n",
239
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
240
+ "vqgan_params = replicate(vqgan.params)"
241
  ]
242
  },
243
  {
 
253
  "id": "20357f74",
254
  "metadata": {},
255
  "source": [
256
+ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
 
 
 
 
 
 
 
 
 
 
 
257
  ]
258
  },
259
  {
 
263
  "metadata": {},
264
  "outputs": [],
265
  "source": [
266
+ "from flax.training.common_utils import shard\n",
267
+ "from functools import partial\n",
268
+ "\n",
269
+ "\n",
270
  "@partial(jax.pmap, axis_name=\"batch\")\n",
271
+ "def p_encode(batch, params):\n",
272
  " # Not sure if we should `replicate` params, does not seem to have any effect\n",
273
+ " _, indices = vqgan.encode(batch, params=params)\n",
274
  " return indices"
275
  ]
276
  },
 
 
 
 
 
 
 
 
277
  {
278
  "cell_type": "code",
279
  "execution_count": null,
 
281
  "metadata": {},
282
  "outputs": [],
283
  "source": [
 
284
  "import pandas as pd\n",
285
  "\n",
 
 
286
  "\n",
287
+ "def encode_dataset(dataloader, output_dir, save_frequency):\n",
288
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
289
+ " all_captions = []\n",
290
+ " all_encoding = []\n",
291
+ " n_file = 1\n",
292
+ " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
293
+ " images = images.numpy()\n",
294
+ " n = len(images) // 8 * 8\n",
295
+ " if n != len(images):\n",
296
+ " # get the max number of images we can (multiple of 8)\n",
297
+ " print(f\"Different sizes {n} vs {len(images)}\")\n",
298
+ " images = images[:n]\n",
299
+ " captions = captions[:n]\n",
300
+ " if not len(captions):\n",
301
+ " print(f\"No images/captions in batch...\")\n",
302
+ " continue\n",
303
+ " images = shard(images)\n",
304
+ " encoded = p_encode(images, vqgan_params)\n",
305
  " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
306
+ " all_captions.extend(captions)\n",
307
+ " all_encoding.extend(encoded.tolist())\n",
308
  "\n",
309
+ " # save files\n",
310
+ " if (idx + 1) % save_frequency == 0:\n",
311
+ " print(f\"Saving file {n_file}\")\n",
312
+ " batch_df = pd.DataFrame.from_dict(\n",
313
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
314
+ " )\n",
315
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
316
+ " all_captions = []\n",
317
+ " all_encoding = []\n",
318
+ " n_file += 1\n",
319
+ "\n",
320
+ " if len(all_captions):\n",
321
+ " print(f\"Saving final file {n_file}\")\n",
322
+ " batch_df = pd.DataFrame.from_dict(\n",
323
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
324
+ " )\n",
325
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
 
 
 
 
326
  ]
327
  },
328
  {
 
332
  "metadata": {},
333
  "outputs": [],
334
  "source": [
335
+ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
336
  ]
337
  },
338
  {
 
363
  "name": "python",
364
  "nbconvert_exporter": "python",
365
  "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
  }
368
  },
369
  "nbformat": 4,
tools/inference/inference_pipeline.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
tools/inference/log_inference_samples.ipynb CHANGED
@@ -31,11 +31,14 @@
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
- "run_ids = ['63otg87g']\n",
35
- "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
36
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
37
- "latest_only = True # log only latest or all versions\n",
38
- "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
 
 
 
39
  "add_clip_32 = False"
40
  ]
41
  },
@@ -63,8 +66,8 @@
63
  "num_images = 128\n",
64
  "top_k = 8\n",
65
  "text_normalizer = TextNormalizer()\n",
66
- "padding_item = 'NONE'\n",
67
- "seed = random.randint(0, 2**32-1)\n",
68
  "key = jax.random.PRNGKey(seed)\n",
69
  "api = wandb.Api()"
70
  ]
@@ -100,12 +103,15 @@
100
  "def p_decode(indices, params):\n",
101
  " return vqgan.decode_code(indices, params=params)\n",
102
  "\n",
 
103
  "@partial(jax.pmap, axis_name=\"batch\")\n",
104
  "def p_clip16(inputs, params):\n",
105
  " logits = clip16(params=params, **inputs).logits_per_image\n",
106
  " return logits\n",
107
  "\n",
 
108
  "if add_clip_32:\n",
 
109
  " @partial(jax.pmap, axis_name=\"batch\")\n",
110
  " def p_clip32(inputs, params):\n",
111
  " logits = clip32(params=params, **inputs).logits_per_image\n",
@@ -119,13 +125,13 @@
119
  "metadata": {},
120
  "outputs": [],
121
  "source": [
122
- "with open('samples.txt', encoding='utf8') as f:\n",
123
  " samples = [l.strip() for l in f.readlines()]\n",
124
  " # make list multiple of batch_size by adding elements\n",
125
  " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
126
  " samples.extend(samples_to_add)\n",
127
  " # reshape\n",
128
- " samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
129
  ]
130
  },
131
  {
@@ -138,9 +144,17 @@
138
  "def get_artifact_versions(run_id, latest_only=False):\n",
139
  " try:\n",
140
  " if latest_only:\n",
141
- " return [api.artifact(type='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}:latest')]\n",
 
 
 
 
142
  " else:\n",
143
- " return api.artifact_versions(type_name='bart_model', name=f'{ENTITY}/{PROJECT}/model-{run_id}', per_page=10000)\n",
 
 
 
 
144
  " except:\n",
145
  " return []"
146
  ]
@@ -153,7 +167,7 @@
153
  "outputs": [],
154
  "source": [
155
  "def get_training_config(run_id):\n",
156
- " training_run = api.run(f'{ENTITY}/{PROJECT}/{run_id}')\n",
157
  " config = training_run.config\n",
158
  " return config"
159
  ]
@@ -168,8 +182,8 @@
168
  "# retrieve inference run details\n",
169
  "def get_last_inference_version(run_id):\n",
170
  " try:\n",
171
- " inference_run = api.run(f'dalle-mini/dalle-mini/{run_id}-clip16{suffix}')\n",
172
- " return inference_run.summary.get('version', None)\n",
173
  " except:\n",
174
  " return None"
175
  ]
@@ -183,7 +197,6 @@
183
  "source": [
184
  "# compile functions - needed only once per run\n",
185
  "def pmap_model_function(model):\n",
186
- " \n",
187
  " @partial(jax.pmap, axis_name=\"batch\")\n",
188
  " def _generate(tokenized_prompt, key, params):\n",
189
  " return model.generate(\n",
@@ -195,7 +208,7 @@
195
  " top_k=gen_top_k,\n",
196
  " top_p=gen_top_p\n",
197
  " )\n",
198
- " \n",
199
  " return _generate"
200
  ]
201
  },
@@ -222,13 +235,21 @@
222
  "training_config = get_training_config(run_id)\n",
223
  "run = None\n",
224
  "p_generate = None\n",
225
- "model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
 
 
 
 
 
 
 
 
226
  "for artifact in artifact_versions:\n",
227
- " print(f'Processing artifact: {artifact.name}')\n",
228
  " version = int(artifact.version[1:])\n",
229
  " results16, results32 = [], []\n",
230
- " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
231
- " \n",
232
  " if latest_only:\n",
233
  " assert last_inference_version is None or version > last_inference_version\n",
234
  " else:\n",
@@ -236,14 +257,23 @@
236
  " # we should start from v0\n",
237
  " assert version == 0\n",
238
  " elif version <= last_inference_version:\n",
239
- " print(f'v{version} has already been logged (versions logged up to v{last_inference_version}')\n",
 
 
240
  " else:\n",
241
  " # check we are logging the correct version\n",
242
  " assert version == last_inference_version + 1\n",
243
  "\n",
244
  " # start/resume corresponding run\n",
245
  " if run is None:\n",
246
- " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip16{suffix}', resume='allow')\n",
 
 
 
 
 
 
 
247
  "\n",
248
  " # work in temporary directory\n",
249
  " with tempfile.TemporaryDirectory() as tmp:\n",
@@ -264,64 +294,109 @@
264
  "\n",
265
  " # process one batch of captions\n",
266
  " for batch in tqdm(samples):\n",
267
- " processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
 
 
 
 
268
  "\n",
269
  " # repeat the prompts to distribute over each device and tokenize\n",
270
  " processed_prompts = processed_prompts * jax.device_count()\n",
271
- " tokenized_prompt = tokenizer(processed_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
 
 
 
 
 
 
272
  " tokenized_prompt = shard(tokenized_prompt)\n",
273
  "\n",
274
  " # generate images\n",
275
  " images = []\n",
276
- " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
 
 
 
 
277
  " for i in pbar:\n",
278
  " key, subkey = jax.random.split(key)\n",
279
- " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
 
 
280
  " encoded_images = encoded_images.sequences[..., 1:]\n",
281
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
282
- " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
 
 
283
  " for img in decoded_images:\n",
284
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
 
 
285
  "\n",
286
- " def add_clip_results(results, processor, p_clip, clip_params): \n",
287
- " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
288
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
289
- " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
290
- " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
 
 
 
 
 
 
 
291
  " clip_inputs = shard(clip_inputs)\n",
292
  " logits = p_clip(clip_inputs, clip_params)\n",
293
  " logits = logits.reshape(-1, num_images)\n",
294
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
295
  " logits = jax.device_get(logits)\n",
296
  " # add to results table\n",
297
- " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
298
- " if sample == padding_item: continue\n",
 
 
 
299
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
300
- " top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
 
 
 
301
  " results.append([sample] + top_images)\n",
302
- " \n",
303
  " # get clip scores\n",
304
- " pbar.set_description('Calculating CLIP 16 scores')\n",
305
  " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
306
- " \n",
307
  " # get clip 32 scores\n",
308
  " if add_clip_32:\n",
309
- " pbar.set_description('Calculating CLIP 32 scores')\n",
310
  " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
311
  "\n",
312
  " pbar.close()\n",
313
  "\n",
314
- " \n",
315
- "\n",
316
  " # log results\n",
317
  " table = wandb.Table(columns=columns, data=results16)\n",
318
- " run.log({'Samples': table, 'version': version})\n",
319
  " wandb.finish()\n",
320
- " \n",
321
- " if add_clip_32: \n",
322
- " run = wandb.init(job_type='inference', entity='dalle-mini', project='dalle-mini', config=training_config, id=f'{run_id}-clip32{suffix}', resume='allow')\n",
 
 
 
 
 
 
 
323
  " table = wandb.Table(columns=columns, data=results32)\n",
324
- " run.log({'Samples': table, 'version': version})\n",
325
  " wandb.finish()\n",
326
  " run = None # ensure we don't log on this run"
327
  ]
 
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "run_ids = [\"63otg87g\"]\n",
35
+ "ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
37
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
38
+ " \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
39
+ ")\n",
40
+ "latest_only = True # log only latest or all versions\n",
41
+ "suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
42
  "add_clip_32 = False"
43
  ]
44
  },
 
66
  "num_images = 128\n",
67
  "top_k = 8\n",
68
  "text_normalizer = TextNormalizer()\n",
69
+ "padding_item = \"NONE\"\n",
70
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
71
  "key = jax.random.PRNGKey(seed)\n",
72
  "api = wandb.Api()"
73
  ]
 
103
  "def p_decode(indices, params):\n",
104
  " return vqgan.decode_code(indices, params=params)\n",
105
  "\n",
106
+ "\n",
107
  "@partial(jax.pmap, axis_name=\"batch\")\n",
108
  "def p_clip16(inputs, params):\n",
109
  " logits = clip16(params=params, **inputs).logits_per_image\n",
110
  " return logits\n",
111
  "\n",
112
+ "\n",
113
  "if add_clip_32:\n",
114
+ "\n",
115
  " @partial(jax.pmap, axis_name=\"batch\")\n",
116
  " def p_clip32(inputs, params):\n",
117
  " logits = clip32(params=params, **inputs).logits_per_image\n",
 
125
  "metadata": {},
126
  "outputs": [],
127
  "source": [
128
+ "with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
129
  " samples = [l.strip() for l in f.readlines()]\n",
130
  " # make list multiple of batch_size by adding elements\n",
131
  " samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
132
  " samples.extend(samples_to_add)\n",
133
  " # reshape\n",
134
+ " samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
135
  ]
136
  },
137
  {
 
144
  "def get_artifact_versions(run_id, latest_only=False):\n",
145
  " try:\n",
146
  " if latest_only:\n",
147
+ " return [\n",
148
+ " api.artifact(\n",
149
+ " type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
150
+ " )\n",
151
+ " ]\n",
152
  " else:\n",
153
+ " return api.artifact_versions(\n",
154
+ " type_name=\"bart_model\",\n",
155
+ " name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
156
+ " per_page=10000,\n",
157
+ " )\n",
158
  " except:\n",
159
  " return []"
160
  ]
 
167
  "outputs": [],
168
  "source": [
169
  "def get_training_config(run_id):\n",
170
+ " training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
171
  " config = training_run.config\n",
172
  " return config"
173
  ]
 
182
  "# retrieve inference run details\n",
183
  "def get_last_inference_version(run_id):\n",
184
  " try:\n",
185
+ " inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
186
+ " return inference_run.summary.get(\"version\", None)\n",
187
  " except:\n",
188
  " return None"
189
  ]
 
197
  "source": [
198
  "# compile functions - needed only once per run\n",
199
  "def pmap_model_function(model):\n",
 
200
  " @partial(jax.pmap, axis_name=\"batch\")\n",
201
  " def _generate(tokenized_prompt, key, params):\n",
202
  " return model.generate(\n",
 
208
  " top_k=gen_top_k,\n",
209
  " top_p=gen_top_p\n",
210
  " )\n",
211
+ "\n",
212
  " return _generate"
213
  ]
214
  },
 
235
  "training_config = get_training_config(run_id)\n",
236
  "run = None\n",
237
  "p_generate = None\n",
238
+ "model_files = [\n",
239
+ " \"config.json\",\n",
240
+ " \"flax_model.msgpack\",\n",
241
+ " \"merges.txt\",\n",
242
+ " \"special_tokens_map.json\",\n",
243
+ " \"tokenizer.json\",\n",
244
+ " \"tokenizer_config.json\",\n",
245
+ " \"vocab.json\",\n",
246
+ "]\n",
247
  "for artifact in artifact_versions:\n",
248
+ " print(f\"Processing artifact: {artifact.name}\")\n",
249
  " version = int(artifact.version[1:])\n",
250
  " results16, results32 = [], []\n",
251
+ " columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
252
+ "\n",
253
  " if latest_only:\n",
254
  " assert last_inference_version is None or version > last_inference_version\n",
255
  " else:\n",
 
257
  " # we should start from v0\n",
258
  " assert version == 0\n",
259
  " elif version <= last_inference_version:\n",
260
+ " print(\n",
261
+ " f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
262
+ " )\n",
263
  " else:\n",
264
  " # check we are logging the correct version\n",
265
  " assert version == last_inference_version + 1\n",
266
  "\n",
267
  " # start/resume corresponding run\n",
268
  " if run is None:\n",
269
+ " run = wandb.init(\n",
270
+ " job_type=\"inference\",\n",
271
+ " entity=\"dalle-mini\",\n",
272
+ " project=\"dalle-mini\",\n",
273
+ " config=training_config,\n",
274
+ " id=f\"{run_id}-clip16{suffix}\",\n",
275
+ " resume=\"allow\",\n",
276
+ " )\n",
277
  "\n",
278
  " # work in temporary directory\n",
279
  " with tempfile.TemporaryDirectory() as tmp:\n",
 
294
  "\n",
295
  " # process one batch of captions\n",
296
  " for batch in tqdm(samples):\n",
297
+ " processed_prompts = (\n",
298
+ " [text_normalizer(x) for x in batch]\n",
299
+ " if model.config.normalize_text\n",
300
+ " else list(batch)\n",
301
+ " )\n",
302
  "\n",
303
  " # repeat the prompts to distribute over each device and tokenize\n",
304
  " processed_prompts = processed_prompts * jax.device_count()\n",
305
+ " tokenized_prompt = tokenizer(\n",
306
+ " processed_prompts,\n",
307
+ " return_tensors=\"jax\",\n",
308
+ " padding=\"max_length\",\n",
309
+ " truncation=True,\n",
310
+ " max_length=128,\n",
311
+ " ).data\n",
312
  " tokenized_prompt = shard(tokenized_prompt)\n",
313
  "\n",
314
  " # generate images\n",
315
  " images = []\n",
316
+ " pbar = tqdm(\n",
317
+ " range(num_images // jax.device_count()),\n",
318
+ " desc=\"Generating Images\",\n",
319
+ " leave=True,\n",
320
+ " )\n",
321
  " for i in pbar:\n",
322
  " key, subkey = jax.random.split(key)\n",
323
+ " encoded_images = p_generate(\n",
324
+ " tokenized_prompt, shard_prng_key(subkey), model_params\n",
325
+ " )\n",
326
  " encoded_images = encoded_images.sequences[..., 1:]\n",
327
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
328
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
329
+ " (-1, 256, 256, 3)\n",
330
+ " )\n",
331
  " for img in decoded_images:\n",
332
+ " images.append(\n",
333
+ " Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
334
+ " )\n",
335
  "\n",
336
+ " def add_clip_results(results, processor, p_clip, clip_params):\n",
337
+ " clip_inputs = processor(\n",
338
+ " text=batch,\n",
339
+ " images=images,\n",
340
+ " return_tensors=\"np\",\n",
341
+ " padding=\"max_length\",\n",
342
+ " max_length=77,\n",
343
+ " truncation=True,\n",
344
+ " ).data\n",
345
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
346
+ " images_per_prompt_indices = np.asarray(\n",
347
+ " range(0, len(images), batch_size)\n",
348
+ " )\n",
349
+ " clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
350
+ " list(\n",
351
+ " clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
352
+ " for i in range(batch_size)\n",
353
+ " )\n",
354
+ " )\n",
355
  " clip_inputs = shard(clip_inputs)\n",
356
  " logits = p_clip(clip_inputs, clip_params)\n",
357
  " logits = logits.reshape(-1, num_images)\n",
358
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
359
  " logits = jax.device_get(logits)\n",
360
  " # add to results table\n",
361
+ " for i, (idx, scores, sample) in enumerate(\n",
362
+ " zip(top_scores, logits, batch)\n",
363
+ " ):\n",
364
+ " if sample == padding_item:\n",
365
+ " continue\n",
366
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
367
+ " top_images = [\n",
368
+ " wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
369
+ " for x in idx\n",
370
+ " ]\n",
371
  " results.append([sample] + top_images)\n",
372
+ "\n",
373
  " # get clip scores\n",
374
+ " pbar.set_description(\"Calculating CLIP 16 scores\")\n",
375
  " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
376
+ "\n",
377
  " # get clip 32 scores\n",
378
  " if add_clip_32:\n",
379
+ " pbar.set_description(\"Calculating CLIP 32 scores\")\n",
380
  " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
381
  "\n",
382
  " pbar.close()\n",
383
  "\n",
 
 
384
  " # log results\n",
385
  " table = wandb.Table(columns=columns, data=results16)\n",
386
+ " run.log({\"Samples\": table, \"version\": version})\n",
387
  " wandb.finish()\n",
388
+ "\n",
389
+ " if add_clip_32:\n",
390
+ " run = wandb.init(\n",
391
+ " job_type=\"inference\",\n",
392
+ " entity=\"dalle-mini\",\n",
393
+ " project=\"dalle-mini\",\n",
394
+ " config=training_config,\n",
395
+ " id=f\"{run_id}-clip32{suffix}\",\n",
396
+ " resume=\"allow\",\n",
397
+ " )\n",
398
  " table = wandb.Table(columns=columns, data=results32)\n",
399
+ " run.log({\"Samples\": table, \"version\": version})\n",
400
  " wandb.finish()\n",
401
  " run = None # ensure we don't log on this run"
402
  ]
{dev/seq2seq → tools/train}/sweep.yaml RENAMED
File without changes
dev/seq2seq/run_seq2seq_flax.py → tools/train/train.py RENAMED
@@ -18,37 +18,31 @@ Fine-tuning the library models for seq2seq, text to image.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
- import os
22
  import logging
 
23
  import sys
24
  import time
25
- from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Callable, Optional
28
- import json
29
 
30
  import datasets
31
- from datasets import Dataset
32
- from tqdm import tqdm
33
- from dataclasses import asdict
34
-
35
  import jax
36
  import jax.numpy as jnp
37
  import optax
38
  import transformers
 
 
39
  from flax import jax_utils, traverse_util
40
- from flax.serialization import from_bytes, to_bytes
41
  from flax.jax_utils import unreplicate
 
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
44
- from transformers import (
45
- AutoTokenizer,
46
- HfArgumentParser,
47
- )
48
  from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
50
- import wandb
51
-
52
  from dalle_mini.data import Dataset
53
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
@@ -797,7 +791,7 @@ def main():
797
 
798
  # init variables
799
  last_time = time.perf_counter()
800
- train_metric = None
801
 
802
  for epoch in epochs:
803
  state.replace(epoch=jax_utils.replicate(epoch))
@@ -821,20 +815,20 @@ def main():
821
  last_time = new_time
822
 
823
  # train step
824
- state, train_metric = p_train_step(
825
  state, batch, jax_utils.replicate(delta_time)
826
  )
827
  step = unreplicate(state.step)
828
 
829
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
830
  # log metrics
831
- wandb_log(unreplicate(train_metric), step=step, prefix="train")
832
  # log state parameters
833
  state_dict = {
834
  k.split("_")[-1]: unreplicate(getattr(state, k))
835
  for k in ["epoch", "train_time", "train_samples"]
836
  }
837
- wandb_log(state_dict, step=step, prefix="train")
838
 
839
  eval_metrics = None
840
  if training_args.eval_steps and step % training_args.eval_steps == 0:
@@ -844,12 +838,12 @@ def main():
844
  run_save_model(state, eval_metrics)
845
 
846
  # log final train metrics
847
- if train_metric is not None:
848
- train_metric = unreplicate(train_metric)
849
- wandb_log(train_metric, step=step, prefix="train")
850
 
851
  epochs.write(
852
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
853
  )
854
 
855
  # Final evaluation
 
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
+ import json
22
  import logging
23
+ import os
24
  import sys
25
  import time
26
+ from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
  from typing import Callable, Optional
 
29
 
30
  import datasets
 
 
 
 
31
  import jax
32
  import jax.numpy as jnp
33
  import optax
34
  import transformers
35
+ import wandb
36
+ from datasets import Dataset
37
  from flax import jax_utils, traverse_util
 
38
  from flax.jax_utils import unreplicate
39
+ from flax.serialization import from_bytes, to_bytes
40
  from flax.training import train_state
41
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
+ from tqdm import tqdm
43
+ from transformers import AutoTokenizer, HfArgumentParser
 
 
44
  from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
 
 
46
  from dalle_mini.data import Dataset
47
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
48
 
 
791
 
792
  # init variables
793
  last_time = time.perf_counter()
794
+ train_metrics = None
795
 
796
  for epoch in epochs:
797
  state.replace(epoch=jax_utils.replicate(epoch))
 
815
  last_time = new_time
816
 
817
  # train step
818
+ state, train_metrics = p_train_step(
819
  state, batch, jax_utils.replicate(delta_time)
820
  )
821
  step = unreplicate(state.step)
822
 
823
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
824
  # log metrics
825
+ metrics = unreplicate(train_metrics)
826
  # log state parameters
827
  state_dict = {
828
  k.split("_")[-1]: unreplicate(getattr(state, k))
829
  for k in ["epoch", "train_time", "train_samples"]
830
  }
831
+ wandb_log({**metrics, **state_dict}, step=step, prefix="train")
832
 
833
  eval_metrics = None
834
  if training_args.eval_steps and step % training_args.eval_steps == 0:
 
838
  run_save_model(state, eval_metrics)
839
 
840
  # log final train metrics
841
+ if train_metrics is not None:
842
+ train_metrics = unreplicate(train_metrics)
843
+ wandb_log(train_metrics, step=step, prefix="train")
844
 
845
  epochs.write(
846
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
847
  )
848
 
849
  # Final evaluation