|
import tarfile |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import boto3 |
|
from joblib import Parallel, delayed |
|
|
|
from realfake.utils import get_user_name, inject_args, Args |
|
|
|
|
|
class DownloadArgs(Args): |
|
start_idx: int = 0 |
|
end_idx: int = 5247 |
|
metadata_only: bool = False |
|
|
|
|
|
@inject_args |
|
def main(args: DownloadArgs) -> None: |
|
print(args) |
|
bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/" |
|
start_idx, end_idx = args.start_idx, args.end_idx |
|
keys_range = list(range(start_idx, end_idx)) |
|
|
|
output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}") |
|
if not args.metadata_only: |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
metadata_dir = output_dir.parent/f"{output_dir.name}.metadata" |
|
metadata_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
jobs = get_jobs(keys_range, bucket, prefix, output_dir, metadata_dir, args.metadata_only) |
|
|
|
Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs) |
|
|
|
|
|
@dataclass |
|
class Job: |
|
bucket: str |
|
key: Path |
|
output_dir: Path |
|
|
|
|
|
def get_jobs( |
|
keys_range: list, |
|
bucket: str, |
|
prefix: str, |
|
output_dir: Path, |
|
metadata_dir: Path, |
|
metadata_only: bool, |
|
) -> List[Job]: |
|
client = boto3.client("s3") |
|
|
|
token, jobs = None, [] |
|
|
|
while True: |
|
conf = dict(Bucket=bucket, Prefix=prefix) |
|
if token is not None: conf["ContinuationToken"] = token |
|
response = client.list_objects_v2(**conf) |
|
|
|
for item in response.get("Contents"): |
|
key = Path(item["Key"]) |
|
if key.suffix == ".tar" and int(key.stem) in keys_range and not metadata_only: |
|
jobs.append(Job(bucket, key, output_dir)) |
|
elif key.suffix == ".parquet" and int(key.stem) in keys_range: |
|
jobs.append(Job(bucket, key, metadata_dir)) |
|
|
|
if not response["IsTruncated"]: break |
|
token = response["NextContinuationToken"] |
|
|
|
return jobs |
|
|
|
|
|
def download_and_extract(job: Job) -> None: |
|
client = boto3.client("s3") |
|
filename = job.output_dir / job.key.name |
|
|
|
print(f"{job.key}: downloading...") |
|
client.download_file(job.bucket, str(job.key), filename) |
|
|
|
if filename.suffix == ".tar": |
|
print(f"{job.key}: extracting...") |
|
with tarfile.open(filename) as tar: |
|
for name in tar.getnames(): |
|
extracted_path = job.output_dir/name |
|
if extracted_path.exists(): |
|
continue |
|
if name.endswith(".jpg"): |
|
tar.extract(name, job.output_dir) |
|
filename.unlink() |
|
|
|
print(f"{job.key}: done!") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|