import logging import random import webdataset as wds from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener from m4.training.types import DatasetTypes meta_prefix = "__" meta_suffix = "__" logger = logging.getLogger(__name__) trace = False def webdoc_valid_sample(sample): """Check whether a sample is valid. :param sample: sample to be checked """ return ( sample is not None and isinstance(sample, dict) and len(list(sample.keys())) > 0 and not sample.get("__bad__", False) and sample_has_all_files(sample) ) def sample_has_all_files(current_sample): meta = current_sample.get("metadata.value", None) if meta is None: return False meta = meta.decode("utf-8") if len(meta) == 0: return False target_file_list = meta.split("\n") fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")] fnames = [current_sample[key] for key in fname_keys] check = all([fname in fnames for fname in target_file_list]) if not check: return False return True class ImageDecoder: def __call__(self, bytes_): import io import PIL.Image img = PIL.Image.open(io.BytesIO(bytes_)) img.load() return img # Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183 def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True # Adapt group_by_keys to our webdocument format in which each samples contains several text and image files # https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250 def group_by_keys_interleaved(data, handler=log_and_continue): """Return function over iterator that groups key, value pairs into samples.""" current_sample = None for filesample in data: try: assert isinstance(filesample, dict) fname, value = filesample["fname"], filesample["data"] fname = fname.strip("./") if fname.endswith(".metadata.txt"): prefix, data_type, extension = fname.split(".") suffix = data_type else: prefix, idx, data_type, extension = fname.split(".") if data_type not in ["text", "image"]: raise ValueError(f"{fname}: unknown data type {data_type}") suffix = idx if trace: print( f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:" f" {current_sample.keys() if isinstance(current_sample, dict) else None}" ) if prefix is None: continue if current_sample is None or prefix != current_sample["__key__"]: valid = webdoc_valid_sample(current_sample) if valid: yield current_sample elif current_sample is not None: logging.warning(f"{fname}: invalid sample {current_sample} ignored") current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) if suffix in current_sample: raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") current_sample[f"{suffix}.value"] = value current_sample[f"{suffix}.type"] = data_type current_sample[f"{suffix}.fname"] = fname except Exception as exn: exn.args = exn.args + (filesample.get("stream"), filesample.get("url")) if handler(exn): continue else: break if webdoc_valid_sample(current_sample): yield current_sample def _tarfile_to_webdocument_samples(src, handler=log_and_continue): streams = url_opener(src, handler=handler) files = tar_file_expander(streams, handler=handler) samples = group_by_keys_interleaved(files, handler=handler) return samples tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples) def _collate_texts_and_images_webdocument(data, handler=log_and_continue): for sample in data: try: max_example_indices = max( [int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"] ) texts = [None for _ in range(max_example_indices + 1)] images = [None for _ in range(max_example_indices + 1)] for idx in range(max_example_indices + 1): if f"{idx}.value" not in sample: continue if "text" in sample[f"{idx}.type"]: texts[idx] = sample[f"{idx}.value"] elif "image" in sample[f"{idx}.type"]: images[idx] = sample[f"{idx}.value"] else: raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}") example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images} yield example except Exception as exn: exn.args = exn.args + (sample.get("stream"), sample.get("url")) if handler(exn): continue else: break collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument) def _decode_image_and_text_webdocument(data, handler=log_and_continue): image_decoder = ImageDecoder() for sample in data: try: sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]] sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]] yield sample except Exception as exn: exn.args = exn.args + (sample.get("stream"), sample.get("url")) if handler(exn): continue else: break decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument) def collate_dicts(samples): keys = samples[0].keys() batched_samples = {key: [sample[key] for sample in samples] for key in keys} return batched_samples def get_webdocuments_webdataset( urls, batch_size, shuffle_initial_urls_list=False, shuffle_before_split_by_node_buffer_size=100, shuffle_before_split_by_worker_buffer_size=100, shuffle_after_tarfile_to_samples_buffer_size=100, shuffle_after_batching_buffer_size=1000, ): if shuffle_initial_urls_list: random.shuffle(urls) pipeline_list = [wds.SimpleShardList(urls)] if shuffle_before_split_by_node_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) pipeline_list.append(wds.split_by_node) if shuffle_before_split_by_worker_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) pipeline_list.extend( [ wds.split_by_worker, tarfile_to_webdocument_samples(), ] ) if shuffle_after_tarfile_to_samples_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) pipeline_list.extend( [ collate_texts_and_images_webdocument(), decode_image_and_text_webdocument(), wds.batched(batch_size, collation_fn=collate_dicts, partial=True), ] ) if shuffle_after_batching_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) dataset = wds.DataPipeline(pipeline_list) return dataset def split_keep_2(x): x = x.strip("./") x_splitter = x.split(".") return x_splitter[0], x_splitter[1] def _tarfile_to_pair_samples(src, handler=log_and_continue): streams = url_opener(src, handler=handler) files = tar_file_expander(streams, handler=handler) samples = group_by_keys(files, keys=split_keep_2, handler=handler) return samples tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples) def _decode_image_and_text_pairs(data, handler=log_and_continue): image_decoder = ImageDecoder() for sample in data: try: sample["image"] = image_decoder(sample["image"]) sample["text"] = sample["text"].decode("utf-8") yield sample except Exception as exn: exn.args = exn.args + (sample.get("stream"), sample.get("url")) if handler(exn): continue else: break decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs) def get_image_caption_pairs_webdataset( urls, batch_size, shuffle_initial_urls_list=False, shuffle_before_split_by_node_buffer_size=100, shuffle_before_split_by_worker_buffer_size=100, shuffle_after_tarfile_to_samples_buffer_size=100, shuffle_after_batching_buffer_size=1000, ): if shuffle_initial_urls_list: random.shuffle(urls) pipeline_list = [wds.SimpleShardList(urls)] if shuffle_before_split_by_node_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) pipeline_list.append(wds.split_by_node) if shuffle_before_split_by_worker_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) pipeline_list.extend( [ wds.split_by_worker, tarfile_to_pair_samples(handler=log_and_continue), ] ) if shuffle_after_tarfile_to_samples_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) pipeline_list.extend( [ decode_image_and_text_pairs(), wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed ] ) if shuffle_after_batching_buffer_size is not None: pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) dataset = wds.DataPipeline(pipeline_list) return dataset def get_webdataset( urls, ds_type: DatasetTypes, batch_size: int, shuffle_initial_urls_list, shuffle_before_split_by_node_buffer_size, shuffle_before_split_by_worker_buffer_size, shuffle_after_tarfile_to_samples_buffer_size, shuffle_after_batching_buffer_size, ): if ds_type == DatasetTypes.WEB_DOCUMENTS: return get_webdocuments_webdataset( urls, batch_size, shuffle_initial_urls_list, shuffle_before_split_by_node_buffer_size, shuffle_before_split_by_worker_buffer_size, shuffle_after_tarfile_to_samples_buffer_size, shuffle_after_batching_buffer_size, ) elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: return get_image_caption_pairs_webdataset( urls, batch_size, shuffle_initial_urls_list, shuffle_before_split_by_node_buffer_size, shuffle_before_split_by_worker_buffer_size, shuffle_after_tarfile_to_samples_buffer_size, shuffle_after_batching_buffer_size, ) else: raise ValueError(f"Unknown dataset type: {ds_type}") def check_webdataset_command(command): if "s3:/" not in command: return True command = command.strip() if not command.startswith("pipe:bash"): return False if not command.endswith(".tar"): return False if "get_file.sh" not in command: return False return True