import tarfile from dataclasses import dataclass from pathlib import Path from typing import List import boto3 from joblib import Parallel, delayed from realfake.utils import get_user_name, inject_args, Args class DownloadArgs(Args): start_idx: int = 0 end_idx: int = 5247 metadata_only: bool = False @inject_args def main(args: DownloadArgs) -> None: print(args) bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/" start_idx, end_idx = args.start_idx, args.end_idx keys_range = list(range(start_idx, end_idx)) output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}") if not args.metadata_only: output_dir.mkdir(parents=True, exist_ok=True) metadata_dir = output_dir.parent/f"{output_dir.name}.metadata" metadata_dir.mkdir(parents=True, exist_ok=True) jobs = get_jobs(keys_range, bucket, prefix, output_dir, metadata_dir, args.metadata_only) Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs) @dataclass class Job: bucket: str key: Path output_dir: Path def get_jobs( keys_range: list, bucket: str, prefix: str, output_dir: Path, metadata_dir: Path, metadata_only: bool, ) -> List[Job]: client = boto3.client("s3") token, jobs = None, [] while True: conf = dict(Bucket=bucket, Prefix=prefix) if token is not None: conf["ContinuationToken"] = token response = client.list_objects_v2(**conf) for item in response.get("Contents"): key = Path(item["Key"]) if key.suffix == ".tar" and int(key.stem) in keys_range and not metadata_only: jobs.append(Job(bucket, key, output_dir)) elif key.suffix == ".parquet" and int(key.stem) in keys_range: jobs.append(Job(bucket, key, metadata_dir)) if not response["IsTruncated"]: break token = response["NextContinuationToken"] return jobs def download_and_extract(job: Job) -> None: client = boto3.client("s3") filename = job.output_dir / job.key.name print(f"{job.key}: downloading...") client.download_file(job.bucket, str(job.key), filename) if filename.suffix == ".tar": print(f"{job.key}: extracting...") with tarfile.open(filename) as tar: for name in tar.getnames(): extracted_path = job.output_dir/name if extracted_path.exists(): continue if name.endswith(".jpg"): tar.extract(name, job.output_dir) filename.unlink() print(f"{job.key}: done!") if __name__ == "__main__": main()