import sys import io import os import re import json import tarfile from functools import partial import webdataset as wds from webdataset import ResampledShards, DataPipeline, tarfile_to_samples from webdataset.filters import pipelinefilter from webdataset.tariterators import url_opener, group_by_keys from webdataset.handlers import reraise_exception from webdataset.gopen import gopen_schemes, gopen def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress """Return node and worker info for PyTorch and some distributed environments.""" rank = 0 world_size = 1 worker = 0 num_workers = 1 try: import torch.distributed if torch.distributed.is_available() and torch.distributed.is_initialized(): group = group or torch.distributed.group.WORLD rank = torch.distributed.get_rank(group=group) world_size = torch.distributed.get_world_size(group=group) except ModuleNotFoundError: pass try: import torch.utils.data worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers except ModuleNotFoundError: pass return rank, world_size, worker, num_workers def pytorch_worker_seed(group=None): """Compute a distinct, deterministic RNG seed for each worker and node.""" rank, world_size, worker, num_workers = pytorch_worker_info(group=group) return rank * 1000 + worker def worker_seed_sat(group=None, seed=0): return pytorch_worker_seed(group=group) + seed * 23 class ConfiguredResampledShards(ResampledShards): def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True): from sat.helpers import print_rank0 try: from megatron.core.parallel_state import get_data_parallel_group group = get_data_parallel_group() print_rank0("Using megatron data parallel group.") except: from sat.mpu import get_data_parallel_group try: group = get_data_parallel_group() print_rank0("Using sat data parallel group.") except AssertionError: group = None print_rank0("No data parallel group is specified!") worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed) super().__init__(urls, nshards, worker_seed_sat_this, deterministic) class SimpleDistributedWebDataset(DataPipeline): def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000): # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle try: from sat.mpu import get_model_parallel_world_size if get_model_parallel_world_size() > 1: shuffle_buffer = 1 except Exception: pass super().__init__( ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly tarfile_to_samples(), wds.shuffle(shuffle_buffer), process_fn, ) def tar_file_iterator_with_meta( fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None ): """Iterate over tar file, yielding filename, content pairs for the given tar stream. :param fileobj: byte stream suitable for tarfile :param meta_names: key of different items in meta file :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") """ stream = tarfile.open(fileobj=fileobj, mode="r|*") data_dir, filename = fileobj.name.rsplit("/", 1) meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}} if meta_stream is None: meta_file_name = filename.split(".")[0] + ".meta.jsonl" meta_path = os.path.join(data_dir, meta_file_name) if os.path.exists(meta_path): meta_stream = open(meta_path, "r") else: meta_file_name = meta_stream.name if meta_stream is not None: for lineno, line in enumerate(meta_stream): meta_list = [] try: meta_list.append(json.loads(line)) except Exception as exn: from sat.helpers import print_rank0 print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") continue for item in meta_list: if not item["key"] in meta_data: meta_data[item["key"]] = {} for meta_name in meta_names: if meta_name in item: meta_data[item["key"]][meta_name] = item[meta_name] meta_stream.close() try: for tarinfo in stream: fname = tarinfo.name try: if not tarinfo.isreg(): continue if fname is None: continue if "/" not in fname and fname.startswith("__") and fname.endswith("__"): # skipping metadata for now continue if skip_meta is not None and re.match(skip_meta, fname): continue if fname.endswith(".txt") and suffix is not None: data = (stream.extractfile(tarinfo).read().decode() + suffix).encode() else: data = stream.extractfile(tarinfo).read() result = dict(fname=fname, data=data) yield result if fname.endswith(".id"): fid = fname.split(".")[0] if "-$#%@&" in fid: sfid = fid.split("-$#%@&")[0] else: sfid = fid meta_data_fid = meta_data.get(sfid, {}) for meta_name in meta_names: meta_fname = fid + "." + meta_name meta = meta_data_fid.get(meta_name, None) yield dict(fname=meta_fname, data=meta) stream.members = [] except Exception as exn: if hasattr(exn, "args") and len(exn.args) > 0: exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] if handler(exn): continue else: break except Exception as exn: print(exn) del stream def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception): """Expand a stream of open tar files into a stream of tar file contents. This returns an iterator over (filename, file_contents). """ for source in data: url = source["url"] try: assert isinstance(source, dict) assert "stream" in source for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): assert isinstance(sample, dict) and "data" in sample and "fname" in sample sample["__url__"] = url yield sample except Exception as exn: exn.args = exn.args + (source.get("stream"), source.get("url")) if handler(exn): continue else: break def url_opener( data, handler, **kw, ): """Open URLs and yield a stream of url+stream pairs. Args: data: iterator over dict(url=...) handler: exception handler. kw: keyword arguments for gopen.gopen. Yields: a stream of url+stream pairs. """ for sample in data: assert isinstance(sample, dict), sample assert "url" in sample url = sample["url"] try: stream = gopen(url, **kw) if hasattr(stream, "meta_stream"): meta_stream = stream.meta_stream del stream.meta_stream else: meta_stream = None sample.update(stream=stream, meta_stream=meta_stream) yield sample except Exception as exn: exn.args = exn.args + (url,) if handler(exn): continue else: break def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception): streams = url_opener(src, handler=handler) files = tar_file_expander_with_meta(streams, meta_names, handler) samples = group_by_keys(files, handler=handler) return samples class MetaDistributedWebDataset(DataPipeline): """WebDataset with meta information files Extra Format: in webdataset (tar), for each sample there is a '.id'; for each tar file, there is a '.meta.jsonl' file with the same name; The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'. """ def __init__( self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None ): # os.environ['WDS_SHOW_SEED'] = '1' import torch if torch.distributed.get_rank() == 0: if include_dirs is not None: # /webdatasets/A,/webdatasets/C other_paths = [] include_dirs = include_dirs.split(",") for include_dir in include_dirs: if "*" in include_dir: include_dir, n = include_dir.split("*") n = int(n) else: n = 1 for cur_dir, dirs, files in os.walk(include_dir): for f in files: if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0: # other_paths.append(os.path.join(cur_dir,f)) other_paths.extend([os.path.join(cur_dir, f)] * n) # print(f'Adding dataset paths {",".join(other_paths)}') from braceexpand import braceexpand if len(path) > 0: # not "" path = list(braceexpand(path)) + other_paths else: path = other_paths path = [path] else: path = [ None, ] torch.distributed.broadcast_object_list(path, src=0) path = path[0] tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names) tarfile_to_samples = pipelinefilter(tarfile_samples) # if model parallel, shuffle_buffer should be 1 to disable shuffling try: from sat.mpu import get_model_parallel_world_size if get_model_parallel_world_size() > 1: shuffle_buffer = 1 except Exception: pass super().__init__( ConfiguredResampledShards(path, seed, nshards=nshards), tarfile_to_samples(), wds.shuffle(shuffle_buffer), process_fn, ) # rclone support from webdataset.gopen import Pipe def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32): """Open a URL with `curl`. :param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured. :param mode: file mode :param bufsize: buffer size """ url = url.replace("rclone://", "") if mode[0] == "r": cmd = f"rclone cat '{url}'" return Pipe( cmd, mode=mode, shell=True, bufsize=bufsize, ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": cmd = f"rclone cp - '{url}'" return Pipe( cmd, mode=mode, shell=True, bufsize=bufsize, ignore_status=[141, 26], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") def gopen_boto3(url, mode="rb", bufsize=8192 * 2): """Open a URL with boto3 API. :param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured. :param mode: file mode :param bufsize: buffer size """ import boto3 # boto3.set_stream_logger('botocore', level='DEBUG') if url.startswith("boto3://"): url = url.replace("boto3://", "") need_meta = False else: url = url.replace("metaboto3://", "") need_meta = True endpoint_url = os.environ.get("S3_ENDPOINT_URL", None) access_key = os.environ.get("S3_ACCESS_KEY_ID", None) secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None) if mode[0] == "r": s3_client = boto3.client( "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key ) bucket, key = url.split("/", 1) if need_meta: # download a meta json meta_file_key = key.split(".")[0] + ".meta.jsonl" meta_stream = io.BytesIO() s3_client.download_fileobj(bucket, meta_file_key, meta_stream) meta_stream.seek(0) meta_stream.name = meta_file_key else: meta_stream = None # data tar stream response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional response["Body"].name = key # actually not used response["Body"].meta_stream = meta_stream return response["Body"] else: raise ValueError(f"{mode}: unknown mode") gopen_schemes["rclone"] = gopen_rclone gopen_schemes["boto3"] = gopen_boto3 gopen_schemes["metaboto3"] = gopen_boto3