|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Simple data input from .jsonl files."""
|
|
|
|
import hashlib
|
|
import json
|
|
from multiprocessing.pool import ThreadPool
|
|
import os
|
|
import tempfile
|
|
import urllib.request
|
|
|
|
from absl import logging
|
|
import big_vision.datasets.core as ds_core
|
|
import jax
|
|
import numpy as np
|
|
import overrides
|
|
import tensorflow as tf
|
|
|
|
|
|
def cached_download(url, dest=None, verbose=True):
|
|
"""Download `url` to local file and return path to that, but with caching."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dest = dest or os.path.join(tempfile.gettempdir(), "bv")
|
|
os.makedirs(dest, exist_ok=True)
|
|
dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
|
|
|
|
|
|
if os.path.isfile(dest):
|
|
return dest
|
|
|
|
if verbose:
|
|
print(f"\rRetrieving {url} into {dest}", end="", flush=True)
|
|
|
|
with urllib.request.urlopen(url) as f:
|
|
data = f.read()
|
|
with open(dest, "wb+") as f:
|
|
f.write(data)
|
|
return dest
|
|
|
|
|
|
class DataSource(ds_core.DataSource):
|
|
""".jsonl DataSource."""
|
|
|
|
def __init__(self, fname, *, fopen_keys=(), download_keys=(),
|
|
start=0, stop=float("inf")):
|
|
"""Create data-source that's jsonl + data files (eg images).
|
|
|
|
This correctly supports multi-host in that each host only reads a subset of
|
|
the dataset automatically. However, currently, all hosts download all items
|
|
if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
|
|
|
|
Args:
|
|
fname: str, the path to the jsonl file that holds the dataset.
|
|
fopen_keys: collection of str or dict, the keys in the dataset whose
|
|
string value actually is a file-path that should be opened and read,
|
|
and its content is what goes into the batch (eg image filenames
|
|
commonly ["image"]).
|
|
If a dict, the values are folders prefixed to the filenames.
|
|
Supports gs:// for reading from buckets.
|
|
download_keys: collection of str, the keys in the dataset whose string
|
|
value actually is a URL from which the file should be downloaded first.
|
|
files are downloaded to a persistent tmp folder using the URL hash as
|
|
filename. If the file already exists, the download is skipped.
|
|
Must be a subset of `fopen_keys`.
|
|
start: int, index of the first row to use; use for slicing the data.
|
|
stop: int or inf, index of the row after the last one to use.
|
|
|
|
Note:
|
|
This simple data input does not allow for nested/hierarchical values,
|
|
or in any way more complicated values like vectors. Use TFDS for that.
|
|
|
|
The way start/stop arguments are used is as in list slicing[start:stop].
|
|
"""
|
|
self.examples = []
|
|
|
|
with tf.io.gfile.GFile(fname) as f:
|
|
for i, line in enumerate(f):
|
|
if (start or 0) <= i < (stop or float("inf")):
|
|
try:
|
|
self.examples.append(json.loads(line))
|
|
except json.decoder.JSONDecodeError as e:
|
|
raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
|
|
|
|
if download_keys:
|
|
for k in download_keys:
|
|
assert k in fopen_keys, (
|
|
f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
|
|
|
|
|
|
logging.info(
|
|
f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
|
|
f"for dataset {fname} ({len(self.examples)} examples) ...")
|
|
|
|
def _dl_one(ex):
|
|
for k in download_keys:
|
|
ex[k] = cached_download(ex[k])
|
|
|
|
ThreadPool(100).map(_dl_one, self.examples)
|
|
print("Done")
|
|
logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
|
|
|
|
|
|
if isinstance(fopen_keys, (list, tuple)):
|
|
self.fopen_keys = {k: "" for k in fopen_keys}
|
|
else:
|
|
self.fopen_keys = fopen_keys or {}
|
|
|
|
|
|
|
|
for ex in self.examples:
|
|
for k, dirname in self.fopen_keys.items():
|
|
ex[k] = os.path.join(dirname, ex[k])
|
|
|
|
def _indices(self, *, process_split=True, process_index=None):
|
|
indices = np.arange(len(self.examples))
|
|
|
|
if not process_split:
|
|
return list(indices)
|
|
|
|
pid = jax.process_index() if process_index is None else process_index
|
|
return list(np.array_split(indices, jax.process_count())[pid])
|
|
|
|
@overrides.overrides
|
|
def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
|
|
del allow_cache
|
|
assert not process_split or len(self.examples) >= jax.process_count(), (
|
|
"Process splitting the data with fewer examples than processes!?")
|
|
|
|
my_idxs = self._indices(process_split=process_split)
|
|
if not ordered:
|
|
np.random.shuffle(my_idxs)
|
|
|
|
dataset = tf.data.Dataset.from_generator(
|
|
generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
|
|
output_signature={
|
|
"id": _guess_signature("0"),
|
|
**{k: _guess_signature(v) for k, v in self.examples[0].items()},
|
|
})
|
|
|
|
def _read_files(example):
|
|
for k in self.fopen_keys:
|
|
example[k] = tf.io.read_file(example[k])
|
|
return example
|
|
dataset = dataset.map(_read_files)
|
|
|
|
return dataset
|
|
|
|
@property
|
|
@overrides.overrides
|
|
def total_examples(self):
|
|
return len(self.examples)
|
|
|
|
@overrides.overrides
|
|
def num_examples_per_process(self):
|
|
return [len(self._indices(process_index=pid))
|
|
for pid in range(jax.process_count())]
|
|
|
|
|
|
def _guess_signature(value):
|
|
return tf.TensorSpec.from_tensor(tf.constant(value))
|
|
|