devforfu commited on
Commit
c1f3687
1 Parent(s): 4334bbd

Movie stills binary classifier

Browse files
metadata/movies_plus.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:750d4828d1d5051b390519fa1e964b45305a5ffae3d1ef50b783568452bc13fa
3
+ size 5992428
realfake/bin/download_s3.py CHANGED
@@ -1,23 +1,34 @@
1
- from __future__ import annotations
2
  import tarfile
3
  from dataclasses import dataclass
4
  from pathlib import Path
 
5
 
6
  import boto3
7
  from joblib import Parallel, delayed
8
 
9
- from realfake.utils import get_user_name
10
 
11
 
12
- def main() -> None:
 
 
 
 
 
 
 
 
13
  bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
14
- start_idx, end_idx = 400, 700
15
  keys_range = list(range(start_idx, end_idx))
16
 
17
  output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
18
- output_dir.mkdir(parents=True, exist_ok=True)
 
 
 
19
 
20
- jobs = get_jobs(keys_range, bucket, prefix, output_dir)
21
 
22
  Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
23
 
@@ -29,7 +40,14 @@ class Job:
29
  output_dir: Path
30
 
31
 
32
- def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> list[Job]:
 
 
 
 
 
 
 
33
  client = boto3.client("s3")
34
 
35
  token, jobs = None, []
@@ -41,8 +59,10 @@ def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> li
41
 
42
  for item in response.get("Contents"):
43
  key = Path(item["Key"])
44
- if key.suffix == ".tar" and int(key.stem) in keys_range:
45
  jobs.append(Job(bucket, key, output_dir))
 
 
46
 
47
  if not response["IsTruncated"]: break
48
  token = response["NextContinuationToken"]
@@ -52,19 +72,24 @@ def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> li
52
 
53
  def download_and_extract(job: Job) -> None:
54
  client = boto3.client("s3")
55
- tar_file = job.output_dir / job.key.name
56
 
57
  print(f"{job.key}: downloading...")
58
- client.download_file(job.bucket, str(job.key), tar_file)
59
-
60
- print(f"{job.key}: extracting...")
61
- with tarfile.open(tar_file) as tar:
62
- for name in tar.getnames():
63
- if name.endswith(".jpg"):
64
- tar.extract(name, job.output_dir)
65
-
 
 
 
 
 
66
  print(f"{job.key}: done!")
67
- tar_file.unlink()
68
 
69
 
70
  if __name__ == "__main__":
 
 
1
  import tarfile
2
  from dataclasses import dataclass
3
  from pathlib import Path
4
+ from typing import List
5
 
6
  import boto3
7
  from joblib import Parallel, delayed
8
 
9
+ from realfake.utils import get_user_name, inject_args, Args
10
 
11
 
12
+ class DownloadArgs(Args):
13
+ start_idx: int = 0
14
+ end_idx: int = 5247
15
+ metadata_only: bool = False
16
+
17
+
18
+ @inject_args
19
+ def main(args: DownloadArgs) -> None:
20
+ print(args)
21
  bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
22
+ start_idx, end_idx = args.start_idx, args.end_idx
23
  keys_range = list(range(start_idx, end_idx))
24
 
25
  output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
26
+ if not args.metadata_only:
27
+ output_dir.mkdir(parents=True, exist_ok=True)
28
+ metadata_dir = output_dir.parent/f"{output_dir.name}.metadata"
29
+ metadata_dir.mkdir(parents=True, exist_ok=True)
30
 
31
+ jobs = get_jobs(keys_range, bucket, prefix, output_dir, metadata_dir, args.metadata_only)
32
 
33
  Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
34
 
 
40
  output_dir: Path
41
 
42
 
43
+ def get_jobs(
44
+ keys_range: list,
45
+ bucket: str,
46
+ prefix: str,
47
+ output_dir: Path,
48
+ metadata_dir: Path,
49
+ metadata_only: bool,
50
+ ) -> List[Job]:
51
  client = boto3.client("s3")
52
 
53
  token, jobs = None, []
 
59
 
60
  for item in response.get("Contents"):
61
  key = Path(item["Key"])
62
+ if key.suffix == ".tar" and int(key.stem) in keys_range and not metadata_only:
63
  jobs.append(Job(bucket, key, output_dir))
64
+ elif key.suffix == ".parquet" and int(key.stem) in keys_range:
65
+ jobs.append(Job(bucket, key, metadata_dir))
66
 
67
  if not response["IsTruncated"]: break
68
  token = response["NextContinuationToken"]
 
72
 
73
  def download_and_extract(job: Job) -> None:
74
  client = boto3.client("s3")
75
+ filename = job.output_dir / job.key.name
76
 
77
  print(f"{job.key}: downloading...")
78
+ client.download_file(job.bucket, str(job.key), filename)
79
+
80
+ if filename.suffix == ".tar":
81
+ print(f"{job.key}: extracting...")
82
+ with tarfile.open(filename) as tar:
83
+ for name in tar.getnames():
84
+ extracted_path = job.output_dir/name
85
+ if extracted_path.exists():
86
+ continue
87
+ if name.endswith(".jpg"):
88
+ tar.extract(name, job.output_dir)
89
+ filename.unlink()
90
+
91
  print(f"{job.key}: done!")
92
+
93
 
94
 
95
  if __name__ == "__main__":
realfake/utils.py CHANGED
@@ -122,5 +122,8 @@ def find_latest_checkpoint(dirname: Path) -> Path:
122
  return latest
123
 
124
 
125
- def list_files(dirname: Path, exts: list[str]) -> list:
126
- return [fn for fn in Path(dirname).iterdir() for ext in exts if fn.match(f"*.{ext}")]
 
 
 
 
122
  return latest
123
 
124
 
125
+ def list_files(dirname: Path, exts: list[str] | None = None) -> list:
126
+ files = Path(dirname).iterdir()
127
+ if not exts:
128
+ return list(files)
129
+ return [fn for fn in files for ext in exts if fn.match(f"*.{ext}")]
submit_movie.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ # SLURM SUBMIT SCRIPT
4
+ #SBATCH --partition=g40
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gpus=8
7
+ #SBATCH --cpus-per-gpu=6
8
+ #SBATCH --job-name=realfake
9
+ #SBATCH --comment=laion
10
+ #SBATCH --signal=SIGUSR1@90
11
+
12
+ source "${HOME}/venv/bin/activate"
13
+
14
+ export NCCL_DEBUG=INFO
15
+ export PYTHONFAULTHANDLER=1
16
+ export PYTHONPATH="${HOME}/realfake"
17
+
18
+ echo "Working directory: `pwd`"
19
+
20
+ srun python3 realfake/train_cluster.py \
21
+ -jf "${HOME}/realfake/metadata/movies_plus.jsonl" \
22
+ -mn convnext_small -e=40 -fe=40 -bs=128 -wl=1 -fw=0.08 \
23
+ --acceleratorparams.devices=8 \
24
+ --acceleratorparams.strategy=ddp_find_unused_parameters_false