Add files using upload-large-folder tool
Browse files- REG/__pycache__/dataset.cpython-312.pyc +0 -0
- REG/__pycache__/loss.cpython-312.pyc +0 -0
- REG/__pycache__/loss.cpython-313.pyc +0 -0
- REG/__pycache__/sample_from_checkpoint.cpython-313.pyc +0 -0
- REG/__pycache__/sample_from_checkpoint_ddp.cpython-313.pyc +0 -0
- REG/__pycache__/samplers.cpython-312.pyc +0 -0
- REG/__pycache__/samplers.cpython-313.pyc +0 -0
- REG/__pycache__/train.cpython-313.pyc +0 -0
- REG/__pycache__/utils.cpython-312.pyc +0 -0
- REG/models/__pycache__/mocov3_vit.cpython-310.pyc +0 -0
- REG/models/__pycache__/mocov3_vit.cpython-312.pyc +0 -0
- REG/models/__pycache__/sit.cpython-310.pyc +0 -0
- REG/models/__pycache__/sit.cpython-312.pyc +0 -0
- REG/preprocessing/README.md +25 -0
- REG/preprocessing/dataset_image_encoder.py +353 -0
- REG/preprocessing/dataset_prepare_convert.sh +11 -0
- REG/preprocessing/dataset_prepare_encode.sh +9 -0
- REG/preprocessing/dataset_tools.py +422 -0
- REG/preprocessing/dnnlib/__init__.py +8 -0
- REG/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc +0 -0
- REG/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc +0 -0
- REG/preprocessing/dnnlib/util.py +485 -0
- REG/preprocessing/encoders.py +103 -0
- REG/preprocessing/torch_utils/__init__.py +8 -0
- REG/preprocessing/torch_utils/distributed.py +140 -0
- REG/preprocessing/torch_utils/misc.py +277 -0
- REG/preprocessing/torch_utils/persistence.py +257 -0
- REG/preprocessing/torch_utils/training_stats.py +283 -0
- REG/wandb/debug-internal.log +21 -0
- REG/wandb/debug.log +22 -0
- REG/wandb/run-20260322_141726-2yw08kz9/files/config.yaml +203 -0
- REG/wandb/run-20260322_141726-2yw08kz9/files/output.log +27 -0
- REG/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt +168 -0
- REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json +1 -0
- REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log +7 -0
- REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log +22 -0
- REG/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb +0 -0
- REG/wandb/run-20260322_141833-vm0y8t9t/files/output.log +0 -0
- REG/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt +168 -0
- REG/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log +6 -0
- REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log +20 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/files/output.log +19 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt +168 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log +7 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log +22 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/run-yhxc5cgu.wandb +0 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/run-e3yw9ii4.wandb +0 -0
REG/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
REG/__pycache__/loss.cpython-312.pyc
ADDED
|
Binary file (9.98 kB). View file
|
|
|
REG/__pycache__/loss.cpython-313.pyc
ADDED
|
Binary file (8.75 kB). View file
|
|
|
REG/__pycache__/sample_from_checkpoint.cpython-313.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
REG/__pycache__/sample_from_checkpoint_ddp.cpython-313.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
REG/__pycache__/samplers.cpython-312.pyc
ADDED
|
Binary file (31.3 kB). View file
|
|
|
REG/__pycache__/samplers.cpython-313.pyc
ADDED
|
Binary file (31.6 kB). View file
|
|
|
REG/__pycache__/train.cpython-313.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
REG/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
REG/models/__pycache__/mocov3_vit.cpython-310.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
REG/models/__pycache__/mocov3_vit.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
REG/models/__pycache__/sit.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
REG/models/__pycache__/sit.cpython-312.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
REG/preprocessing/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center"> Preprocessing Guide
|
| 2 |
+
</h1>
|
| 3 |
+
|
| 4 |
+
#### Dataset download
|
| 5 |
+
|
| 6 |
+
We follow the preprocessing code used in [edm2](https://github.com/NVlabs/edm2). In this code we made a several edits: (1) we removed unncessary parts except preprocessing because this code is only used for preprocessing, (2) we use [-1, 1] range for an input to the stable diffusion VAE (similar to DiT or SiT) unlike edm2 that uses [0, 1] range, and (3) we consider preprocessing to 256x256 resolution (or 512x512 resolution).
|
| 7 |
+
|
| 8 |
+
After downloading ImageNet, please run the following scripts (please update 256x256 to 512x512 if you want to do experiments on 512x512 resolution);
|
| 9 |
+
|
| 10 |
+
Convert raw ImageNet data to a ZIP archive at 256x256 resolution
|
| 11 |
+
```bash
|
| 12 |
+
bash dataset_prepare_encode.sh
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Convert the pixel data to VAE latents
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
bash dataset_prepare_convert.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Here,`YOUR_DOWNLOAD_PATH` is the directory that you downloaded the dataset, and `TARGET_PATH` is the directory that you will save the preprocessed images and corresponding compressed latent vectors. This directory will be used for your experiment scripts.
|
| 22 |
+
|
| 23 |
+
## Acknowledgement
|
| 24 |
+
|
| 25 |
+
This code is mainly built upon [edm2](https://github.com/NVlabs/edm2) repository.
|
REG/preprocessing/dataset_image_encoder.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Tool for creating ZIP/PNG based datasets."""
|
| 9 |
+
|
| 10 |
+
from collections.abc import Iterator
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import functools
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import zipfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Callable, Optional, Tuple, Union
|
| 20 |
+
import click
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from encoders import StabilityVAEEncoder
|
| 27 |
+
from utils import load_encoders
|
| 28 |
+
from torchvision.transforms import Normalize
|
| 29 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 30 |
+
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 31 |
+
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 32 |
+
|
| 33 |
+
def preprocess_raw_image(x, enc_type):
|
| 34 |
+
resolution = x.shape[-1]
|
| 35 |
+
if 'clip' in enc_type:
|
| 36 |
+
x = x / 255.
|
| 37 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 38 |
+
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| 39 |
+
elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| 40 |
+
x = x / 255.
|
| 41 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 42 |
+
elif 'dinov2' in enc_type:
|
| 43 |
+
x = x / 255.
|
| 44 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 45 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 46 |
+
elif 'dinov1' in enc_type:
|
| 47 |
+
x = x / 255.
|
| 48 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 49 |
+
elif 'jepa' in enc_type:
|
| 50 |
+
x = x / 255.
|
| 51 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 52 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
#----------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class ImageEntry:
|
| 61 |
+
img: np.ndarray
|
| 62 |
+
label: Optional[int]
|
| 63 |
+
|
| 64 |
+
#----------------------------------------------------------------------------
|
| 65 |
+
# Parse a 'M,N' or 'MxN' integer tuple.
|
| 66 |
+
# Example: '4x2' returns (4,2)
|
| 67 |
+
|
| 68 |
+
def parse_tuple(s: str) -> Tuple[int, int]:
|
| 69 |
+
m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| 70 |
+
if m:
|
| 71 |
+
return int(m.group(1)), int(m.group(2))
|
| 72 |
+
raise click.ClickException(f'cannot parse tuple {s}')
|
| 73 |
+
|
| 74 |
+
#----------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
| 77 |
+
if b is not None:
|
| 78 |
+
return min(a, b)
|
| 79 |
+
return a
|
| 80 |
+
|
| 81 |
+
#----------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def file_ext(name: Union[str, Path]) -> str:
|
| 84 |
+
return str(name).split('.')[-1]
|
| 85 |
+
|
| 86 |
+
#----------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
| 89 |
+
ext = file_ext(fname).lower()
|
| 90 |
+
return f'.{ext}' in PIL.Image.EXTENSION
|
| 91 |
+
|
| 92 |
+
#----------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 95 |
+
input_images = []
|
| 96 |
+
def _recurse_dirs(root: str): # workaround Path().rglob() slowness
|
| 97 |
+
with os.scandir(root) as it:
|
| 98 |
+
for e in it:
|
| 99 |
+
if e.is_file():
|
| 100 |
+
input_images.append(os.path.join(root, e.name))
|
| 101 |
+
elif e.is_dir():
|
| 102 |
+
_recurse_dirs(os.path.join(root, e.name))
|
| 103 |
+
_recurse_dirs(source_dir)
|
| 104 |
+
input_images = sorted([f for f in input_images if is_image_ext(f)])
|
| 105 |
+
|
| 106 |
+
arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| 107 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 108 |
+
|
| 109 |
+
# Load labels.
|
| 110 |
+
labels = dict()
|
| 111 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
| 112 |
+
if os.path.isfile(meta_fname):
|
| 113 |
+
with open(meta_fname, 'r') as file:
|
| 114 |
+
data = json.load(file)['labels']
|
| 115 |
+
if data is not None:
|
| 116 |
+
labels = {x[0]: x[1] for x in data}
|
| 117 |
+
|
| 118 |
+
# No labels available => determine from top-level directory names.
|
| 119 |
+
if len(labels) == 0:
|
| 120 |
+
toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| 121 |
+
toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| 122 |
+
if len(toplevel_indices) > 1:
|
| 123 |
+
labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
| 124 |
+
|
| 125 |
+
def iterate_images():
|
| 126 |
+
for idx, fname in enumerate(input_images):
|
| 127 |
+
img = np.array(PIL.Image.open(fname).convert('RGB'))#.transpose(2, 0, 1)
|
| 128 |
+
yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
|
| 129 |
+
if idx >= max_idx - 1:
|
| 130 |
+
break
|
| 131 |
+
return max_idx, iterate_images()
|
| 132 |
+
|
| 133 |
+
#----------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 136 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 137 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| 138 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 139 |
+
|
| 140 |
+
# Load labels.
|
| 141 |
+
labels = dict()
|
| 142 |
+
if 'dataset.json' in z.namelist():
|
| 143 |
+
with z.open('dataset.json', 'r') as file:
|
| 144 |
+
data = json.load(file)['labels']
|
| 145 |
+
if data is not None:
|
| 146 |
+
labels = {x[0]: x[1] for x in data}
|
| 147 |
+
|
| 148 |
+
def iterate_images():
|
| 149 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 150 |
+
for idx, fname in enumerate(input_images):
|
| 151 |
+
with z.open(fname, 'r') as file:
|
| 152 |
+
img = np.array(PIL.Image.open(file).convert('RGB'))
|
| 153 |
+
yield ImageEntry(img=img, label=labels.get(fname))
|
| 154 |
+
if idx >= max_idx - 1:
|
| 155 |
+
break
|
| 156 |
+
return max_idx, iterate_images()
|
| 157 |
+
|
| 158 |
+
#----------------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
def make_transform(
|
| 161 |
+
transform: Optional[str],
|
| 162 |
+
output_width: Optional[int],
|
| 163 |
+
output_height: Optional[int]
|
| 164 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| 165 |
+
def scale(width, height, img):
|
| 166 |
+
w = img.shape[1]
|
| 167 |
+
h = img.shape[0]
|
| 168 |
+
if width == w and height == h:
|
| 169 |
+
return img
|
| 170 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 171 |
+
ww = width if width is not None else w
|
| 172 |
+
hh = height if height is not None else h
|
| 173 |
+
img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| 174 |
+
return np.array(img)
|
| 175 |
+
|
| 176 |
+
def center_crop(width, height, img):
|
| 177 |
+
crop = np.min(img.shape[:2])
|
| 178 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| 179 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 180 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 181 |
+
return np.array(img)
|
| 182 |
+
|
| 183 |
+
def center_crop_wide(width, height, img):
|
| 184 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| 185 |
+
if img.shape[1] < width or ch < height:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| 189 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 190 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 191 |
+
img = np.array(img)
|
| 192 |
+
|
| 193 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| 194 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| 195 |
+
return canvas
|
| 196 |
+
|
| 197 |
+
def center_crop_imagenet(image_size: int, arr: np.ndarray):
|
| 198 |
+
"""
|
| 199 |
+
Center cropping implementation from ADM.
|
| 200 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 201 |
+
"""
|
| 202 |
+
pil_image = PIL.Image.fromarray(arr)
|
| 203 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 204 |
+
new_size = tuple(x // 2 for x in pil_image.size)
|
| 205 |
+
assert len(new_size) == 2
|
| 206 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
|
| 207 |
+
|
| 208 |
+
scale = image_size / min(*pil_image.size)
|
| 209 |
+
new_size = tuple(round(x * scale) for x in pil_image.size)
|
| 210 |
+
assert len(new_size) == 2
|
| 211 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
|
| 212 |
+
|
| 213 |
+
arr = np.array(pil_image)
|
| 214 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 215 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 216 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
| 217 |
+
|
| 218 |
+
if transform is None:
|
| 219 |
+
return functools.partial(scale, output_width, output_height)
|
| 220 |
+
if transform == 'center-crop':
|
| 221 |
+
if output_width is None or output_height is None:
|
| 222 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| 223 |
+
return functools.partial(center_crop, output_width, output_height)
|
| 224 |
+
if transform == 'center-crop-wide':
|
| 225 |
+
if output_width is None or output_height is None:
|
| 226 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 227 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
| 228 |
+
if transform == 'center-crop-dhariwal':
|
| 229 |
+
if output_width is None or output_height is None:
|
| 230 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 231 |
+
if output_width != output_height:
|
| 232 |
+
raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
|
| 233 |
+
return functools.partial(center_crop_imagenet, output_width)
|
| 234 |
+
assert False, 'unknown transform'
|
| 235 |
+
|
| 236 |
+
#----------------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
| 239 |
+
if os.path.isdir(source):
|
| 240 |
+
return open_image_folder(source, max_images=max_images)
|
| 241 |
+
elif os.path.isfile(source):
|
| 242 |
+
if file_ext(source) == 'zip':
|
| 243 |
+
return open_image_zip(source, max_images=max_images)
|
| 244 |
+
else:
|
| 245 |
+
raise click.ClickException(f'Only zip archives are supported: {source}')
|
| 246 |
+
else:
|
| 247 |
+
raise click.ClickException(f'Missing input file or directory: {source}')
|
| 248 |
+
|
| 249 |
+
#----------------------------------------------------------------------------
|
| 250 |
+
|
| 251 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| 252 |
+
dest_ext = file_ext(dest)
|
| 253 |
+
|
| 254 |
+
if dest_ext == 'zip':
|
| 255 |
+
if os.path.dirname(dest) != '':
|
| 256 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| 257 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| 258 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| 259 |
+
zf.writestr(fname, data)
|
| 260 |
+
return '', zip_write_bytes, zf.close
|
| 261 |
+
else:
|
| 262 |
+
# If the output folder already exists, check that is is
|
| 263 |
+
# empty.
|
| 264 |
+
#
|
| 265 |
+
# Note: creating the output directory is not strictly
|
| 266 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
| 267 |
+
# to give an error message earlier in case the dest folder
|
| 268 |
+
# somehow cannot be created.
|
| 269 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| 270 |
+
raise click.ClickException('--dest folder must be empty')
|
| 271 |
+
os.makedirs(dest, exist_ok=True)
|
| 272 |
+
|
| 273 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| 274 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| 275 |
+
with open(fname, 'wb') as fout:
|
| 276 |
+
if isinstance(data, str):
|
| 277 |
+
data = data.encode('utf8')
|
| 278 |
+
fout.write(data)
|
| 279 |
+
return dest, folder_write_bytes, lambda: None
|
| 280 |
+
|
| 281 |
+
#----------------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
@click.group()
|
| 284 |
+
def cmdline():
|
| 285 |
+
'''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
|
| 286 |
+
if os.environ.get('WORLD_SIZE', '1') != '1':
|
| 287 |
+
raise click.ClickException('Distributed execution is not supported.')
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
#----------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@cmdline.command()
|
| 295 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 296 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 297 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 298 |
+
@click.option('--enc-type', help='Maximum number of images to output', metavar='PATH', type=str, default='dinov2-vit-b')
|
| 299 |
+
@click.option('--resolution', help='Maximum number of images to output', metavar='INT', type=int, default=256)
|
| 300 |
+
|
| 301 |
+
def encode(
|
| 302 |
+
source: str,
|
| 303 |
+
dest: str,
|
| 304 |
+
max_images: Optional[int],
|
| 305 |
+
enc_type,
|
| 306 |
+
resolution
|
| 307 |
+
):
|
| 308 |
+
|
| 309 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 310 |
+
encoder, encoder_type, architectures = load_encoders(enc_type, device, resolution)
|
| 311 |
+
encoder, encoder_type, architectures = encoder[0], encoder_type[0], architectures[0]
|
| 312 |
+
print("Encoder is over!!!")
|
| 313 |
+
|
| 314 |
+
"""Encode pixel data to VAE latents."""
|
| 315 |
+
PIL.Image.init()
|
| 316 |
+
if dest == '':
|
| 317 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 318 |
+
|
| 319 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 320 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 321 |
+
print("Data is over!!!")
|
| 322 |
+
labels = []
|
| 323 |
+
|
| 324 |
+
temp_list1 = []
|
| 325 |
+
temp_list2 = []
|
| 326 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
|
| 329 |
+
raw_image_ = preprocess_raw_image(img_tensor, encoder_type)
|
| 330 |
+
z = encoder.forward_features(raw_image_)
|
| 331 |
+
if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
|
| 332 |
+
temp_list1.append(z)
|
| 333 |
+
z = z.detach().cpu().numpy()
|
| 334 |
+
temp_list2.append(z)
|
| 335 |
+
|
| 336 |
+
idx_str = f'{idx:08d}'
|
| 337 |
+
archive_fname = f'{idx_str[:5]}/img-feature-{idx_str}.npy'
|
| 338 |
+
|
| 339 |
+
f = io.BytesIO()
|
| 340 |
+
np.save(f, z)
|
| 341 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
|
| 342 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 346 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 347 |
+
close_dest()
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
cmdline()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/dataset_prepare_convert.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#256
|
| 7 |
+
python preprocessing/dataset_tools.py convert \
|
| 8 |
+
--source=/home/share/imagenet/train \
|
| 9 |
+
--dest=/home/share/imagenet_vae/imagenet_256_vae \
|
| 10 |
+
--resolution=256x256 \
|
| 11 |
+
--transform=center-crop-dhariwal
|
REG/preprocessing/dataset_prepare_encode.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
#256
|
| 7 |
+
python preprocessing/dataset_tools.py encode \
|
| 8 |
+
--source=/home/share/imagenet_vae/imagenet_256_vae \
|
| 9 |
+
--dest=/home/share/imagenet_vae/vae-sd-256
|
REG/preprocessing/dataset_tools.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Tool for creating ZIP/PNG based datasets."""
|
| 9 |
+
|
| 10 |
+
from collections.abc import Iterator
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import functools
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import zipfile
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Callable, Optional, Tuple, Union
|
| 20 |
+
import click
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from encoders import StabilityVAEEncoder
|
| 27 |
+
|
| 28 |
+
#----------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ImageEntry:
|
| 32 |
+
img: np.ndarray
|
| 33 |
+
label: Optional[int]
|
| 34 |
+
|
| 35 |
+
#----------------------------------------------------------------------------
|
| 36 |
+
# Parse a 'M,N' or 'MxN' integer tuple.
|
| 37 |
+
# Example: '4x2' returns (4,2)
|
| 38 |
+
|
| 39 |
+
def parse_tuple(s: str) -> Tuple[int, int]:
|
| 40 |
+
m = re.match(r'^(\d+)[x,](\d+)$', s)
|
| 41 |
+
if m:
|
| 42 |
+
return int(m.group(1)), int(m.group(2))
|
| 43 |
+
raise click.ClickException(f'cannot parse tuple {s}')
|
| 44 |
+
|
| 45 |
+
#----------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
| 48 |
+
if b is not None:
|
| 49 |
+
return min(a, b)
|
| 50 |
+
return a
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
def file_ext(name: Union[str, Path]) -> str:
|
| 55 |
+
return str(name).split('.')[-1]
|
| 56 |
+
|
| 57 |
+
#----------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
| 60 |
+
ext = file_ext(fname).lower()
|
| 61 |
+
return f'.{ext}' in PIL.Image.EXTENSION
|
| 62 |
+
|
| 63 |
+
#----------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 66 |
+
input_images = []
|
| 67 |
+
def _recurse_dirs(root: str): # workaround Path().rglob() slowness
|
| 68 |
+
with os.scandir(root) as it:
|
| 69 |
+
for e in it:
|
| 70 |
+
if e.is_file():
|
| 71 |
+
input_images.append(os.path.join(root, e.name))
|
| 72 |
+
elif e.is_dir():
|
| 73 |
+
_recurse_dirs(os.path.join(root, e.name))
|
| 74 |
+
_recurse_dirs(source_dir)
|
| 75 |
+
input_images = sorted([f for f in input_images if is_image_ext(f)])
|
| 76 |
+
|
| 77 |
+
arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
|
| 78 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 79 |
+
|
| 80 |
+
# Load labels.
|
| 81 |
+
labels = dict()
|
| 82 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
| 83 |
+
if os.path.isfile(meta_fname):
|
| 84 |
+
with open(meta_fname, 'r') as file:
|
| 85 |
+
data = json.load(file)['labels']
|
| 86 |
+
if data is not None:
|
| 87 |
+
labels = {x[0]: x[1] for x in data}
|
| 88 |
+
|
| 89 |
+
# No labels available => determine from top-level directory names.
|
| 90 |
+
if len(labels) == 0:
|
| 91 |
+
toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
|
| 92 |
+
toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
|
| 93 |
+
if len(toplevel_indices) > 1:
|
| 94 |
+
labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
|
| 95 |
+
|
| 96 |
+
def iterate_images():
|
| 97 |
+
for idx, fname in enumerate(input_images):
|
| 98 |
+
img = np.array(PIL.Image.open(fname).convert('RGB'))
|
| 99 |
+
yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
|
| 100 |
+
if idx >= max_idx - 1:
|
| 101 |
+
break
|
| 102 |
+
return max_idx, iterate_images()
|
| 103 |
+
|
| 104 |
+
#----------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
|
| 107 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 108 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
| 109 |
+
max_idx = maybe_min(len(input_images), max_images)
|
| 110 |
+
|
| 111 |
+
# Load labels.
|
| 112 |
+
labels = dict()
|
| 113 |
+
if 'dataset.json' in z.namelist():
|
| 114 |
+
with z.open('dataset.json', 'r') as file:
|
| 115 |
+
data = json.load(file)['labels']
|
| 116 |
+
if data is not None:
|
| 117 |
+
labels = {x[0]: x[1] for x in data}
|
| 118 |
+
|
| 119 |
+
def iterate_images():
|
| 120 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
| 121 |
+
for idx, fname in enumerate(input_images):
|
| 122 |
+
with z.open(fname, 'r') as file:
|
| 123 |
+
img = np.array(PIL.Image.open(file).convert('RGB'))
|
| 124 |
+
yield ImageEntry(img=img, label=labels.get(fname))
|
| 125 |
+
if idx >= max_idx - 1:
|
| 126 |
+
break
|
| 127 |
+
return max_idx, iterate_images()
|
| 128 |
+
|
| 129 |
+
#----------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def make_transform(
|
| 132 |
+
transform: Optional[str],
|
| 133 |
+
output_width: Optional[int],
|
| 134 |
+
output_height: Optional[int]
|
| 135 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
| 136 |
+
def scale(width, height, img):
|
| 137 |
+
w = img.shape[1]
|
| 138 |
+
h = img.shape[0]
|
| 139 |
+
if width == w and height == h:
|
| 140 |
+
return img
|
| 141 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 142 |
+
ww = width if width is not None else w
|
| 143 |
+
hh = height if height is not None else h
|
| 144 |
+
img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
|
| 145 |
+
return np.array(img)
|
| 146 |
+
|
| 147 |
+
def center_crop(width, height, img):
|
| 148 |
+
crop = np.min(img.shape[:2])
|
| 149 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
| 150 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 151 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 152 |
+
return np.array(img)
|
| 153 |
+
|
| 154 |
+
def center_crop_wide(width, height, img):
|
| 155 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
| 156 |
+
if img.shape[1] < width or ch < height:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
| 160 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
| 161 |
+
img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
|
| 162 |
+
img = np.array(img)
|
| 163 |
+
|
| 164 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
| 165 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
| 166 |
+
return canvas
|
| 167 |
+
|
| 168 |
+
def center_crop_imagenet(image_size: int, arr: np.ndarray):
|
| 169 |
+
"""
|
| 170 |
+
Center cropping implementation from ADM.
|
| 171 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 172 |
+
"""
|
| 173 |
+
pil_image = PIL.Image.fromarray(arr)
|
| 174 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 175 |
+
new_size = tuple(x // 2 for x in pil_image.size)
|
| 176 |
+
assert len(new_size) == 2
|
| 177 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
|
| 178 |
+
|
| 179 |
+
scale = image_size / min(*pil_image.size)
|
| 180 |
+
new_size = tuple(round(x * scale) for x in pil_image.size)
|
| 181 |
+
assert len(new_size) == 2
|
| 182 |
+
pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
|
| 183 |
+
|
| 184 |
+
arr = np.array(pil_image)
|
| 185 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 186 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 187 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
| 188 |
+
|
| 189 |
+
if transform is None:
|
| 190 |
+
return functools.partial(scale, output_width, output_height)
|
| 191 |
+
if transform == 'center-crop':
|
| 192 |
+
if output_width is None or output_height is None:
|
| 193 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
|
| 194 |
+
return functools.partial(center_crop, output_width, output_height)
|
| 195 |
+
if transform == 'center-crop-wide':
|
| 196 |
+
if output_width is None or output_height is None:
|
| 197 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 198 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
| 199 |
+
if transform == 'center-crop-dhariwal':
|
| 200 |
+
if output_width is None or output_height is None:
|
| 201 |
+
raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
|
| 202 |
+
if output_width != output_height:
|
| 203 |
+
raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
|
| 204 |
+
return functools.partial(center_crop_imagenet, output_width)
|
| 205 |
+
assert False, 'unknown transform'
|
| 206 |
+
|
| 207 |
+
#----------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
| 210 |
+
if os.path.isdir(source):
|
| 211 |
+
return open_image_folder(source, max_images=max_images)
|
| 212 |
+
elif os.path.isfile(source):
|
| 213 |
+
if file_ext(source) == 'zip':
|
| 214 |
+
return open_image_zip(source, max_images=max_images)
|
| 215 |
+
else:
|
| 216 |
+
raise click.ClickException(f'Only zip archives are supported: {source}')
|
| 217 |
+
else:
|
| 218 |
+
raise click.ClickException(f'Missing input file or directory: {source}')
|
| 219 |
+
|
| 220 |
+
#----------------------------------------------------------------------------
|
| 221 |
+
|
| 222 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
| 223 |
+
dest_ext = file_ext(dest)
|
| 224 |
+
|
| 225 |
+
if dest_ext == 'zip':
|
| 226 |
+
if os.path.dirname(dest) != '':
|
| 227 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| 228 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
| 229 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
| 230 |
+
zf.writestr(fname, data)
|
| 231 |
+
return '', zip_write_bytes, zf.close
|
| 232 |
+
else:
|
| 233 |
+
# If the output folder already exists, check that is is
|
| 234 |
+
# empty.
|
| 235 |
+
#
|
| 236 |
+
# Note: creating the output directory is not strictly
|
| 237 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
| 238 |
+
# to give an error message earlier in case the dest folder
|
| 239 |
+
# somehow cannot be created.
|
| 240 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
| 241 |
+
raise click.ClickException('--dest folder must be empty')
|
| 242 |
+
os.makedirs(dest, exist_ok=True)
|
| 243 |
+
|
| 244 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
| 245 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
| 246 |
+
with open(fname, 'wb') as fout:
|
| 247 |
+
if isinstance(data, str):
|
| 248 |
+
data = data.encode('utf8')
|
| 249 |
+
fout.write(data)
|
| 250 |
+
return dest, folder_write_bytes, lambda: None
|
| 251 |
+
|
| 252 |
+
#----------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
@click.group()
|
| 255 |
+
def cmdline():
|
| 256 |
+
'''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
|
| 257 |
+
if os.environ.get('WORLD_SIZE', '1') != '1':
|
| 258 |
+
raise click.ClickException('Distributed execution is not supported.')
|
| 259 |
+
|
| 260 |
+
#----------------------------------------------------------------------------
|
| 261 |
+
|
| 262 |
+
@cmdline.command()
|
| 263 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 264 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 265 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 266 |
+
@click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-dhariwal']))
|
| 267 |
+
@click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
|
| 268 |
+
|
| 269 |
+
def convert(
|
| 270 |
+
source: str,
|
| 271 |
+
dest: str,
|
| 272 |
+
max_images: Optional[int],
|
| 273 |
+
transform: Optional[str],
|
| 274 |
+
resolution: Optional[Tuple[int, int]]
|
| 275 |
+
):
|
| 276 |
+
"""Convert an image dataset into archive format for training.
|
| 277 |
+
|
| 278 |
+
Specifying the input images:
|
| 279 |
+
|
| 280 |
+
\b
|
| 281 |
+
--source path/ Recursively load all images from path/
|
| 282 |
+
--source dataset.zip Load all images from dataset.zip
|
| 283 |
+
|
| 284 |
+
Specifying the output format and path:
|
| 285 |
+
|
| 286 |
+
\b
|
| 287 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
| 288 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
| 289 |
+
|
| 290 |
+
The output dataset format can be either an image folder or an uncompressed zip archive.
|
| 291 |
+
Zip archives makes it easier to move datasets around file servers and clusters, and may
|
| 292 |
+
offer better training performance on network file systems.
|
| 293 |
+
|
| 294 |
+
Images within the dataset archive will be stored as uncompressed PNG.
|
| 295 |
+
Uncompresed PNGs can be efficiently decoded in the training loop.
|
| 296 |
+
|
| 297 |
+
Class labels are stored in a file called 'dataset.json' that is stored at the
|
| 298 |
+
dataset root folder. This file has the following structure:
|
| 299 |
+
|
| 300 |
+
\b
|
| 301 |
+
{
|
| 302 |
+
"labels": [
|
| 303 |
+
["00000/img00000000.png",6],
|
| 304 |
+
["00000/img00000001.png",9],
|
| 305 |
+
... repeated for every image in the datase
|
| 306 |
+
["00049/img00049999.png",1]
|
| 307 |
+
]
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
If the 'dataset.json' file cannot be found, class labels are determined from
|
| 311 |
+
top-level directory names.
|
| 312 |
+
|
| 313 |
+
Image scale/crop and resolution requirements:
|
| 314 |
+
|
| 315 |
+
Output images must be square-shaped and they must all have the same power-of-two
|
| 316 |
+
dimensions.
|
| 317 |
+
|
| 318 |
+
To scale arbitrary input image size to a specific width and height, use the
|
| 319 |
+
--resolution option. Output resolution will be either the original
|
| 320 |
+
input resolution (if resolution was not specified) or the one specified with
|
| 321 |
+
--resolution option.
|
| 322 |
+
|
| 323 |
+
The --transform=center-crop-dhariwal selects a crop/rescale mode that is intended
|
| 324 |
+
to exactly match with results obtained for ImageNet in common diffusion model literature:
|
| 325 |
+
|
| 326 |
+
\b
|
| 327 |
+
python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \\
|
| 328 |
+
--dest=datasets/img64.zip --resolution=64x64 --transform=center-crop-dhariwal
|
| 329 |
+
"""
|
| 330 |
+
PIL.Image.init()
|
| 331 |
+
if dest == '':
|
| 332 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 333 |
+
print("Begin!!!!!!!!")
|
| 334 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 335 |
+
print("open_dataset is over")
|
| 336 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 337 |
+
print("open_dest is over")
|
| 338 |
+
transform_image = make_transform(transform, *resolution if resolution is not None else (None, None))
|
| 339 |
+
dataset_attrs = None
|
| 340 |
+
|
| 341 |
+
labels = []
|
| 342 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 343 |
+
idx_str = f'{idx:08d}'
|
| 344 |
+
archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
|
| 345 |
+
|
| 346 |
+
# Apply crop and resize.
|
| 347 |
+
img = transform_image(image.img)
|
| 348 |
+
if img is None:
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# Error check to require uniform image attributes across
|
| 352 |
+
# the whole dataset.
|
| 353 |
+
assert img.ndim == 3
|
| 354 |
+
cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0]}
|
| 355 |
+
if dataset_attrs is None:
|
| 356 |
+
dataset_attrs = cur_image_attrs
|
| 357 |
+
width = dataset_attrs['width']
|
| 358 |
+
height = dataset_attrs['height']
|
| 359 |
+
if width != height:
|
| 360 |
+
raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
|
| 361 |
+
if width != 2 ** int(np.floor(np.log2(width))):
|
| 362 |
+
raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
|
| 363 |
+
elif dataset_attrs != cur_image_attrs:
|
| 364 |
+
err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
|
| 365 |
+
raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
|
| 366 |
+
|
| 367 |
+
# Save the image as an uncompressed PNG.
|
| 368 |
+
img = PIL.Image.fromarray(img)
|
| 369 |
+
image_bits = io.BytesIO()
|
| 370 |
+
img.save(image_bits, format='png', compress_level=0, optimize=False)
|
| 371 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
|
| 372 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 373 |
+
|
| 374 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 375 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 376 |
+
close_dest()
|
| 377 |
+
|
| 378 |
+
#----------------------------------------------------------------------------
|
| 379 |
+
|
| 380 |
+
@cmdline.command()
|
| 381 |
+
@click.option('--model-url', help='VAE encoder model', metavar='URL', type=str, default='stabilityai/sd-vae-ft-mse', show_default=True)
|
| 382 |
+
@click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
|
| 383 |
+
@click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
|
| 384 |
+
@click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
|
| 385 |
+
|
| 386 |
+
def encode(
|
| 387 |
+
model_url: str,
|
| 388 |
+
source: str,
|
| 389 |
+
dest: str,
|
| 390 |
+
max_images: Optional[int],
|
| 391 |
+
):
|
| 392 |
+
"""Encode pixel data to VAE latents."""
|
| 393 |
+
PIL.Image.init()
|
| 394 |
+
if dest == '':
|
| 395 |
+
raise click.ClickException('--dest output filename or directory must not be an empty string')
|
| 396 |
+
|
| 397 |
+
vae = StabilityVAEEncoder(vae_name=model_url, batch_size=1)
|
| 398 |
+
print("VAE is over!!!")
|
| 399 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
| 400 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
| 401 |
+
print("Data is over!!!")
|
| 402 |
+
labels = []
|
| 403 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 404 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
| 405 |
+
img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
|
| 406 |
+
mean_std = vae.encode_pixels(img_tensor)[0].cpu()
|
| 407 |
+
idx_str = f'{idx:08d}'
|
| 408 |
+
archive_fname = f'{idx_str[:5]}/img-mean-std-{idx_str}.npy'
|
| 409 |
+
|
| 410 |
+
f = io.BytesIO()
|
| 411 |
+
np.save(f, mean_std)
|
| 412 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
|
| 413 |
+
labels.append([archive_fname, image.label] if image.label is not None else None)
|
| 414 |
+
|
| 415 |
+
metadata = {'labels': labels if all(x is not None for x in labels) else None}
|
| 416 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
| 417 |
+
close_dest()
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
cmdline()
|
| 421 |
+
|
| 422 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
from .util import EasyDict, make_cache_dir_path
|
REG/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
REG/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
REG/preprocessing/dnnlib/util.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Miscellaneous utility classes and functions."""
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import fnmatch
|
| 12 |
+
import importlib
|
| 13 |
+
import inspect
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import sys
|
| 18 |
+
import types
|
| 19 |
+
import io
|
| 20 |
+
import pickle
|
| 21 |
+
import re
|
| 22 |
+
import requests
|
| 23 |
+
import html
|
| 24 |
+
import hashlib
|
| 25 |
+
import glob
|
| 26 |
+
import tempfile
|
| 27 |
+
import urllib
|
| 28 |
+
import urllib.parse
|
| 29 |
+
import uuid
|
| 30 |
+
|
| 31 |
+
from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional
|
| 32 |
+
|
| 33 |
+
# Util classes
|
| 34 |
+
# ------------------------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EasyDict(dict):
|
| 38 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 39 |
+
|
| 40 |
+
def __getattr__(self, name: str) -> Any:
|
| 41 |
+
try:
|
| 42 |
+
return self[name]
|
| 43 |
+
except KeyError:
|
| 44 |
+
raise AttributeError(name)
|
| 45 |
+
|
| 46 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 47 |
+
self[name] = value
|
| 48 |
+
|
| 49 |
+
def __delattr__(self, name: str) -> None:
|
| 50 |
+
del self[name]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Logger(object):
|
| 54 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
|
| 57 |
+
self.file = None
|
| 58 |
+
|
| 59 |
+
if file_name is not None:
|
| 60 |
+
self.file = open(file_name, file_mode)
|
| 61 |
+
|
| 62 |
+
self.should_flush = should_flush
|
| 63 |
+
self.stdout = sys.stdout
|
| 64 |
+
self.stderr = sys.stderr
|
| 65 |
+
|
| 66 |
+
sys.stdout = self
|
| 67 |
+
sys.stderr = self
|
| 68 |
+
|
| 69 |
+
def __enter__(self) -> "Logger":
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 73 |
+
self.close()
|
| 74 |
+
|
| 75 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 76 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 77 |
+
if isinstance(text, bytes):
|
| 78 |
+
text = text.decode()
|
| 79 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
if self.file is not None:
|
| 83 |
+
self.file.write(text)
|
| 84 |
+
|
| 85 |
+
self.stdout.write(text)
|
| 86 |
+
|
| 87 |
+
if self.should_flush:
|
| 88 |
+
self.flush()
|
| 89 |
+
|
| 90 |
+
def flush(self) -> None:
|
| 91 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 92 |
+
if self.file is not None:
|
| 93 |
+
self.file.flush()
|
| 94 |
+
|
| 95 |
+
self.stdout.flush()
|
| 96 |
+
|
| 97 |
+
def close(self) -> None:
|
| 98 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 99 |
+
self.flush()
|
| 100 |
+
|
| 101 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 102 |
+
if sys.stdout is self:
|
| 103 |
+
sys.stdout = self.stdout
|
| 104 |
+
if sys.stderr is self:
|
| 105 |
+
sys.stderr = self.stderr
|
| 106 |
+
|
| 107 |
+
if self.file is not None:
|
| 108 |
+
self.file.close()
|
| 109 |
+
self.file = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Cache directories
|
| 113 |
+
# ------------------------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
_dnnlib_cache_dir = None
|
| 116 |
+
|
| 117 |
+
def set_cache_dir(path: str) -> None:
|
| 118 |
+
global _dnnlib_cache_dir
|
| 119 |
+
_dnnlib_cache_dir = path
|
| 120 |
+
|
| 121 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 122 |
+
if _dnnlib_cache_dir is not None:
|
| 123 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 124 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 125 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 126 |
+
if 'HOME' in os.environ:
|
| 127 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 128 |
+
if 'USERPROFILE' in os.environ:
|
| 129 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 130 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 131 |
+
|
| 132 |
+
# Small util functions
|
| 133 |
+
# ------------------------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 137 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 138 |
+
s = int(np.rint(seconds))
|
| 139 |
+
|
| 140 |
+
if s < 60:
|
| 141 |
+
return "{0}s".format(s)
|
| 142 |
+
elif s < 60 * 60:
|
| 143 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 144 |
+
elif s < 24 * 60 * 60:
|
| 145 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 146 |
+
else:
|
| 147 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 151 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 152 |
+
s = int(np.rint(seconds))
|
| 153 |
+
|
| 154 |
+
if s < 60:
|
| 155 |
+
return "{0}s".format(s)
|
| 156 |
+
elif s < 60 * 60:
|
| 157 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 158 |
+
elif s < 24 * 60 * 60:
|
| 159 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 160 |
+
else:
|
| 161 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def tuple_product(t: Tuple) -> Any:
|
| 165 |
+
"""Calculate the product of the tuple elements."""
|
| 166 |
+
result = 1
|
| 167 |
+
|
| 168 |
+
for v in t:
|
| 169 |
+
result *= v
|
| 170 |
+
|
| 171 |
+
return result
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
_str_to_ctype = {
|
| 175 |
+
"uint8": ctypes.c_ubyte,
|
| 176 |
+
"uint16": ctypes.c_uint16,
|
| 177 |
+
"uint32": ctypes.c_uint32,
|
| 178 |
+
"uint64": ctypes.c_uint64,
|
| 179 |
+
"int8": ctypes.c_byte,
|
| 180 |
+
"int16": ctypes.c_int16,
|
| 181 |
+
"int32": ctypes.c_int32,
|
| 182 |
+
"int64": ctypes.c_int64,
|
| 183 |
+
"float32": ctypes.c_float,
|
| 184 |
+
"float64": ctypes.c_double
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 189 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 190 |
+
type_str = None
|
| 191 |
+
|
| 192 |
+
if isinstance(type_obj, str):
|
| 193 |
+
type_str = type_obj
|
| 194 |
+
elif hasattr(type_obj, "__name__"):
|
| 195 |
+
type_str = type_obj.__name__
|
| 196 |
+
elif hasattr(type_obj, "name"):
|
| 197 |
+
type_str = type_obj.name
|
| 198 |
+
else:
|
| 199 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 200 |
+
|
| 201 |
+
assert type_str in _str_to_ctype.keys()
|
| 202 |
+
|
| 203 |
+
my_dtype = np.dtype(type_str)
|
| 204 |
+
my_ctype = _str_to_ctype[type_str]
|
| 205 |
+
|
| 206 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 207 |
+
|
| 208 |
+
return my_dtype, my_ctype
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def is_pickleable(obj: Any) -> bool:
|
| 212 |
+
try:
|
| 213 |
+
with io.BytesIO() as stream:
|
| 214 |
+
pickle.dump(obj, stream)
|
| 215 |
+
return True
|
| 216 |
+
except:
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 221 |
+
# ------------------------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 224 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 225 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 226 |
+
|
| 227 |
+
# allow convenience shorthands, substitute them by full names
|
| 228 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 229 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 230 |
+
|
| 231 |
+
# list alternatives for (module_name, local_obj_name)
|
| 232 |
+
parts = obj_name.split(".")
|
| 233 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 234 |
+
|
| 235 |
+
# try each alternative in turn
|
| 236 |
+
for module_name, local_obj_name in name_pairs:
|
| 237 |
+
try:
|
| 238 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 239 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 240 |
+
return module, local_obj_name
|
| 241 |
+
except:
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
+
# maybe some of the modules themselves contain errors?
|
| 245 |
+
for module_name, _local_obj_name in name_pairs:
|
| 246 |
+
try:
|
| 247 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 248 |
+
except ImportError:
|
| 249 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 250 |
+
raise
|
| 251 |
+
|
| 252 |
+
# maybe the requested attribute is missing?
|
| 253 |
+
for module_name, local_obj_name in name_pairs:
|
| 254 |
+
try:
|
| 255 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 256 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 257 |
+
except ImportError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
# we are out of luck, but we have no idea why
|
| 261 |
+
raise ImportError(obj_name)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 265 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 266 |
+
if obj_name == '':
|
| 267 |
+
return module
|
| 268 |
+
obj = module
|
| 269 |
+
for part in obj_name.split("."):
|
| 270 |
+
obj = getattr(obj, part)
|
| 271 |
+
return obj
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_obj_by_name(name: str) -> Any:
|
| 275 |
+
"""Finds the python object with the given name."""
|
| 276 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 277 |
+
return get_obj_from_module(module, obj_name)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any:
|
| 281 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 282 |
+
assert func_name is not None
|
| 283 |
+
func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name
|
| 284 |
+
assert callable(func_obj)
|
| 285 |
+
return func_obj(*args, **kwargs)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any:
|
| 289 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 290 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 294 |
+
"""Get the directory path of the module containing the given object name."""
|
| 295 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 296 |
+
return os.path.dirname(inspect.getfile(module))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 300 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 301 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 305 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 306 |
+
assert is_top_level_function(obj)
|
| 307 |
+
module = obj.__module__
|
| 308 |
+
if module == '__main__':
|
| 309 |
+
fname = sys.modules[module].__file__
|
| 310 |
+
assert fname is not None
|
| 311 |
+
module = os.path.splitext(os.path.basename(fname))[0]
|
| 312 |
+
return module + "." + obj.__name__
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# File system helpers
|
| 316 |
+
# ------------------------------------------------------------------------------------------
|
| 317 |
+
|
| 318 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 319 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 320 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 321 |
+
assert os.path.isdir(dir_path)
|
| 322 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 323 |
+
|
| 324 |
+
if ignores is None:
|
| 325 |
+
ignores = []
|
| 326 |
+
|
| 327 |
+
result = []
|
| 328 |
+
|
| 329 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 330 |
+
for ignore_ in ignores:
|
| 331 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 332 |
+
|
| 333 |
+
# dirs need to be edited in-place
|
| 334 |
+
for d in dirs_to_remove:
|
| 335 |
+
dirs.remove(d)
|
| 336 |
+
|
| 337 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 338 |
+
|
| 339 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 340 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 341 |
+
|
| 342 |
+
if add_base_to_relative:
|
| 343 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 344 |
+
|
| 345 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 346 |
+
result += zip(absolute_paths, relative_paths)
|
| 347 |
+
|
| 348 |
+
return result
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 352 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 353 |
+
Will create all necessary directories."""
|
| 354 |
+
for file in files:
|
| 355 |
+
target_dir_name = os.path.dirname(file[1])
|
| 356 |
+
|
| 357 |
+
# will create all intermediate-level directories
|
| 358 |
+
os.makedirs(target_dir_name, exist_ok=True)
|
| 359 |
+
shutil.copyfile(file[0], file[1])
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# URL helpers
|
| 363 |
+
# ------------------------------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 366 |
+
"""Determine whether the given object is a valid URL string."""
|
| 367 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 368 |
+
return False
|
| 369 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 370 |
+
return True
|
| 371 |
+
try:
|
| 372 |
+
res = urllib.parse.urlparse(obj)
|
| 373 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 374 |
+
return False
|
| 375 |
+
res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/"))
|
| 376 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 377 |
+
return False
|
| 378 |
+
except:
|
| 379 |
+
return False
|
| 380 |
+
return True
|
| 381 |
+
|
| 382 |
+
# Note on static typing: a better API would be to split 'open_url' to 'openl_url' and
|
| 383 |
+
# 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True`
|
| 384 |
+
# case is somewhat uncommon, we just pretend like this function never returns a string
|
| 385 |
+
# and type ignore return value for those cases.
|
| 386 |
+
def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO:
|
| 387 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 388 |
+
assert num_attempts >= 1
|
| 389 |
+
assert not (return_filename and (not cache))
|
| 390 |
+
|
| 391 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 392 |
+
if not re.match('^[a-z]+://', url):
|
| 393 |
+
return url if return_filename else open(url, "rb") # type: ignore
|
| 394 |
+
|
| 395 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 396 |
+
# arise on Windows:
|
| 397 |
+
#
|
| 398 |
+
# file:///c:/foo.txt
|
| 399 |
+
#
|
| 400 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 401 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 402 |
+
#
|
| 403 |
+
# If you touch this code path, you should test it on both Linux and
|
| 404 |
+
# Windows.
|
| 405 |
+
#
|
| 406 |
+
# Some internet resources suggest using urllib.request.url2pathname()
|
| 407 |
+
# but that converts forward slashes to backslashes and this causes
|
| 408 |
+
# its own set of problems.
|
| 409 |
+
if url.startswith('file://'):
|
| 410 |
+
filename = urllib.parse.urlparse(url).path
|
| 411 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 412 |
+
filename = filename[1:]
|
| 413 |
+
return filename if return_filename else open(filename, "rb") # type: ignore
|
| 414 |
+
|
| 415 |
+
assert is_url(url)
|
| 416 |
+
|
| 417 |
+
# Lookup from cache.
|
| 418 |
+
if cache_dir is None:
|
| 419 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 420 |
+
|
| 421 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 422 |
+
if cache:
|
| 423 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 424 |
+
if len(cache_files) == 1:
|
| 425 |
+
filename = cache_files[0]
|
| 426 |
+
return filename if return_filename else open(filename, "rb") # type: ignore
|
| 427 |
+
|
| 428 |
+
# Download.
|
| 429 |
+
url_name = None
|
| 430 |
+
url_data = None
|
| 431 |
+
with requests.Session() as session:
|
| 432 |
+
if verbose:
|
| 433 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 434 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 435 |
+
try:
|
| 436 |
+
with session.get(url) as res:
|
| 437 |
+
res.raise_for_status()
|
| 438 |
+
if len(res.content) == 0:
|
| 439 |
+
raise IOError("No data received")
|
| 440 |
+
|
| 441 |
+
if len(res.content) < 8192:
|
| 442 |
+
content_str = res.content.decode("utf-8")
|
| 443 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 444 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 445 |
+
if len(links) == 1:
|
| 446 |
+
url = urllib.parse.urljoin(url, links[0])
|
| 447 |
+
raise IOError("Google Drive virus checker nag")
|
| 448 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 449 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 450 |
+
|
| 451 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 452 |
+
url_name = match[1] if match else url
|
| 453 |
+
url_data = res.content
|
| 454 |
+
if verbose:
|
| 455 |
+
print(" done")
|
| 456 |
+
break
|
| 457 |
+
except KeyboardInterrupt:
|
| 458 |
+
raise
|
| 459 |
+
except:
|
| 460 |
+
if not attempts_left:
|
| 461 |
+
if verbose:
|
| 462 |
+
print(" failed")
|
| 463 |
+
raise
|
| 464 |
+
if verbose:
|
| 465 |
+
print(".", end="", flush=True)
|
| 466 |
+
|
| 467 |
+
assert url_data is not None
|
| 468 |
+
|
| 469 |
+
# Save to cache.
|
| 470 |
+
if cache:
|
| 471 |
+
assert url_name is not None
|
| 472 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 473 |
+
safe_name = safe_name[:min(len(safe_name), 128)]
|
| 474 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 475 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 476 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 477 |
+
with open(temp_file, "wb") as f:
|
| 478 |
+
f.write(url_data)
|
| 479 |
+
os.replace(temp_file, cache_file) # atomic
|
| 480 |
+
if return_filename:
|
| 481 |
+
return cache_file # type: ignore
|
| 482 |
+
|
| 483 |
+
# Return data as file object.
|
| 484 |
+
assert not return_filename
|
| 485 |
+
return io.BytesIO(url_data)
|
REG/preprocessing/encoders.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Converting between pixel and latent representations of image data."""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import warnings
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch_utils import persistence
|
| 15 |
+
from torch_utils import misc
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
|
| 18 |
+
warnings.filterwarnings('ignore', '`resume_download` is deprecated')
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
# Abstract base class for encoders/decoders that convert back and forth
|
| 22 |
+
# between pixel and latent representations of image data.
|
| 23 |
+
#
|
| 24 |
+
# Logically, "raw pixels" are first encoded into "raw latents" that are
|
| 25 |
+
# then further encoded into "final latents". Decoding, on the other hand,
|
| 26 |
+
# goes directly from the final latents to raw pixels. The final latents are
|
| 27 |
+
# used as inputs and outputs of the model, whereas the raw latents are
|
| 28 |
+
# stored in the dataset. This separation provides added flexibility in terms
|
| 29 |
+
# of performing just-in-time adjustments, such as data whitening, without
|
| 30 |
+
# having to construct a new dataset.
|
| 31 |
+
#
|
| 32 |
+
# All image data is represented as PyTorch tensors in NCHW order.
|
| 33 |
+
# Raw pixels are represented as 3-channel uint8.
|
| 34 |
+
|
| 35 |
+
@persistence.persistent_class
|
| 36 |
+
class Encoder:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def init(self, device): # force lazy init to happen now
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
def __getstate__(self):
|
| 44 |
+
return self.__dict__
|
| 45 |
+
|
| 46 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
| 47 |
+
raise NotImplementedError # to be overridden by subclass
|
| 48 |
+
#----------------------------------------------------------------------------
|
| 49 |
+
# Pre-trained VAE encoder from Stability AI.
|
| 50 |
+
|
| 51 |
+
@persistence.persistent_class
|
| 52 |
+
class StabilityVAEEncoder(Encoder):
|
| 53 |
+
def __init__(self,
|
| 54 |
+
vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
|
| 55 |
+
batch_size = 8, # Batch size to use when running the VAE.
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.vae_name = vae_name
|
| 59 |
+
self.batch_size = int(batch_size)
|
| 60 |
+
self._vae = None
|
| 61 |
+
|
| 62 |
+
def init(self, device): # force lazy init to happen now
|
| 63 |
+
super().init(device)
|
| 64 |
+
if self._vae is None:
|
| 65 |
+
self._vae = load_stability_vae(self.vae_name, device=device)
|
| 66 |
+
else:
|
| 67 |
+
self._vae.to(device)
|
| 68 |
+
|
| 69 |
+
def __getstate__(self):
|
| 70 |
+
return dict(super().__getstate__(), _vae=None) # do not pickle the vae
|
| 71 |
+
|
| 72 |
+
def _run_vae_encoder(self, x):
|
| 73 |
+
d = self._vae.encode(x)['latent_dist']
|
| 74 |
+
return torch.cat([d.mean, d.std], dim=1)
|
| 75 |
+
|
| 76 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
| 77 |
+
self.init(x.device)
|
| 78 |
+
x = x.to(torch.float32) / 127.5 - 1
|
| 79 |
+
x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
#----------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
|
| 85 |
+
import dnnlib
|
| 86 |
+
cache_dir = dnnlib.make_cache_dir_path('diffusers')
|
| 87 |
+
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
|
| 88 |
+
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
|
| 89 |
+
os.environ['HF_HOME'] = cache_dir
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
|
| 93 |
+
try:
|
| 94 |
+
# First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
|
| 95 |
+
vae = diffusers.models.AutoencoderKL.from_pretrained(
|
| 96 |
+
vae_name, cache_dir=cache_dir, local_files_only=True
|
| 97 |
+
)
|
| 98 |
+
except:
|
| 99 |
+
# Could not load the model from cache; try without local_files_only.
|
| 100 |
+
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir)
|
| 101 |
+
return vae.eval().requires_grad_(False).to(device)
|
| 102 |
+
|
| 103 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/torch_utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
# empty
|
REG/preprocessing/torch_utils/distributed.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import socket
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed
|
| 13 |
+
from . import training_stats
|
| 14 |
+
|
| 15 |
+
_sync_device = None
|
| 16 |
+
|
| 17 |
+
#----------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
def init():
|
| 20 |
+
global _sync_device
|
| 21 |
+
|
| 22 |
+
if not torch.distributed.is_initialized():
|
| 23 |
+
# Setup some reasonable defaults for env-based distributed init if
|
| 24 |
+
# not set by the running environment.
|
| 25 |
+
if 'MASTER_ADDR' not in os.environ:
|
| 26 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 27 |
+
if 'MASTER_PORT' not in os.environ:
|
| 28 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 29 |
+
s.bind(('', 0))
|
| 30 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 31 |
+
os.environ['MASTER_PORT'] = str(s.getsockname()[1])
|
| 32 |
+
s.close()
|
| 33 |
+
if 'RANK' not in os.environ:
|
| 34 |
+
os.environ['RANK'] = '0'
|
| 35 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 36 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 37 |
+
if 'WORLD_SIZE' not in os.environ:
|
| 38 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 39 |
+
backend = 'gloo' if os.name == 'nt' else 'nccl'
|
| 40 |
+
torch.distributed.init_process_group(backend=backend, init_method='env://')
|
| 41 |
+
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
|
| 42 |
+
|
| 43 |
+
_sync_device = torch.device('cuda') if get_world_size() > 1 else None
|
| 44 |
+
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
|
| 45 |
+
|
| 46 |
+
#----------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def get_rank():
|
| 49 |
+
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def get_world_size():
|
| 54 |
+
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
| 55 |
+
|
| 56 |
+
#----------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def should_stop():
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
#----------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
def should_suspend():
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
#----------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def request_suspend():
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
#----------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def update_progress(cur, total):
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
#----------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
def print0(*args, **kwargs):
|
| 79 |
+
if get_rank() == 0:
|
| 80 |
+
print(*args, **kwargs)
|
| 81 |
+
|
| 82 |
+
#----------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
class CheckpointIO:
|
| 85 |
+
def __init__(self, **kwargs):
|
| 86 |
+
self._state_objs = kwargs
|
| 87 |
+
|
| 88 |
+
def save(self, pt_path, verbose=True):
|
| 89 |
+
if verbose:
|
| 90 |
+
print0(f'Saving {pt_path} ... ', end='', flush=True)
|
| 91 |
+
data = dict()
|
| 92 |
+
for name, obj in self._state_objs.items():
|
| 93 |
+
if obj is None:
|
| 94 |
+
data[name] = None
|
| 95 |
+
elif isinstance(obj, dict):
|
| 96 |
+
data[name] = obj
|
| 97 |
+
elif hasattr(obj, 'state_dict'):
|
| 98 |
+
data[name] = obj.state_dict()
|
| 99 |
+
elif hasattr(obj, '__getstate__'):
|
| 100 |
+
data[name] = obj.__getstate__()
|
| 101 |
+
elif hasattr(obj, '__dict__'):
|
| 102 |
+
data[name] = obj.__dict__
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
|
| 105 |
+
if get_rank() == 0:
|
| 106 |
+
torch.save(data, pt_path)
|
| 107 |
+
if verbose:
|
| 108 |
+
print0('done')
|
| 109 |
+
|
| 110 |
+
def load(self, pt_path, verbose=True):
|
| 111 |
+
if verbose:
|
| 112 |
+
print0(f'Loading {pt_path} ... ', end='', flush=True)
|
| 113 |
+
data = torch.load(pt_path, map_location=torch.device('cpu'))
|
| 114 |
+
for name, obj in self._state_objs.items():
|
| 115 |
+
if obj is None:
|
| 116 |
+
pass
|
| 117 |
+
elif isinstance(obj, dict):
|
| 118 |
+
obj.clear()
|
| 119 |
+
obj.update(data[name])
|
| 120 |
+
elif hasattr(obj, 'load_state_dict'):
|
| 121 |
+
obj.load_state_dict(data[name])
|
| 122 |
+
elif hasattr(obj, '__setstate__'):
|
| 123 |
+
obj.__setstate__(data[name])
|
| 124 |
+
elif hasattr(obj, '__dict__'):
|
| 125 |
+
obj.__dict__.clear()
|
| 126 |
+
obj.__dict__.update(data[name])
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
|
| 129 |
+
if verbose:
|
| 130 |
+
print0('done')
|
| 131 |
+
|
| 132 |
+
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
|
| 133 |
+
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
|
| 134 |
+
if len(fnames) == 0:
|
| 135 |
+
return None
|
| 136 |
+
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
|
| 137 |
+
self.load(pt_path, verbose=verbose)
|
| 138 |
+
return pt_path
|
| 139 |
+
|
| 140 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/torch_utils/misc.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import contextlib
|
| 10 |
+
import functools
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import warnings
|
| 14 |
+
import dnnlib
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
# Re-seed torch & numpy random generators based on the given arguments.
|
| 18 |
+
|
| 19 |
+
def set_random_seed(*args):
|
| 20 |
+
seed = hash(args) % (1 << 31)
|
| 21 |
+
torch.manual_seed(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
|
| 24 |
+
#----------------------------------------------------------------------------
|
| 25 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
| 26 |
+
# same constant is used multiple times.
|
| 27 |
+
|
| 28 |
+
_constant_cache = dict()
|
| 29 |
+
|
| 30 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
| 31 |
+
value = np.asarray(value)
|
| 32 |
+
if shape is not None:
|
| 33 |
+
shape = tuple(shape)
|
| 34 |
+
if dtype is None:
|
| 35 |
+
dtype = torch.get_default_dtype()
|
| 36 |
+
if device is None:
|
| 37 |
+
device = torch.device('cpu')
|
| 38 |
+
if memory_format is None:
|
| 39 |
+
memory_format = torch.contiguous_format
|
| 40 |
+
|
| 41 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
| 42 |
+
tensor = _constant_cache.get(key, None)
|
| 43 |
+
if tensor is None:
|
| 44 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
| 45 |
+
if shape is not None:
|
| 46 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
| 47 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
| 48 |
+
_constant_cache[key] = tensor
|
| 49 |
+
return tensor
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
# Variant of constant() that inherits dtype and device from the given
|
| 53 |
+
# reference tensor by default.
|
| 54 |
+
|
| 55 |
+
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
| 56 |
+
if dtype is None:
|
| 57 |
+
dtype = ref.dtype
|
| 58 |
+
if device is None:
|
| 59 |
+
device = ref.device
|
| 60 |
+
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
| 61 |
+
|
| 62 |
+
#----------------------------------------------------------------------------
|
| 63 |
+
# Cached construction of temporary tensors in pinned CPU memory.
|
| 64 |
+
|
| 65 |
+
@functools.lru_cache(None)
|
| 66 |
+
def pinned_buf(shape, dtype):
|
| 67 |
+
return torch.empty(shape, dtype=dtype).pin_memory()
|
| 68 |
+
|
| 69 |
+
#----------------------------------------------------------------------------
|
| 70 |
+
# Symbolic assert.
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
| 74 |
+
except AttributeError:
|
| 75 |
+
symbolic_assert = torch.Assert # 1.7.0
|
| 76 |
+
|
| 77 |
+
#----------------------------------------------------------------------------
|
| 78 |
+
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
| 79 |
+
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
| 80 |
+
|
| 81 |
+
@contextlib.contextmanager
|
| 82 |
+
def suppress_tracer_warnings():
|
| 83 |
+
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
| 84 |
+
warnings.filters.insert(0, flt)
|
| 85 |
+
yield
|
| 86 |
+
warnings.filters.remove(flt)
|
| 87 |
+
|
| 88 |
+
#----------------------------------------------------------------------------
|
| 89 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
| 90 |
+
# None indicates that the size of a dimension is allowed to vary.
|
| 91 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
| 92 |
+
|
| 93 |
+
def assert_shape(tensor, ref_shape):
|
| 94 |
+
if tensor.ndim != len(ref_shape):
|
| 95 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
| 96 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
| 97 |
+
if ref_size is None:
|
| 98 |
+
pass
|
| 99 |
+
elif isinstance(ref_size, torch.Tensor):
|
| 100 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 101 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
| 102 |
+
elif isinstance(size, torch.Tensor):
|
| 103 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
| 104 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
| 105 |
+
elif size != ref_size:
|
| 106 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
| 107 |
+
|
| 108 |
+
#----------------------------------------------------------------------------
|
| 109 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
| 110 |
+
|
| 111 |
+
def profiled_function(fn):
|
| 112 |
+
def decorator(*args, **kwargs):
|
| 113 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
| 114 |
+
return fn(*args, **kwargs)
|
| 115 |
+
decorator.__name__ = fn.__name__
|
| 116 |
+
return decorator
|
| 117 |
+
|
| 118 |
+
#----------------------------------------------------------------------------
|
| 119 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
| 120 |
+
# indefinitely, shuffling items as it goes.
|
| 121 |
+
|
| 122 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
| 123 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0):
|
| 124 |
+
assert len(dataset) > 0
|
| 125 |
+
assert num_replicas > 0
|
| 126 |
+
assert 0 <= rank < num_replicas
|
| 127 |
+
warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed')
|
| 128 |
+
super().__init__(dataset)
|
| 129 |
+
self.dataset_size = len(dataset)
|
| 130 |
+
self.start_idx = start_idx + rank
|
| 131 |
+
self.stride = num_replicas
|
| 132 |
+
self.shuffle = shuffle
|
| 133 |
+
self.seed = seed
|
| 134 |
+
|
| 135 |
+
def __iter__(self):
|
| 136 |
+
idx = self.start_idx
|
| 137 |
+
epoch = None
|
| 138 |
+
while True:
|
| 139 |
+
if epoch != idx // self.dataset_size:
|
| 140 |
+
epoch = idx // self.dataset_size
|
| 141 |
+
order = np.arange(self.dataset_size)
|
| 142 |
+
if self.shuffle:
|
| 143 |
+
np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order)
|
| 144 |
+
yield int(order[idx % self.dataset_size])
|
| 145 |
+
idx += self.stride
|
| 146 |
+
|
| 147 |
+
#----------------------------------------------------------------------------
|
| 148 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
| 149 |
+
|
| 150 |
+
def params_and_buffers(module):
|
| 151 |
+
assert isinstance(module, torch.nn.Module)
|
| 152 |
+
return list(module.parameters()) + list(module.buffers())
|
| 153 |
+
|
| 154 |
+
def named_params_and_buffers(module):
|
| 155 |
+
assert isinstance(module, torch.nn.Module)
|
| 156 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
| 160 |
+
assert isinstance(src_module, torch.nn.Module)
|
| 161 |
+
assert isinstance(dst_module, torch.nn.Module)
|
| 162 |
+
src_tensors = dict(named_params_and_buffers(src_module))
|
| 163 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
| 164 |
+
assert (name in src_tensors) or (not require_all)
|
| 165 |
+
if name in src_tensors:
|
| 166 |
+
tensor.copy_(src_tensors[name])
|
| 167 |
+
|
| 168 |
+
#----------------------------------------------------------------------------
|
| 169 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
| 170 |
+
# synchronization.
|
| 171 |
+
|
| 172 |
+
@contextlib.contextmanager
|
| 173 |
+
def ddp_sync(module, sync):
|
| 174 |
+
assert isinstance(module, torch.nn.Module)
|
| 175 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
| 176 |
+
yield
|
| 177 |
+
else:
|
| 178 |
+
with module.no_sync():
|
| 179 |
+
yield
|
| 180 |
+
|
| 181 |
+
#----------------------------------------------------------------------------
|
| 182 |
+
# Check DistributedDataParallel consistency across processes.
|
| 183 |
+
|
| 184 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
| 185 |
+
assert isinstance(module, torch.nn.Module)
|
| 186 |
+
for name, tensor in named_params_and_buffers(module):
|
| 187 |
+
fullname = type(module).__name__ + '.' + name
|
| 188 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
| 189 |
+
continue
|
| 190 |
+
tensor = tensor.detach()
|
| 191 |
+
if tensor.is_floating_point():
|
| 192 |
+
tensor = torch.nan_to_num(tensor)
|
| 193 |
+
other = tensor.clone()
|
| 194 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 195 |
+
assert (tensor == other).all(), fullname
|
| 196 |
+
|
| 197 |
+
#----------------------------------------------------------------------------
|
| 198 |
+
# Print summary table of module hierarchy.
|
| 199 |
+
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
| 202 |
+
assert isinstance(module, torch.nn.Module)
|
| 203 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
| 204 |
+
assert isinstance(inputs, (tuple, list))
|
| 205 |
+
|
| 206 |
+
# Register hooks.
|
| 207 |
+
entries = []
|
| 208 |
+
nesting = [0]
|
| 209 |
+
def pre_hook(_mod, _inputs):
|
| 210 |
+
nesting[0] += 1
|
| 211 |
+
def post_hook(mod, _inputs, outputs):
|
| 212 |
+
nesting[0] -= 1
|
| 213 |
+
if nesting[0] <= max_nesting:
|
| 214 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
| 215 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
| 216 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
| 217 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
| 218 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
| 219 |
+
|
| 220 |
+
# Run module.
|
| 221 |
+
outputs = module(*inputs)
|
| 222 |
+
for hook in hooks:
|
| 223 |
+
hook.remove()
|
| 224 |
+
|
| 225 |
+
# Identify unique outputs, parameters, and buffers.
|
| 226 |
+
tensors_seen = set()
|
| 227 |
+
for e in entries:
|
| 228 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
| 229 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
| 230 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
| 231 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
| 232 |
+
|
| 233 |
+
# Filter out redundant entries.
|
| 234 |
+
if skip_redundant:
|
| 235 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
| 236 |
+
|
| 237 |
+
# Construct table.
|
| 238 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
| 239 |
+
rows += [['---'] * len(rows[0])]
|
| 240 |
+
param_total = 0
|
| 241 |
+
buffer_total = 0
|
| 242 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
| 243 |
+
for e in entries:
|
| 244 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
| 245 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
| 246 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
| 247 |
+
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
| 248 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
| 249 |
+
rows += [[
|
| 250 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
| 251 |
+
str(param_size) if param_size else '-',
|
| 252 |
+
str(buffer_size) if buffer_size else '-',
|
| 253 |
+
(output_shapes + ['-'])[0],
|
| 254 |
+
(output_dtypes + ['-'])[0],
|
| 255 |
+
]]
|
| 256 |
+
for idx in range(1, len(e.outputs)):
|
| 257 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
| 258 |
+
param_total += param_size
|
| 259 |
+
buffer_total += buffer_size
|
| 260 |
+
rows += [['---'] * len(rows[0])]
|
| 261 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
| 262 |
+
|
| 263 |
+
# Print table.
|
| 264 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
| 265 |
+
print()
|
| 266 |
+
for row in rows:
|
| 267 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
| 268 |
+
print()
|
| 269 |
+
|
| 270 |
+
#----------------------------------------------------------------------------
|
| 271 |
+
# Tile a batch of images into a 2D grid.
|
| 272 |
+
|
| 273 |
+
def tile_images(x, w, h):
|
| 274 |
+
assert x.ndim == 4 # NCHW => CHW
|
| 275 |
+
return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3])
|
| 276 |
+
|
| 277 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/torch_utils/persistence.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for pickling Python code alongside other data.
|
| 9 |
+
|
| 10 |
+
The pickled code is automatically imported into a separate Python module
|
| 11 |
+
during unpickling. This way, any previously exported pickles will remain
|
| 12 |
+
usable even if the original code is no longer available, or if the current
|
| 13 |
+
version of the code is not consistent with what was originally pickled."""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
import pickle
|
| 17 |
+
import io
|
| 18 |
+
import inspect
|
| 19 |
+
import copy
|
| 20 |
+
import uuid
|
| 21 |
+
import types
|
| 22 |
+
import functools
|
| 23 |
+
import dnnlib
|
| 24 |
+
|
| 25 |
+
#----------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
_version = 6 # internal version number
|
| 28 |
+
_decorators = set() # {decorator_class, ...}
|
| 29 |
+
_import_hooks = [] # [hook_function, ...]
|
| 30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
| 31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
| 32 |
+
|
| 33 |
+
#----------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def persistent_class(orig_class):
|
| 36 |
+
r"""Class decorator that extends a given class to save its source code
|
| 37 |
+
when pickled.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
|
| 41 |
+
from torch_utils import persistence
|
| 42 |
+
|
| 43 |
+
@persistence.persistent_class
|
| 44 |
+
class MyNetwork(torch.nn.Module):
|
| 45 |
+
def __init__(self, num_inputs, num_outputs):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
@persistence.persistent_class
|
| 51 |
+
class MyLayer(torch.nn.Module):
|
| 52 |
+
...
|
| 53 |
+
|
| 54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
| 55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
| 56 |
+
and submodules). This way, any previously exported pickle will remain
|
| 57 |
+
usable even if the class definitions have been modified or are no
|
| 58 |
+
longer available.
|
| 59 |
+
|
| 60 |
+
The decorator saves the source code of the entire Python module
|
| 61 |
+
containing the decorated class. It does *not* save the source code of
|
| 62 |
+
any imported modules. Thus, the imported modules must be available
|
| 63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
| 64 |
+
|
| 65 |
+
It is ok to call functions defined in the same module from the
|
| 66 |
+
decorated class. However, if the decorated class depends on other
|
| 67 |
+
classes defined in the same module, they must be decorated as well.
|
| 68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
| 69 |
+
|
| 70 |
+
It is also possible to employ the decorator just-in-time before
|
| 71 |
+
calling the constructor. For example:
|
| 72 |
+
|
| 73 |
+
cls = MyLayer
|
| 74 |
+
if want_to_make_it_persistent:
|
| 75 |
+
cls = persistence.persistent_class(cls)
|
| 76 |
+
layer = cls(num_inputs, num_outputs)
|
| 77 |
+
|
| 78 |
+
As an additional feature, the decorator also keeps track of the
|
| 79 |
+
arguments that were used to construct each instance of the decorated
|
| 80 |
+
class. The arguments can be queried via `obj.init_args` and
|
| 81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
| 82 |
+
object state. This feature can be disabled on a per-instance basis
|
| 83 |
+
by setting `self._record_init_args = False` in the constructor.
|
| 84 |
+
|
| 85 |
+
A typical use case is to first unpickle a previous instance of a
|
| 86 |
+
persistent class, and then upgrade it to use the latest version of
|
| 87 |
+
the source code:
|
| 88 |
+
|
| 89 |
+
with open('old_pickle.pkl', 'rb') as f:
|
| 90 |
+
old_net = pickle.load(f)
|
| 91 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
| 92 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
| 93 |
+
"""
|
| 94 |
+
assert isinstance(orig_class, type)
|
| 95 |
+
if is_persistent(orig_class):
|
| 96 |
+
return orig_class
|
| 97 |
+
|
| 98 |
+
assert orig_class.__module__ in sys.modules
|
| 99 |
+
orig_module = sys.modules[orig_class.__module__]
|
| 100 |
+
orig_module_src = _module_to_src(orig_module)
|
| 101 |
+
|
| 102 |
+
@functools.wraps(orig_class, updated=())
|
| 103 |
+
class Decorator(orig_class):
|
| 104 |
+
_orig_module_src = orig_module_src
|
| 105 |
+
_orig_class_name = orig_class.__name__
|
| 106 |
+
|
| 107 |
+
def __init__(self, *args, **kwargs):
|
| 108 |
+
super().__init__(*args, **kwargs)
|
| 109 |
+
record_init_args = getattr(self, '_record_init_args', True)
|
| 110 |
+
self._init_args = copy.deepcopy(args) if record_init_args else None
|
| 111 |
+
self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
|
| 112 |
+
assert orig_class.__name__ in orig_module.__dict__
|
| 113 |
+
_check_pickleable(self.__reduce__())
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def init_args(self):
|
| 117 |
+
assert self._init_args is not None
|
| 118 |
+
return copy.deepcopy(self._init_args)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def init_kwargs(self):
|
| 122 |
+
assert self._init_kwargs is not None
|
| 123 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
| 124 |
+
|
| 125 |
+
def __reduce__(self):
|
| 126 |
+
fields = list(super().__reduce__())
|
| 127 |
+
fields += [None] * max(3 - len(fields), 0)
|
| 128 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
| 129 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
| 130 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
| 131 |
+
fields[1] = (meta,) # reconstruct args
|
| 132 |
+
fields[2] = None # state dict
|
| 133 |
+
return tuple(fields)
|
| 134 |
+
|
| 135 |
+
_decorators.add(Decorator)
|
| 136 |
+
return Decorator
|
| 137 |
+
|
| 138 |
+
#----------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def is_persistent(obj):
|
| 141 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
| 142 |
+
whether it will save its source code when pickled.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
if obj in _decorators:
|
| 146 |
+
return True
|
| 147 |
+
except TypeError:
|
| 148 |
+
pass
|
| 149 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
| 150 |
+
|
| 151 |
+
#----------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
def import_hook(hook):
|
| 154 |
+
r"""Register an import hook that is called whenever a persistent object
|
| 155 |
+
is being unpickled. A typical use case is to patch the pickled source
|
| 156 |
+
code to avoid errors and inconsistencies when the API of some imported
|
| 157 |
+
module has changed.
|
| 158 |
+
|
| 159 |
+
The hook should have the following signature:
|
| 160 |
+
|
| 161 |
+
hook(meta) -> modified meta
|
| 162 |
+
|
| 163 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
| 164 |
+
|
| 165 |
+
type: Type of the persistent object, e.g. `'class'`.
|
| 166 |
+
version: Internal version number of `torch_utils.persistence`.
|
| 167 |
+
module_src Original source code of the Python module.
|
| 168 |
+
class_name: Class name in the original Python module.
|
| 169 |
+
state: Internal state of the object.
|
| 170 |
+
|
| 171 |
+
Example:
|
| 172 |
+
|
| 173 |
+
@persistence.import_hook
|
| 174 |
+
def wreck_my_network(meta):
|
| 175 |
+
if meta.class_name == 'MyNetwork':
|
| 176 |
+
print('MyNetwork is being imported. I will wreck it!')
|
| 177 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
| 178 |
+
return meta
|
| 179 |
+
"""
|
| 180 |
+
assert callable(hook)
|
| 181 |
+
_import_hooks.append(hook)
|
| 182 |
+
|
| 183 |
+
#----------------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
def _reconstruct_persistent_obj(meta):
|
| 186 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
| 187 |
+
a persistent object.
|
| 188 |
+
"""
|
| 189 |
+
meta = dnnlib.EasyDict(meta)
|
| 190 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
| 191 |
+
for hook in _import_hooks:
|
| 192 |
+
meta = hook(meta)
|
| 193 |
+
assert meta is not None
|
| 194 |
+
|
| 195 |
+
assert meta.version == _version
|
| 196 |
+
module = _src_to_module(meta.module_src)
|
| 197 |
+
|
| 198 |
+
assert meta.type == 'class'
|
| 199 |
+
orig_class = module.__dict__[meta.class_name]
|
| 200 |
+
decorator_class = persistent_class(orig_class)
|
| 201 |
+
obj = decorator_class.__new__(decorator_class)
|
| 202 |
+
|
| 203 |
+
setstate = getattr(obj, '__setstate__', None)
|
| 204 |
+
if callable(setstate):
|
| 205 |
+
setstate(meta.state) # pylint: disable=not-callable
|
| 206 |
+
else:
|
| 207 |
+
obj.__dict__.update(meta.state)
|
| 208 |
+
return obj
|
| 209 |
+
|
| 210 |
+
#----------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def _module_to_src(module):
|
| 213 |
+
r"""Query the source code of a given Python module.
|
| 214 |
+
"""
|
| 215 |
+
src = _module_to_src_dict.get(module, None)
|
| 216 |
+
if src is None:
|
| 217 |
+
src = inspect.getsource(module)
|
| 218 |
+
_module_to_src_dict[module] = src
|
| 219 |
+
_src_to_module_dict[src] = module
|
| 220 |
+
return src
|
| 221 |
+
|
| 222 |
+
def _src_to_module(src):
|
| 223 |
+
r"""Get or create a Python module for the given source code.
|
| 224 |
+
"""
|
| 225 |
+
module = _src_to_module_dict.get(src, None)
|
| 226 |
+
if module is None:
|
| 227 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
| 228 |
+
module = types.ModuleType(module_name)
|
| 229 |
+
sys.modules[module_name] = module
|
| 230 |
+
_module_to_src_dict[module] = src
|
| 231 |
+
_src_to_module_dict[src] = module
|
| 232 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
| 233 |
+
return module
|
| 234 |
+
|
| 235 |
+
#----------------------------------------------------------------------------
|
| 236 |
+
|
| 237 |
+
def _check_pickleable(obj):
|
| 238 |
+
r"""Check that the given object is pickleable, raising an exception if
|
| 239 |
+
it is not. This function is expected to be considerably more efficient
|
| 240 |
+
than actually pickling the object.
|
| 241 |
+
"""
|
| 242 |
+
def recurse(obj):
|
| 243 |
+
if isinstance(obj, (list, tuple, set)):
|
| 244 |
+
return [recurse(x) for x in obj]
|
| 245 |
+
if isinstance(obj, dict):
|
| 246 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
| 247 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
| 248 |
+
return None # Python primitive types are pickleable.
|
| 249 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
| 250 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
| 251 |
+
if is_persistent(obj):
|
| 252 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
| 253 |
+
return obj
|
| 254 |
+
with io.BytesIO() as f:
|
| 255 |
+
pickle.dump(recurse(obj), f)
|
| 256 |
+
|
| 257 |
+
#----------------------------------------------------------------------------
|
REG/preprocessing/torch_utils/training_stats.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Facilities for reporting and collecting training statistics across
|
| 9 |
+
multiple processes and devices. The interface is designed to minimize
|
| 10 |
+
synchronization overhead as well as the amount of boilerplate in user
|
| 11 |
+
code."""
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import dnnlib
|
| 17 |
+
|
| 18 |
+
from . import misc
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
| 23 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
| 24 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
| 25 |
+
_rank = 0 # Rank of the current process.
|
| 26 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
| 27 |
+
_sync_called = False # Has _sync() been called yet?
|
| 28 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
| 29 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
| 30 |
+
|
| 31 |
+
#----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def init_multiprocessing(rank, sync_device):
|
| 34 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
| 35 |
+
across multiple processes.
|
| 36 |
+
|
| 37 |
+
This function must be called after
|
| 38 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
| 39 |
+
The call is not necessary if multi-process collection is not needed.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
rank: Rank of the current process.
|
| 43 |
+
sync_device: PyTorch device to use for inter-process
|
| 44 |
+
communication, or None to disable multi-process
|
| 45 |
+
collection. Typically `torch.device('cuda', rank)`.
|
| 46 |
+
"""
|
| 47 |
+
global _rank, _sync_device
|
| 48 |
+
assert not _sync_called
|
| 49 |
+
_rank = rank
|
| 50 |
+
_sync_device = sync_device
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
@misc.profiled_function
|
| 55 |
+
def report(name, value):
|
| 56 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
| 57 |
+
`Collector`, across device and process boundaries. NaNs and Infs are
|
| 58 |
+
ignored.
|
| 59 |
+
|
| 60 |
+
This function is expected to be extremely cheap and can be safely
|
| 61 |
+
called from anywhere in the training loop, loss function, or inside a
|
| 62 |
+
`torch.nn.Module`.
|
| 63 |
+
|
| 64 |
+
Warning: The current implementation expects the set of unique names to
|
| 65 |
+
be consistent across processes. Please make sure that `report()` is
|
| 66 |
+
called at least once for each unique name by each process, and in the
|
| 67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
| 68 |
+
`report(name, [])` (empty list).
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
name: Arbitrary string specifying the name of the statistic.
|
| 72 |
+
Averages are accumulated separately for each unique name.
|
| 73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
| 74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
The same `value` that was passed in.
|
| 78 |
+
"""
|
| 79 |
+
if name not in _counters:
|
| 80 |
+
_counters[name] = dict()
|
| 81 |
+
|
| 82 |
+
elems = torch.as_tensor(value)
|
| 83 |
+
if elems.numel() == 0:
|
| 84 |
+
return value
|
| 85 |
+
|
| 86 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
| 87 |
+
square = elems.square()
|
| 88 |
+
finite = square.isfinite()
|
| 89 |
+
moments = torch.stack([
|
| 90 |
+
finite.sum(dtype=_reduce_dtype),
|
| 91 |
+
torch.where(finite, elems, 0).sum(),
|
| 92 |
+
torch.where(finite, square, 0).sum(),
|
| 93 |
+
])
|
| 94 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
| 95 |
+
moments = moments.to(_counter_dtype)
|
| 96 |
+
|
| 97 |
+
device = moments.device
|
| 98 |
+
if device not in _counters[name]:
|
| 99 |
+
_counters[name][device] = torch.zeros_like(moments)
|
| 100 |
+
_counters[name][device].add_(moments)
|
| 101 |
+
return value
|
| 102 |
+
|
| 103 |
+
#----------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
def report0(name, value):
|
| 106 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
| 107 |
+
but ignores any scalars provided by the other processes.
|
| 108 |
+
See `report()` for further details.
|
| 109 |
+
"""
|
| 110 |
+
report(name, value if _rank == 0 else [])
|
| 111 |
+
return value
|
| 112 |
+
|
| 113 |
+
#----------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class Collector:
|
| 116 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
| 117 |
+
computes their long-term averages (mean and standard deviation) over
|
| 118 |
+
user-defined periods of time.
|
| 119 |
+
|
| 120 |
+
The averages are first collected into internal counters that are not
|
| 121 |
+
directly visible to the user. They are then copied to the user-visible
|
| 122 |
+
state as a result of calling `update()` and can then be queried using
|
| 123 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
| 124 |
+
internal counters for the next round, so that the user-visible state
|
| 125 |
+
effectively reflects averages collected between the last two calls to
|
| 126 |
+
`update()`.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
regex: Regular expression defining which statistics to
|
| 130 |
+
collect. The default is to collect everything.
|
| 131 |
+
keep_previous: Whether to retain the previous averages if no
|
| 132 |
+
scalars were collected on a given round
|
| 133 |
+
(default: False).
|
| 134 |
+
"""
|
| 135 |
+
def __init__(self, regex='.*', keep_previous=False):
|
| 136 |
+
self._regex = re.compile(regex)
|
| 137 |
+
self._keep_previous = keep_previous
|
| 138 |
+
self._cumulative = dict()
|
| 139 |
+
self._moments = dict()
|
| 140 |
+
self.update()
|
| 141 |
+
self._moments.clear()
|
| 142 |
+
|
| 143 |
+
def names(self):
|
| 144 |
+
r"""Returns the names of all statistics broadcasted so far that
|
| 145 |
+
match the regular expression specified at construction time.
|
| 146 |
+
"""
|
| 147 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
| 148 |
+
|
| 149 |
+
def update(self):
|
| 150 |
+
r"""Copies current values of the internal counters to the
|
| 151 |
+
user-visible state and resets them for the next round.
|
| 152 |
+
|
| 153 |
+
If `keep_previous=True` was specified at construction time, the
|
| 154 |
+
operation is skipped for statistics that have received no scalars
|
| 155 |
+
since the last update, retaining their previous averages.
|
| 156 |
+
|
| 157 |
+
This method performs a number of GPU-to-CPU transfers and one
|
| 158 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
| 159 |
+
periodically in the main training loop, typically once every
|
| 160 |
+
N training steps.
|
| 161 |
+
"""
|
| 162 |
+
if not self._keep_previous:
|
| 163 |
+
self._moments.clear()
|
| 164 |
+
for name, cumulative in _sync(self.names()):
|
| 165 |
+
if name not in self._cumulative:
|
| 166 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 167 |
+
delta = cumulative - self._cumulative[name]
|
| 168 |
+
self._cumulative[name].copy_(cumulative)
|
| 169 |
+
if float(delta[0]) != 0:
|
| 170 |
+
self._moments[name] = delta
|
| 171 |
+
|
| 172 |
+
def _get_delta(self, name):
|
| 173 |
+
r"""Returns the raw moments that were accumulated for the given
|
| 174 |
+
statistic between the last two calls to `update()`, or zero if
|
| 175 |
+
no scalars were collected.
|
| 176 |
+
"""
|
| 177 |
+
assert self._regex.fullmatch(name)
|
| 178 |
+
if name not in self._moments:
|
| 179 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 180 |
+
return self._moments[name]
|
| 181 |
+
|
| 182 |
+
def num(self, name):
|
| 183 |
+
r"""Returns the number of scalars that were accumulated for the given
|
| 184 |
+
statistic between the last two calls to `update()`, or zero if
|
| 185 |
+
no scalars were collected.
|
| 186 |
+
"""
|
| 187 |
+
delta = self._get_delta(name)
|
| 188 |
+
return int(delta[0])
|
| 189 |
+
|
| 190 |
+
def mean(self, name):
|
| 191 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
| 192 |
+
given statistic between the last two calls to `update()`, or NaN if
|
| 193 |
+
no scalars were collected.
|
| 194 |
+
"""
|
| 195 |
+
delta = self._get_delta(name)
|
| 196 |
+
if int(delta[0]) == 0:
|
| 197 |
+
return float('nan')
|
| 198 |
+
return float(delta[1] / delta[0])
|
| 199 |
+
|
| 200 |
+
def std(self, name):
|
| 201 |
+
r"""Returns the standard deviation of the scalars that were
|
| 202 |
+
accumulated for the given statistic between the last two calls to
|
| 203 |
+
`update()`, or NaN if no scalars were collected.
|
| 204 |
+
"""
|
| 205 |
+
delta = self._get_delta(name)
|
| 206 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
| 207 |
+
return float('nan')
|
| 208 |
+
if int(delta[0]) == 1:
|
| 209 |
+
return float(0)
|
| 210 |
+
mean = float(delta[1] / delta[0])
|
| 211 |
+
raw_var = float(delta[2] / delta[0])
|
| 212 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
| 213 |
+
|
| 214 |
+
def as_dict(self):
|
| 215 |
+
r"""Returns the averages accumulated between the last two calls to
|
| 216 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
| 217 |
+
|
| 218 |
+
dnnlib.EasyDict(
|
| 219 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
| 220 |
+
...
|
| 221 |
+
)
|
| 222 |
+
"""
|
| 223 |
+
stats = dnnlib.EasyDict()
|
| 224 |
+
for name in self.names():
|
| 225 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
| 226 |
+
return stats
|
| 227 |
+
|
| 228 |
+
def __getitem__(self, name):
|
| 229 |
+
r"""Convenience getter.
|
| 230 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
| 231 |
+
"""
|
| 232 |
+
return self.mean(name)
|
| 233 |
+
|
| 234 |
+
#----------------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def _sync(names):
|
| 237 |
+
r"""Synchronize the global cumulative counters across devices and
|
| 238 |
+
processes. Called internally by `Collector.update()`.
|
| 239 |
+
"""
|
| 240 |
+
if len(names) == 0:
|
| 241 |
+
return []
|
| 242 |
+
global _sync_called
|
| 243 |
+
_sync_called = True
|
| 244 |
+
|
| 245 |
+
# Check that all ranks have the same set of names.
|
| 246 |
+
if _sync_device is not None:
|
| 247 |
+
value = hash(tuple(tuple(ord(char) for char in name) for name in names))
|
| 248 |
+
other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
|
| 249 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
| 250 |
+
if value != int(other.cpu()):
|
| 251 |
+
raise ValueError('Training statistics are inconsistent between ranks')
|
| 252 |
+
|
| 253 |
+
# Collect deltas within current rank.
|
| 254 |
+
deltas = []
|
| 255 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
| 256 |
+
for name in names:
|
| 257 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
| 258 |
+
for counter in _counters[name].values():
|
| 259 |
+
delta.add_(counter.to(device))
|
| 260 |
+
counter.copy_(torch.zeros_like(counter))
|
| 261 |
+
deltas.append(delta)
|
| 262 |
+
deltas = torch.stack(deltas)
|
| 263 |
+
|
| 264 |
+
# Sum deltas across ranks.
|
| 265 |
+
if _sync_device is not None:
|
| 266 |
+
torch.distributed.all_reduce(deltas)
|
| 267 |
+
|
| 268 |
+
# Update cumulative values.
|
| 269 |
+
deltas = deltas.cpu()
|
| 270 |
+
for idx, name in enumerate(names):
|
| 271 |
+
if name not in _cumulative:
|
| 272 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| 273 |
+
_cumulative[name].add_(deltas[idx])
|
| 274 |
+
|
| 275 |
+
# Return name-value pairs.
|
| 276 |
+
return [(name, _cumulative[name]) for name in names]
|
| 277 |
+
|
| 278 |
+
#----------------------------------------------------------------------------
|
| 279 |
+
# Convenience.
|
| 280 |
+
|
| 281 |
+
default_collector = Collector()
|
| 282 |
+
|
| 283 |
+
#----------------------------------------------------------------------------
|
REG/wandb/debug-internal.log
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-04-08T18:26:46.552297532+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-04-08T18:26:47.20146143+08:00","level":"INFO","msg":"stream: created new stream","id":"xtwg5t5s"}
|
| 3 |
+
{"time":"2026-04-08T18:26:47.201551011+08:00","level":"INFO","msg":"handler: started","stream_id":"xtwg5t5s"}
|
| 4 |
+
{"time":"2026-04-08T18:26:47.202423643+08:00","level":"INFO","msg":"stream: started","id":"xtwg5t5s"}
|
| 5 |
+
{"time":"2026-04-08T18:26:47.202450453+08:00","level":"INFO","msg":"writer: started","stream_id":"xtwg5t5s"}
|
| 6 |
+
{"time":"2026-04-08T18:26:47.202479681+08:00","level":"INFO","msg":"sender: started","stream_id":"xtwg5t5s"}
|
| 7 |
+
{"time":"2026-04-09T00:59:33.394616937+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 8 |
+
{"time":"2026-04-09T15:26:36.673675921+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:37630->35.186.228.49:443: read: connection reset by peer"}
|
| 9 |
+
{"time":"2026-04-09T15:32:51.675782111+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:55710->35.186.228.49:443: read: connection reset by peer"}
|
| 10 |
+
{"time":"2026-04-09T15:33:36.688517829+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": EOF"}
|
| 11 |
+
{"time":"2026-04-10T00:33:41.365462236+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 12 |
+
{"time":"2026-04-10T06:11:35.438909216+08:00","level":"INFO","msg":"api: retrying HTTP error","status":429,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"rate limit exceeded: per_run limit on filestream requests\"}"}
|
| 13 |
+
{"time":"2026-04-11T02:04:06.260667043+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
|
| 14 |
+
{"time":"2026-04-11T10:00:44.531212038+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
|
| 15 |
+
{"time":"2026-04-11T10:20:26.360393211+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
|
| 16 |
+
{"time":"2026-04-12T21:59:44.458847327+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 17 |
+
{"time":"2026-04-13T00:04:28.494081102+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:35484->35.186.228.49:443: read: connection reset by peer"}
|
| 18 |
+
{"time":"2026-04-13T02:39:57.535775934+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
|
| 19 |
+
{"time":"2026-04-13T02:39:58.493368195+08:00","level":"INFO","msg":"handler: closed","stream_id":"xtwg5t5s"}
|
| 20 |
+
{"time":"2026-04-13T02:39:58.494772782+08:00","level":"INFO","msg":"sender: closed","stream_id":"xtwg5t5s"}
|
| 21 |
+
{"time":"2026-04-13T02:39:58.49521181+08:00","level":"INFO","msg":"stream: closed","id":"xtwg5t5s"}
|
REG/wandb/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Configure stats pid to 128263
|
| 3 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260408_182646-xtwg5t5s/logs/debug.log
|
| 5 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260408_182646-xtwg5t5s/logs/debug-internal.log
|
| 6 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-04-08 18:26:46,532 INFO MainThread:128263 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-04-08 18:26:46,548 INFO MainThread:128263 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-04-08 18:26:46,551 INFO MainThread:128263 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-04-08 18:26:46,572 INFO MainThread:128263 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-04-08 18:26:47,862 INFO MainThread:128263 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-04-08 18:26:48,108 INFO MainThread:128263 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-04-08 18:26:48,108 INFO MainThread:128263 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75-0.01-one-step', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'resume_from_ckpt': '/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1920000.pt', 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 14000, 'max_train_steps': 10000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.005, 't_c': 0.75, 'ot_cls': True, 'tc_velocity_loss_coeff': 2.0}
|
| 21 |
+
2026-04-13 02:35:32,832 INFO wandb-AsyncioManager-main:128263 [service_client.py:_forward_responses():134] Reached EOF.
|
| 22 |
+
2026-04-13 02:35:32,833 INFO wandb-AsyncioManager-main:128263 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
|
REG/wandb/run-20260322_141726-2yw08kz9/files/config.yaml
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_wandb:
|
| 2 |
+
value:
|
| 3 |
+
cli_version: 0.25.0
|
| 4 |
+
e:
|
| 5 |
+
257k9ot60u1bv0aiwlacsvutj9c72h7y:
|
| 6 |
+
args:
|
| 7 |
+
- --report-to
|
| 8 |
+
- wandb
|
| 9 |
+
- --allow-tf32
|
| 10 |
+
- --mixed-precision
|
| 11 |
+
- bf16
|
| 12 |
+
- --seed
|
| 13 |
+
- "0"
|
| 14 |
+
- --path-type
|
| 15 |
+
- linear
|
| 16 |
+
- --prediction
|
| 17 |
+
- v
|
| 18 |
+
- --weighting
|
| 19 |
+
- uniform
|
| 20 |
+
- --model
|
| 21 |
+
- SiT-XL/2
|
| 22 |
+
- --enc-type
|
| 23 |
+
- dinov2-vit-b
|
| 24 |
+
- --encoder-depth
|
| 25 |
+
- "8"
|
| 26 |
+
- --proj-coeff
|
| 27 |
+
- "0.5"
|
| 28 |
+
- --output-dir
|
| 29 |
+
- exps
|
| 30 |
+
- --exp-name
|
| 31 |
+
- jsflow-experiment
|
| 32 |
+
- --batch-size
|
| 33 |
+
- "256"
|
| 34 |
+
- --data-dir
|
| 35 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 36 |
+
- --semantic-features-dir
|
| 37 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 38 |
+
- --learning-rate
|
| 39 |
+
- "0.00005"
|
| 40 |
+
- --t-c
|
| 41 |
+
- "0.5"
|
| 42 |
+
- --cls
|
| 43 |
+
- "0.2"
|
| 44 |
+
- --ot-cls
|
| 45 |
+
codePath: train.py
|
| 46 |
+
codePathLocal: train.py
|
| 47 |
+
cpu_count: 96
|
| 48 |
+
cpu_count_logical: 192
|
| 49 |
+
cudaVersion: "13.0"
|
| 50 |
+
disk:
|
| 51 |
+
/:
|
| 52 |
+
total: "3838880616448"
|
| 53 |
+
used: "357556633600"
|
| 54 |
+
email: 2365972933@qq.com
|
| 55 |
+
executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 56 |
+
git:
|
| 57 |
+
commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
|
| 58 |
+
remote: https://github.com/Martinser/REG.git
|
| 59 |
+
gpu: NVIDIA H100 80GB HBM3
|
| 60 |
+
gpu_count: 4
|
| 61 |
+
gpu_nvidia:
|
| 62 |
+
- architecture: Hopper
|
| 63 |
+
cudaCores: 16896
|
| 64 |
+
memoryTotal: "85520809984"
|
| 65 |
+
name: NVIDIA H100 80GB HBM3
|
| 66 |
+
uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
|
| 67 |
+
- architecture: Hopper
|
| 68 |
+
cudaCores: 16896
|
| 69 |
+
memoryTotal: "85520809984"
|
| 70 |
+
name: NVIDIA H100 80GB HBM3
|
| 71 |
+
uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
|
| 72 |
+
- architecture: Hopper
|
| 73 |
+
cudaCores: 16896
|
| 74 |
+
memoryTotal: "85520809984"
|
| 75 |
+
name: NVIDIA H100 80GB HBM3
|
| 76 |
+
uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
|
| 77 |
+
- architecture: Hopper
|
| 78 |
+
cudaCores: 16896
|
| 79 |
+
memoryTotal: "85520809984"
|
| 80 |
+
name: NVIDIA H100 80GB HBM3
|
| 81 |
+
uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
|
| 82 |
+
host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 83 |
+
memory:
|
| 84 |
+
total: "2164115296256"
|
| 85 |
+
os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
|
| 86 |
+
program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
|
| 87 |
+
python: CPython 3.12.9
|
| 88 |
+
root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
|
| 89 |
+
startedAt: "2026-03-22T06:17:26.670763Z"
|
| 90 |
+
writerId: 257k9ot60u1bv0aiwlacsvutj9c72h7y
|
| 91 |
+
m: []
|
| 92 |
+
python_version: 3.12.9
|
| 93 |
+
t:
|
| 94 |
+
"1":
|
| 95 |
+
- 1
|
| 96 |
+
- 5
|
| 97 |
+
- 11
|
| 98 |
+
- 41
|
| 99 |
+
- 49
|
| 100 |
+
- 53
|
| 101 |
+
- 63
|
| 102 |
+
- 71
|
| 103 |
+
- 83
|
| 104 |
+
- 98
|
| 105 |
+
"2":
|
| 106 |
+
- 1
|
| 107 |
+
- 5
|
| 108 |
+
- 11
|
| 109 |
+
- 41
|
| 110 |
+
- 49
|
| 111 |
+
- 53
|
| 112 |
+
- 63
|
| 113 |
+
- 71
|
| 114 |
+
- 83
|
| 115 |
+
- 98
|
| 116 |
+
"3":
|
| 117 |
+
- 13
|
| 118 |
+
- 61
|
| 119 |
+
"4": 3.12.9
|
| 120 |
+
"5": 0.25.0
|
| 121 |
+
"6": 4.53.2
|
| 122 |
+
"12": 0.25.0
|
| 123 |
+
"13": linux-x86_64
|
| 124 |
+
adam_beta1:
|
| 125 |
+
value: 0.9
|
| 126 |
+
adam_beta2:
|
| 127 |
+
value: 0.999
|
| 128 |
+
adam_epsilon:
|
| 129 |
+
value: 1e-08
|
| 130 |
+
adam_weight_decay:
|
| 131 |
+
value: 0
|
| 132 |
+
allow_tf32:
|
| 133 |
+
value: true
|
| 134 |
+
batch_size:
|
| 135 |
+
value: 256
|
| 136 |
+
cfg_prob:
|
| 137 |
+
value: 0.1
|
| 138 |
+
checkpointing_steps:
|
| 139 |
+
value: 10000
|
| 140 |
+
cls:
|
| 141 |
+
value: 0.2
|
| 142 |
+
data_dir:
|
| 143 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 144 |
+
enc_type:
|
| 145 |
+
value: dinov2-vit-b
|
| 146 |
+
encoder_depth:
|
| 147 |
+
value: 8
|
| 148 |
+
epochs:
|
| 149 |
+
value: 1400
|
| 150 |
+
exp_name:
|
| 151 |
+
value: jsflow-experiment
|
| 152 |
+
fused_attn:
|
| 153 |
+
value: true
|
| 154 |
+
gradient_accumulation_steps:
|
| 155 |
+
value: 1
|
| 156 |
+
learning_rate:
|
| 157 |
+
value: 5e-05
|
| 158 |
+
legacy:
|
| 159 |
+
value: false
|
| 160 |
+
logging_dir:
|
| 161 |
+
value: logs
|
| 162 |
+
max_grad_norm:
|
| 163 |
+
value: 1
|
| 164 |
+
max_train_steps:
|
| 165 |
+
value: 1000000
|
| 166 |
+
mixed_precision:
|
| 167 |
+
value: bf16
|
| 168 |
+
model:
|
| 169 |
+
value: SiT-XL/2
|
| 170 |
+
num_classes:
|
| 171 |
+
value: 1000
|
| 172 |
+
num_workers:
|
| 173 |
+
value: 4
|
| 174 |
+
ops_head:
|
| 175 |
+
value: 16
|
| 176 |
+
ot_cls:
|
| 177 |
+
value: true
|
| 178 |
+
output_dir:
|
| 179 |
+
value: exps
|
| 180 |
+
path_type:
|
| 181 |
+
value: linear
|
| 182 |
+
prediction:
|
| 183 |
+
value: v
|
| 184 |
+
proj_coeff:
|
| 185 |
+
value: 0.5
|
| 186 |
+
qk_norm:
|
| 187 |
+
value: false
|
| 188 |
+
report_to:
|
| 189 |
+
value: wandb
|
| 190 |
+
resolution:
|
| 191 |
+
value: 256
|
| 192 |
+
resume_step:
|
| 193 |
+
value: 0
|
| 194 |
+
sampling_steps:
|
| 195 |
+
value: 10000
|
| 196 |
+
seed:
|
| 197 |
+
value: 0
|
| 198 |
+
semantic_features_dir:
|
| 199 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 200 |
+
t_c:
|
| 201 |
+
value: 0.5
|
| 202 |
+
weighting:
|
| 203 |
+
value: uniform
|
REG/wandb/run-20260322_141726-2yw08kz9/files/output.log
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<614:34:39, 2.21s/it][[34m2026-03-22 14:17:31[0m] Generating EMA samples done.
|
| 2 |
+
[[34m2026-03-22 14:17:31[0m] Step: 1, Training Logs: loss_final: 3.278940, loss_mean: 1.706308, proj_loss: 0.001541, loss_mean_cls: 1.571091, grad_norm: 1.481672
|
| 3 |
+
Steps: 0%| | 2/1000000 [00:02<289:06:04, 1.04s/it, grad_norm=1.48, loss_final=3.28, loss_mean=1.71, loss_mean_cls=1.57, proj_loss=0.001[[34m2026-03-22 14:17:31[0m] Step: 2, Training Logs: loss_final: 3.211831, loss_mean: 1.688932, proj_loss: -0.010287, loss_mean_cls: 1.533185, grad_norm: 1.055476
|
| 4 |
+
Steps: 0%| | 3/1000000 [00:02<187:48:39, 1.48it/s, grad_norm=1.06, loss_final=3.21, loss_mean=1.69, loss_mean_cls=1.53, proj_loss=-0.01[[34m2026-03-22 14:17:31[0m] Step: 3, Training Logs: loss_final: 3.201248, loss_mean: 1.663205, proj_loss: -0.019184, loss_mean_cls: 1.557227, grad_norm: 1.116387
|
| 5 |
+
Steps: 0%| | 4/1000000 [00:02<140:12:43, 1.98it/s, grad_norm=1.12, loss_final=3.2, loss_mean=1.66, loss_mean_cls=1.56, proj_loss=-0.019[[34m2026-03-22 14:17:32[0m] Step: 4, Training Logs: loss_final: 3.198367, loss_mean: 1.682051, proj_loss: -0.026376, loss_mean_cls: 1.542691, grad_norm: 0.722294
|
| 6 |
+
Steps: 0%| | 5/1000000 [00:03<113:52:43, 2.44it/s, grad_norm=0.722, loss_final=3.2, loss_mean=1.68, loss_mean_cls=1.54, proj_loss=-0.02[[34m2026-03-22 14:17:32[0m] Step: 5, Training Logs: loss_final: 3.140483, loss_mean: 1.679105, proj_loss: -0.034564, loss_mean_cls: 1.495943, grad_norm: 0.811589
|
| 7 |
+
Steps: 0%| | 6/1000000 [00:03<97:59:40, 2.83it/s, grad_norm=0.812, loss_final=3.14, loss_mean=1.68, loss_mean_cls=1.5, proj_loss=-0.034[[34m2026-03-22 14:17:32[0m] Step: 6, Training Logs: loss_final: 2.988440, loss_mean: 1.682339, proj_loss: -0.039506, loss_mean_cls: 1.345606, grad_norm: 0.931524
|
| 8 |
+
Steps: 0%| | 7/1000000 [00:03<87:55:00, 3.16it/s, grad_norm=0.932, loss_final=2.99, loss_mean=1.68, loss_mean_cls=1.35, proj_loss=-0.03[[34m2026-03-22 14:17:32[0m] Step: 7, Training Logs: loss_final: 3.111949, loss_mean: 1.690802, proj_loss: -0.042757, loss_mean_cls: 1.463904, grad_norm: 0.830852
|
| 9 |
+
Steps: 0%| | 8/1000000 [00:03<81:19:20, 3.42it/s, grad_norm=0.831, loss_final=3.11, loss_mean=1.69, loss_mean_cls=1.46, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 8, Training Logs: loss_final: 3.278931, loss_mean: 1.660797, proj_loss: -0.045011, loss_mean_cls: 1.663145, grad_norm: 0.847438
|
| 10 |
+
Steps: 0%| | 9/1000000 [00:04<76:56:10, 3.61it/s, grad_norm=0.847, loss_final=3.28, loss_mean=1.66, loss_mean_cls=1.66, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 9, Training Logs: loss_final: 3.221569, loss_mean: 1.658834, proj_loss: -0.046031, loss_mean_cls: 1.608767, grad_norm: 0.909827
|
| 11 |
+
Steps: 0%| | 10/1000000 [00:04<73:57:18, 3.76it/s, grad_norm=0.91, loss_final=3.22, loss_mean=1.66, loss_mean_cls=1.61, proj_loss=-0.04[[34m2026-03-22 14:17:33[0m] Step: 10, Training Logs: loss_final: 3.216744, loss_mean: 1.665229, proj_loss: -0.047761, loss_mean_cls: 1.599277, grad_norm: 1.014574
|
| 12 |
+
Steps: 0%| | 11/1000000 [00:04<71:52:01, 3.87it/s, grad_norm=1.01, loss_final=3.22, loss_mean=1.67, loss_mean_cls=1.6, proj_loss=-0.047[[34m2026-03-22 14:17:33[0m] Step: 11, Training Logs: loss_final: 3.216658, loss_mean: 1.649915, proj_loss: -0.049347, loss_mean_cls: 1.616090, grad_norm: 1.028789
|
| 13 |
+
Steps: 0%| | 12/1000000 [00:04<70:26:20, 3.94it/s, grad_norm=1.03, loss_final=3.22, loss_mean=1.65, loss_mean_cls=1.62, proj_loss=-0.04[[34m2026-03-22 14:17:34[0m] Step: 12, Training Logs: loss_final: 3.155676, loss_mean: 1.624463, proj_loss: -0.049856, loss_mean_cls: 1.581069, grad_norm: 1.231291
|
| 14 |
+
Steps: 0%| | 13/1000000 [00:05<69:25:29, 4.00it/s, grad_norm=1.23, loss_final=3.16, loss_mean=1.62, loss_mean_cls=1.58, proj_loss=-0.04Traceback (most recent call last):
|
| 15 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
|
| 16 |
+
main(args)
|
| 17 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
|
| 18 |
+
"loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 19 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 20 |
+
KeyboardInterrupt
|
| 21 |
+
[rank0]: Traceback (most recent call last):
|
| 22 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
|
| 23 |
+
[rank0]: main(args)
|
| 24 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
|
| 25 |
+
[rank0]: "loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 26 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 27 |
+
[rank0]: KeyboardInterrupt
|
REG/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T06:17:26.670763Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357556633600"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "257k9ot60u1bv0aiwlacsvutj9c72h7y"
|
| 101 |
+
}
|
REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"loss_mean_cls":1.5810688734054565,"_timestamp":1.7741602540511734e+09,"_runtime":5.247627056,"loss_mean":1.6244629621505737,"proj_loss":-0.04985573887825012,"grad_norm":1.2312908172607422,"_wandb":{"runtime":5},"_step":12,"loss_final":3.1556761264801025}
|
REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T14:17:27.013311984+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T14:17:28.347732261+08:00","level":"INFO","msg":"stream: created new stream","id":"2yw08kz9"}
|
| 3 |
+
{"time":"2026-03-22T14:17:28.347960938+08:00","level":"INFO","msg":"handler: started","stream_id":"2yw08kz9"}
|
| 4 |
+
{"time":"2026-03-22T14:17:28.348671928+08:00","level":"INFO","msg":"stream: started","id":"2yw08kz9"}
|
| 5 |
+
{"time":"2026-03-22T14:17:28.348731034+08:00","level":"INFO","msg":"sender: started","stream_id":"2yw08kz9"}
|
| 6 |
+
{"time":"2026-03-22T14:17:28.348748525+08:00","level":"INFO","msg":"writer: started","stream_id":"2yw08kz9"}
|
| 7 |
+
{"time":"2026-03-22T14:17:34.316421629+08:00","level":"INFO","msg":"stream: closing","id":"2yw08kz9"}
|
REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Configure stats pid to 316313
|
| 3 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log
|
| 5 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log
|
| 6 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 14:17:26,994 INFO MainThread:316313 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 14:17:27,008 INFO MainThread:316313 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 14:17:27,011 INFO MainThread:316313 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 14:17:27,025 INFO MainThread:316313 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 14:17:29,067 INFO MainThread:316313 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 14:17:29,159 INFO MainThread:316313 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
| 21 |
+
2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [service_client.py:_forward_responses():134] Reached EOF.
|
| 22 |
+
2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
|
REG/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb
ADDED
|
Binary file (7 Bytes). View file
|
|
|
REG/wandb/run-20260322_141833-vm0y8t9t/files/output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
REG/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T06:18:33.208941Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357556703232"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "gklxguwapb72cxij4696gj37bh1rbthi"
|
| 101 |
+
}
|
REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T14:18:33.472940651+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T14:18:35.380852704+08:00","level":"INFO","msg":"stream: created new stream","id":"vm0y8t9t"}
|
| 3 |
+
{"time":"2026-03-22T14:18:35.381056887+08:00","level":"INFO","msg":"handler: started","stream_id":"vm0y8t9t"}
|
| 4 |
+
{"time":"2026-03-22T14:18:35.382108345+08:00","level":"INFO","msg":"writer: started","stream_id":"vm0y8t9t"}
|
| 5 |
+
{"time":"2026-03-22T14:18:35.382119604+08:00","level":"INFO","msg":"stream: started","id":"vm0y8t9t"}
|
| 6 |
+
{"time":"2026-03-22T14:18:35.382161533+08:00","level":"INFO","msg":"sender: started","stream_id":"vm0y8t9t"}
|
REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Configure stats pid to 318585
|
| 3 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log
|
| 5 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log
|
| 6 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 14:18:33,460 INFO MainThread:318585 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 14:18:33,470 INFO MainThread:318585 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 14:18:33,472 INFO MainThread:318585 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 14:18:33,485 INFO MainThread:318585 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 14:18:36,829 INFO MainThread:318585 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
REG/wandb/run-20260322_150022-yhxc5cgu/files/output.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<652:30:07, 2.35s/it][[34m2026-03-22 15:00:28[0m] Generating EMA samples for evaluation (t=1→0 and t=0.5)...
|
| 2 |
+
Traceback (most recent call last):
|
| 3 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
|
| 4 |
+
main(args)
|
| 5 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
|
| 6 |
+
cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
|
| 7 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 8 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
|
| 9 |
+
raise AttributeError(
|
| 10 |
+
AttributeError: 'SiT' object has no attribute 'semantic_channels'
|
| 11 |
+
[rank0]: Traceback (most recent call last):
|
| 12 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
|
| 13 |
+
[rank0]: main(args)
|
| 14 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
|
| 15 |
+
[rank0]: cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
|
| 16 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 17 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
|
| 18 |
+
[rank0]: raise AttributeError(
|
| 19 |
+
[rank0]: AttributeError: 'SiT' object has no attribute 'semantic_channels'
|
REG/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T07:00:22.092510Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357557354496"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "ucanic8s891x6sl28vnbha78lzoecw66"
|
| 101 |
+
}
|
REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T15:00:22.432399726+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T15:00:25.799578446+08:00","level":"INFO","msg":"stream: created new stream","id":"yhxc5cgu"}
|
| 3 |
+
{"time":"2026-03-22T15:00:25.799734466+08:00","level":"INFO","msg":"handler: started","stream_id":"yhxc5cgu"}
|
| 4 |
+
{"time":"2026-03-22T15:00:25.80075778+08:00","level":"INFO","msg":"stream: started","id":"yhxc5cgu"}
|
| 5 |
+
{"time":"2026-03-22T15:00:25.800786229+08:00","level":"INFO","msg":"writer: started","stream_id":"yhxc5cgu"}
|
| 6 |
+
{"time":"2026-03-22T15:00:25.800837858+08:00","level":"INFO","msg":"sender: started","stream_id":"yhxc5cgu"}
|
| 7 |
+
{"time":"2026-03-22T15:00:28.913273863+08:00","level":"INFO","msg":"stream: closing","id":"yhxc5cgu"}
|
REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Configure stats pid to 323629
|
| 3 |
+
2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log
|
| 5 |
+
2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log
|
| 6 |
+
2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 15:00:22,416 INFO MainThread:323629 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 15:00:22,429 INFO MainThread:323629 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 15:00:22,431 INFO MainThread:323629 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 15:00:22,447 INFO MainThread:323629 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 15:00:26,403 INFO MainThread:323629 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 15:00:26,495 INFO MainThread:323629 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 15:00:26,500 INFO MainThread:323629 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 15:00:26,500 INFO MainThread:323629 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
| 21 |
+
2026-03-22 15:00:28,913 INFO wandb-AsyncioManager-main:323629 [service_client.py:_forward_responses():134] Reached EOF.
|
| 22 |
+
2026-03-22 15:00:28,913 INFO wandb-AsyncioManager-main:323629 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
|
REG/wandb/run-20260322_150022-yhxc5cgu/run-yhxc5cgu.wandb
ADDED
|
Binary file (7 Bytes). View file
|
|
|
REG/wandb/run-20260322_150443-e3yw9ii4/run-e3yw9ii4.wandb
ADDED
|
Binary file (7 Bytes). View file
|
|
|