Spaces:
Runtime error
Runtime error
import logging | |
import random | |
import webdataset as wds | |
from webdataset.tariterators import group_by_keys, tar_file_expander, url_opener | |
from m4.training.types import DatasetTypes | |
meta_prefix = "__" | |
meta_suffix = "__" | |
logger = logging.getLogger(__name__) | |
trace = False | |
def webdoc_valid_sample(sample): | |
"""Check whether a sample is valid. | |
:param sample: sample to be checked | |
""" | |
return ( | |
sample is not None | |
and isinstance(sample, dict) | |
and len(list(sample.keys())) > 0 | |
and not sample.get("__bad__", False) | |
and sample_has_all_files(sample) | |
) | |
def sample_has_all_files(current_sample): | |
meta = current_sample.get("metadata.value", None) | |
if meta is None: | |
return False | |
meta = meta.decode("utf-8") | |
if len(meta) == 0: | |
return False | |
target_file_list = meta.split("\n") | |
fname_keys = [key for key in current_sample.keys() if key.endswith(".fname")] | |
fnames = [current_sample[key] for key in fname_keys] | |
check = all([fname in fnames for fname in target_file_list]) | |
if not check: | |
return False | |
return True | |
class ImageDecoder: | |
def __call__(self, bytes_): | |
import io | |
import PIL.Image | |
img = PIL.Image.open(io.BytesIO(bytes_)) | |
img.load() | |
return img | |
# Taken from https://github.com/mlfoundations/open_clip/blob/c48111dacac55db24878af229d8a5662c03e6f1c/src/training/data.py#L180-L183 | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
return True | |
# Adapt group_by_keys to our webdocument format in which each samples contains several text and image files | |
# https://github.com/webdataset/webdataset/blob/039d74319ae55e5696dcef89829be9671802cf70/webdataset/tariterators.py#L195-L250 | |
def group_by_keys_interleaved(data, handler=log_and_continue): | |
"""Return function over iterator that groups key, value pairs into samples.""" | |
current_sample = None | |
for filesample in data: | |
try: | |
assert isinstance(filesample, dict) | |
fname, value = filesample["fname"], filesample["data"] | |
fname = fname.strip("./") | |
if fname.endswith(".metadata.txt"): | |
prefix, data_type, extension = fname.split(".") | |
suffix = data_type | |
else: | |
prefix, idx, data_type, extension = fname.split(".") | |
if data_type not in ["text", "image"]: | |
raise ValueError(f"{fname}: unknown data type {data_type}") | |
suffix = idx | |
if trace: | |
print( | |
f"prefix: {prefix}, idx: {idx}, data_type: {data_type}, extension: {extension}, keys:" | |
f" {current_sample.keys() if isinstance(current_sample, dict) else None}" | |
) | |
if prefix is None: | |
continue | |
if current_sample is None or prefix != current_sample["__key__"]: | |
valid = webdoc_valid_sample(current_sample) | |
if valid: | |
yield current_sample | |
elif current_sample is not None: | |
logging.warning(f"{fname}: invalid sample {current_sample} ignored") | |
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
if suffix in current_sample: | |
raise ValueError(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") | |
current_sample[f"{suffix}.value"] = value | |
current_sample[f"{suffix}.type"] = data_type | |
current_sample[f"{suffix}.fname"] = fname | |
except Exception as exn: | |
exn.args = exn.args + (filesample.get("stream"), filesample.get("url")) | |
if handler(exn): | |
continue | |
else: | |
break | |
if webdoc_valid_sample(current_sample): | |
yield current_sample | |
def _tarfile_to_webdocument_samples(src, handler=log_and_continue): | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys_interleaved(files, handler=handler) | |
return samples | |
tarfile_to_webdocument_samples = wds.filters.pipelinefilter(_tarfile_to_webdocument_samples) | |
def _collate_texts_and_images_webdocument(data, handler=log_and_continue): | |
for sample in data: | |
try: | |
max_example_indices = max( | |
[int(key.split(".")[0]) for key in sample.keys() if key.endswith(".value") and key != "metadata.value"] | |
) | |
texts = [None for _ in range(max_example_indices + 1)] | |
images = [None for _ in range(max_example_indices + 1)] | |
for idx in range(max_example_indices + 1): | |
if f"{idx}.value" not in sample: | |
continue | |
if "text" in sample[f"{idx}.type"]: | |
texts[idx] = sample[f"{idx}.value"] | |
elif "image" in sample[f"{idx}.type"]: | |
images[idx] = sample[f"{idx}.value"] | |
else: | |
raise ValueError(f"Unknown data type: {sample[f'{idx}.type']}") | |
example = {"__key__": sample["__key__"], "__url__": sample["__url__"], "texts": texts, "images": images} | |
yield example | |
except Exception as exn: | |
exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
if handler(exn): | |
continue | |
else: | |
break | |
collate_texts_and_images_webdocument = wds.filters.pipelinefilter(_collate_texts_and_images_webdocument) | |
def _decode_image_and_text_webdocument(data, handler=log_and_continue): | |
image_decoder = ImageDecoder() | |
for sample in data: | |
try: | |
sample["images"] = [image_decoder(image) if image is not None else None for image in sample["images"]] | |
sample["texts"] = [text.decode("utf-8") if text is not None else None for text in sample["texts"]] | |
yield sample | |
except Exception as exn: | |
exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
if handler(exn): | |
continue | |
else: | |
break | |
decode_image_and_text_webdocument = wds.filters.pipelinefilter(_decode_image_and_text_webdocument) | |
def collate_dicts(samples): | |
keys = samples[0].keys() | |
batched_samples = {key: [sample[key] for sample in samples] for key in keys} | |
return batched_samples | |
def get_webdocuments_webdataset( | |
urls, | |
batch_size, | |
shuffle_initial_urls_list=False, | |
shuffle_before_split_by_node_buffer_size=100, | |
shuffle_before_split_by_worker_buffer_size=100, | |
shuffle_after_tarfile_to_samples_buffer_size=100, | |
shuffle_after_batching_buffer_size=1000, | |
): | |
if shuffle_initial_urls_list: | |
random.shuffle(urls) | |
pipeline_list = [wds.SimpleShardList(urls)] | |
if shuffle_before_split_by_node_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) | |
pipeline_list.append(wds.split_by_node) | |
if shuffle_before_split_by_worker_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) | |
pipeline_list.extend( | |
[ | |
wds.split_by_worker, | |
tarfile_to_webdocument_samples(), | |
] | |
) | |
if shuffle_after_tarfile_to_samples_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) | |
pipeline_list.extend( | |
[ | |
collate_texts_and_images_webdocument(), | |
decode_image_and_text_webdocument(), | |
wds.batched(batch_size, collation_fn=collate_dicts, partial=True), | |
] | |
) | |
if shuffle_after_batching_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) | |
dataset = wds.DataPipeline(pipeline_list) | |
return dataset | |
def split_keep_2(x): | |
x = x.strip("./") | |
x_splitter = x.split(".") | |
return x_splitter[0], x_splitter[1] | |
def _tarfile_to_pair_samples(src, handler=log_and_continue): | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys(files, keys=split_keep_2, handler=handler) | |
return samples | |
tarfile_to_pair_samples = wds.filters.pipelinefilter(_tarfile_to_pair_samples) | |
def _decode_image_and_text_pairs(data, handler=log_and_continue): | |
image_decoder = ImageDecoder() | |
for sample in data: | |
try: | |
sample["image"] = image_decoder(sample["image"]) | |
sample["text"] = sample["text"].decode("utf-8") | |
yield sample | |
except Exception as exn: | |
exn.args = exn.args + (sample.get("stream"), sample.get("url")) | |
if handler(exn): | |
continue | |
else: | |
break | |
decode_image_and_text_pairs = wds.filters.pipelinefilter(_decode_image_and_text_pairs) | |
def get_image_caption_pairs_webdataset( | |
urls, | |
batch_size, | |
shuffle_initial_urls_list=False, | |
shuffle_before_split_by_node_buffer_size=100, | |
shuffle_before_split_by_worker_buffer_size=100, | |
shuffle_after_tarfile_to_samples_buffer_size=100, | |
shuffle_after_batching_buffer_size=1000, | |
): | |
if shuffle_initial_urls_list: | |
random.shuffle(urls) | |
pipeline_list = [wds.SimpleShardList(urls)] | |
if shuffle_before_split_by_node_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_before_split_by_node_buffer_size)) | |
pipeline_list.append(wds.split_by_node) | |
if shuffle_before_split_by_worker_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_before_split_by_worker_buffer_size)) | |
pipeline_list.extend( | |
[ | |
wds.split_by_worker, | |
tarfile_to_pair_samples(handler=log_and_continue), | |
] | |
) | |
if shuffle_after_tarfile_to_samples_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_after_tarfile_to_samples_buffer_size)) | |
pipeline_list.extend( | |
[ | |
decode_image_and_text_pairs(), | |
wds.batched(batch_size, collation_fn=collate_dicts, partial=True), # todo: check if partial is needed | |
] | |
) | |
if shuffle_after_batching_buffer_size is not None: | |
pipeline_list.append(wds.shuffle(shuffle_after_batching_buffer_size)) | |
dataset = wds.DataPipeline(pipeline_list) | |
return dataset | |
def get_webdataset( | |
urls, | |
ds_type: DatasetTypes, | |
batch_size: int, | |
shuffle_initial_urls_list, | |
shuffle_before_split_by_node_buffer_size, | |
shuffle_before_split_by_worker_buffer_size, | |
shuffle_after_tarfile_to_samples_buffer_size, | |
shuffle_after_batching_buffer_size, | |
): | |
if ds_type == DatasetTypes.WEB_DOCUMENTS: | |
return get_webdocuments_webdataset( | |
urls, | |
batch_size, | |
shuffle_initial_urls_list, | |
shuffle_before_split_by_node_buffer_size, | |
shuffle_before_split_by_worker_buffer_size, | |
shuffle_after_tarfile_to_samples_buffer_size, | |
shuffle_after_batching_buffer_size, | |
) | |
elif ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: | |
return get_image_caption_pairs_webdataset( | |
urls, | |
batch_size, | |
shuffle_initial_urls_list, | |
shuffle_before_split_by_node_buffer_size, | |
shuffle_before_split_by_worker_buffer_size, | |
shuffle_after_tarfile_to_samples_buffer_size, | |
shuffle_after_batching_buffer_size, | |
) | |
else: | |
raise ValueError(f"Unknown dataset type: {ds_type}") | |
def check_webdataset_command(command): | |
if "s3:/" not in command: | |
return True | |
command = command.strip() | |
if not command.startswith("pipe:bash"): | |
return False | |
if not command.endswith(".tar"): | |
return False | |
if "get_file.sh" not in command: | |
return False | |
return True | |