File size: 2,349 Bytes
5231633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import json
import subprocess
import yaml
import os
from .bucketeer import Bucketeer
class MultiFilter():
def __init__(self, rules, default=False):
self.rules = rules
self.default = default
def __call__(self, x):
try:
x_json = x['json']
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
validations = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
validations.append(v)
return all(validations)
except Exception:
return False
class MultiGetter():
def __init__(self, rules):
self.rules = rules
def __call__(self, x_json):
if isinstance(x_json, bytes):
x_json = json.loads(x_json)
outputs = []
for k, r in self.rules.items():
if isinstance(k, tuple):
v = r(*[x_json[kv] for kv in k])
else:
v = r(x_json[k])
outputs.append(v)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
def setup_webdataset_path(paths, cache_path=None):
if cache_path is None or not os.path.exists(cache_path):
tar_paths = []
if isinstance(paths, str):
paths = [paths]
for path in paths:
if path.strip().endswith(".tar"):
# Avoid looking up s3 if we already have a tar file
tar_paths.append(path)
continue
bucket = "/".join(path.split("/")[:3])
result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
files = result.stdout.decode('utf-8').split()
files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
tar_paths += files
with open(cache_path, 'w', encoding='utf-8') as outfile:
yaml.dump(tar_paths, outfile, default_flow_style=False)
else:
with open(cache_path, 'r', encoding='utf-8') as file:
tar_paths = yaml.safe_load(file)
tar_paths_str = ",".join([f"{p}" for p in tar_paths])
return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
|