idefics_playground / m4 /training /dataset_utils.py
VictorSanh's picture
Update visualization
217780a
raw
history blame
12 kB
import logging
import random
import webdataset as wds
from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener
from m4.training.types import DatasetTypes
meta_prefix = "__"
meta_suffix = "__"
logger = logging.getLogger(__name__)
trace = False
def webdoc_valid_sample(sample):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return (
sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
and sample_has_all_files(sample)
)
def sample_has_all_files(current_sample):
meta = current_sample.get("metadata.value", None)
if meta is None:
return False
meta = meta.decode("utf-8")
if len(meta) == 0:
return False
target_file_list = meta.split("\n")
fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")]
fnames = [current_sample[key] for key in fname_keys]
check = all([fname in fnames for fname in target_file_list])
if not check:
return False
return True
class ImageDecoder:
def __call__(self, bytes_):
import io
import PIL.Image
img = PIL.Image.open(io.BytesIO(bytes_))
img.load()
return img
# Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True
# Adapt group_by_keys to our webdocument format in which each samples contains several text and image files
# https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250
def group_by_keys_interleaved(data, handler=log_and_continue):
"""Return function over iterator that groups key, value pairs into samples."""
current_sample = None
for filesample in data:
try:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
fname = fname.strip("./")
if fname.endswith(".metadata.txt"):
prefix, data_type, extension = fname.split(".")
suffix = data_type
else:
prefix, idx, data_type, extension = fname.split(".")
if data_type not in ["text", "image"]:
raise ValueError(f"{fname}: unknown data type {data_type}")
suffix = idx
if trace:
print(
f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:"
f" {current_sample.keys() if isinstance(current_sample, dict) else None}"
)
if prefix is None:
continue
if current_sample is None or prefix != current_sample["__key__"]:
valid = webdoc_valid_sample(current_sample)
if valid:
yield current_sample
elif current_sample is not None:
logging.warning(f"{fname}: invalid sample {current_sample} ignored")
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffix in current_sample:
raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
current_sample[f"{suffix}.value"] = value
current_sample[f"{suffix}.type"] = data_type
current_sample[f"{suffix}.fname"] = fname
except Exception as exn:
exn.args = exn.args + (filesample.get("stream"), filesample.get("url"))
if handler(exn):
continue
else:
break
if webdoc_valid_sample(current_sample):
yield current_sample
def _tarfile_to_webdocument_samples(src, handler=log_and_continue):
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_interleaved(files, handler=handler)
return samples
tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples)
def _collate_texts_and_images_webdocument(data, handler=log_and_continue):
for sample in data:
try:
max_example_indices = max(
[int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"]
)
texts = [None for _ in range(max_example_indices + 1)]
images = [None for _ in range(max_example_indices + 1)]
for idx in range(max_example_indices + 1):
if f"{idx}.value" not in sample:
continue
if "text" in sample[f"{idx}.type"]:
texts[idx] = sample[f"{idx}.value"]
elif "image" in sample[f"{idx}.type"]:
images[idx] = sample[f"{idx}.value"]
else:
raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}")
example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images}
yield example
except Exception as exn:
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
if handler(exn):
continue
else:
break
collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument)
def _decode_image_and_text_webdocument(data, handler=log_and_continue):
image_decoder = ImageDecoder()
for sample in data:
try:
sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]]
sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]]
yield sample
except Exception as exn:
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
if handler(exn):
continue
else:
break
decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument)
def collate_dicts(samples):
keys = samples[0].keys()
batched_samples = {key: [sample[key] for sample in samples] for key in keys}
return batched_samples
def get_webdocuments_webdataset(
urls,
batch_size,
shuffle_initial_urls_list=False,
shuffle_before_split_by_node_buffer_size=100,
shuffle_before_split_by_worker_buffer_size=100,
shuffle_after_tarfile_to_samples_buffer_size=100,
shuffle_after_batching_buffer_size=1000,
):
if shuffle_initial_urls_list:
random.shuffle(urls)
pipeline_list = [wds.SimpleShardList(urls)]
if shuffle_before_split_by_node_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
pipeline_list.append(wds.split_by_node)
if shuffle_before_split_by_worker_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
pipeline_list.extend(
[
wds.split_by_worker,
tarfile_to_webdocument_samples(),
]
)
if shuffle_after_tarfile_to_samples_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
pipeline_list.extend(
[
collate_texts_and_images_webdocument(),
decode_image_and_text_webdocument(),
wds.batched(batch_size, collation_fn=collate_dicts, partial=True),
]
)
if shuffle_after_batching_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
dataset = wds.DataPipeline(pipeline_list)
return dataset
def split_keep_2(x):
x = x.strip("./")
x_splitter = x.split(".")
return x_splitter[0], x_splitter[1]
def _tarfile_to_pair_samples(src, handler=log_and_continue):
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys(files, keys=split_keep_2, handler=handler)
return samples
tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples)
def _decode_image_and_text_pairs(data, handler=log_and_continue):
image_decoder = ImageDecoder()
for sample in data:
try:
sample["image"] = image_decoder(sample["image"])
sample["text"] = sample["text"].decode("utf-8")
yield sample
except Exception as exn:
exn.args = exn.args + (sample.get("stream"), sample.get("url"))
if handler(exn):
continue
else:
break
decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs)
def get_image_caption_pairs_webdataset(
urls,
batch_size,
shuffle_initial_urls_list=False,
shuffle_before_split_by_node_buffer_size=100,
shuffle_before_split_by_worker_buffer_size=100,
shuffle_after_tarfile_to_samples_buffer_size=100,
shuffle_after_batching_buffer_size=1000,
):
if shuffle_initial_urls_list:
random.shuffle(urls)
pipeline_list = [wds.SimpleShardList(urls)]
if shuffle_before_split_by_node_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size))
pipeline_list.append(wds.split_by_node)
if shuffle_before_split_by_worker_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size))
pipeline_list.extend(
[
wds.split_by_worker,
tarfile_to_pair_samples(handler=log_and_continue),
]
)
if shuffle_after_tarfile_to_samples_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size))
pipeline_list.extend(
[
decode_image_and_text_pairs(),
wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed
]
)
if shuffle_after_batching_buffer_size is not None:
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size))
dataset = wds.DataPipeline(pipeline_list)
return dataset
def get_webdataset(
urls,
ds_type: DatasetTypes,
batch_size: int,
shuffle_initial_urls_list,
shuffle_before_split_by_node_buffer_size,
shuffle_before_split_by_worker_buffer_size,
shuffle_after_tarfile_to_samples_buffer_size,
shuffle_after_batching_buffer_size,
):
if ds_type == DatasetTypes.WEB_DOCUMENTS:
return get_webdocuments_webdataset(
urls,
batch_size,
shuffle_initial_urls_list,
shuffle_before_split_by_node_buffer_size,
shuffle_before_split_by_worker_buffer_size,
shuffle_after_tarfile_to_samples_buffer_size,
shuffle_after_batching_buffer_size,
)
elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS:
return get_image_caption_pairs_webdataset(
urls,
batch_size,
shuffle_initial_urls_list,
shuffle_before_split_by_node_buffer_size,
shuffle_before_split_by_worker_buffer_size,
shuffle_after_tarfile_to_samples_buffer_size,
shuffle_after_batching_buffer_size,
)
else:
raise ValueError(f"Unknown dataset type: {ds_type}")
def check_webdataset_command(command):
if "s3:/" not in command:
return True
command = command.strip()
if not command.startswith("pipe:bash"):
return False
if not command.endswith(".tar"):
return False
if "get_file.sh" not in command:
return False
return True