xiangzai commited on
Commit
b5e1f6d
·
verified ·
1 Parent(s): 342a08e

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. REG/__pycache__/dataset.cpython-312.pyc +0 -0
  2. REG/__pycache__/loss.cpython-312.pyc +0 -0
  3. REG/__pycache__/loss.cpython-313.pyc +0 -0
  4. REG/__pycache__/sample_from_checkpoint.cpython-313.pyc +0 -0
  5. REG/__pycache__/sample_from_checkpoint_ddp.cpython-313.pyc +0 -0
  6. REG/__pycache__/samplers.cpython-312.pyc +0 -0
  7. REG/__pycache__/samplers.cpython-313.pyc +0 -0
  8. REG/__pycache__/train.cpython-313.pyc +0 -0
  9. REG/__pycache__/utils.cpython-312.pyc +0 -0
  10. REG/models/__pycache__/mocov3_vit.cpython-310.pyc +0 -0
  11. REG/models/__pycache__/mocov3_vit.cpython-312.pyc +0 -0
  12. REG/models/__pycache__/sit.cpython-310.pyc +0 -0
  13. REG/models/__pycache__/sit.cpython-312.pyc +0 -0
  14. REG/preprocessing/README.md +25 -0
  15. REG/preprocessing/dataset_image_encoder.py +353 -0
  16. REG/preprocessing/dataset_prepare_convert.sh +11 -0
  17. REG/preprocessing/dataset_prepare_encode.sh +9 -0
  18. REG/preprocessing/dataset_tools.py +422 -0
  19. REG/preprocessing/dnnlib/__init__.py +8 -0
  20. REG/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc +0 -0
  21. REG/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc +0 -0
  22. REG/preprocessing/dnnlib/util.py +485 -0
  23. REG/preprocessing/encoders.py +103 -0
  24. REG/preprocessing/torch_utils/__init__.py +8 -0
  25. REG/preprocessing/torch_utils/distributed.py +140 -0
  26. REG/preprocessing/torch_utils/misc.py +277 -0
  27. REG/preprocessing/torch_utils/persistence.py +257 -0
  28. REG/preprocessing/torch_utils/training_stats.py +283 -0
  29. REG/wandb/debug-internal.log +21 -0
  30. REG/wandb/debug.log +22 -0
  31. REG/wandb/run-20260322_141726-2yw08kz9/files/config.yaml +203 -0
  32. REG/wandb/run-20260322_141726-2yw08kz9/files/output.log +27 -0
  33. REG/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt +168 -0
  34. REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json +101 -0
  35. REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json +1 -0
  36. REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log +7 -0
  37. REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log +22 -0
  38. REG/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb +0 -0
  39. REG/wandb/run-20260322_141833-vm0y8t9t/files/output.log +0 -0
  40. REG/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt +168 -0
  41. REG/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json +101 -0
  42. REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log +6 -0
  43. REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log +20 -0
  44. REG/wandb/run-20260322_150022-yhxc5cgu/files/output.log +19 -0
  45. REG/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt +168 -0
  46. REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-metadata.json +101 -0
  47. REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log +7 -0
  48. REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log +22 -0
  49. REG/wandb/run-20260322_150022-yhxc5cgu/run-yhxc5cgu.wandb +0 -0
  50. REG/wandb/run-20260322_150443-e3yw9ii4/run-e3yw9ii4.wandb +0 -0
REG/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
REG/__pycache__/loss.cpython-312.pyc ADDED
Binary file (9.98 kB). View file
 
REG/__pycache__/loss.cpython-313.pyc ADDED
Binary file (8.75 kB). View file
 
REG/__pycache__/sample_from_checkpoint.cpython-313.pyc ADDED
Binary file (15.7 kB). View file
 
REG/__pycache__/sample_from_checkpoint_ddp.cpython-313.pyc ADDED
Binary file (22 kB). View file
 
REG/__pycache__/samplers.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
REG/__pycache__/samplers.cpython-313.pyc ADDED
Binary file (31.6 kB). View file
 
REG/__pycache__/train.cpython-313.pyc ADDED
Binary file (33.4 kB). View file
 
REG/__pycache__/utils.cpython-312.pyc ADDED
Binary file (10.8 kB). View file
 
REG/models/__pycache__/mocov3_vit.cpython-310.pyc ADDED
Binary file (6.5 kB). View file
 
REG/models/__pycache__/mocov3_vit.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
REG/models/__pycache__/sit.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
REG/models/__pycache__/sit.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
REG/preprocessing/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center"> Preprocessing Guide
2
+ </h1>
3
+
4
+ #### Dataset download
5
+
6
+ We follow the preprocessing code used in [edm2](https://github.com/NVlabs/edm2). In this code we made a several edits: (1) we removed unncessary parts except preprocessing because this code is only used for preprocessing, (2) we use [-1, 1] range for an input to the stable diffusion VAE (similar to DiT or SiT) unlike edm2 that uses [0, 1] range, and (3) we consider preprocessing to 256x256 resolution (or 512x512 resolution).
7
+
8
+ After downloading ImageNet, please run the following scripts (please update 256x256 to 512x512 if you want to do experiments on 512x512 resolution);
9
+
10
+ Convert raw ImageNet data to a ZIP archive at 256x256 resolution
11
+ ```bash
12
+ bash dataset_prepare_encode.sh
13
+ ```
14
+
15
+ Convert the pixel data to VAE latents
16
+
17
+ ```bash
18
+ bash dataset_prepare_convert.sh
19
+ ```
20
+
21
+ Here,`YOUR_DOWNLOAD_PATH` is the directory that you downloaded the dataset, and `TARGET_PATH` is the directory that you will save the preprocessed images and corresponding compressed latent vectors. This directory will be used for your experiment scripts.
22
+
23
+ ## Acknowledgement
24
+
25
+ This code is mainly built upon [edm2](https://github.com/NVlabs/edm2) repository.
REG/preprocessing/dataset_image_encoder.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Tool for creating ZIP/PNG based datasets."""
9
+
10
+ from collections.abc import Iterator
11
+ from dataclasses import dataclass
12
+ import functools
13
+ import io
14
+ import json
15
+ import os
16
+ import re
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+ import click
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from tqdm import tqdm
25
+
26
+ from encoders import StabilityVAEEncoder
27
+ from utils import load_encoders
28
+ from torchvision.transforms import Normalize
29
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
30
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
31
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
32
+
33
+ def preprocess_raw_image(x, enc_type):
34
+ resolution = x.shape[-1]
35
+ if 'clip' in enc_type:
36
+ x = x / 255.
37
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
38
+ x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
39
+ elif 'mocov3' in enc_type or 'mae' in enc_type:
40
+ x = x / 255.
41
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
42
+ elif 'dinov2' in enc_type:
43
+ x = x / 255.
44
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
45
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
46
+ elif 'dinov1' in enc_type:
47
+ x = x / 255.
48
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
49
+ elif 'jepa' in enc_type:
50
+ x = x / 255.
51
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
52
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
53
+
54
+ return x
55
+
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ @dataclass
60
+ class ImageEntry:
61
+ img: np.ndarray
62
+ label: Optional[int]
63
+
64
+ #----------------------------------------------------------------------------
65
+ # Parse a 'M,N' or 'MxN' integer tuple.
66
+ # Example: '4x2' returns (4,2)
67
+
68
+ def parse_tuple(s: str) -> Tuple[int, int]:
69
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
70
+ if m:
71
+ return int(m.group(1)), int(m.group(2))
72
+ raise click.ClickException(f'cannot parse tuple {s}')
73
+
74
+ #----------------------------------------------------------------------------
75
+
76
+ def maybe_min(a: int, b: Optional[int]) -> int:
77
+ if b is not None:
78
+ return min(a, b)
79
+ return a
80
+
81
+ #----------------------------------------------------------------------------
82
+
83
+ def file_ext(name: Union[str, Path]) -> str:
84
+ return str(name).split('.')[-1]
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def is_image_ext(fname: Union[str, Path]) -> bool:
89
+ ext = file_ext(fname).lower()
90
+ return f'.{ext}' in PIL.Image.EXTENSION
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
95
+ input_images = []
96
+ def _recurse_dirs(root: str): # workaround Path().rglob() slowness
97
+ with os.scandir(root) as it:
98
+ for e in it:
99
+ if e.is_file():
100
+ input_images.append(os.path.join(root, e.name))
101
+ elif e.is_dir():
102
+ _recurse_dirs(os.path.join(root, e.name))
103
+ _recurse_dirs(source_dir)
104
+ input_images = sorted([f for f in input_images if is_image_ext(f)])
105
+
106
+ arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
107
+ max_idx = maybe_min(len(input_images), max_images)
108
+
109
+ # Load labels.
110
+ labels = dict()
111
+ meta_fname = os.path.join(source_dir, 'dataset.json')
112
+ if os.path.isfile(meta_fname):
113
+ with open(meta_fname, 'r') as file:
114
+ data = json.load(file)['labels']
115
+ if data is not None:
116
+ labels = {x[0]: x[1] for x in data}
117
+
118
+ # No labels available => determine from top-level directory names.
119
+ if len(labels) == 0:
120
+ toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
121
+ toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
122
+ if len(toplevel_indices) > 1:
123
+ labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
124
+
125
+ def iterate_images():
126
+ for idx, fname in enumerate(input_images):
127
+ img = np.array(PIL.Image.open(fname).convert('RGB'))#.transpose(2, 0, 1)
128
+ yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
129
+ if idx >= max_idx - 1:
130
+ break
131
+ return max_idx, iterate_images()
132
+
133
+ #----------------------------------------------------------------------------
134
+
135
+ def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
136
+ with zipfile.ZipFile(source, mode='r') as z:
137
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
138
+ max_idx = maybe_min(len(input_images), max_images)
139
+
140
+ # Load labels.
141
+ labels = dict()
142
+ if 'dataset.json' in z.namelist():
143
+ with z.open('dataset.json', 'r') as file:
144
+ data = json.load(file)['labels']
145
+ if data is not None:
146
+ labels = {x[0]: x[1] for x in data}
147
+
148
+ def iterate_images():
149
+ with zipfile.ZipFile(source, mode='r') as z:
150
+ for idx, fname in enumerate(input_images):
151
+ with z.open(fname, 'r') as file:
152
+ img = np.array(PIL.Image.open(file).convert('RGB'))
153
+ yield ImageEntry(img=img, label=labels.get(fname))
154
+ if idx >= max_idx - 1:
155
+ break
156
+ return max_idx, iterate_images()
157
+
158
+ #----------------------------------------------------------------------------
159
+
160
+ def make_transform(
161
+ transform: Optional[str],
162
+ output_width: Optional[int],
163
+ output_height: Optional[int]
164
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
165
+ def scale(width, height, img):
166
+ w = img.shape[1]
167
+ h = img.shape[0]
168
+ if width == w and height == h:
169
+ return img
170
+ img = PIL.Image.fromarray(img, 'RGB')
171
+ ww = width if width is not None else w
172
+ hh = height if height is not None else h
173
+ img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
174
+ return np.array(img)
175
+
176
+ def center_crop(width, height, img):
177
+ crop = np.min(img.shape[:2])
178
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
179
+ img = PIL.Image.fromarray(img, 'RGB')
180
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
181
+ return np.array(img)
182
+
183
+ def center_crop_wide(width, height, img):
184
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
185
+ if img.shape[1] < width or ch < height:
186
+ return None
187
+
188
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
189
+ img = PIL.Image.fromarray(img, 'RGB')
190
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
191
+ img = np.array(img)
192
+
193
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
194
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
195
+ return canvas
196
+
197
+ def center_crop_imagenet(image_size: int, arr: np.ndarray):
198
+ """
199
+ Center cropping implementation from ADM.
200
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
201
+ """
202
+ pil_image = PIL.Image.fromarray(arr)
203
+ while min(*pil_image.size) >= 2 * image_size:
204
+ new_size = tuple(x // 2 for x in pil_image.size)
205
+ assert len(new_size) == 2
206
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
207
+
208
+ scale = image_size / min(*pil_image.size)
209
+ new_size = tuple(round(x * scale) for x in pil_image.size)
210
+ assert len(new_size) == 2
211
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
212
+
213
+ arr = np.array(pil_image)
214
+ crop_y = (arr.shape[0] - image_size) // 2
215
+ crop_x = (arr.shape[1] - image_size) // 2
216
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
217
+
218
+ if transform is None:
219
+ return functools.partial(scale, output_width, output_height)
220
+ if transform == 'center-crop':
221
+ if output_width is None or output_height is None:
222
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
223
+ return functools.partial(center_crop, output_width, output_height)
224
+ if transform == 'center-crop-wide':
225
+ if output_width is None or output_height is None:
226
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
227
+ return functools.partial(center_crop_wide, output_width, output_height)
228
+ if transform == 'center-crop-dhariwal':
229
+ if output_width is None or output_height is None:
230
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
231
+ if output_width != output_height:
232
+ raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
233
+ return functools.partial(center_crop_imagenet, output_width)
234
+ assert False, 'unknown transform'
235
+
236
+ #----------------------------------------------------------------------------
237
+
238
+ def open_dataset(source, *, max_images: Optional[int]):
239
+ if os.path.isdir(source):
240
+ return open_image_folder(source, max_images=max_images)
241
+ elif os.path.isfile(source):
242
+ if file_ext(source) == 'zip':
243
+ return open_image_zip(source, max_images=max_images)
244
+ else:
245
+ raise click.ClickException(f'Only zip archives are supported: {source}')
246
+ else:
247
+ raise click.ClickException(f'Missing input file or directory: {source}')
248
+
249
+ #----------------------------------------------------------------------------
250
+
251
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
252
+ dest_ext = file_ext(dest)
253
+
254
+ if dest_ext == 'zip':
255
+ if os.path.dirname(dest) != '':
256
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
257
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
258
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
259
+ zf.writestr(fname, data)
260
+ return '', zip_write_bytes, zf.close
261
+ else:
262
+ # If the output folder already exists, check that is is
263
+ # empty.
264
+ #
265
+ # Note: creating the output directory is not strictly
266
+ # necessary as folder_write_bytes() also mkdirs, but it's better
267
+ # to give an error message earlier in case the dest folder
268
+ # somehow cannot be created.
269
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
270
+ raise click.ClickException('--dest folder must be empty')
271
+ os.makedirs(dest, exist_ok=True)
272
+
273
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
274
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
275
+ with open(fname, 'wb') as fout:
276
+ if isinstance(data, str):
277
+ data = data.encode('utf8')
278
+ fout.write(data)
279
+ return dest, folder_write_bytes, lambda: None
280
+
281
+ #----------------------------------------------------------------------------
282
+
283
+ @click.group()
284
+ def cmdline():
285
+ '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
286
+ if os.environ.get('WORLD_SIZE', '1') != '1':
287
+ raise click.ClickException('Distributed execution is not supported.')
288
+
289
+
290
+ #----------------------------------------------------------------------------
291
+
292
+
293
+
294
+ @cmdline.command()
295
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
296
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
297
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
298
+ @click.option('--enc-type', help='Maximum number of images to output', metavar='PATH', type=str, default='dinov2-vit-b')
299
+ @click.option('--resolution', help='Maximum number of images to output', metavar='INT', type=int, default=256)
300
+
301
+ def encode(
302
+ source: str,
303
+ dest: str,
304
+ max_images: Optional[int],
305
+ enc_type,
306
+ resolution
307
+ ):
308
+
309
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
310
+ encoder, encoder_type, architectures = load_encoders(enc_type, device, resolution)
311
+ encoder, encoder_type, architectures = encoder[0], encoder_type[0], architectures[0]
312
+ print("Encoder is over!!!")
313
+
314
+ """Encode pixel data to VAE latents."""
315
+ PIL.Image.init()
316
+ if dest == '':
317
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
318
+
319
+ num_files, input_iter = open_dataset(source, max_images=max_images)
320
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
321
+ print("Data is over!!!")
322
+ labels = []
323
+
324
+ temp_list1 = []
325
+ temp_list2 = []
326
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
327
+ with torch.no_grad():
328
+ img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
329
+ raw_image_ = preprocess_raw_image(img_tensor, encoder_type)
330
+ z = encoder.forward_features(raw_image_)
331
+ if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
332
+ temp_list1.append(z)
333
+ z = z.detach().cpu().numpy()
334
+ temp_list2.append(z)
335
+
336
+ idx_str = f'{idx:08d}'
337
+ archive_fname = f'{idx_str[:5]}/img-feature-{idx_str}.npy'
338
+
339
+ f = io.BytesIO()
340
+ np.save(f, z)
341
+ save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
342
+ labels.append([archive_fname, image.label] if image.label is not None else None)
343
+
344
+
345
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
346
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
347
+ close_dest()
348
+
349
+ if __name__ == "__main__":
350
+ cmdline()
351
+
352
+
353
+ #----------------------------------------------------------------------------
REG/preprocessing/dataset_prepare_convert.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+ #256
7
+ python preprocessing/dataset_tools.py convert \
8
+ --source=/home/share/imagenet/train \
9
+ --dest=/home/share/imagenet_vae/imagenet_256_vae \
10
+ --resolution=256x256 \
11
+ --transform=center-crop-dhariwal
REG/preprocessing/dataset_prepare_encode.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+ #256
7
+ python preprocessing/dataset_tools.py encode \
8
+ --source=/home/share/imagenet_vae/imagenet_256_vae \
9
+ --dest=/home/share/imagenet_vae/vae-sd-256
REG/preprocessing/dataset_tools.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Tool for creating ZIP/PNG based datasets."""
9
+
10
+ from collections.abc import Iterator
11
+ from dataclasses import dataclass
12
+ import functools
13
+ import io
14
+ import json
15
+ import os
16
+ import re
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+ import click
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from tqdm import tqdm
25
+
26
+ from encoders import StabilityVAEEncoder
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ @dataclass
31
+ class ImageEntry:
32
+ img: np.ndarray
33
+ label: Optional[int]
34
+
35
+ #----------------------------------------------------------------------------
36
+ # Parse a 'M,N' or 'MxN' integer tuple.
37
+ # Example: '4x2' returns (4,2)
38
+
39
+ def parse_tuple(s: str) -> Tuple[int, int]:
40
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
41
+ if m:
42
+ return int(m.group(1)), int(m.group(2))
43
+ raise click.ClickException(f'cannot parse tuple {s}')
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def maybe_min(a: int, b: Optional[int]) -> int:
48
+ if b is not None:
49
+ return min(a, b)
50
+ return a
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ def file_ext(name: Union[str, Path]) -> str:
55
+ return str(name).split('.')[-1]
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ def is_image_ext(fname: Union[str, Path]) -> bool:
60
+ ext = file_ext(fname).lower()
61
+ return f'.{ext}' in PIL.Image.EXTENSION
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
66
+ input_images = []
67
+ def _recurse_dirs(root: str): # workaround Path().rglob() slowness
68
+ with os.scandir(root) as it:
69
+ for e in it:
70
+ if e.is_file():
71
+ input_images.append(os.path.join(root, e.name))
72
+ elif e.is_dir():
73
+ _recurse_dirs(os.path.join(root, e.name))
74
+ _recurse_dirs(source_dir)
75
+ input_images = sorted([f for f in input_images if is_image_ext(f)])
76
+
77
+ arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
78
+ max_idx = maybe_min(len(input_images), max_images)
79
+
80
+ # Load labels.
81
+ labels = dict()
82
+ meta_fname = os.path.join(source_dir, 'dataset.json')
83
+ if os.path.isfile(meta_fname):
84
+ with open(meta_fname, 'r') as file:
85
+ data = json.load(file)['labels']
86
+ if data is not None:
87
+ labels = {x[0]: x[1] for x in data}
88
+
89
+ # No labels available => determine from top-level directory names.
90
+ if len(labels) == 0:
91
+ toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
92
+ toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
93
+ if len(toplevel_indices) > 1:
94
+ labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
95
+
96
+ def iterate_images():
97
+ for idx, fname in enumerate(input_images):
98
+ img = np.array(PIL.Image.open(fname).convert('RGB'))
99
+ yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
100
+ if idx >= max_idx - 1:
101
+ break
102
+ return max_idx, iterate_images()
103
+
104
+ #----------------------------------------------------------------------------
105
+
106
+ def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
107
+ with zipfile.ZipFile(source, mode='r') as z:
108
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
109
+ max_idx = maybe_min(len(input_images), max_images)
110
+
111
+ # Load labels.
112
+ labels = dict()
113
+ if 'dataset.json' in z.namelist():
114
+ with z.open('dataset.json', 'r') as file:
115
+ data = json.load(file)['labels']
116
+ if data is not None:
117
+ labels = {x[0]: x[1] for x in data}
118
+
119
+ def iterate_images():
120
+ with zipfile.ZipFile(source, mode='r') as z:
121
+ for idx, fname in enumerate(input_images):
122
+ with z.open(fname, 'r') as file:
123
+ img = np.array(PIL.Image.open(file).convert('RGB'))
124
+ yield ImageEntry(img=img, label=labels.get(fname))
125
+ if idx >= max_idx - 1:
126
+ break
127
+ return max_idx, iterate_images()
128
+
129
+ #----------------------------------------------------------------------------
130
+
131
+ def make_transform(
132
+ transform: Optional[str],
133
+ output_width: Optional[int],
134
+ output_height: Optional[int]
135
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
136
+ def scale(width, height, img):
137
+ w = img.shape[1]
138
+ h = img.shape[0]
139
+ if width == w and height == h:
140
+ return img
141
+ img = PIL.Image.fromarray(img, 'RGB')
142
+ ww = width if width is not None else w
143
+ hh = height if height is not None else h
144
+ img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
145
+ return np.array(img)
146
+
147
+ def center_crop(width, height, img):
148
+ crop = np.min(img.shape[:2])
149
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
150
+ img = PIL.Image.fromarray(img, 'RGB')
151
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
152
+ return np.array(img)
153
+
154
+ def center_crop_wide(width, height, img):
155
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
156
+ if img.shape[1] < width or ch < height:
157
+ return None
158
+
159
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
160
+ img = PIL.Image.fromarray(img, 'RGB')
161
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
162
+ img = np.array(img)
163
+
164
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
165
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
166
+ return canvas
167
+
168
+ def center_crop_imagenet(image_size: int, arr: np.ndarray):
169
+ """
170
+ Center cropping implementation from ADM.
171
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
172
+ """
173
+ pil_image = PIL.Image.fromarray(arr)
174
+ while min(*pil_image.size) >= 2 * image_size:
175
+ new_size = tuple(x // 2 for x in pil_image.size)
176
+ assert len(new_size) == 2
177
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
178
+
179
+ scale = image_size / min(*pil_image.size)
180
+ new_size = tuple(round(x * scale) for x in pil_image.size)
181
+ assert len(new_size) == 2
182
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
183
+
184
+ arr = np.array(pil_image)
185
+ crop_y = (arr.shape[0] - image_size) // 2
186
+ crop_x = (arr.shape[1] - image_size) // 2
187
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
188
+
189
+ if transform is None:
190
+ return functools.partial(scale, output_width, output_height)
191
+ if transform == 'center-crop':
192
+ if output_width is None or output_height is None:
193
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
194
+ return functools.partial(center_crop, output_width, output_height)
195
+ if transform == 'center-crop-wide':
196
+ if output_width is None or output_height is None:
197
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
198
+ return functools.partial(center_crop_wide, output_width, output_height)
199
+ if transform == 'center-crop-dhariwal':
200
+ if output_width is None or output_height is None:
201
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
202
+ if output_width != output_height:
203
+ raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
204
+ return functools.partial(center_crop_imagenet, output_width)
205
+ assert False, 'unknown transform'
206
+
207
+ #----------------------------------------------------------------------------
208
+
209
+ def open_dataset(source, *, max_images: Optional[int]):
210
+ if os.path.isdir(source):
211
+ return open_image_folder(source, max_images=max_images)
212
+ elif os.path.isfile(source):
213
+ if file_ext(source) == 'zip':
214
+ return open_image_zip(source, max_images=max_images)
215
+ else:
216
+ raise click.ClickException(f'Only zip archives are supported: {source}')
217
+ else:
218
+ raise click.ClickException(f'Missing input file or directory: {source}')
219
+
220
+ #----------------------------------------------------------------------------
221
+
222
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
223
+ dest_ext = file_ext(dest)
224
+
225
+ if dest_ext == 'zip':
226
+ if os.path.dirname(dest) != '':
227
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
228
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
229
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
230
+ zf.writestr(fname, data)
231
+ return '', zip_write_bytes, zf.close
232
+ else:
233
+ # If the output folder already exists, check that is is
234
+ # empty.
235
+ #
236
+ # Note: creating the output directory is not strictly
237
+ # necessary as folder_write_bytes() also mkdirs, but it's better
238
+ # to give an error message earlier in case the dest folder
239
+ # somehow cannot be created.
240
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
241
+ raise click.ClickException('--dest folder must be empty')
242
+ os.makedirs(dest, exist_ok=True)
243
+
244
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
245
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
246
+ with open(fname, 'wb') as fout:
247
+ if isinstance(data, str):
248
+ data = data.encode('utf8')
249
+ fout.write(data)
250
+ return dest, folder_write_bytes, lambda: None
251
+
252
+ #----------------------------------------------------------------------------
253
+
254
+ @click.group()
255
+ def cmdline():
256
+ '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
257
+ if os.environ.get('WORLD_SIZE', '1') != '1':
258
+ raise click.ClickException('Distributed execution is not supported.')
259
+
260
+ #----------------------------------------------------------------------------
261
+
262
+ @cmdline.command()
263
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
264
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
265
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
266
+ @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-dhariwal']))
267
+ @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
268
+
269
+ def convert(
270
+ source: str,
271
+ dest: str,
272
+ max_images: Optional[int],
273
+ transform: Optional[str],
274
+ resolution: Optional[Tuple[int, int]]
275
+ ):
276
+ """Convert an image dataset into archive format for training.
277
+
278
+ Specifying the input images:
279
+
280
+ \b
281
+ --source path/ Recursively load all images from path/
282
+ --source dataset.zip Load all images from dataset.zip
283
+
284
+ Specifying the output format and path:
285
+
286
+ \b
287
+ --dest /path/to/dir Save output files under /path/to/dir
288
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
289
+
290
+ The output dataset format can be either an image folder or an uncompressed zip archive.
291
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
292
+ offer better training performance on network file systems.
293
+
294
+ Images within the dataset archive will be stored as uncompressed PNG.
295
+ Uncompresed PNGs can be efficiently decoded in the training loop.
296
+
297
+ Class labels are stored in a file called 'dataset.json' that is stored at the
298
+ dataset root folder. This file has the following structure:
299
+
300
+ \b
301
+ {
302
+ "labels": [
303
+ ["00000/img00000000.png",6],
304
+ ["00000/img00000001.png",9],
305
+ ... repeated for every image in the datase
306
+ ["00049/img00049999.png",1]
307
+ ]
308
+ }
309
+
310
+ If the 'dataset.json' file cannot be found, class labels are determined from
311
+ top-level directory names.
312
+
313
+ Image scale/crop and resolution requirements:
314
+
315
+ Output images must be square-shaped and they must all have the same power-of-two
316
+ dimensions.
317
+
318
+ To scale arbitrary input image size to a specific width and height, use the
319
+ --resolution option. Output resolution will be either the original
320
+ input resolution (if resolution was not specified) or the one specified with
321
+ --resolution option.
322
+
323
+ The --transform=center-crop-dhariwal selects a crop/rescale mode that is intended
324
+ to exactly match with results obtained for ImageNet in common diffusion model literature:
325
+
326
+ \b
327
+ python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \\
328
+ --dest=datasets/img64.zip --resolution=64x64 --transform=center-crop-dhariwal
329
+ """
330
+ PIL.Image.init()
331
+ if dest == '':
332
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
333
+ print("Begin!!!!!!!!")
334
+ num_files, input_iter = open_dataset(source, max_images=max_images)
335
+ print("open_dataset is over")
336
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
337
+ print("open_dest is over")
338
+ transform_image = make_transform(transform, *resolution if resolution is not None else (None, None))
339
+ dataset_attrs = None
340
+
341
+ labels = []
342
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
343
+ idx_str = f'{idx:08d}'
344
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
345
+
346
+ # Apply crop and resize.
347
+ img = transform_image(image.img)
348
+ if img is None:
349
+ continue
350
+
351
+ # Error check to require uniform image attributes across
352
+ # the whole dataset.
353
+ assert img.ndim == 3
354
+ cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0]}
355
+ if dataset_attrs is None:
356
+ dataset_attrs = cur_image_attrs
357
+ width = dataset_attrs['width']
358
+ height = dataset_attrs['height']
359
+ if width != height:
360
+ raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
361
+ if width != 2 ** int(np.floor(np.log2(width))):
362
+ raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
363
+ elif dataset_attrs != cur_image_attrs:
364
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
365
+ raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
366
+
367
+ # Save the image as an uncompressed PNG.
368
+ img = PIL.Image.fromarray(img)
369
+ image_bits = io.BytesIO()
370
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
371
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
372
+ labels.append([archive_fname, image.label] if image.label is not None else None)
373
+
374
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
375
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
376
+ close_dest()
377
+
378
+ #----------------------------------------------------------------------------
379
+
380
+ @cmdline.command()
381
+ @click.option('--model-url', help='VAE encoder model', metavar='URL', type=str, default='stabilityai/sd-vae-ft-mse', show_default=True)
382
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
383
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
384
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
385
+
386
+ def encode(
387
+ model_url: str,
388
+ source: str,
389
+ dest: str,
390
+ max_images: Optional[int],
391
+ ):
392
+ """Encode pixel data to VAE latents."""
393
+ PIL.Image.init()
394
+ if dest == '':
395
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
396
+
397
+ vae = StabilityVAEEncoder(vae_name=model_url, batch_size=1)
398
+ print("VAE is over!!!")
399
+ num_files, input_iter = open_dataset(source, max_images=max_images)
400
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
401
+ print("Data is over!!!")
402
+ labels = []
403
+ #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
404
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
405
+ img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
406
+ mean_std = vae.encode_pixels(img_tensor)[0].cpu()
407
+ idx_str = f'{idx:08d}'
408
+ archive_fname = f'{idx_str[:5]}/img-mean-std-{idx_str}.npy'
409
+
410
+ f = io.BytesIO()
411
+ np.save(f, mean_std)
412
+ save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
413
+ labels.append([archive_fname, image.label] if image.label is not None else None)
414
+
415
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
416
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
417
+ close_dest()
418
+
419
+ if __name__ == "__main__":
420
+ cmdline()
421
+
422
+ #----------------------------------------------------------------------------
REG/preprocessing/dnnlib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ from .util import EasyDict, make_cache_dir_path
REG/preprocessing/dnnlib/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (291 Bytes). View file
 
REG/preprocessing/dnnlib/__pycache__/util.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
REG/preprocessing/dnnlib/util.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import tempfile
27
+ import urllib
28
+ import urllib.parse
29
+ import uuid
30
+
31
+ from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional
32
+
33
+ # Util classes
34
+ # ------------------------------------------------------------------------------------------
35
+
36
+
37
+ class EasyDict(dict):
38
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
39
+
40
+ def __getattr__(self, name: str) -> Any:
41
+ try:
42
+ return self[name]
43
+ except KeyError:
44
+ raise AttributeError(name)
45
+
46
+ def __setattr__(self, name: str, value: Any) -> None:
47
+ self[name] = value
48
+
49
+ def __delattr__(self, name: str) -> None:
50
+ del self[name]
51
+
52
+
53
+ class Logger(object):
54
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
55
+
56
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
57
+ self.file = None
58
+
59
+ if file_name is not None:
60
+ self.file = open(file_name, file_mode)
61
+
62
+ self.should_flush = should_flush
63
+ self.stdout = sys.stdout
64
+ self.stderr = sys.stderr
65
+
66
+ sys.stdout = self
67
+ sys.stderr = self
68
+
69
+ def __enter__(self) -> "Logger":
70
+ return self
71
+
72
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
73
+ self.close()
74
+
75
+ def write(self, text: Union[str, bytes]) -> None:
76
+ """Write text to stdout (and a file) and optionally flush."""
77
+ if isinstance(text, bytes):
78
+ text = text.decode()
79
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
80
+ return
81
+
82
+ if self.file is not None:
83
+ self.file.write(text)
84
+
85
+ self.stdout.write(text)
86
+
87
+ if self.should_flush:
88
+ self.flush()
89
+
90
+ def flush(self) -> None:
91
+ """Flush written text to both stdout and a file, if open."""
92
+ if self.file is not None:
93
+ self.file.flush()
94
+
95
+ self.stdout.flush()
96
+
97
+ def close(self) -> None:
98
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
99
+ self.flush()
100
+
101
+ # if using multiple loggers, prevent closing in wrong order
102
+ if sys.stdout is self:
103
+ sys.stdout = self.stdout
104
+ if sys.stderr is self:
105
+ sys.stderr = self.stderr
106
+
107
+ if self.file is not None:
108
+ self.file.close()
109
+ self.file = None
110
+
111
+
112
+ # Cache directories
113
+ # ------------------------------------------------------------------------------------------
114
+
115
+ _dnnlib_cache_dir = None
116
+
117
+ def set_cache_dir(path: str) -> None:
118
+ global _dnnlib_cache_dir
119
+ _dnnlib_cache_dir = path
120
+
121
+ def make_cache_dir_path(*paths: str) -> str:
122
+ if _dnnlib_cache_dir is not None:
123
+ return os.path.join(_dnnlib_cache_dir, *paths)
124
+ if 'DNNLIB_CACHE_DIR' in os.environ:
125
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
126
+ if 'HOME' in os.environ:
127
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
128
+ if 'USERPROFILE' in os.environ:
129
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
130
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
131
+
132
+ # Small util functions
133
+ # ------------------------------------------------------------------------------------------
134
+
135
+
136
+ def format_time(seconds: Union[int, float]) -> str:
137
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
138
+ s = int(np.rint(seconds))
139
+
140
+ if s < 60:
141
+ return "{0}s".format(s)
142
+ elif s < 60 * 60:
143
+ return "{0}m {1:02}s".format(s // 60, s % 60)
144
+ elif s < 24 * 60 * 60:
145
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
146
+ else:
147
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
148
+
149
+
150
+ def format_time_brief(seconds: Union[int, float]) -> str:
151
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
152
+ s = int(np.rint(seconds))
153
+
154
+ if s < 60:
155
+ return "{0}s".format(s)
156
+ elif s < 60 * 60:
157
+ return "{0}m {1:02}s".format(s // 60, s % 60)
158
+ elif s < 24 * 60 * 60:
159
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
160
+ else:
161
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
162
+
163
+
164
+ def tuple_product(t: Tuple) -> Any:
165
+ """Calculate the product of the tuple elements."""
166
+ result = 1
167
+
168
+ for v in t:
169
+ result *= v
170
+
171
+ return result
172
+
173
+
174
+ _str_to_ctype = {
175
+ "uint8": ctypes.c_ubyte,
176
+ "uint16": ctypes.c_uint16,
177
+ "uint32": ctypes.c_uint32,
178
+ "uint64": ctypes.c_uint64,
179
+ "int8": ctypes.c_byte,
180
+ "int16": ctypes.c_int16,
181
+ "int32": ctypes.c_int32,
182
+ "int64": ctypes.c_int64,
183
+ "float32": ctypes.c_float,
184
+ "float64": ctypes.c_double
185
+ }
186
+
187
+
188
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
189
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
190
+ type_str = None
191
+
192
+ if isinstance(type_obj, str):
193
+ type_str = type_obj
194
+ elif hasattr(type_obj, "__name__"):
195
+ type_str = type_obj.__name__
196
+ elif hasattr(type_obj, "name"):
197
+ type_str = type_obj.name
198
+ else:
199
+ raise RuntimeError("Cannot infer type name from input")
200
+
201
+ assert type_str in _str_to_ctype.keys()
202
+
203
+ my_dtype = np.dtype(type_str)
204
+ my_ctype = _str_to_ctype[type_str]
205
+
206
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
207
+
208
+ return my_dtype, my_ctype
209
+
210
+
211
+ def is_pickleable(obj: Any) -> bool:
212
+ try:
213
+ with io.BytesIO() as stream:
214
+ pickle.dump(obj, stream)
215
+ return True
216
+ except:
217
+ return False
218
+
219
+
220
+ # Functionality to import modules/objects by name, and call functions by name
221
+ # ------------------------------------------------------------------------------------------
222
+
223
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
224
+ """Searches for the underlying module behind the name to some python object.
225
+ Returns the module and the object name (original name with module part removed)."""
226
+
227
+ # allow convenience shorthands, substitute them by full names
228
+ obj_name = re.sub("^np.", "numpy.", obj_name)
229
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
230
+
231
+ # list alternatives for (module_name, local_obj_name)
232
+ parts = obj_name.split(".")
233
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
234
+
235
+ # try each alternative in turn
236
+ for module_name, local_obj_name in name_pairs:
237
+ try:
238
+ module = importlib.import_module(module_name) # may raise ImportError
239
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
240
+ return module, local_obj_name
241
+ except:
242
+ pass
243
+
244
+ # maybe some of the modules themselves contain errors?
245
+ for module_name, _local_obj_name in name_pairs:
246
+ try:
247
+ importlib.import_module(module_name) # may raise ImportError
248
+ except ImportError:
249
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
250
+ raise
251
+
252
+ # maybe the requested attribute is missing?
253
+ for module_name, local_obj_name in name_pairs:
254
+ try:
255
+ module = importlib.import_module(module_name) # may raise ImportError
256
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
257
+ except ImportError:
258
+ pass
259
+
260
+ # we are out of luck, but we have no idea why
261
+ raise ImportError(obj_name)
262
+
263
+
264
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
265
+ """Traverses the object name and returns the last (rightmost) python object."""
266
+ if obj_name == '':
267
+ return module
268
+ obj = module
269
+ for part in obj_name.split("."):
270
+ obj = getattr(obj, part)
271
+ return obj
272
+
273
+
274
+ def get_obj_by_name(name: str) -> Any:
275
+ """Finds the python object with the given name."""
276
+ module, obj_name = get_module_from_obj_name(name)
277
+ return get_obj_from_module(module, obj_name)
278
+
279
+
280
+ def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any:
281
+ """Finds the python object with the given name and calls it as a function."""
282
+ assert func_name is not None
283
+ func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name
284
+ assert callable(func_obj)
285
+ return func_obj(*args, **kwargs)
286
+
287
+
288
+ def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any:
289
+ """Finds the python class with the given name and constructs it with the given arguments."""
290
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
291
+
292
+
293
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
294
+ """Get the directory path of the module containing the given object name."""
295
+ module, _ = get_module_from_obj_name(obj_name)
296
+ return os.path.dirname(inspect.getfile(module))
297
+
298
+
299
+ def is_top_level_function(obj: Any) -> bool:
300
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
301
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
302
+
303
+
304
+ def get_top_level_function_name(obj: Any) -> str:
305
+ """Return the fully-qualified name of a top-level function."""
306
+ assert is_top_level_function(obj)
307
+ module = obj.__module__
308
+ if module == '__main__':
309
+ fname = sys.modules[module].__file__
310
+ assert fname is not None
311
+ module = os.path.splitext(os.path.basename(fname))[0]
312
+ return module + "." + obj.__name__
313
+
314
+
315
+ # File system helpers
316
+ # ------------------------------------------------------------------------------------------
317
+
318
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
319
+ """List all files recursively in a given directory while ignoring given file and directory names.
320
+ Returns list of tuples containing both absolute and relative paths."""
321
+ assert os.path.isdir(dir_path)
322
+ base_name = os.path.basename(os.path.normpath(dir_path))
323
+
324
+ if ignores is None:
325
+ ignores = []
326
+
327
+ result = []
328
+
329
+ for root, dirs, files in os.walk(dir_path, topdown=True):
330
+ for ignore_ in ignores:
331
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
332
+
333
+ # dirs need to be edited in-place
334
+ for d in dirs_to_remove:
335
+ dirs.remove(d)
336
+
337
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
338
+
339
+ absolute_paths = [os.path.join(root, f) for f in files]
340
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
341
+
342
+ if add_base_to_relative:
343
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
344
+
345
+ assert len(absolute_paths) == len(relative_paths)
346
+ result += zip(absolute_paths, relative_paths)
347
+
348
+ return result
349
+
350
+
351
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
352
+ """Takes in a list of tuples of (src, dst) paths and copies files.
353
+ Will create all necessary directories."""
354
+ for file in files:
355
+ target_dir_name = os.path.dirname(file[1])
356
+
357
+ # will create all intermediate-level directories
358
+ os.makedirs(target_dir_name, exist_ok=True)
359
+ shutil.copyfile(file[0], file[1])
360
+
361
+
362
+ # URL helpers
363
+ # ------------------------------------------------------------------------------------------
364
+
365
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
366
+ """Determine whether the given object is a valid URL string."""
367
+ if not isinstance(obj, str) or not "://" in obj:
368
+ return False
369
+ if allow_file_urls and obj.startswith('file://'):
370
+ return True
371
+ try:
372
+ res = urllib.parse.urlparse(obj)
373
+ if not res.scheme or not res.netloc or not "." in res.netloc:
374
+ return False
375
+ res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/"))
376
+ if not res.scheme or not res.netloc or not "." in res.netloc:
377
+ return False
378
+ except:
379
+ return False
380
+ return True
381
+
382
+ # Note on static typing: a better API would be to split 'open_url' to 'openl_url' and
383
+ # 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True`
384
+ # case is somewhat uncommon, we just pretend like this function never returns a string
385
+ # and type ignore return value for those cases.
386
+ def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO:
387
+ """Download the given URL and return a binary-mode file object to access the data."""
388
+ assert num_attempts >= 1
389
+ assert not (return_filename and (not cache))
390
+
391
+ # Doesn't look like an URL scheme so interpret it as a local filename.
392
+ if not re.match('^[a-z]+://', url):
393
+ return url if return_filename else open(url, "rb") # type: ignore
394
+
395
+ # Handle file URLs. This code handles unusual file:// patterns that
396
+ # arise on Windows:
397
+ #
398
+ # file:///c:/foo.txt
399
+ #
400
+ # which would translate to a local '/c:/foo.txt' filename that's
401
+ # invalid. Drop the forward slash for such pathnames.
402
+ #
403
+ # If you touch this code path, you should test it on both Linux and
404
+ # Windows.
405
+ #
406
+ # Some internet resources suggest using urllib.request.url2pathname()
407
+ # but that converts forward slashes to backslashes and this causes
408
+ # its own set of problems.
409
+ if url.startswith('file://'):
410
+ filename = urllib.parse.urlparse(url).path
411
+ if re.match(r'^/[a-zA-Z]:', filename):
412
+ filename = filename[1:]
413
+ return filename if return_filename else open(filename, "rb") # type: ignore
414
+
415
+ assert is_url(url)
416
+
417
+ # Lookup from cache.
418
+ if cache_dir is None:
419
+ cache_dir = make_cache_dir_path('downloads')
420
+
421
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
422
+ if cache:
423
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
424
+ if len(cache_files) == 1:
425
+ filename = cache_files[0]
426
+ return filename if return_filename else open(filename, "rb") # type: ignore
427
+
428
+ # Download.
429
+ url_name = None
430
+ url_data = None
431
+ with requests.Session() as session:
432
+ if verbose:
433
+ print("Downloading %s ..." % url, end="", flush=True)
434
+ for attempts_left in reversed(range(num_attempts)):
435
+ try:
436
+ with session.get(url) as res:
437
+ res.raise_for_status()
438
+ if len(res.content) == 0:
439
+ raise IOError("No data received")
440
+
441
+ if len(res.content) < 8192:
442
+ content_str = res.content.decode("utf-8")
443
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
444
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
445
+ if len(links) == 1:
446
+ url = urllib.parse.urljoin(url, links[0])
447
+ raise IOError("Google Drive virus checker nag")
448
+ if "Google Drive - Quota exceeded" in content_str:
449
+ raise IOError("Google Drive download quota exceeded -- please try again later")
450
+
451
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
452
+ url_name = match[1] if match else url
453
+ url_data = res.content
454
+ if verbose:
455
+ print(" done")
456
+ break
457
+ except KeyboardInterrupt:
458
+ raise
459
+ except:
460
+ if not attempts_left:
461
+ if verbose:
462
+ print(" failed")
463
+ raise
464
+ if verbose:
465
+ print(".", end="", flush=True)
466
+
467
+ assert url_data is not None
468
+
469
+ # Save to cache.
470
+ if cache:
471
+ assert url_name is not None
472
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
473
+ safe_name = safe_name[:min(len(safe_name), 128)]
474
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
475
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
476
+ os.makedirs(cache_dir, exist_ok=True)
477
+ with open(temp_file, "wb") as f:
478
+ f.write(url_data)
479
+ os.replace(temp_file, cache_file) # atomic
480
+ if return_filename:
481
+ return cache_file # type: ignore
482
+
483
+ # Return data as file object.
484
+ assert not return_filename
485
+ return io.BytesIO(url_data)
REG/preprocessing/encoders.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Converting between pixel and latent representations of image data."""
9
+
10
+ import os
11
+ import warnings
12
+ import numpy as np
13
+ import torch
14
+ from torch_utils import persistence
15
+ from torch_utils import misc
16
+
17
+ warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
18
+ warnings.filterwarnings('ignore', '`resume_download` is deprecated')
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Abstract base class for encoders/decoders that convert back and forth
22
+ # between pixel and latent representations of image data.
23
+ #
24
+ # Logically, "raw pixels" are first encoded into "raw latents" that are
25
+ # then further encoded into "final latents". Decoding, on the other hand,
26
+ # goes directly from the final latents to raw pixels. The final latents are
27
+ # used as inputs and outputs of the model, whereas the raw latents are
28
+ # stored in the dataset. This separation provides added flexibility in terms
29
+ # of performing just-in-time adjustments, such as data whitening, without
30
+ # having to construct a new dataset.
31
+ #
32
+ # All image data is represented as PyTorch tensors in NCHW order.
33
+ # Raw pixels are represented as 3-channel uint8.
34
+
35
+ @persistence.persistent_class
36
+ class Encoder:
37
+ def __init__(self):
38
+ pass
39
+
40
+ def init(self, device): # force lazy init to happen now
41
+ pass
42
+
43
+ def __getstate__(self):
44
+ return self.__dict__
45
+
46
+ def encode_pixels(self, x): # raw pixels => raw latents
47
+ raise NotImplementedError # to be overridden by subclass
48
+ #----------------------------------------------------------------------------
49
+ # Pre-trained VAE encoder from Stability AI.
50
+
51
+ @persistence.persistent_class
52
+ class StabilityVAEEncoder(Encoder):
53
+ def __init__(self,
54
+ vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
55
+ batch_size = 8, # Batch size to use when running the VAE.
56
+ ):
57
+ super().__init__()
58
+ self.vae_name = vae_name
59
+ self.batch_size = int(batch_size)
60
+ self._vae = None
61
+
62
+ def init(self, device): # force lazy init to happen now
63
+ super().init(device)
64
+ if self._vae is None:
65
+ self._vae = load_stability_vae(self.vae_name, device=device)
66
+ else:
67
+ self._vae.to(device)
68
+
69
+ def __getstate__(self):
70
+ return dict(super().__getstate__(), _vae=None) # do not pickle the vae
71
+
72
+ def _run_vae_encoder(self, x):
73
+ d = self._vae.encode(x)['latent_dist']
74
+ return torch.cat([d.mean, d.std], dim=1)
75
+
76
+ def encode_pixels(self, x): # raw pixels => raw latents
77
+ self.init(x.device)
78
+ x = x.to(torch.float32) / 127.5 - 1
79
+ x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
80
+ return x
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
85
+ import dnnlib
86
+ cache_dir = dnnlib.make_cache_dir_path('diffusers')
87
+ os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
88
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
89
+ os.environ['HF_HOME'] = cache_dir
90
+
91
+
92
+ import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
93
+ try:
94
+ # First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
95
+ vae = diffusers.models.AutoencoderKL.from_pretrained(
96
+ vae_name, cache_dir=cache_dir, local_files_only=True
97
+ )
98
+ except:
99
+ # Could not load the model from cache; try without local_files_only.
100
+ vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir)
101
+ return vae.eval().requires_grad_(False).to(device)
102
+
103
+ #----------------------------------------------------------------------------
REG/preprocessing/torch_utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ # empty
REG/preprocessing/torch_utils/distributed.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import os
9
+ import re
10
+ import socket
11
+ import torch
12
+ import torch.distributed
13
+ from . import training_stats
14
+
15
+ _sync_device = None
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init():
20
+ global _sync_device
21
+
22
+ if not torch.distributed.is_initialized():
23
+ # Setup some reasonable defaults for env-based distributed init if
24
+ # not set by the running environment.
25
+ if 'MASTER_ADDR' not in os.environ:
26
+ os.environ['MASTER_ADDR'] = 'localhost'
27
+ if 'MASTER_PORT' not in os.environ:
28
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29
+ s.bind(('', 0))
30
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
31
+ os.environ['MASTER_PORT'] = str(s.getsockname()[1])
32
+ s.close()
33
+ if 'RANK' not in os.environ:
34
+ os.environ['RANK'] = '0'
35
+ if 'LOCAL_RANK' not in os.environ:
36
+ os.environ['LOCAL_RANK'] = '0'
37
+ if 'WORLD_SIZE' not in os.environ:
38
+ os.environ['WORLD_SIZE'] = '1'
39
+ backend = 'gloo' if os.name == 'nt' else 'nccl'
40
+ torch.distributed.init_process_group(backend=backend, init_method='env://')
41
+ torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
42
+
43
+ _sync_device = torch.device('cuda') if get_world_size() > 1 else None
44
+ training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ def get_rank():
49
+ return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ def get_world_size():
54
+ return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ def should_stop():
59
+ return False
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ def should_suspend():
64
+ return False
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def request_suspend():
69
+ pass
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def update_progress(cur, total):
74
+ pass
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def print0(*args, **kwargs):
79
+ if get_rank() == 0:
80
+ print(*args, **kwargs)
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ class CheckpointIO:
85
+ def __init__(self, **kwargs):
86
+ self._state_objs = kwargs
87
+
88
+ def save(self, pt_path, verbose=True):
89
+ if verbose:
90
+ print0(f'Saving {pt_path} ... ', end='', flush=True)
91
+ data = dict()
92
+ for name, obj in self._state_objs.items():
93
+ if obj is None:
94
+ data[name] = None
95
+ elif isinstance(obj, dict):
96
+ data[name] = obj
97
+ elif hasattr(obj, 'state_dict'):
98
+ data[name] = obj.state_dict()
99
+ elif hasattr(obj, '__getstate__'):
100
+ data[name] = obj.__getstate__()
101
+ elif hasattr(obj, '__dict__'):
102
+ data[name] = obj.__dict__
103
+ else:
104
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
105
+ if get_rank() == 0:
106
+ torch.save(data, pt_path)
107
+ if verbose:
108
+ print0('done')
109
+
110
+ def load(self, pt_path, verbose=True):
111
+ if verbose:
112
+ print0(f'Loading {pt_path} ... ', end='', flush=True)
113
+ data = torch.load(pt_path, map_location=torch.device('cpu'))
114
+ for name, obj in self._state_objs.items():
115
+ if obj is None:
116
+ pass
117
+ elif isinstance(obj, dict):
118
+ obj.clear()
119
+ obj.update(data[name])
120
+ elif hasattr(obj, 'load_state_dict'):
121
+ obj.load_state_dict(data[name])
122
+ elif hasattr(obj, '__setstate__'):
123
+ obj.__setstate__(data[name])
124
+ elif hasattr(obj, '__dict__'):
125
+ obj.__dict__.clear()
126
+ obj.__dict__.update(data[name])
127
+ else:
128
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
129
+ if verbose:
130
+ print0('done')
131
+
132
+ def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
133
+ fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
134
+ if len(fnames) == 0:
135
+ return None
136
+ pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
137
+ self.load(pt_path, verbose=verbose)
138
+ return pt_path
139
+
140
+ #----------------------------------------------------------------------------
REG/preprocessing/torch_utils/misc.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import re
9
+ import contextlib
10
+ import functools
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Re-seed torch & numpy random generators based on the given arguments.
18
+
19
+ def set_random_seed(*args):
20
+ seed = hash(args) % (1 << 31)
21
+ torch.manual_seed(seed)
22
+ np.random.seed(seed)
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
26
+ # same constant is used multiple times.
27
+
28
+ _constant_cache = dict()
29
+
30
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
31
+ value = np.asarray(value)
32
+ if shape is not None:
33
+ shape = tuple(shape)
34
+ if dtype is None:
35
+ dtype = torch.get_default_dtype()
36
+ if device is None:
37
+ device = torch.device('cpu')
38
+ if memory_format is None:
39
+ memory_format = torch.contiguous_format
40
+
41
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
42
+ tensor = _constant_cache.get(key, None)
43
+ if tensor is None:
44
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
45
+ if shape is not None:
46
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
47
+ tensor = tensor.contiguous(memory_format=memory_format)
48
+ _constant_cache[key] = tensor
49
+ return tensor
50
+
51
+ #----------------------------------------------------------------------------
52
+ # Variant of constant() that inherits dtype and device from the given
53
+ # reference tensor by default.
54
+
55
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
56
+ if dtype is None:
57
+ dtype = ref.dtype
58
+ if device is None:
59
+ device = ref.device
60
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
61
+
62
+ #----------------------------------------------------------------------------
63
+ # Cached construction of temporary tensors in pinned CPU memory.
64
+
65
+ @functools.lru_cache(None)
66
+ def pinned_buf(shape, dtype):
67
+ return torch.empty(shape, dtype=dtype).pin_memory()
68
+
69
+ #----------------------------------------------------------------------------
70
+ # Symbolic assert.
71
+
72
+ try:
73
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
74
+ except AttributeError:
75
+ symbolic_assert = torch.Assert # 1.7.0
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
79
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
80
+
81
+ @contextlib.contextmanager
82
+ def suppress_tracer_warnings():
83
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
84
+ warnings.filters.insert(0, flt)
85
+ yield
86
+ warnings.filters.remove(flt)
87
+
88
+ #----------------------------------------------------------------------------
89
+ # Assert that the shape of a tensor matches the given list of integers.
90
+ # None indicates that the size of a dimension is allowed to vary.
91
+ # Performs symbolic assertion when used in torch.jit.trace().
92
+
93
+ def assert_shape(tensor, ref_shape):
94
+ if tensor.ndim != len(ref_shape):
95
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
96
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
97
+ if ref_size is None:
98
+ pass
99
+ elif isinstance(ref_size, torch.Tensor):
100
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
101
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
102
+ elif isinstance(size, torch.Tensor):
103
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
104
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
105
+ elif size != ref_size:
106
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
107
+
108
+ #----------------------------------------------------------------------------
109
+ # Function decorator that calls torch.autograd.profiler.record_function().
110
+
111
+ def profiled_function(fn):
112
+ def decorator(*args, **kwargs):
113
+ with torch.autograd.profiler.record_function(fn.__name__):
114
+ return fn(*args, **kwargs)
115
+ decorator.__name__ = fn.__name__
116
+ return decorator
117
+
118
+ #----------------------------------------------------------------------------
119
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
120
+ # indefinitely, shuffling items as it goes.
121
+
122
+ class InfiniteSampler(torch.utils.data.Sampler):
123
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0):
124
+ assert len(dataset) > 0
125
+ assert num_replicas > 0
126
+ assert 0 <= rank < num_replicas
127
+ warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed')
128
+ super().__init__(dataset)
129
+ self.dataset_size = len(dataset)
130
+ self.start_idx = start_idx + rank
131
+ self.stride = num_replicas
132
+ self.shuffle = shuffle
133
+ self.seed = seed
134
+
135
+ def __iter__(self):
136
+ idx = self.start_idx
137
+ epoch = None
138
+ while True:
139
+ if epoch != idx // self.dataset_size:
140
+ epoch = idx // self.dataset_size
141
+ order = np.arange(self.dataset_size)
142
+ if self.shuffle:
143
+ np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order)
144
+ yield int(order[idx % self.dataset_size])
145
+ idx += self.stride
146
+
147
+ #----------------------------------------------------------------------------
148
+ # Utilities for operating with torch.nn.Module parameters and buffers.
149
+
150
+ def params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.parameters()) + list(module.buffers())
153
+
154
+ def named_params_and_buffers(module):
155
+ assert isinstance(module, torch.nn.Module)
156
+ return list(module.named_parameters()) + list(module.named_buffers())
157
+
158
+ @torch.no_grad()
159
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
160
+ assert isinstance(src_module, torch.nn.Module)
161
+ assert isinstance(dst_module, torch.nn.Module)
162
+ src_tensors = dict(named_params_and_buffers(src_module))
163
+ for name, tensor in named_params_and_buffers(dst_module):
164
+ assert (name in src_tensors) or (not require_all)
165
+ if name in src_tensors:
166
+ tensor.copy_(src_tensors[name])
167
+
168
+ #----------------------------------------------------------------------------
169
+ # Context manager for easily enabling/disabling DistributedDataParallel
170
+ # synchronization.
171
+
172
+ @contextlib.contextmanager
173
+ def ddp_sync(module, sync):
174
+ assert isinstance(module, torch.nn.Module)
175
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
176
+ yield
177
+ else:
178
+ with module.no_sync():
179
+ yield
180
+
181
+ #----------------------------------------------------------------------------
182
+ # Check DistributedDataParallel consistency across processes.
183
+
184
+ def check_ddp_consistency(module, ignore_regex=None):
185
+ assert isinstance(module, torch.nn.Module)
186
+ for name, tensor in named_params_and_buffers(module):
187
+ fullname = type(module).__name__ + '.' + name
188
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
189
+ continue
190
+ tensor = tensor.detach()
191
+ if tensor.is_floating_point():
192
+ tensor = torch.nan_to_num(tensor)
193
+ other = tensor.clone()
194
+ torch.distributed.broadcast(tensor=other, src=0)
195
+ assert (tensor == other).all(), fullname
196
+
197
+ #----------------------------------------------------------------------------
198
+ # Print summary table of module hierarchy.
199
+
200
+ @torch.no_grad()
201
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
202
+ assert isinstance(module, torch.nn.Module)
203
+ assert not isinstance(module, torch.jit.ScriptModule)
204
+ assert isinstance(inputs, (tuple, list))
205
+
206
+ # Register hooks.
207
+ entries = []
208
+ nesting = [0]
209
+ def pre_hook(_mod, _inputs):
210
+ nesting[0] += 1
211
+ def post_hook(mod, _inputs, outputs):
212
+ nesting[0] -= 1
213
+ if nesting[0] <= max_nesting:
214
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
215
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
216
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
217
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
218
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
219
+
220
+ # Run module.
221
+ outputs = module(*inputs)
222
+ for hook in hooks:
223
+ hook.remove()
224
+
225
+ # Identify unique outputs, parameters, and buffers.
226
+ tensors_seen = set()
227
+ for e in entries:
228
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
229
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
230
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
231
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
232
+
233
+ # Filter out redundant entries.
234
+ if skip_redundant:
235
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
236
+
237
+ # Construct table.
238
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
239
+ rows += [['---'] * len(rows[0])]
240
+ param_total = 0
241
+ buffer_total = 0
242
+ submodule_names = {mod: name for name, mod in module.named_modules()}
243
+ for e in entries:
244
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
245
+ param_size = sum(t.numel() for t in e.unique_params)
246
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
247
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
248
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
249
+ rows += [[
250
+ name + (':0' if len(e.outputs) >= 2 else ''),
251
+ str(param_size) if param_size else '-',
252
+ str(buffer_size) if buffer_size else '-',
253
+ (output_shapes + ['-'])[0],
254
+ (output_dtypes + ['-'])[0],
255
+ ]]
256
+ for idx in range(1, len(e.outputs)):
257
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
258
+ param_total += param_size
259
+ buffer_total += buffer_size
260
+ rows += [['---'] * len(rows[0])]
261
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
262
+
263
+ # Print table.
264
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
265
+ print()
266
+ for row in rows:
267
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
268
+ print()
269
+
270
+ #----------------------------------------------------------------------------
271
+ # Tile a batch of images into a 2D grid.
272
+
273
+ def tile_images(x, w, h):
274
+ assert x.ndim == 4 # NCHW => CHW
275
+ return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3])
276
+
277
+ #----------------------------------------------------------------------------
REG/preprocessing/torch_utils/persistence.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for pickling Python code alongside other data.
9
+
10
+ The pickled code is automatically imported into a separate Python module
11
+ during unpickling. This way, any previously exported pickles will remain
12
+ usable even if the original code is no longer available, or if the current
13
+ version of the code is not consistent with what was originally pickled."""
14
+
15
+ import sys
16
+ import pickle
17
+ import io
18
+ import inspect
19
+ import copy
20
+ import uuid
21
+ import types
22
+ import functools
23
+ import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. This feature can be disabled on a per-instance basis
83
+ by setting `self._record_init_args = False` in the constructor.
84
+
85
+ A typical use case is to first unpickle a previous instance of a
86
+ persistent class, and then upgrade it to use the latest version of
87
+ the source code:
88
+
89
+ with open('old_pickle.pkl', 'rb') as f:
90
+ old_net = pickle.load(f)
91
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
92
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
93
+ """
94
+ assert isinstance(orig_class, type)
95
+ if is_persistent(orig_class):
96
+ return orig_class
97
+
98
+ assert orig_class.__module__ in sys.modules
99
+ orig_module = sys.modules[orig_class.__module__]
100
+ orig_module_src = _module_to_src(orig_module)
101
+
102
+ @functools.wraps(orig_class, updated=())
103
+ class Decorator(orig_class):
104
+ _orig_module_src = orig_module_src
105
+ _orig_class_name = orig_class.__name__
106
+
107
+ def __init__(self, *args, **kwargs):
108
+ super().__init__(*args, **kwargs)
109
+ record_init_args = getattr(self, '_record_init_args', True)
110
+ self._init_args = copy.deepcopy(args) if record_init_args else None
111
+ self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
112
+ assert orig_class.__name__ in orig_module.__dict__
113
+ _check_pickleable(self.__reduce__())
114
+
115
+ @property
116
+ def init_args(self):
117
+ assert self._init_args is not None
118
+ return copy.deepcopy(self._init_args)
119
+
120
+ @property
121
+ def init_kwargs(self):
122
+ assert self._init_kwargs is not None
123
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
124
+
125
+ def __reduce__(self):
126
+ fields = list(super().__reduce__())
127
+ fields += [None] * max(3 - len(fields), 0)
128
+ if fields[0] is not _reconstruct_persistent_obj:
129
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
130
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
131
+ fields[1] = (meta,) # reconstruct args
132
+ fields[2] = None # state dict
133
+ return tuple(fields)
134
+
135
+ _decorators.add(Decorator)
136
+ return Decorator
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ def is_persistent(obj):
141
+ r"""Test whether the given object or class is persistent, i.e.,
142
+ whether it will save its source code when pickled.
143
+ """
144
+ try:
145
+ if obj in _decorators:
146
+ return True
147
+ except TypeError:
148
+ pass
149
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ def import_hook(hook):
154
+ r"""Register an import hook that is called whenever a persistent object
155
+ is being unpickled. A typical use case is to patch the pickled source
156
+ code to avoid errors and inconsistencies when the API of some imported
157
+ module has changed.
158
+
159
+ The hook should have the following signature:
160
+
161
+ hook(meta) -> modified meta
162
+
163
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164
+
165
+ type: Type of the persistent object, e.g. `'class'`.
166
+ version: Internal version number of `torch_utils.persistence`.
167
+ module_src Original source code of the Python module.
168
+ class_name: Class name in the original Python module.
169
+ state: Internal state of the object.
170
+
171
+ Example:
172
+
173
+ @persistence.import_hook
174
+ def wreck_my_network(meta):
175
+ if meta.class_name == 'MyNetwork':
176
+ print('MyNetwork is being imported. I will wreck it!')
177
+ meta.module_src = meta.module_src.replace("True", "False")
178
+ return meta
179
+ """
180
+ assert callable(hook)
181
+ _import_hooks.append(hook)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ def _reconstruct_persistent_obj(meta):
186
+ r"""Hook that is called internally by the `pickle` module to unpickle
187
+ a persistent object.
188
+ """
189
+ meta = dnnlib.EasyDict(meta)
190
+ meta.state = dnnlib.EasyDict(meta.state)
191
+ for hook in _import_hooks:
192
+ meta = hook(meta)
193
+ assert meta is not None
194
+
195
+ assert meta.version == _version
196
+ module = _src_to_module(meta.module_src)
197
+
198
+ assert meta.type == 'class'
199
+ orig_class = module.__dict__[meta.class_name]
200
+ decorator_class = persistent_class(orig_class)
201
+ obj = decorator_class.__new__(decorator_class)
202
+
203
+ setstate = getattr(obj, '__setstate__', None)
204
+ if callable(setstate):
205
+ setstate(meta.state) # pylint: disable=not-callable
206
+ else:
207
+ obj.__dict__.update(meta.state)
208
+ return obj
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ def _module_to_src(module):
213
+ r"""Query the source code of a given Python module.
214
+ """
215
+ src = _module_to_src_dict.get(module, None)
216
+ if src is None:
217
+ src = inspect.getsource(module)
218
+ _module_to_src_dict[module] = src
219
+ _src_to_module_dict[src] = module
220
+ return src
221
+
222
+ def _src_to_module(src):
223
+ r"""Get or create a Python module for the given source code.
224
+ """
225
+ module = _src_to_module_dict.get(src, None)
226
+ if module is None:
227
+ module_name = "_imported_module_" + uuid.uuid4().hex
228
+ module = types.ModuleType(module_name)
229
+ sys.modules[module_name] = module
230
+ _module_to_src_dict[module] = src
231
+ _src_to_module_dict[src] = module
232
+ exec(src, module.__dict__) # pylint: disable=exec-used
233
+ return module
234
+
235
+ #----------------------------------------------------------------------------
236
+
237
+ def _check_pickleable(obj):
238
+ r"""Check that the given object is pickleable, raising an exception if
239
+ it is not. This function is expected to be considerably more efficient
240
+ than actually pickling the object.
241
+ """
242
+ def recurse(obj):
243
+ if isinstance(obj, (list, tuple, set)):
244
+ return [recurse(x) for x in obj]
245
+ if isinstance(obj, dict):
246
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
247
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248
+ return None # Python primitive types are pickleable.
249
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250
+ return None # NumPy arrays and PyTorch tensors are pickleable.
251
+ if is_persistent(obj):
252
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
253
+ return obj
254
+ with io.BytesIO() as f:
255
+ pickle.dump(recurse(obj), f)
256
+
257
+ #----------------------------------------------------------------------------
REG/preprocessing/torch_utils/training_stats.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for reporting and collecting training statistics across
9
+ multiple processes and devices. The interface is designed to minimize
10
+ synchronization overhead as well as the amount of boilerplate in user
11
+ code."""
12
+
13
+ import re
14
+ import numpy as np
15
+ import torch
16
+ import dnnlib
17
+
18
+ from . import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
25
+ _rank = 0 # Rank of the current process.
26
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
27
+ _sync_called = False # Has _sync() been called yet?
28
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def init_multiprocessing(rank, sync_device):
34
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
35
+ across multiple processes.
36
+
37
+ This function must be called after
38
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
39
+ The call is not necessary if multi-process collection is not needed.
40
+
41
+ Args:
42
+ rank: Rank of the current process.
43
+ sync_device: PyTorch device to use for inter-process
44
+ communication, or None to disable multi-process
45
+ collection. Typically `torch.device('cuda', rank)`.
46
+ """
47
+ global _rank, _sync_device
48
+ assert not _sync_called
49
+ _rank = rank
50
+ _sync_device = sync_device
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ @misc.profiled_function
55
+ def report(name, value):
56
+ r"""Broadcasts the given set of scalars to all interested instances of
57
+ `Collector`, across device and process boundaries. NaNs and Infs are
58
+ ignored.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
+ if name not in _counters:
80
+ _counters[name] = dict()
81
+
82
+ elems = torch.as_tensor(value)
83
+ if elems.numel() == 0:
84
+ return value
85
+
86
+ elems = elems.detach().flatten().to(_reduce_dtype)
87
+ square = elems.square()
88
+ finite = square.isfinite()
89
+ moments = torch.stack([
90
+ finite.sum(dtype=_reduce_dtype),
91
+ torch.where(finite, elems, 0).sum(),
92
+ torch.where(finite, square, 0).sum(),
93
+ ])
94
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
95
+ moments = moments.to(_counter_dtype)
96
+
97
+ device = moments.device
98
+ if device not in _counters[name]:
99
+ _counters[name][device] = torch.zeros_like(moments)
100
+ _counters[name][device].add_(moments)
101
+ return value
102
+
103
+ #----------------------------------------------------------------------------
104
+
105
+ def report0(name, value):
106
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107
+ but ignores any scalars provided by the other processes.
108
+ See `report()` for further details.
109
+ """
110
+ report(name, value if _rank == 0 else [])
111
+ return value
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ class Collector:
116
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
117
+ computes their long-term averages (mean and standard deviation) over
118
+ user-defined periods of time.
119
+
120
+ The averages are first collected into internal counters that are not
121
+ directly visible to the user. They are then copied to the user-visible
122
+ state as a result of calling `update()` and can then be queried using
123
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124
+ internal counters for the next round, so that the user-visible state
125
+ effectively reflects averages collected between the last two calls to
126
+ `update()`.
127
+
128
+ Args:
129
+ regex: Regular expression defining which statistics to
130
+ collect. The default is to collect everything.
131
+ keep_previous: Whether to retain the previous averages if no
132
+ scalars were collected on a given round
133
+ (default: False).
134
+ """
135
+ def __init__(self, regex='.*', keep_previous=False):
136
+ self._regex = re.compile(regex)
137
+ self._keep_previous = keep_previous
138
+ self._cumulative = dict()
139
+ self._moments = dict()
140
+ self.update()
141
+ self._moments.clear()
142
+
143
+ def names(self):
144
+ r"""Returns the names of all statistics broadcasted so far that
145
+ match the regular expression specified at construction time.
146
+ """
147
+ return [name for name in _counters if self._regex.fullmatch(name)]
148
+
149
+ def update(self):
150
+ r"""Copies current values of the internal counters to the
151
+ user-visible state and resets them for the next round.
152
+
153
+ If `keep_previous=True` was specified at construction time, the
154
+ operation is skipped for statistics that have received no scalars
155
+ since the last update, retaining their previous averages.
156
+
157
+ This method performs a number of GPU-to-CPU transfers and one
158
+ `torch.distributed.all_reduce()`. It is intended to be called
159
+ periodically in the main training loop, typically once every
160
+ N training steps.
161
+ """
162
+ if not self._keep_previous:
163
+ self._moments.clear()
164
+ for name, cumulative in _sync(self.names()):
165
+ if name not in self._cumulative:
166
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167
+ delta = cumulative - self._cumulative[name]
168
+ self._cumulative[name].copy_(cumulative)
169
+ if float(delta[0]) != 0:
170
+ self._moments[name] = delta
171
+
172
+ def _get_delta(self, name):
173
+ r"""Returns the raw moments that were accumulated for the given
174
+ statistic between the last two calls to `update()`, or zero if
175
+ no scalars were collected.
176
+ """
177
+ assert self._regex.fullmatch(name)
178
+ if name not in self._moments:
179
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180
+ return self._moments[name]
181
+
182
+ def num(self, name):
183
+ r"""Returns the number of scalars that were accumulated for the given
184
+ statistic between the last two calls to `update()`, or zero if
185
+ no scalars were collected.
186
+ """
187
+ delta = self._get_delta(name)
188
+ return int(delta[0])
189
+
190
+ def mean(self, name):
191
+ r"""Returns the mean of the scalars that were accumulated for the
192
+ given statistic between the last two calls to `update()`, or NaN if
193
+ no scalars were collected.
194
+ """
195
+ delta = self._get_delta(name)
196
+ if int(delta[0]) == 0:
197
+ return float('nan')
198
+ return float(delta[1] / delta[0])
199
+
200
+ def std(self, name):
201
+ r"""Returns the standard deviation of the scalars that were
202
+ accumulated for the given statistic between the last two calls to
203
+ `update()`, or NaN if no scalars were collected.
204
+ """
205
+ delta = self._get_delta(name)
206
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207
+ return float('nan')
208
+ if int(delta[0]) == 1:
209
+ return float(0)
210
+ mean = float(delta[1] / delta[0])
211
+ raw_var = float(delta[2] / delta[0])
212
+ return np.sqrt(max(raw_var - np.square(mean), 0))
213
+
214
+ def as_dict(self):
215
+ r"""Returns the averages accumulated between the last two calls to
216
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217
+
218
+ dnnlib.EasyDict(
219
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220
+ ...
221
+ )
222
+ """
223
+ stats = dnnlib.EasyDict()
224
+ for name in self.names():
225
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226
+ return stats
227
+
228
+ def __getitem__(self, name):
229
+ r"""Convenience getter.
230
+ `collector[name]` is a synonym for `collector.mean(name)`.
231
+ """
232
+ return self.mean(name)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def _sync(names):
237
+ r"""Synchronize the global cumulative counters across devices and
238
+ processes. Called internally by `Collector.update()`.
239
+ """
240
+ if len(names) == 0:
241
+ return []
242
+ global _sync_called
243
+ _sync_called = True
244
+
245
+ # Check that all ranks have the same set of names.
246
+ if _sync_device is not None:
247
+ value = hash(tuple(tuple(ord(char) for char in name) for name in names))
248
+ other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
249
+ torch.distributed.broadcast(tensor=other, src=0)
250
+ if value != int(other.cpu()):
251
+ raise ValueError('Training statistics are inconsistent between ranks')
252
+
253
+ # Collect deltas within current rank.
254
+ deltas = []
255
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
256
+ for name in names:
257
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
258
+ for counter in _counters[name].values():
259
+ delta.add_(counter.to(device))
260
+ counter.copy_(torch.zeros_like(counter))
261
+ deltas.append(delta)
262
+ deltas = torch.stack(deltas)
263
+
264
+ # Sum deltas across ranks.
265
+ if _sync_device is not None:
266
+ torch.distributed.all_reduce(deltas)
267
+
268
+ # Update cumulative values.
269
+ deltas = deltas.cpu()
270
+ for idx, name in enumerate(names):
271
+ if name not in _cumulative:
272
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
273
+ _cumulative[name].add_(deltas[idx])
274
+
275
+ # Return name-value pairs.
276
+ return [(name, _cumulative[name]) for name in names]
277
+
278
+ #----------------------------------------------------------------------------
279
+ # Convenience.
280
+
281
+ default_collector = Collector()
282
+
283
+ #----------------------------------------------------------------------------
REG/wandb/debug-internal.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-04-08T18:26:46.552297532+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-04-08T18:26:47.20146143+08:00","level":"INFO","msg":"stream: created new stream","id":"xtwg5t5s"}
3
+ {"time":"2026-04-08T18:26:47.201551011+08:00","level":"INFO","msg":"handler: started","stream_id":"xtwg5t5s"}
4
+ {"time":"2026-04-08T18:26:47.202423643+08:00","level":"INFO","msg":"stream: started","id":"xtwg5t5s"}
5
+ {"time":"2026-04-08T18:26:47.202450453+08:00","level":"INFO","msg":"writer: started","stream_id":"xtwg5t5s"}
6
+ {"time":"2026-04-08T18:26:47.202479681+08:00","level":"INFO","msg":"sender: started","stream_id":"xtwg5t5s"}
7
+ {"time":"2026-04-09T00:59:33.394616937+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
8
+ {"time":"2026-04-09T15:26:36.673675921+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:37630->35.186.228.49:443: read: connection reset by peer"}
9
+ {"time":"2026-04-09T15:32:51.675782111+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:55710->35.186.228.49:443: read: connection reset by peer"}
10
+ {"time":"2026-04-09T15:33:36.688517829+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": EOF"}
11
+ {"time":"2026-04-10T00:33:41.365462236+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
12
+ {"time":"2026-04-10T06:11:35.438909216+08:00","level":"INFO","msg":"api: retrying HTTP error","status":429,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"rate limit exceeded: per_run limit on filestream requests\"}"}
13
+ {"time":"2026-04-11T02:04:06.260667043+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
14
+ {"time":"2026-04-11T10:00:44.531212038+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
15
+ {"time":"2026-04-11T10:20:26.360393211+08:00","level":"INFO","msg":"api: retrying HTTP error","status":500,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream","body":"{\"error\":\"context deadline exceeded\"}"}
16
+ {"time":"2026-04-12T21:59:44.458847327+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
17
+ {"time":"2026-04-13T00:04:28.494081102+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/xtwg5t5s/file_stream\": read tcp 172.20.98.30:35484->35.186.228.49:443: read: connection reset by peer"}
18
+ {"time":"2026-04-13T02:39:57.535775934+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
19
+ {"time":"2026-04-13T02:39:58.493368195+08:00","level":"INFO","msg":"handler: closed","stream_id":"xtwg5t5s"}
20
+ {"time":"2026-04-13T02:39:58.494772782+08:00","level":"INFO","msg":"sender: closed","stream_id":"xtwg5t5s"}
21
+ {"time":"2026-04-13T02:39:58.49521181+08:00","level":"INFO","msg":"stream: closed","id":"xtwg5t5s"}
REG/wandb/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Configure stats pid to 128263
3
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260408_182646-xtwg5t5s/logs/debug.log
5
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260408_182646-xtwg5t5s/logs/debug-internal.log
6
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():844] calling init triggers
7
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-04-08 18:26:46,224 INFO MainThread:128263 [wandb_init.py:init():892] starting backend
10
+ 2026-04-08 18:26:46,532 INFO MainThread:128263 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-04-08 18:26:46,548 INFO MainThread:128263 [wandb_init.py:init():903] backend started and connected
12
+ 2026-04-08 18:26:46,551 INFO MainThread:128263 [wandb_init.py:init():973] updated telemetry
13
+ 2026-04-08 18:26:46,572 INFO MainThread:128263 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-04-08 18:26:47,862 INFO MainThread:128263 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-04-08 18:26:47,956 INFO MainThread:128263 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-04-08 18:26:48,108 INFO MainThread:128263 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-04-08 18:26:48,108 INFO MainThread:128263 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75-0.01-one-step', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'resume_from_ckpt': '/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1920000.pt', 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 14000, 'max_train_steps': 10000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.005, 't_c': 0.75, 'ot_cls': True, 'tc_velocity_loss_coeff': 2.0}
21
+ 2026-04-13 02:35:32,832 INFO wandb-AsyncioManager-main:128263 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-04-13 02:35:32,833 INFO wandb-AsyncioManager-main:128263 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
REG/wandb/run-20260322_141726-2yw08kz9/files/config.yaml ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.25.0
4
+ e:
5
+ 257k9ot60u1bv0aiwlacsvutj9c72h7y:
6
+ args:
7
+ - --report-to
8
+ - wandb
9
+ - --allow-tf32
10
+ - --mixed-precision
11
+ - bf16
12
+ - --seed
13
+ - "0"
14
+ - --path-type
15
+ - linear
16
+ - --prediction
17
+ - v
18
+ - --weighting
19
+ - uniform
20
+ - --model
21
+ - SiT-XL/2
22
+ - --enc-type
23
+ - dinov2-vit-b
24
+ - --encoder-depth
25
+ - "8"
26
+ - --proj-coeff
27
+ - "0.5"
28
+ - --output-dir
29
+ - exps
30
+ - --exp-name
31
+ - jsflow-experiment
32
+ - --batch-size
33
+ - "256"
34
+ - --data-dir
35
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
36
+ - --semantic-features-dir
37
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
38
+ - --learning-rate
39
+ - "0.00005"
40
+ - --t-c
41
+ - "0.5"
42
+ - --cls
43
+ - "0.2"
44
+ - --ot-cls
45
+ codePath: train.py
46
+ codePathLocal: train.py
47
+ cpu_count: 96
48
+ cpu_count_logical: 192
49
+ cudaVersion: "13.0"
50
+ disk:
51
+ /:
52
+ total: "3838880616448"
53
+ used: "357556633600"
54
+ email: 2365972933@qq.com
55
+ executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
56
+ git:
57
+ commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
58
+ remote: https://github.com/Martinser/REG.git
59
+ gpu: NVIDIA H100 80GB HBM3
60
+ gpu_count: 4
61
+ gpu_nvidia:
62
+ - architecture: Hopper
63
+ cudaCores: 16896
64
+ memoryTotal: "85520809984"
65
+ name: NVIDIA H100 80GB HBM3
66
+ uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
67
+ - architecture: Hopper
68
+ cudaCores: 16896
69
+ memoryTotal: "85520809984"
70
+ name: NVIDIA H100 80GB HBM3
71
+ uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
72
+ - architecture: Hopper
73
+ cudaCores: 16896
74
+ memoryTotal: "85520809984"
75
+ name: NVIDIA H100 80GB HBM3
76
+ uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
77
+ - architecture: Hopper
78
+ cudaCores: 16896
79
+ memoryTotal: "85520809984"
80
+ name: NVIDIA H100 80GB HBM3
81
+ uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
82
+ host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
83
+ memory:
84
+ total: "2164115296256"
85
+ os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
86
+ program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
87
+ python: CPython 3.12.9
88
+ root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
89
+ startedAt: "2026-03-22T06:17:26.670763Z"
90
+ writerId: 257k9ot60u1bv0aiwlacsvutj9c72h7y
91
+ m: []
92
+ python_version: 3.12.9
93
+ t:
94
+ "1":
95
+ - 1
96
+ - 5
97
+ - 11
98
+ - 41
99
+ - 49
100
+ - 53
101
+ - 63
102
+ - 71
103
+ - 83
104
+ - 98
105
+ "2":
106
+ - 1
107
+ - 5
108
+ - 11
109
+ - 41
110
+ - 49
111
+ - 53
112
+ - 63
113
+ - 71
114
+ - 83
115
+ - 98
116
+ "3":
117
+ - 13
118
+ - 61
119
+ "4": 3.12.9
120
+ "5": 0.25.0
121
+ "6": 4.53.2
122
+ "12": 0.25.0
123
+ "13": linux-x86_64
124
+ adam_beta1:
125
+ value: 0.9
126
+ adam_beta2:
127
+ value: 0.999
128
+ adam_epsilon:
129
+ value: 1e-08
130
+ adam_weight_decay:
131
+ value: 0
132
+ allow_tf32:
133
+ value: true
134
+ batch_size:
135
+ value: 256
136
+ cfg_prob:
137
+ value: 0.1
138
+ checkpointing_steps:
139
+ value: 10000
140
+ cls:
141
+ value: 0.2
142
+ data_dir:
143
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
144
+ enc_type:
145
+ value: dinov2-vit-b
146
+ encoder_depth:
147
+ value: 8
148
+ epochs:
149
+ value: 1400
150
+ exp_name:
151
+ value: jsflow-experiment
152
+ fused_attn:
153
+ value: true
154
+ gradient_accumulation_steps:
155
+ value: 1
156
+ learning_rate:
157
+ value: 5e-05
158
+ legacy:
159
+ value: false
160
+ logging_dir:
161
+ value: logs
162
+ max_grad_norm:
163
+ value: 1
164
+ max_train_steps:
165
+ value: 1000000
166
+ mixed_precision:
167
+ value: bf16
168
+ model:
169
+ value: SiT-XL/2
170
+ num_classes:
171
+ value: 1000
172
+ num_workers:
173
+ value: 4
174
+ ops_head:
175
+ value: 16
176
+ ot_cls:
177
+ value: true
178
+ output_dir:
179
+ value: exps
180
+ path_type:
181
+ value: linear
182
+ prediction:
183
+ value: v
184
+ proj_coeff:
185
+ value: 0.5
186
+ qk_norm:
187
+ value: false
188
+ report_to:
189
+ value: wandb
190
+ resolution:
191
+ value: 256
192
+ resume_step:
193
+ value: 0
194
+ sampling_steps:
195
+ value: 10000
196
+ seed:
197
+ value: 0
198
+ semantic_features_dir:
199
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
200
+ t_c:
201
+ value: 0.5
202
+ weighting:
203
+ value: uniform
REG/wandb/run-20260322_141726-2yw08kz9/files/output.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Steps: 0%| | 1/1000000 [00:02<614:34:39, 2.21s/it][2026-03-22 14:17:31] Generating EMA samples done.
2
+ [2026-03-22 14:17:31] Step: 1, Training Logs: loss_final: 3.278940, loss_mean: 1.706308, proj_loss: 0.001541, loss_mean_cls: 1.571091, grad_norm: 1.481672
3
+ Steps: 0%| | 2/1000000 [00:02<289:06:04, 1.04s/it, grad_norm=1.48, loss_final=3.28, loss_mean=1.71, loss_mean_cls=1.57, proj_loss=0.001[2026-03-22 14:17:31] Step: 2, Training Logs: loss_final: 3.211831, loss_mean: 1.688932, proj_loss: -0.010287, loss_mean_cls: 1.533185, grad_norm: 1.055476
4
+ Steps: 0%| | 3/1000000 [00:02<187:48:39, 1.48it/s, grad_norm=1.06, loss_final=3.21, loss_mean=1.69, loss_mean_cls=1.53, proj_loss=-0.01[2026-03-22 14:17:31] Step: 3, Training Logs: loss_final: 3.201248, loss_mean: 1.663205, proj_loss: -0.019184, loss_mean_cls: 1.557227, grad_norm: 1.116387
5
+ Steps: 0%| | 4/1000000 [00:02<140:12:43, 1.98it/s, grad_norm=1.12, loss_final=3.2, loss_mean=1.66, loss_mean_cls=1.56, proj_loss=-0.019[2026-03-22 14:17:32] Step: 4, Training Logs: loss_final: 3.198367, loss_mean: 1.682051, proj_loss: -0.026376, loss_mean_cls: 1.542691, grad_norm: 0.722294
6
+ Steps: 0%| | 5/1000000 [00:03<113:52:43, 2.44it/s, grad_norm=0.722, loss_final=3.2, loss_mean=1.68, loss_mean_cls=1.54, proj_loss=-0.02[2026-03-22 14:17:32] Step: 5, Training Logs: loss_final: 3.140483, loss_mean: 1.679105, proj_loss: -0.034564, loss_mean_cls: 1.495943, grad_norm: 0.811589
7
+ Steps: 0%| | 6/1000000 [00:03<97:59:40, 2.83it/s, grad_norm=0.812, loss_final=3.14, loss_mean=1.68, loss_mean_cls=1.5, proj_loss=-0.034[2026-03-22 14:17:32] Step: 6, Training Logs: loss_final: 2.988440, loss_mean: 1.682339, proj_loss: -0.039506, loss_mean_cls: 1.345606, grad_norm: 0.931524
8
+ Steps: 0%| | 7/1000000 [00:03<87:55:00, 3.16it/s, grad_norm=0.932, loss_final=2.99, loss_mean=1.68, loss_mean_cls=1.35, proj_loss=-0.03[2026-03-22 14:17:32] Step: 7, Training Logs: loss_final: 3.111949, loss_mean: 1.690802, proj_loss: -0.042757, loss_mean_cls: 1.463904, grad_norm: 0.830852
9
+ Steps: 0%| | 8/1000000 [00:03<81:19:20, 3.42it/s, grad_norm=0.831, loss_final=3.11, loss_mean=1.69, loss_mean_cls=1.46, proj_loss=-0.04[2026-03-22 14:17:33] Step: 8, Training Logs: loss_final: 3.278931, loss_mean: 1.660797, proj_loss: -0.045011, loss_mean_cls: 1.663145, grad_norm: 0.847438
10
+ Steps: 0%| | 9/1000000 [00:04<76:56:10, 3.61it/s, grad_norm=0.847, loss_final=3.28, loss_mean=1.66, loss_mean_cls=1.66, proj_loss=-0.04[2026-03-22 14:17:33] Step: 9, Training Logs: loss_final: 3.221569, loss_mean: 1.658834, proj_loss: -0.046031, loss_mean_cls: 1.608767, grad_norm: 0.909827
11
+ Steps: 0%| | 10/1000000 [00:04<73:57:18, 3.76it/s, grad_norm=0.91, loss_final=3.22, loss_mean=1.66, loss_mean_cls=1.61, proj_loss=-0.04[2026-03-22 14:17:33] Step: 10, Training Logs: loss_final: 3.216744, loss_mean: 1.665229, proj_loss: -0.047761, loss_mean_cls: 1.599277, grad_norm: 1.014574
12
+ Steps: 0%| | 11/1000000 [00:04<71:52:01, 3.87it/s, grad_norm=1.01, loss_final=3.22, loss_mean=1.67, loss_mean_cls=1.6, proj_loss=-0.047[2026-03-22 14:17:33] Step: 11, Training Logs: loss_final: 3.216658, loss_mean: 1.649915, proj_loss: -0.049347, loss_mean_cls: 1.616090, grad_norm: 1.028789
13
+ Steps: 0%| | 12/1000000 [00:04<70:26:20, 3.94it/s, grad_norm=1.03, loss_final=3.22, loss_mean=1.65, loss_mean_cls=1.62, proj_loss=-0.04[2026-03-22 14:17:34] Step: 12, Training Logs: loss_final: 3.155676, loss_mean: 1.624463, proj_loss: -0.049856, loss_mean_cls: 1.581069, grad_norm: 1.231291
14
+ Steps: 0%| | 13/1000000 [00:05<69:25:29, 4.00it/s, grad_norm=1.23, loss_final=3.16, loss_mean=1.62, loss_mean_cls=1.58, proj_loss=-0.04Traceback (most recent call last):
15
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
16
+ main(args)
17
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
18
+ "loss_final": accelerator.gather(loss).mean().detach().item(),
19
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20
+ KeyboardInterrupt
21
+ [rank0]: Traceback (most recent call last):
22
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 527, in <module>
23
+ [rank0]: main(args)
24
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 415, in main
25
+ [rank0]: "loss_final": accelerator.gather(loss).mean().detach().item(),
26
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27
+ [rank0]: KeyboardInterrupt
REG/wandb/run-20260322_141726-2yw08kz9/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-22T06:17:26.670763Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.5",
41
+ "--cls",
42
+ "0.2",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357556633600"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "257k9ot60u1bv0aiwlacsvutj9c72h7y"
101
+ }
REG/wandb/run-20260322_141726-2yw08kz9/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"loss_mean_cls":1.5810688734054565,"_timestamp":1.7741602540511734e+09,"_runtime":5.247627056,"loss_mean":1.6244629621505737,"proj_loss":-0.04985573887825012,"grad_norm":1.2312908172607422,"_wandb":{"runtime":5},"_step":12,"loss_final":3.1556761264801025}
REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-22T14:17:27.013311984+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-22T14:17:28.347732261+08:00","level":"INFO","msg":"stream: created new stream","id":"2yw08kz9"}
3
+ {"time":"2026-03-22T14:17:28.347960938+08:00","level":"INFO","msg":"handler: started","stream_id":"2yw08kz9"}
4
+ {"time":"2026-03-22T14:17:28.348671928+08:00","level":"INFO","msg":"stream: started","id":"2yw08kz9"}
5
+ {"time":"2026-03-22T14:17:28.348731034+08:00","level":"INFO","msg":"sender: started","stream_id":"2yw08kz9"}
6
+ {"time":"2026-03-22T14:17:28.348748525+08:00","level":"INFO","msg":"writer: started","stream_id":"2yw08kz9"}
7
+ {"time":"2026-03-22T14:17:34.316421629+08:00","level":"INFO","msg":"stream: closing","id":"2yw08kz9"}
REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Configure stats pid to 316313
3
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug.log
5
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141726-2yw08kz9/logs/debug-internal.log
6
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-22 14:17:26,691 INFO MainThread:316313 [wandb_init.py:init():892] starting backend
10
+ 2026-03-22 14:17:26,994 INFO MainThread:316313 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-22 14:17:27,008 INFO MainThread:316313 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-22 14:17:27,011 INFO MainThread:316313 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-22 14:17:27,025 INFO MainThread:316313 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-22 14:17:29,067 INFO MainThread:316313 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-22 14:17:29,158 INFO MainThread:316313 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-22 14:17:29,159 INFO MainThread:316313 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-22 14:17:29,163 INFO MainThread:316313 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
21
+ 2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-03-22 14:17:34,316 INFO wandb-AsyncioManager-main:316313 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
REG/wandb/run-20260322_141726-2yw08kz9/run-2yw08kz9.wandb ADDED
Binary file (7 Bytes). View file
 
REG/wandb/run-20260322_141833-vm0y8t9t/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
REG/wandb/run-20260322_141833-vm0y8t9t/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260322_141833-vm0y8t9t/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-22T06:18:33.208941Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.5",
41
+ "--cls",
42
+ "0.2",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357556703232"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "gklxguwapb72cxij4696gj37bh1rbthi"
101
+ }
REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"time":"2026-03-22T14:18:33.472940651+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-22T14:18:35.380852704+08:00","level":"INFO","msg":"stream: created new stream","id":"vm0y8t9t"}
3
+ {"time":"2026-03-22T14:18:35.381056887+08:00","level":"INFO","msg":"handler: started","stream_id":"vm0y8t9t"}
4
+ {"time":"2026-03-22T14:18:35.382108345+08:00","level":"INFO","msg":"writer: started","stream_id":"vm0y8t9t"}
5
+ {"time":"2026-03-22T14:18:35.382119604+08:00","level":"INFO","msg":"stream: started","id":"vm0y8t9t"}
6
+ {"time":"2026-03-22T14:18:35.382161533+08:00","level":"INFO","msg":"sender: started","stream_id":"vm0y8t9t"}
REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Configure stats pid to 318585
3
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug.log
5
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_141833-vm0y8t9t/logs/debug-internal.log
6
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-22 14:18:33,237 INFO MainThread:318585 [wandb_init.py:init():892] starting backend
10
+ 2026-03-22 14:18:33,460 INFO MainThread:318585 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-22 14:18:33,470 INFO MainThread:318585 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-22 14:18:33,472 INFO MainThread:318585 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-22 14:18:33,485 INFO MainThread:318585 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-22 14:18:36,829 INFO MainThread:318585 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-22 14:18:36,920 INFO MainThread:318585 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-22 14:18:36,921 INFO MainThread:318585 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-22 14:18:36,924 INFO MainThread:318585 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
REG/wandb/run-20260322_150022-yhxc5cgu/files/output.log ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Steps: 0%| | 1/1000000 [00:02<652:30:07, 2.35s/it][2026-03-22 15:00:28] Generating EMA samples for evaluation (t=1→0 and t=0.5)...
2
+ Traceback (most recent call last):
3
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
4
+ main(args)
5
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
6
+ cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
7
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
8
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
9
+ raise AttributeError(
10
+ AttributeError: 'SiT' object has no attribute 'semantic_channels'
11
+ [rank0]: Traceback (most recent call last):
12
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 628, in <module>
13
+ [rank0]: main(args)
14
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 425, in main
15
+ [rank0]: cls_init = torch.randn(n_samples, base_model.semantic_channels, device=device)
16
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
18
+ [rank0]: raise AttributeError(
19
+ [rank0]: AttributeError: 'SiT' object has no attribute 'semantic_channels'
REG/wandb/run-20260322_150022-yhxc5cgu/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-22T07:00:22.092510Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.5",
41
+ "--cls",
42
+ "0.2",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357557354496"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "ucanic8s891x6sl28vnbha78lzoecw66"
101
+ }
REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-22T15:00:22.432399726+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-22T15:00:25.799578446+08:00","level":"INFO","msg":"stream: created new stream","id":"yhxc5cgu"}
3
+ {"time":"2026-03-22T15:00:25.799734466+08:00","level":"INFO","msg":"handler: started","stream_id":"yhxc5cgu"}
4
+ {"time":"2026-03-22T15:00:25.80075778+08:00","level":"INFO","msg":"stream: started","id":"yhxc5cgu"}
5
+ {"time":"2026-03-22T15:00:25.800786229+08:00","level":"INFO","msg":"writer: started","stream_id":"yhxc5cgu"}
6
+ {"time":"2026-03-22T15:00:25.800837858+08:00","level":"INFO","msg":"sender: started","stream_id":"yhxc5cgu"}
7
+ {"time":"2026-03-22T15:00:28.913273863+08:00","level":"INFO","msg":"stream: closing","id":"yhxc5cgu"}
REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Configure stats pid to 323629
3
+ 2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug.log
5
+ 2026-03-22 15:00:22,124 INFO MainThread:323629 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150022-yhxc5cgu/logs/debug-internal.log
6
+ 2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-22 15:00:22,125 INFO MainThread:323629 [wandb_init.py:init():892] starting backend
10
+ 2026-03-22 15:00:22,416 INFO MainThread:323629 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-22 15:00:22,429 INFO MainThread:323629 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-22 15:00:22,431 INFO MainThread:323629 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-22 15:00:22,447 INFO MainThread:323629 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-22 15:00:26,403 INFO MainThread:323629 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-22 15:00:26,494 INFO MainThread:323629 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-22 15:00:26,495 INFO MainThread:323629 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-22 15:00:26,500 INFO MainThread:323629 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-22 15:00:26,500 INFO MainThread:323629 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
21
+ 2026-03-22 15:00:28,913 INFO wandb-AsyncioManager-main:323629 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-03-22 15:00:28,913 INFO wandb-AsyncioManager-main:323629 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
REG/wandb/run-20260322_150022-yhxc5cgu/run-yhxc5cgu.wandb ADDED
Binary file (7 Bytes). View file
 
REG/wandb/run-20260322_150443-e3yw9ii4/run-e3yw9ii4.wandb ADDED
Binary file (7 Bytes). View file