realfake / realfake /bin /download_s3.py
devforfu
Movie stills binary classifier
c1f3687
import tarfile
from dataclasses import dataclass
from pathlib import Path
from typing import List
import boto3
from joblib import Parallel, delayed
from realfake.utils import get_user_name, inject_args, Args
class DownloadArgs(Args):
start_idx: int = 0
end_idx: int = 5247
metadata_only: bool = False
@inject_args
def main(args: DownloadArgs) -> None:
print(args)
bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
start_idx, end_idx = args.start_idx, args.end_idx
keys_range = list(range(start_idx, end_idx))
output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
if not args.metadata_only:
output_dir.mkdir(parents=True, exist_ok=True)
metadata_dir = output_dir.parent/f"{output_dir.name}.metadata"
metadata_dir.mkdir(parents=True, exist_ok=True)
jobs = get_jobs(keys_range, bucket, prefix, output_dir, metadata_dir, args.metadata_only)
Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
@dataclass
class Job:
bucket: str
key: Path
output_dir: Path
def get_jobs(
keys_range: list,
bucket: str,
prefix: str,
output_dir: Path,
metadata_dir: Path,
metadata_only: bool,
) -> List[Job]:
client = boto3.client("s3")
token, jobs = None, []
while True:
conf = dict(Bucket=bucket, Prefix=prefix)
if token is not None: conf["ContinuationToken"] = token
response = client.list_objects_v2(**conf)
for item in response.get("Contents"):
key = Path(item["Key"])
if key.suffix == ".tar" and int(key.stem) in keys_range and not metadata_only:
jobs.append(Job(bucket, key, output_dir))
elif key.suffix == ".parquet" and int(key.stem) in keys_range:
jobs.append(Job(bucket, key, metadata_dir))
if not response["IsTruncated"]: break
token = response["NextContinuationToken"]
return jobs
def download_and_extract(job: Job) -> None:
client = boto3.client("s3")
filename = job.output_dir / job.key.name
print(f"{job.key}: downloading...")
client.download_file(job.bucket, str(job.key), filename)
if filename.suffix == ".tar":
print(f"{job.key}: extracting...")
with tarfile.open(filename) as tar:
for name in tar.getnames():
extracted_path = job.output_dir/name
if extracted_path.exists():
continue
if name.endswith(".jpg"):
tar.extract(name, job.output_dir)
filename.unlink()
print(f"{job.key}: done!")
if __name__ == "__main__":
main()