xiangzai commited on
Commit
b65e56d
·
verified ·
1 Parent(s): b5e1f6d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. REG copy/LICENSE +21 -0
  2. REG copy/README.md +156 -0
  3. REG copy/dataset.py +80 -0
  4. REG copy/eval.sh +52 -0
  5. REG copy/generate.py +243 -0
  6. REG copy/loss.py +102 -0
  7. REG copy/requirements.txt +97 -0
  8. REG copy/samplers.py +169 -0
  9. REG copy/utils.py +225 -0
  10. REG/LICENSE +21 -0
  11. REG/README.md +156 -0
  12. REG/dataset.py +149 -0
  13. REG/eval.sh +52 -0
  14. REG/eval_custom_0.25.log +1 -0
  15. REG/generate.py +227 -0
  16. REG/loss.py +193 -0
  17. REG/requirements.txt +97 -0
  18. REG/sample_from_checkpoint.py +611 -0
  19. REG/sample_from_checkpoint_ddp.py +416 -0
  20. REG/samplers.py +840 -0
  21. REG/samples.sh +15 -0
  22. REG/samples_0.25_new.log +43 -0
  23. REG/samples_0.5.log +0 -0
  24. REG/samples_0.75.log +0 -0
  25. REG/samples_0.75_new.log +46 -0
  26. REG/samples_ddp.sh +32 -0
  27. REG/train.py +708 -0
  28. REG/train.sh +43 -0
  29. REG/train_resume_tc_velocity.sh +41 -0
  30. REG/utils.py +225 -0
  31. REG/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml +202 -0
  32. REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-summary.json +1 -0
  33. REG/wandb/run-20260322_150443-e3yw9ii4/files/config.yaml +202 -0
  34. REG/wandb/run-20260322_150443-e3yw9ii4/files/output.log +15 -0
  35. REG/wandb/run-20260322_150443-e3yw9ii4/files/requirements.txt +168 -0
  36. REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-metadata.json +101 -0
  37. REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-summary.json +1 -0
  38. REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log +7 -0
  39. REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log +22 -0
  40. REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log +6 -0
  41. REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log +20 -0
  42. REG/wandb/run-20260323_135607-zue1y2ba/files/output.log +289 -0
  43. REG/wandb/run-20260323_135607-zue1y2ba/files/requirements.txt +168 -0
  44. REG/wandb/run-20260323_135607-zue1y2ba/files/wandb-metadata.json +101 -0
  45. REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log +6 -0
  46. REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log +20 -0
  47. REG/wandb/run-20260323_135841-w9holkos/files/requirements.txt +168 -0
  48. REG/wandb/run-20260323_135841-w9holkos/files/wandb-metadata.json +101 -0
  49. REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log +19 -0
  50. REG/wandb/run-20260323_135841-w9holkos/logs/debug.log +20 -0
REG copy/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sihyun Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
REG copy/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
3
+ </h1>
4
+ <p align="center">
5
+ <a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup>&emsp;
6
+ <a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup>&emsp;
7
+ <a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup>&emsp;
8
+ <a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup>&emsp;
9
+ <a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup>&emsp;
10
+ <a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup>&emsp;
11
+ <a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup>&emsp;
12
+ <a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup>&emsp;
13
+ <a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup>&emsp;
14
+ <a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup>&emsp;
15
+ <a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup>&emsp;
16
+ <a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup>&emsp;
17
+ <p align="center">
18
+ $^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
19
+ $^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
20
+ <p align='center'>
21
+ <div align="center">
22
+ <a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
23
+ <a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
24
+ <a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
25
+ </div>
26
+ <p align='center'>
27
+ </p>
28
+ </p>
29
+ </p>
30
+
31
+
32
+ ## 🚩 Overview
33
+
34
+ ![overview](fig/reg.png)
35
+
36
+ REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
37
+ We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
38
+
39
+ In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
40
+ REG acquires the capability to produce coherent image-class pairs directly from pure noise,
41
+ substantially improving both generation quality and training efficiency.
42
+ This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
43
+ The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
44
+
45
+ On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
46
+ More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
47
+
48
+
49
+
50
+ ## 📰 News
51
+
52
+ - **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
53
+
54
+
55
+ ## 📝 Results
56
+
57
+ - Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
58
+ - $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
59
+
60
+ <div align="center">
61
+ <img src="fig/img.png" alt="Results">
62
+ </div>
63
+
64
+
65
+ ## 📋 Plan
66
+ - More training steps on ImageNet 256&512 and T2I.
67
+
68
+
69
+ ## 👊 Usage
70
+
71
+ ### 1. Environment setup
72
+
73
+ ```bash
74
+ conda create -n reg python=3.10.16 -y
75
+ conda activate reg
76
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
77
+ pip install -r requirements.txt
78
+ ```
79
+
80
+ ### 2. Dataset
81
+
82
+ #### Dataset download
83
+
84
+ Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
85
+
86
+ #### Preprocessing data
87
+ Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
88
+
89
+ ### 3. Training
90
+ Run train.sh
91
+ ```bash
92
+ bash train.sh
93
+ ```
94
+
95
+ train.sh contains the following content.
96
+ ```bash
97
+ accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
98
+ --report-to="wandb" \
99
+ --allow-tf32 \
100
+ --mixed-precision="fp16" \
101
+ --seed=0 \
102
+ --path-type="linear" \
103
+ --prediction="v" \
104
+ --weighting="uniform" \
105
+ --model="SiT-B/2" \
106
+ --enc-type="dinov2-vit-b" \
107
+ --proj-coeff=0.5 \
108
+ --encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
109
+ --output-dir="your_path" \
110
+ --exp-name="linear-dinov2-b-enc4" \
111
+ --batch-size=256 \
112
+ --data-dir="data_path/imagenet_vae" \
113
+ --cls=0.03
114
+ ```
115
+
116
+ Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
117
+
118
+ - `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
119
+ - `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
120
+ - `--proj-coeff`: Any values larger than 0
121
+ - `--encoder-depth`: Any values between 1 to the depth of the model
122
+ - `--output-dir`: Any directory that you want to save checkpoints and logs
123
+ - `--exp-name`: Any string name (the folder will be created under `output-dir`)
124
+ - `--cls`: Weight coefficients of REG loss
125
+
126
+
127
+ ### 4. Generate images and evaluation
128
+ You can generate images and get the final results through the following script.
129
+ The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
130
+
131
+ ```bash
132
+ bash eval.sh
133
+ ```
134
+
135
+
136
+ ## Citation
137
+ If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
138
+ ```
139
+ @article{wu2025representation,
140
+ title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
141
+ author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
142
+ journal={arXiv preprint arXiv:2507.01467},
143
+ year={2025}
144
+ }
145
+ ```
146
+
147
+ ## Contact
148
+ If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
149
+
150
+
151
+ ## Acknowledgements
152
+
153
+ Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
154
+
155
+
156
+
REG copy/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ import PIL.Image
10
+ try:
11
+ import pyspng
12
+ except ImportError:
13
+ pyspng = None
14
+
15
+
16
+ class CustomDataset(Dataset):
17
+ def __init__(self, data_dir):
18
+ PIL.Image.init()
19
+ supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
20
+
21
+ self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
22
+ self.features_dir = os.path.join(data_dir, 'vae-sd')
23
+
24
+ # images
25
+ self._image_fnames = {
26
+ os.path.relpath(os.path.join(root, fname), start=self.images_dir)
27
+ for root, _dirs, files in os.walk(self.images_dir) for fname in files
28
+ }
29
+ self.image_fnames = sorted(
30
+ fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
31
+ )
32
+ # features
33
+ self._feature_fnames = {
34
+ os.path.relpath(os.path.join(root, fname), start=self.features_dir)
35
+ for root, _dirs, files in os.walk(self.features_dir) for fname in files
36
+ }
37
+ self.feature_fnames = sorted(
38
+ fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
39
+ )
40
+
41
+ # labels
42
+ fname = os.path.join(self.features_dir, 'dataset.json')
43
+ if os.path.exists(fname):
44
+ print(f"Using {fname}.")
45
+ else:
46
+ raise FileNotFoundError("Neither of the specified files exists.")
47
+
48
+ with open(fname, 'rb') as f:
49
+ labels = json.load(f)['labels']
50
+ labels = dict(labels)
51
+ labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
52
+ labels = np.array(labels)
53
+ self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
54
+
55
+
56
+ def _file_ext(self, fname):
57
+ return os.path.splitext(fname)[1].lower()
58
+
59
+ def __len__(self):
60
+ assert len(self.image_fnames) == len(self.feature_fnames), \
61
+ "Number of feature files and label files should be same"
62
+ return len(self.feature_fnames)
63
+
64
+ def __getitem__(self, idx):
65
+ image_fname = self.image_fnames[idx]
66
+ feature_fname = self.feature_fnames[idx]
67
+ image_ext = self._file_ext(image_fname)
68
+ with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
69
+ if image_ext == '.npy':
70
+ image = np.load(f)
71
+ image = image.reshape(-1, *image.shape[-2:])
72
+ elif image_ext == '.png' and pyspng is not None:
73
+ image = pyspng.load(f.read())
74
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
75
+ else:
76
+ image = np.array(PIL.Image.open(f))
77
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
78
+
79
+ features = np.load(os.path.join(self.features_dir, feature_fname))
80
+ return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
REG copy/eval.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ random_number=$((RANDOM % 100 + 1200))
3
+ NUM_GPUS=8
4
+ STEP="4000000"
5
+ SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
6
+ VAE_PATH="your_vae_path/"
7
+ NUM_STEP=250
8
+ MODEL_SIZE='XL'
9
+ CFG_SCALE=2.3
10
+ CLS_CFG_SCALE=2.3
11
+ GH=0.85
12
+
13
+ export NCCL_P2P_DISABLE=1
14
+
15
+ python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
16
+ --model SiT-XL/2 \
17
+ --num-fid-samples 50000 \
18
+ --ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
19
+ --path-type=linear \
20
+ --encoder-depth=8 \
21
+ --projector-embed-dims=768 \
22
+ --per-proc-batch-size=64 \
23
+ --mode=sde \
24
+ --num-steps=${NUM_STEP} \
25
+ --cfg-scale=${CFG_SCALE} \
26
+ --cls-cfg-scale=${CLS_CFG_SCALE} \
27
+ --guidance-high=${GH} \
28
+ --sample-dir ${SAVE_PATH}/checkpoints \
29
+ --cls=768
30
+
31
+
32
+ python ./evaluations/evaluator.py \
33
+ --ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
34
+ --sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
35
+ --save_path ${SAVE_PATH}/checkpoints \
36
+ --cfg_cond 1 \
37
+ --step ${STEP} \
38
+ --num_steps ${NUM_STEP} \
39
+ --cfg ${CFG_SCALE} \
40
+ --cls_cfg ${CLS_CFG_SCALE} \
41
+ --gh ${GH}
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
REG copy/generate.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Samples a large number of images from a pre-trained SiT model using DDP.
9
+ Subsequently saves a .npz file that can be used to compute FID and other
10
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11
+
12
+ For a simple single-GPU/CPU sampling script, see sample.py.
13
+ """
14
+ import torch
15
+ import torch.distributed as dist
16
+ from models.sit import SiT_models
17
+ from diffusers.models import AutoencoderKL
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ import socket
25
+ from samplers import euler_maruyama_sampler
26
+ from utils import load_legacy_checkpoints, download_model
27
+
28
+
29
+ def setup_distributed():
30
+ """Init NCCL process group. If not launched via torchrun, use a single local process."""
31
+ if "RANK" not in os.environ:
32
+ os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
33
+ if "MASTER_PORT" not in os.environ:
34
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
35
+ sock.bind(("", 0))
36
+ os.environ["MASTER_PORT"] = str(sock.getsockname()[1])
37
+ sock.close()
38
+ os.environ["RANK"] = "0"
39
+ os.environ["WORLD_SIZE"] = "1"
40
+ os.environ["LOCAL_RANK"] = "0"
41
+ dist.init_process_group("nccl")
42
+
43
+
44
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
45
+ """
46
+ Builds a single .npz file from a folder of .png samples.
47
+ """
48
+ samples = []
49
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
50
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
51
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
52
+ samples.append(sample_np)
53
+ samples = np.stack(samples)
54
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
55
+ npz_path = f"{sample_dir}.npz"
56
+ np.savez(npz_path, arr_0=samples)
57
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
58
+ return npz_path
59
+
60
+
61
+ def main(args):
62
+ """
63
+ Run sampling.
64
+ """
65
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
66
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
67
+ torch.set_grad_enabled(False)
68
+
69
+ # Setup DDP (works with plain `python` or torchrun)
70
+ setup_distributed()
71
+ rank = dist.get_rank()
72
+ device = rank % torch.cuda.device_count()
73
+ seed = args.global_seed * dist.get_world_size() + rank
74
+ torch.manual_seed(seed)
75
+ torch.cuda.set_device(device)
76
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
77
+
78
+ # Load model:
79
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
80
+ latent_size = args.resolution // 8
81
+ model = SiT_models[args.model](
82
+ input_size=latent_size,
83
+ num_classes=args.num_classes,
84
+ use_cfg = True,
85
+ z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
86
+ encoder_depth=args.encoder_depth,
87
+ **block_kwargs,
88
+ ).to(device)
89
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
90
+ ckpt_path = args.ckpt
91
+
92
+
93
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
94
+ if ckpt_path is None:
95
+ args.ckpt = 'SiT-XL-2-256x256.pt'
96
+ assert args.model == 'SiT-XL/2'
97
+ assert len(args.projector_embed_dims.split(',')) == 1
98
+ assert int(args.projector_embed_dims.split(',')[0]) == 768
99
+ state_dict = download_model('last.pt')
100
+ else:
101
+ state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
102
+
103
+ if args.legacy:
104
+ state_dict = load_legacy_checkpoints(
105
+ state_dict=state_dict, encoder_depth=args.encoder_depth
106
+ )
107
+ model.load_state_dict(state_dict)
108
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
109
+
110
+
111
+ model.eval() # important!
112
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
113
+ #vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
114
+
115
+
116
+ # Create folder to save samples:
117
+ model_string_name = args.model.replace("/", "-")
118
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
119
+ folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
120
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
121
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
122
+ if rank == 0:
123
+ os.makedirs(sample_folder_dir, exist_ok=True)
124
+ print(f"Saving .png samples at {sample_folder_dir}")
125
+ dist.barrier()
126
+
127
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
128
+ n = args.per_proc_batch_size
129
+ global_batch_size = n * dist.get_world_size()
130
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
131
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
132
+ if rank == 0:
133
+ print(f"Total number of images that will be sampled: {total_samples}")
134
+ print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
135
+ print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
136
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
137
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
138
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
139
+ iterations = int(samples_needed_this_gpu // n)
140
+ pbar = range(iterations)
141
+ pbar = tqdm(pbar) if rank == 0 else pbar
142
+ total = 0
143
+ for _ in pbar:
144
+ # Sample inputs:
145
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
146
+ y = torch.randint(0, args.num_classes, (n,), device=device)
147
+ cls_z = torch.randn(n, args.cls, device=device)
148
+
149
+ # Sample images:
150
+ sampling_kwargs = dict(
151
+ model=model,
152
+ latents=z,
153
+ y=y,
154
+ num_steps=args.num_steps,
155
+ heun=args.heun,
156
+ cfg_scale=args.cfg_scale,
157
+ guidance_low=args.guidance_low,
158
+ guidance_high=args.guidance_high,
159
+ path_type=args.path_type,
160
+ cls_latents=cls_z,
161
+ args=args
162
+ )
163
+ with torch.no_grad():
164
+ if args.mode == "sde":
165
+ samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
166
+ elif args.mode == "ode":# will support
167
+ exit()
168
+ #samples = euler_sampler(**sampling_kwargs).to(torch.float32)
169
+ else:
170
+ raise NotImplementedError()
171
+
172
+ latents_scale = torch.tensor(
173
+ [0.18215, 0.18215, 0.18215, 0.18215, ]
174
+ ).view(1, 4, 1, 1).to(device)
175
+ latents_bias = -torch.tensor(
176
+ [0., 0., 0., 0.,]
177
+ ).view(1, 4, 1, 1).to(device)
178
+ samples = vae.decode((samples - latents_bias) / latents_scale).sample
179
+ samples = (samples + 1) / 2.
180
+ samples = torch.clamp(
181
+ 255. * samples, 0, 255
182
+ ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
183
+
184
+ # Save samples to disk as individual .png files
185
+ for i, sample in enumerate(samples):
186
+ index = i * dist.get_world_size() + rank + total
187
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
188
+ total += global_batch_size
189
+
190
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
191
+ dist.barrier()
192
+ if rank == 0:
193
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
194
+ print("Done.")
195
+ dist.barrier()
196
+ dist.destroy_process_group()
197
+
198
+
199
+ if __name__ == "__main__":
200
+ parser = argparse.ArgumentParser()
201
+ # seed
202
+ parser.add_argument("--global-seed", type=int, default=0)
203
+
204
+ # precision
205
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
206
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
207
+
208
+ # logging/saving:
209
+ parser.add_argument("--ckpt", type=str, default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/test.pt", help="Optional path to a SiT checkpoint.")
210
+ parser.add_argument("--sample-dir", type=str, default="samples_50")
211
+
212
+ # model
213
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
214
+ parser.add_argument("--num-classes", type=int, default=1000)
215
+ parser.add_argument("--encoder-depth", type=int, default=8)
216
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
217
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
218
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
219
+ # vae
220
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
221
+
222
+ # number of samples
223
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
224
+ parser.add_argument("--num-fid-samples", type=int, default=100)
225
+
226
+ # sampling related hyperparameters
227
+ parser.add_argument("--mode", type=str, default="ode")
228
+ parser.add_argument("--cfg-scale", type=float, default=1)
229
+ parser.add_argument("--cls-cfg-scale", type=float, default=1)
230
+ parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
231
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
232
+ parser.add_argument("--num-steps", type=int, default=50)
233
+ parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
234
+ parser.add_argument("--guidance-low", type=float, default=0.)
235
+ parser.add_argument("--guidance-high", type=float, default=1.)
236
+ parser.add_argument('--local-rank', default=-1, type=int)
237
+ parser.add_argument('--cls', default=768, type=int)
238
+ # will be deprecated
239
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
240
+
241
+
242
+ args = parser.parse_args()
243
+ main(args)
REG copy/loss.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ def mean_flat(x):
6
+ """
7
+ Take the mean over all non-batch dimensions.
8
+ """
9
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
10
+
11
+ def sum_flat(x):
12
+ """
13
+ Take the mean over all non-batch dimensions.
14
+ """
15
+ return torch.sum(x, dim=list(range(1, len(x.size()))))
16
+
17
+ class SILoss:
18
+ def __init__(
19
+ self,
20
+ prediction='v',
21
+ path_type="linear",
22
+ weighting="uniform",
23
+ encoders=[],
24
+ accelerator=None,
25
+ latents_scale=None,
26
+ latents_bias=None,
27
+ ):
28
+ self.prediction = prediction
29
+ self.weighting = weighting
30
+ self.path_type = path_type
31
+ self.encoders = encoders
32
+ self.accelerator = accelerator
33
+ self.latents_scale = latents_scale
34
+ self.latents_bias = latents_bias
35
+
36
+ def interpolant(self, t):
37
+ if self.path_type == "linear":
38
+ alpha_t = 1 - t
39
+ sigma_t = t
40
+ d_alpha_t = -1
41
+ d_sigma_t = 1
42
+ elif self.path_type == "cosine":
43
+ alpha_t = torch.cos(t * np.pi / 2)
44
+ sigma_t = torch.sin(t * np.pi / 2)
45
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
46
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
47
+ else:
48
+ raise NotImplementedError()
49
+
50
+ return alpha_t, sigma_t, d_alpha_t, d_sigma_t
51
+
52
+ def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
53
+ time_input=None, noises=None,):
54
+ if model_kwargs == None:
55
+ model_kwargs = {}
56
+ # sample timesteps
57
+ if time_input is None:
58
+ if self.weighting == "uniform":
59
+ time_input = torch.rand((images.shape[0], 1, 1, 1))
60
+ elif self.weighting == "lognormal":
61
+ # sample timestep according to log-normal distribution of sigmas following EDM
62
+ rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
63
+ sigma = rnd_normal.exp()
64
+ if self.path_type == "linear":
65
+ time_input = sigma / (1 + sigma)
66
+ elif self.path_type == "cosine":
67
+ time_input = 2 / np.pi * torch.atan(sigma)
68
+
69
+ time_input = time_input.to(device=images.device, dtype=images.dtype)
70
+
71
+ if noises is None:
72
+ noises = torch.randn_like(images)
73
+ noises_cls = torch.randn_like(cls_token)
74
+
75
+ alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
76
+
77
+ model_input = alpha_t * images + sigma_t * noises
78
+ cls_input = alpha_t.squeeze(-1).squeeze(-1) * cls_token + sigma_t.squeeze(-1).squeeze(-1) * noises_cls
79
+ if self.prediction == 'v':
80
+ model_target = d_alpha_t * images + d_sigma_t * noises
81
+ cls_target = d_alpha_t * cls_token + d_sigma_t * noises_cls
82
+ else:
83
+ raise NotImplementedError()
84
+
85
+ model_output, zs_tilde, cls_output = model(model_input, time_input.flatten(), **model_kwargs,
86
+ cls_token=cls_input)
87
+
88
+ #denoising_loss
89
+ denoising_loss = mean_flat((model_output - model_target) ** 2)
90
+ denoising_loss_cls = mean_flat((cls_output - cls_target) ** 2)
91
+
92
+ # projection loss
93
+ proj_loss = 0.
94
+ bsz = zs[0].shape[0]
95
+ for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
96
+ for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
97
+ z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
98
+ z_j = torch.nn.functional.normalize(z_j, dim=-1)
99
+ proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
100
+ proj_loss /= (len(zs) * bsz)
101
+
102
+ return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
REG copy/requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - pip:
2
+ absl-py==2.2.2
3
+ accelerate==1.2.1
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.11.16
6
+ aiosignal==1.3.2
7
+ astunparse==1.6.3
8
+ async-timeout==5.0.1
9
+ attrs==25.3.0
10
+ certifi==2022.12.7
11
+ charset-normalizer==2.1.1
12
+ click==8.1.8
13
+ datasets==2.20.0
14
+ diffusers==0.32.1
15
+ dill==0.3.8
16
+ docker-pycreds==0.4.0
17
+ einops==0.8.1
18
+ filelock==3.13.1
19
+ flatbuffers==25.2.10
20
+ frozenlist==1.5.0
21
+ fsspec==2024.5.0
22
+ ftfy==6.3.1
23
+ gast==0.6.0
24
+ gitdb==4.0.12
25
+ gitpython==3.1.44
26
+ google-pasta==0.2.0
27
+ grpcio==1.71.0
28
+ h5py==3.13.0
29
+ huggingface-hub==0.27.1
30
+ idna==3.4
31
+ importlib-metadata==8.6.1
32
+ jinja2==3.1.4
33
+ joblib==1.4.2
34
+ keras==3.9.2
35
+ libclang==18.1.1
36
+ markdown==3.8
37
+ markdown-it-py==3.0.0
38
+ markupsafe==2.1.5
39
+ mdurl==0.1.2
40
+ ml-dtypes==0.3.2
41
+ mpmath==1.3.0
42
+ multidict==6.4.3
43
+ multiprocess==0.70.16
44
+ namex==0.0.8
45
+ networkx==3.3
46
+ numpy==1.26.4
47
+ opt-einsum==3.4.0
48
+ optree==0.15.0
49
+ packaging==24.2
50
+ pandas==2.2.3
51
+ pillow==11.0.0
52
+ platformdirs==4.3.7
53
+ propcache==0.3.1
54
+ protobuf==4.25.6
55
+ psutil==7.0.0
56
+ pyarrow==19.0.1
57
+ pyarrow-hotfix==0.6
58
+ pygments==2.19.1
59
+ python-dateutil==2.9.0.post0
60
+ pytz==2025.2
61
+ pyyaml==6.0.2
62
+ regex==2024.11.6
63
+ requests==2.32.3
64
+ rich==14.0.0
65
+ safetensors==0.5.3
66
+ scikit-learn==1.5.1
67
+ scipy==1.15.2
68
+ sentry-sdk==2.26.1
69
+ setproctitle==1.3.5
70
+ six==1.17.0
71
+ smmap==5.0.2
72
+ sympy==1.13.1
73
+ tensorboard==2.16.1
74
+ tensorboard-data-server==0.7.2
75
+ tensorflow==2.16.1
76
+ tensorflow-io-gcs-filesystem==0.37.1
77
+ termcolor==3.0.1
78
+ tf-keras==2.16.0
79
+ threadpoolctl==3.6.0
80
+ timm==1.0.12
81
+ tokenizers==0.21.0
82
+ tqdm==4.67.1
83
+ transformers==4.47.0
84
+ triton==2.1.0
85
+ typing-extensions==4.12.2
86
+ tzdata==2025.2
87
+ urllib3==1.26.13
88
+ wandb==0.17.6
89
+ wcwidth==0.2.13
90
+ werkzeug==3.1.3
91
+ wrapt==1.17.2
92
+ xformer==1.0.1
93
+ xformers==0.0.23
94
+ xxhash==3.5.0
95
+ yarl==1.20.0
96
+ zipp==3.21.0
97
+
REG copy/samplers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def expand_t_like_x(t, x_cur):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x_cur.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+ def get_score_from_velocity(vt, xt, t, path_type="linear"):
16
+ """Wrapper function: transfrom velocity prediction model to score
17
+ Args:
18
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
19
+ x: [batch_dim, ...] shaped tensor; x_t data point
20
+ t: [batch_dim,] time tensor
21
+ """
22
+ t = expand_t_like_x(t, xt)
23
+ if path_type == "linear":
24
+ alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
25
+ sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
26
+ elif path_type == "cosine":
27
+ alpha_t = torch.cos(t * np.pi / 2)
28
+ sigma_t = torch.sin(t * np.pi / 2)
29
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
30
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ mean = xt
35
+ reverse_alpha_ratio = alpha_t / d_alpha_t
36
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
37
+ score = (reverse_alpha_ratio * vt - mean) / var
38
+
39
+ return score
40
+
41
+
42
+ def compute_diffusion(t_cur):
43
+ return 2 * t_cur
44
+
45
+
46
+ def euler_maruyama_sampler(
47
+ model,
48
+ latents,
49
+ y,
50
+ num_steps=20,
51
+ heun=False, # not used, just for compatability
52
+ cfg_scale=1.0,
53
+ guidance_low=0.0,
54
+ guidance_high=1.0,
55
+ path_type="linear",
56
+ cls_latents=None,
57
+ args=None
58
+ ):
59
+ # setup conditioning
60
+ if cfg_scale > 1.0:
61
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
62
+ #[1000, 1000]
63
+ _dtype = latents.dtype
64
+
65
+
66
+ t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64)
67
+ t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)])
68
+ x_next = latents.to(torch.float64)
69
+ cls_x_next = cls_latents.to(torch.float64)
70
+ device = x_next.device
71
+
72
+
73
+ with torch.no_grad():
74
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
75
+ dt = t_next - t_cur
76
+ x_cur = x_next
77
+ cls_x_cur = cls_x_next
78
+
79
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
80
+ model_input = torch.cat([x_cur] * 2, dim=0)
81
+ cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
82
+ y_cur = torch.cat([y, y_null], dim=0)
83
+ else:
84
+ model_input = x_cur
85
+ cls_model_input = cls_x_cur
86
+ y_cur = y
87
+
88
+ kwargs = dict(y=y_cur)
89
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
90
+ diffusion = compute_diffusion(t_cur)
91
+
92
+ eps_i = torch.randn_like(x_cur).to(device)
93
+ cls_eps_i = torch.randn_like(cls_x_cur).to(device)
94
+ deps = eps_i * torch.sqrt(torch.abs(dt))
95
+ cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
96
+
97
+ # compute drift
98
+ v_cur, _, cls_v_cur = model(
99
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
100
+ )
101
+ v_cur = v_cur.to(torch.float64)
102
+ cls_v_cur = cls_v_cur.to(torch.float64)
103
+
104
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
105
+ d_cur = v_cur - 0.5 * diffusion * s_cur
106
+
107
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
108
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
109
+
110
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
111
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
112
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
113
+
114
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
115
+ if args.cls_cfg_scale >0:
116
+ cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
117
+ else:
118
+ cls_d_cur = cls_d_cur_cond
119
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
120
+ cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
121
+
122
+ # last step
123
+ t_cur, t_next = t_steps[-2], t_steps[-1]
124
+ dt = t_next - t_cur
125
+ x_cur = x_next
126
+ cls_x_cur = cls_x_next
127
+
128
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
129
+ model_input = torch.cat([x_cur] * 2, dim=0)
130
+ cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
131
+ y_cur = torch.cat([y, y_null], dim=0)
132
+ else:
133
+ model_input = x_cur
134
+ cls_model_input = cls_x_cur
135
+ y_cur = y
136
+ kwargs = dict(y=y_cur)
137
+ time_input = torch.ones(model_input.size(0)).to(
138
+ device=device, dtype=torch.float64
139
+ ) * t_cur
140
+
141
+ # compute drift
142
+ v_cur, _, cls_v_cur = model(
143
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
144
+ )
145
+ v_cur = v_cur.to(torch.float64)
146
+ cls_v_cur = cls_v_cur.to(torch.float64)
147
+
148
+
149
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
150
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
151
+
152
+ diffusion = compute_diffusion(t_cur)
153
+ d_cur = v_cur - 0.5 * diffusion * s_cur
154
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
155
+
156
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
157
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
158
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
159
+
160
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
161
+ if args.cls_cfg_scale > 0:
162
+ cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
163
+ else:
164
+ cls_d_cur = cls_d_cur_cond
165
+
166
+ mean_x = x_cur + dt * d_cur
167
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
168
+
169
+ return mean_x
REG copy/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision.datasets.utils import download_url
3
+ import torch
4
+ import torchvision.models as torchvision_models
5
+ import timm
6
+ from models import mocov3_vit
7
+ import math
8
+ import warnings
9
+
10
+
11
+ # code from SiT repository
12
+ pretrained_models = {'last.pt'}
13
+
14
+ def download_model(model_name):
15
+ """
16
+ Downloads a pre-trained SiT model from the web.
17
+ """
18
+ assert model_name in pretrained_models
19
+ local_path = f'pretrained_models/{model_name}'
20
+ if not os.path.isfile(local_path):
21
+ os.makedirs('pretrained_models', exist_ok=True)
22
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
23
+ download_url(web_path, 'pretrained_models', filename=model_name)
24
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
25
+ return model
26
+
27
+ def fix_mocov3_state_dict(state_dict):
28
+ for k in list(state_dict.keys()):
29
+ # retain only base_encoder up to before the embedding layer
30
+ if k.startswith('module.base_encoder'):
31
+ # fix naming bug in checkpoint
32
+ new_k = k[len("module.base_encoder."):]
33
+ if "blocks.13.norm13" in new_k:
34
+ new_k = new_k.replace("norm13", "norm1")
35
+ if "blocks.13.mlp.fc13" in k:
36
+ new_k = new_k.replace("fc13", "fc1")
37
+ if "blocks.14.norm14" in k:
38
+ new_k = new_k.replace("norm14", "norm2")
39
+ if "blocks.14.mlp.fc14" in k:
40
+ new_k = new_k.replace("fc14", "fc2")
41
+ # remove prefix
42
+ if 'head' not in new_k and new_k.split('.')[0] != 'fc':
43
+ state_dict[new_k] = state_dict[k]
44
+ # delete renamed or unused k
45
+ del state_dict[k]
46
+ if 'pos_embed' in state_dict.keys():
47
+ state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
48
+ state_dict['pos_embed'], [16, 16],
49
+ )
50
+ return state_dict
51
+
52
+ @torch.no_grad()
53
+ def load_encoders(enc_type, device, resolution=256):
54
+ assert (resolution == 256) or (resolution == 512)
55
+
56
+ enc_names = enc_type.split(',')
57
+ encoders, architectures, encoder_types = [], [], []
58
+ for enc_name in enc_names:
59
+ encoder_type, architecture, model_config = enc_name.split('-')
60
+ # Currently, we only support 512x512 experiments with DINOv2 encoders.
61
+ if resolution == 512:
62
+ if encoder_type != 'dinov2':
63
+ raise NotImplementedError(
64
+ "Currently, we only support 512x512 experiments with DINOv2 encoders."
65
+ )
66
+
67
+ architectures.append(architecture)
68
+ encoder_types.append(encoder_type)
69
+ if encoder_type == 'mocov3':
70
+ if architecture == 'vit':
71
+ if model_config == 's':
72
+ encoder = mocov3_vit.vit_small()
73
+ elif model_config == 'b':
74
+ encoder = mocov3_vit.vit_base()
75
+ elif model_config == 'l':
76
+ encoder = mocov3_vit.vit_large()
77
+ ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
78
+ state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
79
+ del encoder.head
80
+ encoder.load_state_dict(state_dict, strict=True)
81
+ encoder.head = torch.nn.Identity()
82
+ elif architecture == 'resnet':
83
+ raise NotImplementedError()
84
+
85
+ encoder = encoder.to(device)
86
+ encoder.eval()
87
+
88
+ elif 'dinov2' in encoder_type:
89
+ import timm
90
+ if 'reg' in encoder_type:
91
+ try:
92
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
93
+ f'dinov2_vit{model_config}14_reg', source='local')
94
+ except:
95
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
96
+ else:
97
+ try:
98
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
99
+ f'dinov2_vit{model_config}14', source='local')
100
+ except:
101
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
102
+
103
+ print(f"Now you are using the {enc_name} as the aligning model")
104
+ del encoder.head
105
+ patch_resolution = 16 * (resolution // 256)
106
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
107
+ encoder.pos_embed.data, [patch_resolution, patch_resolution],
108
+ )
109
+ encoder.head = torch.nn.Identity()
110
+ encoder = encoder.to(device)
111
+ encoder.eval()
112
+
113
+ elif 'dinov1' == encoder_type:
114
+ import timm
115
+ from models import dinov1
116
+ encoder = dinov1.vit_base()
117
+ ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
118
+ if 'pos_embed' in ckpt.keys():
119
+ ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
120
+ ckpt['pos_embed'], [16, 16],
121
+ )
122
+ del encoder.head
123
+ encoder.head = torch.nn.Identity()
124
+ encoder.load_state_dict(ckpt, strict=True)
125
+ encoder = encoder.to(device)
126
+ encoder.forward_features = encoder.forward
127
+ encoder.eval()
128
+
129
+ elif encoder_type == 'clip':
130
+ import clip
131
+ from models.clip_vit import UpdatedVisionTransformer
132
+ encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
133
+ encoder = UpdatedVisionTransformer(encoder_).to(device)
134
+ #.to(device)
135
+ encoder.embed_dim = encoder.model.transformer.width
136
+ encoder.forward_features = encoder.forward
137
+ encoder.eval()
138
+
139
+ elif encoder_type == 'mae':
140
+ from models.mae_vit import vit_large_patch16
141
+ import timm
142
+ kwargs = dict(img_size=256)
143
+ encoder = vit_large_patch16(**kwargs).to(device)
144
+ with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
145
+ state_dict = torch.load(f)
146
+ if 'pos_embed' in state_dict["model"].keys():
147
+ state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
148
+ state_dict["model"]['pos_embed'], [16, 16],
149
+ )
150
+ encoder.load_state_dict(state_dict["model"])
151
+
152
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
153
+ encoder.pos_embed.data, [16, 16],
154
+ )
155
+
156
+ elif encoder_type == 'jepa':
157
+ from models.jepa import vit_huge
158
+ kwargs = dict(img_size=[224, 224], patch_size=14)
159
+ encoder = vit_huge(**kwargs).to(device)
160
+ with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
161
+ state_dict = torch.load(f, map_location=device)
162
+ new_state_dict = dict()
163
+ for key, value in state_dict['encoder'].items():
164
+ new_state_dict[key[7:]] = value
165
+ encoder.load_state_dict(new_state_dict)
166
+ encoder.forward_features = encoder.forward
167
+
168
+ encoders.append(encoder)
169
+
170
+ return encoders, encoder_types, architectures
171
+
172
+
173
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
174
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
175
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
176
+ def norm_cdf(x):
177
+ # Computes standard normal cumulative distribution function
178
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
179
+
180
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
181
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
182
+ "The distribution of values may be incorrect.",
183
+ stacklevel=2)
184
+
185
+ with torch.no_grad():
186
+ # Values are generated by using a truncated uniform distribution and
187
+ # then using the inverse CDF for the normal distribution.
188
+ # Get upper and lower cdf values
189
+ l = norm_cdf((a - mean) / std)
190
+ u = norm_cdf((b - mean) / std)
191
+
192
+ # Uniformly fill tensor with values from [l, u], then translate to
193
+ # [2l-1, 2u-1].
194
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
195
+
196
+ # Use inverse cdf transform for normal distribution to get truncated
197
+ # standard normal
198
+ tensor.erfinv_()
199
+
200
+ # Transform to proper mean, std
201
+ tensor.mul_(std * math.sqrt(2.))
202
+ tensor.add_(mean)
203
+
204
+ # Clamp to ensure it's in the proper range
205
+ tensor.clamp_(min=a, max=b)
206
+ return tensor
207
+
208
+
209
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
210
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
211
+
212
+
213
+ def load_legacy_checkpoints(state_dict, encoder_depth):
214
+ new_state_dict = dict()
215
+ for key, value in state_dict.items():
216
+ if 'decoder_blocks' in key:
217
+ parts =key.split('.')
218
+ new_idx = int(parts[1]) + encoder_depth
219
+ parts[0] = 'blocks'
220
+ parts[1] = str(new_idx)
221
+ new_key = '.'.join(parts)
222
+ new_state_dict[new_key] = value
223
+ else:
224
+ new_state_dict[key] = value
225
+ return new_state_dict
REG/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sihyun Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
REG/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
3
+ </h1>
4
+ <p align="center">
5
+ <a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup>&emsp;
6
+ <a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup>&emsp;
7
+ <a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup>&emsp;
8
+ <a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup>&emsp;
9
+ <a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup>&emsp;
10
+ <a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup>&emsp;
11
+ <a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup>&emsp;
12
+ <a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup>&emsp;
13
+ <a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup>&emsp;
14
+ <a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup>&emsp;
15
+ <a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup>&emsp;
16
+ <a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup>&emsp;
17
+ <p align="center">
18
+ $^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
19
+ $^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
20
+ <p align='center'>
21
+ <div align="center">
22
+ <a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
23
+ <a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
24
+ <a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
25
+ </div>
26
+ <p align='center'>
27
+ </p>
28
+ </p>
29
+ </p>
30
+
31
+
32
+ ## 🚩 Overview
33
+
34
+ ![overview](fig/reg.png)
35
+
36
+ REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
37
+ We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
38
+
39
+ In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
40
+ REG acquires the capability to produce coherent image-class pairs directly from pure noise,
41
+ substantially improving both generation quality and training efficiency.
42
+ This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
43
+ The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
44
+
45
+ On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
46
+ More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
47
+
48
+
49
+
50
+ ## 📰 News
51
+
52
+ - **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
53
+
54
+
55
+ ## 📝 Results
56
+
57
+ - Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
58
+ - $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
59
+
60
+ <div align="center">
61
+ <img src="fig/img.png" alt="Results">
62
+ </div>
63
+
64
+
65
+ ## 📋 Plan
66
+ - More training steps on ImageNet 256&512 and T2I.
67
+
68
+
69
+ ## 👊 Usage
70
+
71
+ ### 1. Environment setup
72
+
73
+ ```bash
74
+ conda create -n reg python=3.10.16 -y
75
+ conda activate reg
76
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
77
+ pip install -r requirements.txt
78
+ ```
79
+
80
+ ### 2. Dataset
81
+
82
+ #### Dataset download
83
+
84
+ Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
85
+
86
+ #### Preprocessing data
87
+ Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
88
+
89
+ ### 3. Training
90
+ Run train.sh
91
+ ```bash
92
+ bash train.sh
93
+ ```
94
+
95
+ train.sh contains the following content.
96
+ ```bash
97
+ accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
98
+ --report-to="wandb" \
99
+ --allow-tf32 \
100
+ --mixed-precision="fp16" \
101
+ --seed=0 \
102
+ --path-type="linear" \
103
+ --prediction="v" \
104
+ --weighting="uniform" \
105
+ --model="SiT-B/2" \
106
+ --enc-type="dinov2-vit-b" \
107
+ --proj-coeff=0.5 \
108
+ --encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
109
+ --output-dir="your_path" \
110
+ --exp-name="linear-dinov2-b-enc4" \
111
+ --batch-size=256 \
112
+ --data-dir="data_path/imagenet_vae" \
113
+ --cls=0.03
114
+ ```
115
+
116
+ Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
117
+
118
+ - `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
119
+ - `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
120
+ - `--proj-coeff`: Any values larger than 0
121
+ - `--encoder-depth`: Any values between 1 to the depth of the model
122
+ - `--output-dir`: Any directory that you want to save checkpoints and logs
123
+ - `--exp-name`: Any string name (the folder will be created under `output-dir`)
124
+ - `--cls`: Weight coefficients of REG loss
125
+
126
+
127
+ ### 4. Generate images and evaluation
128
+ You can generate images and get the final results through the following script.
129
+ The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
130
+
131
+ ```bash
132
+ bash eval.sh
133
+ ```
134
+
135
+
136
+ ## Citation
137
+ If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
138
+ ```
139
+ @article{wu2025representation,
140
+ title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
141
+ author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
142
+ journal={arXiv preprint arXiv:2507.01467},
143
+ year={2025}
144
+ }
145
+ ```
146
+
147
+ ## Contact
148
+ If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
149
+
150
+
151
+ ## Acknowledgements
152
+
153
+ Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
154
+
155
+
156
+
REG/dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ import PIL.Image
10
+ try:
11
+ import pyspng
12
+ except ImportError:
13
+ pyspng = None
14
+
15
+
16
+ class CustomDataset(Dataset):
17
+ """
18
+ data_dir 下 VAE latent:imagenet_256_vae/
19
+ 无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。
20
+ 有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引,
21
+ 按特征文件名推断 imagenet_256_vae 中对应 npy。
22
+ """
23
+
24
+ def __init__(self, data_dir, semantic_features_dir=None):
25
+ PIL.Image.init()
26
+ supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
27
+
28
+ self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
29
+
30
+ if semantic_features_dir is None:
31
+ potential_semantic_dir = os.path.join(
32
+ data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0'
33
+ )
34
+ if os.path.exists(potential_semantic_dir):
35
+ self.semantic_features_dir = potential_semantic_dir
36
+ self.use_preprocessed_semantic = True
37
+ print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
38
+ else:
39
+ self.semantic_features_dir = None
40
+ self.use_preprocessed_semantic = False
41
+ else:
42
+ self.semantic_features_dir = semantic_features_dir
43
+ self.use_preprocessed_semantic = True
44
+ print(f"Using preprocessed semantic features from: {self.semantic_features_dir}")
45
+
46
+ if self.use_preprocessed_semantic:
47
+ label_fname = os.path.join(self.semantic_features_dir, 'dataset.json')
48
+ if not os.path.exists(label_fname):
49
+ raise FileNotFoundError(f"Label file not found: {label_fname}")
50
+
51
+ print(f"Using {label_fname}.")
52
+ with open(label_fname, 'rb') as f:
53
+ data = json.load(f)
54
+ labels_list = data.get('labels', None)
55
+ if labels_list is None:
56
+ raise ValueError(f"'labels' field is missing in {label_fname}")
57
+
58
+ semantic_fnames = []
59
+ labels = []
60
+ for entry in labels_list:
61
+ if entry is None:
62
+ continue
63
+ fname, lab = entry
64
+ semantic_fnames.append(fname)
65
+ labels.append(0 if lab is None else lab)
66
+
67
+ self.semantic_fnames = semantic_fnames
68
+ self.labels = np.array(labels, dtype=np.int64)
69
+ self.num_samples = len(self.semantic_fnames)
70
+ print(f"Loaded {self.num_samples} semantic entries from dataset.json")
71
+ else:
72
+ self.features_dir = os.path.join(data_dir, 'vae-sd')
73
+
74
+ self._image_fnames = {
75
+ os.path.relpath(os.path.join(root, fname), start=self.images_dir)
76
+ for root, _dirs, files in os.walk(self.images_dir) for fname in files
77
+ }
78
+ self.image_fnames = sorted(
79
+ fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
80
+ )
81
+ self._feature_fnames = {
82
+ os.path.relpath(os.path.join(root, fname), start=self.features_dir)
83
+ for root, _dirs, files in os.walk(self.features_dir) for fname in files
84
+ }
85
+ self.feature_fnames = sorted(
86
+ fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
87
+ )
88
+
89
+ fname = os.path.join(self.features_dir, 'dataset.json')
90
+ if os.path.exists(fname):
91
+ print(f"Using {fname}.")
92
+ else:
93
+ raise FileNotFoundError("Neither of the specified files exists.")
94
+
95
+ with open(fname, 'rb') as f:
96
+ labels = json.load(f)['labels']
97
+ labels = dict(labels)
98
+ labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
99
+ labels = np.array(labels)
100
+ self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
101
+
102
+ def _file_ext(self, fname):
103
+ return os.path.splitext(fname)[1].lower()
104
+
105
+ def __len__(self):
106
+ if self.use_preprocessed_semantic:
107
+ return self.num_samples
108
+ assert len(self.image_fnames) == len(self.feature_fnames), \
109
+ "Number of feature files and label files should be same"
110
+ return len(self.feature_fnames)
111
+
112
+ def __getitem__(self, idx):
113
+ if self.use_preprocessed_semantic:
114
+ semantic_fname = self.semantic_fnames[idx]
115
+ basename = os.path.basename(semantic_fname)
116
+ idx_str = basename.split('-')[-1].split('.')[0]
117
+ subdir = idx_str[:5]
118
+ vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
119
+ vae_path = os.path.join(self.images_dir, vae_relpath)
120
+
121
+ with open(vae_path, 'rb') as f:
122
+ image = np.load(f)
123
+
124
+ semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
125
+ semantic_features = np.load(semantic_path)
126
+
127
+ return (
128
+ torch.from_numpy(image).float(),
129
+ torch.from_numpy(image).float(),
130
+ torch.from_numpy(semantic_features).float(),
131
+ torch.tensor(self.labels[idx]),
132
+ )
133
+
134
+ image_fname = self.image_fnames[idx]
135
+ feature_fname = self.feature_fnames[idx]
136
+ image_ext = self._file_ext(image_fname)
137
+ with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
138
+ if image_ext == '.npy':
139
+ image = np.load(f)
140
+ image = image.reshape(-1, *image.shape[-2:])
141
+ elif image_ext == '.png' and pyspng is not None:
142
+ image = pyspng.load(f.read())
143
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
144
+ else:
145
+ image = np.array(PIL.Image.open(f))
146
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
147
+
148
+ features = np.load(os.path.join(self.features_dir, feature_fname))
149
+ return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
REG/eval.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ random_number=$((RANDOM % 100 + 1200))
3
+ NUM_GPUS=8
4
+ STEP="4000000"
5
+ SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
6
+ VAE_PATH="your_vae_path/"
7
+ NUM_STEP=250
8
+ MODEL_SIZE='XL'
9
+ CFG_SCALE=2.3
10
+ CLS_CFG_SCALE=2.3
11
+ GH=0.85
12
+
13
+ export NCCL_P2P_DISABLE=1
14
+
15
+ python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
16
+ --model SiT-XL/2 \
17
+ --num-fid-samples 50000 \
18
+ --ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
19
+ --path-type=linear \
20
+ --encoder-depth=8 \
21
+ --projector-embed-dims=768 \
22
+ --per-proc-batch-size=64 \
23
+ --mode=sde \
24
+ --num-steps=${NUM_STEP} \
25
+ --cfg-scale=${CFG_SCALE} \
26
+ --cls-cfg-scale=${CLS_CFG_SCALE} \
27
+ --guidance-high=${GH} \
28
+ --sample-dir ${SAVE_PATH}/checkpoints \
29
+ --cls=768
30
+
31
+
32
+ python ./evaluations/evaluator.py \
33
+ --ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
34
+ --sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
35
+ --save_path ${SAVE_PATH}/checkpoints \
36
+ --cfg_cond 1 \
37
+ --step ${STEP} \
38
+ --num_steps ${NUM_STEP} \
39
+ --cfg ${CFG_SCALE} \
40
+ --cls_cfg ${CLS_CFG_SCALE} \
41
+ --gh ${GH}
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
REG/eval_custom_0.25.log ADDED
@@ -0,0 +1 @@
 
 
1
+ python: can't open file 'fid_custom.py': [Errno 2] No such file or directory
REG/generate.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Samples a large number of images from a pre-trained SiT model using DDP.
9
+ Subsequently saves a .npz file that can be used to compute FID and other
10
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11
+
12
+ For a simple single-GPU/CPU sampling script, see sample.py.
13
+ """
14
+ import torch
15
+ import torch.distributed as dist
16
+ from models.sit import SiT_models
17
+ from diffusers.models import AutoencoderKL
18
+ from tqdm import tqdm
19
+ import os
20
+ from PIL import Image
21
+ import numpy as np
22
+ import math
23
+ import argparse
24
+ from samplers import euler_maruyama_sampler
25
+ from utils import load_legacy_checkpoints, download_model
26
+
27
+
28
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
29
+ """
30
+ Builds a single .npz file from a folder of .png samples.
31
+ """
32
+ samples = []
33
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
34
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
35
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
36
+ samples.append(sample_np)
37
+ samples = np.stack(samples)
38
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
39
+ npz_path = f"{sample_dir}.npz"
40
+ np.savez(npz_path, arr_0=samples)
41
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
42
+ return npz_path
43
+
44
+
45
+ def main(args):
46
+ """
47
+ Run sampling.
48
+ """
49
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
50
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
51
+ torch.set_grad_enabled(False)
52
+
53
+ # Setup DDP:cd
54
+ dist.init_process_group("nccl")
55
+ rank = dist.get_rank()
56
+ device = rank % torch.cuda.device_count()
57
+ seed = args.global_seed * dist.get_world_size() + rank
58
+ torch.manual_seed(seed)
59
+ torch.cuda.set_device(device)
60
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
61
+
62
+ # Load model:
63
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
64
+ latent_size = args.resolution // 8
65
+ model = SiT_models[args.model](
66
+ input_size=latent_size,
67
+ num_classes=args.num_classes,
68
+ use_cfg = True,
69
+ z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
70
+ encoder_depth=args.encoder_depth,
71
+ **block_kwargs,
72
+ ).to(device)
73
+ # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
74
+ ckpt_path = args.ckpt
75
+
76
+
77
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
78
+ if ckpt_path is None:
79
+ args.ckpt = 'SiT-XL-2-256x256.pt'
80
+ assert args.model == 'SiT-XL/2'
81
+ assert len(args.projector_embed_dims.split(',')) == 1
82
+ assert int(args.projector_embed_dims.split(',')[0]) == 768
83
+ state_dict = download_model('last.pt')
84
+ else:
85
+ state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
86
+
87
+ if args.legacy:
88
+ state_dict = load_legacy_checkpoints(
89
+ state_dict=state_dict, encoder_depth=args.encoder_depth
90
+ )
91
+ model.load_state_dict(state_dict)
92
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
93
+
94
+
95
+ model.eval() # important!
96
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
97
+ #vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
98
+
99
+
100
+ # Create folder to save samples:
101
+ model_string_name = args.model.replace("/", "-")
102
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
103
+ folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
104
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
105
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
106
+ if rank == 0:
107
+ os.makedirs(sample_folder_dir, exist_ok=True)
108
+ print(f"Saving .png samples at {sample_folder_dir}")
109
+ dist.barrier()
110
+
111
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
112
+ n = args.per_proc_batch_size
113
+ global_batch_size = n * dist.get_world_size()
114
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
115
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
116
+ if rank == 0:
117
+ print(f"Total number of images that will be sampled: {total_samples}")
118
+ print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
119
+ print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
120
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
121
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
122
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
123
+ iterations = int(samples_needed_this_gpu // n)
124
+ pbar = range(iterations)
125
+ pbar = tqdm(pbar) if rank == 0 else pbar
126
+ total = 0
127
+ for _ in pbar:
128
+ # Sample inputs:
129
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
130
+ y = torch.randint(0, args.num_classes, (n,), device=device)
131
+ cls_z = torch.randn(n, args.cls, device=device)
132
+
133
+ # Sample images:
134
+ sampling_kwargs = dict(
135
+ model=model,
136
+ latents=z,
137
+ y=y,
138
+ num_steps=args.num_steps,
139
+ heun=args.heun,
140
+ cfg_scale=args.cfg_scale,
141
+ guidance_low=args.guidance_low,
142
+ guidance_high=args.guidance_high,
143
+ path_type=args.path_type,
144
+ cls_latents=cls_z,
145
+ args=args
146
+ )
147
+ with torch.no_grad():
148
+ if args.mode == "sde":
149
+ samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
150
+ elif args.mode == "ode":# will support
151
+ exit()
152
+ #samples = euler_sampler(**sampling_kwargs).to(torch.float32)
153
+ else:
154
+ raise NotImplementedError()
155
+
156
+ latents_scale = torch.tensor(
157
+ [0.18215, 0.18215, 0.18215, 0.18215, ]
158
+ ).view(1, 4, 1, 1).to(device)
159
+ latents_bias = -torch.tensor(
160
+ [0., 0., 0., 0.,]
161
+ ).view(1, 4, 1, 1).to(device)
162
+ samples = vae.decode((samples - latents_bias) / latents_scale).sample
163
+ samples = (samples + 1) / 2.
164
+ samples = torch.clamp(
165
+ 255. * samples, 0, 255
166
+ ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
167
+
168
+ # Save samples to disk as individual .png files
169
+ for i, sample in enumerate(samples):
170
+ index = i * dist.get_world_size() + rank + total
171
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
172
+ total += global_batch_size
173
+
174
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
175
+ dist.barrier()
176
+ if rank == 0:
177
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
178
+ print("Done.")
179
+ dist.barrier()
180
+ dist.destroy_process_group()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ parser = argparse.ArgumentParser()
185
+ # seed
186
+ parser.add_argument("--global-seed", type=int, default=0)
187
+
188
+ # precision
189
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
190
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
191
+
192
+ # logging/saving:
193
+ parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
194
+ parser.add_argument("--sample-dir", type=str, default="samples")
195
+
196
+ # model
197
+ parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
198
+ parser.add_argument("--num-classes", type=int, default=1000)
199
+ parser.add_argument("--encoder-depth", type=int, default=8)
200
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
201
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
202
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
203
+ # vae
204
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
205
+
206
+ # number of samples
207
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
208
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
209
+
210
+ # sampling related hyperparameters
211
+ parser.add_argument("--mode", type=str, default="ode")
212
+ parser.add_argument("--cfg-scale", type=float, default=1.5)
213
+ parser.add_argument("--cls-cfg-scale", type=float, default=1.5)
214
+ parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
215
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
216
+ parser.add_argument("--num-steps", type=int, default=50)
217
+ parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
218
+ parser.add_argument("--guidance-low", type=float, default=0.)
219
+ parser.add_argument("--guidance-high", type=float, default=1.)
220
+ parser.add_argument('--local-rank', default=-1, type=int)
221
+ parser.add_argument('--cls', default=768, type=int)
222
+ # will be deprecated
223
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
224
+
225
+
226
+ args = parser.parse_args()
227
+ main(args)
REG/loss.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from scipy.optimize import linear_sum_assignment
7
+ except ImportError:
8
+ linear_sum_assignment = None
9
+
10
+
11
+ def ot_pair_noise_to_cls(noise_cls, cls_gt):
12
+ """
13
+ Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致):
14
+ 在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。
15
+ noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。
16
+ """
17
+ n = noise_cls.shape[0]
18
+ if n <= 1:
19
+ return noise_cls, cls_gt
20
+ if linear_sum_assignment is None:
21
+ return noise_cls, cls_gt
22
+ x0 = noise_cls.detach().float().reshape(n, -1)
23
+ x1 = cls_gt.detach().float().reshape(n, -1)
24
+ M = torch.cdist(x0, x1) ** 2
25
+ _, j = linear_sum_assignment(M.cpu().numpy())
26
+ j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long)
27
+ return noise_cls[j], cls_gt
28
+
29
+
30
+ def mean_flat(x):
31
+ """
32
+ Take the mean over all non-batch dimensions.
33
+ """
34
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
35
+
36
+ def sum_flat(x):
37
+ """
38
+ Take the mean over all non-batch dimensions.
39
+ """
40
+ return torch.sum(x, dim=list(range(1, len(x.size()))))
41
+
42
+ class SILoss:
43
+ def __init__(
44
+ self,
45
+ prediction='v',
46
+ path_type="linear",
47
+ weighting="uniform",
48
+ encoders=[],
49
+ accelerator=None,
50
+ latents_scale=None,
51
+ latents_bias=None,
52
+ t_c=0.5,
53
+ ot_cls=True,
54
+ ):
55
+ self.prediction = prediction
56
+ self.weighting = weighting
57
+ self.path_type = path_type
58
+ self.encoders = encoders
59
+ self.accelerator = accelerator
60
+ self.latents_scale = latents_scale
61
+ self.latents_bias = latents_bias
62
+ # t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。
63
+ # t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道);
64
+ # t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。
65
+ tc = float(t_c)
66
+ self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4)
67
+ self.ot_cls = bool(ot_cls)
68
+
69
+ def interpolant(self, t):
70
+ if self.path_type == "linear":
71
+ alpha_t = 1 - t
72
+ sigma_t = t
73
+ d_alpha_t = -1
74
+ d_sigma_t = 1
75
+ elif self.path_type == "cosine":
76
+ alpha_t = torch.cos(t * np.pi / 2)
77
+ sigma_t = torch.sin(t * np.pi / 2)
78
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
79
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
80
+ else:
81
+ raise NotImplementedError()
82
+
83
+ return alpha_t, sigma_t, d_alpha_t, d_sigma_t
84
+
85
+ def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
86
+ time_input=None, noises=None,):
87
+ if model_kwargs == None:
88
+ model_kwargs = {}
89
+ # sample timesteps
90
+ if time_input is None:
91
+ if self.weighting == "uniform":
92
+ time_input = torch.rand((images.shape[0], 1, 1, 1))
93
+ elif self.weighting == "lognormal":
94
+ # sample timestep according to log-normal distribution of sigmas following EDM
95
+ rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
96
+ sigma = rnd_normal.exp()
97
+ if self.path_type == "linear":
98
+ time_input = sigma / (1 + sigma)
99
+ elif self.path_type == "cosine":
100
+ time_input = 2 / np.pi * torch.atan(sigma)
101
+
102
+ time_input = time_input.to(device=images.device, dtype=torch.float32)
103
+ cls_token = cls_token.to(device=images.device, dtype=torch.float32)
104
+
105
+ if noises is None:
106
+ noises = torch.randn_like(images)
107
+ noises_cls = torch.randn_like(cls_token)
108
+ else:
109
+ if isinstance(noises, (tuple, list)) and len(noises) == 2:
110
+ noises, noises_cls = noises
111
+ else:
112
+ noises_cls = torch.randn_like(cls_token)
113
+
114
+ alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
115
+
116
+ model_input = alpha_t * images + sigma_t * noises
117
+ if self.prediction == 'v':
118
+ model_target = d_alpha_t * images + d_sigma_t * noises
119
+ else:
120
+ raise NotImplementedError()
121
+
122
+ N = images.shape[0]
123
+ t_flat = time_input.view(-1).float()
124
+ high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1)))
125
+ low_noise_mask = 1.0 - high_noise_mask
126
+
127
+ noise_cls_raw = noises_cls
128
+ if self.ot_cls:
129
+ noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token)
130
+ else:
131
+ noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token
132
+
133
+ tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1)
134
+ tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8)
135
+ tau = torch.clamp(tau, 0.0, 1.0)
136
+ alpha_sem = 1.0 - tau
137
+ sigma_sem = tau
138
+
139
+ cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired
140
+ cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token
141
+ cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4)
142
+ cls_t = torch.clamp(cls_t, -1e4, 1e4)
143
+
144
+ cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask
145
+
146
+ inv_scale = 1.0 / (1.0 - self.t_c + 1e-8)
147
+ v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale
148
+ v_cls_target = high_noise_mask * v_cls_high
149
+
150
+ model_output, zs_tilde, cls_output = model(
151
+ model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model
152
+ )
153
+
154
+ #denoising_loss
155
+ denoising_loss = mean_flat((model_output - model_target) ** 2)
156
+ denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2)
157
+
158
+ # projection loss
159
+ proj_loss = 0.
160
+ bsz = zs[0].shape[0]
161
+ for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
162
+ for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
163
+ z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
164
+ z_j = torch.nn.functional.normalize(z_j, dim=-1)
165
+ proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
166
+ proj_loss /= (len(zs) * bsz)
167
+
168
+ return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
169
+
170
+ def tc_velocity_loss(self, model, images, model_kwargs=None, cls_token=None, noises=None):
171
+ """
172
+ 额外约束:在 t=t_c 处直接监督图像 velocity 场,增强单步(t_c -> 0)稳定性。
173
+ 仅作用于图像分支,不改变原有 cls/projection 主损失定义。
174
+ """
175
+ if model_kwargs is None:
176
+ model_kwargs = {}
177
+ if cls_token is None:
178
+ raise ValueError("tc_velocity_loss requires cls_token")
179
+ if noises is None:
180
+ noises = torch.randn_like(images)
181
+
182
+ bsz = images.shape[0]
183
+ time_input = torch.full(
184
+ (bsz, 1, 1, 1), float(self.t_c), device=images.device, dtype=torch.float32
185
+ )
186
+ alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
187
+ model_input = alpha_t * images + sigma_t * noises
188
+ model_target = d_alpha_t * images + d_sigma_t * noises
189
+
190
+ model_output, _, _ = model(
191
+ model_input, time_input.flatten(), **model_kwargs, cls_token=cls_token
192
+ )
193
+ return mean_flat((model_output - model_target) ** 2)
REG/requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - pip:
2
+ absl-py==2.2.2
3
+ accelerate==1.2.1
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.11.16
6
+ aiosignal==1.3.2
7
+ astunparse==1.6.3
8
+ async-timeout==5.0.1
9
+ attrs==25.3.0
10
+ certifi==2022.12.7
11
+ charset-normalizer==2.1.1
12
+ click==8.1.8
13
+ datasets==2.20.0
14
+ diffusers==0.32.1
15
+ dill==0.3.8
16
+ docker-pycreds==0.4.0
17
+ einops==0.8.1
18
+ filelock==3.13.1
19
+ flatbuffers==25.2.10
20
+ frozenlist==1.5.0
21
+ fsspec==2024.5.0
22
+ ftfy==6.3.1
23
+ gast==0.6.0
24
+ gitdb==4.0.12
25
+ gitpython==3.1.44
26
+ google-pasta==0.2.0
27
+ grpcio==1.71.0
28
+ h5py==3.13.0
29
+ huggingface-hub==0.27.1
30
+ idna==3.4
31
+ importlib-metadata==8.6.1
32
+ jinja2==3.1.4
33
+ joblib==1.4.2
34
+ keras==3.9.2
35
+ libclang==18.1.1
36
+ markdown==3.8
37
+ markdown-it-py==3.0.0
38
+ markupsafe==2.1.5
39
+ mdurl==0.1.2
40
+ ml-dtypes==0.3.2
41
+ mpmath==1.3.0
42
+ multidict==6.4.3
43
+ multiprocess==0.70.16
44
+ namex==0.0.8
45
+ networkx==3.3
46
+ numpy==1.26.4
47
+ opt-einsum==3.4.0
48
+ optree==0.15.0
49
+ packaging==24.2
50
+ pandas==2.2.3
51
+ pillow==11.0.0
52
+ platformdirs==4.3.7
53
+ propcache==0.3.1
54
+ protobuf==4.25.6
55
+ psutil==7.0.0
56
+ pyarrow==19.0.1
57
+ pyarrow-hotfix==0.6
58
+ pygments==2.19.1
59
+ python-dateutil==2.9.0.post0
60
+ pytz==2025.2
61
+ pyyaml==6.0.2
62
+ regex==2024.11.6
63
+ requests==2.32.3
64
+ rich==14.0.0
65
+ safetensors==0.5.3
66
+ scikit-learn==1.5.1
67
+ scipy==1.15.2
68
+ sentry-sdk==2.26.1
69
+ setproctitle==1.3.5
70
+ six==1.17.0
71
+ smmap==5.0.2
72
+ sympy==1.13.1
73
+ tensorboard==2.16.1
74
+ tensorboard-data-server==0.7.2
75
+ tensorflow==2.16.1
76
+ tensorflow-io-gcs-filesystem==0.37.1
77
+ termcolor==3.0.1
78
+ tf-keras==2.16.0
79
+ threadpoolctl==3.6.0
80
+ timm==1.0.12
81
+ tokenizers==0.21.0
82
+ tqdm==4.67.1
83
+ transformers==4.47.0
84
+ triton==2.1.0
85
+ typing-extensions==4.12.2
86
+ tzdata==2025.2
87
+ urllib3==1.26.13
88
+ wandb==0.17.6
89
+ wcwidth==0.2.13
90
+ werkzeug==3.1.3
91
+ wrapt==1.17.2
92
+ xformer==1.0.1
93
+ xformers==0.0.23
94
+ xxhash==3.5.0
95
+ yarl==1.20.0
96
+ zipp==3.21.0
97
+
REG/sample_from_checkpoint.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 从 REG/train.py 保存的检查点加载权重,在指定目录生成若干 PNG。
4
+
5
+ 示例:
6
+ python sample_from_checkpoint.py \\
7
+ --ckpt exps/jsflow-experiment/checkpoints/0050000.pt \\
8
+ --out-dir ./samples_gen \\
9
+ --num-images 64 \\
10
+ --batch-size 8
11
+
12
+ # 按训练 t_c 分段分配步数(t=1→t_c 与 t_c→0;--t-c 可省略若检查点含 t_c):
13
+ python sample_from_checkpoint.py ... \\
14
+ --steps-before-tc 150 --steps-after-tc 100 --t-c 0.5
15
+
16
+ # 同一批初始噪声连跑两种 t_c 后段步数(输出到 out-dir 下子目录):
17
+ python sample_from_checkpoint.py ... \\
18
+ --steps-before-tc 150 --steps-after-tc 5 --dual-compare-after
19
+ # 分段时会在 at_tc/(或 at_tc/after_input、at_tc/after_equal_before)额外保存 t≈t_c 的解码图。
20
+
21
+ 检查点需包含 train.py 写入的键:ema(或 model)、args(推荐,用于自动还原结构)。
22
+ 若缺少 args,需通过命令行显式传入 --model、--resolution、--enc-type 等。
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import os
29
+ import sys
30
+ import types
31
+ import numpy as np
32
+ import torch
33
+ from diffusers.models import AutoencoderKL
34
+ from PIL import Image
35
+ from tqdm import tqdm
36
+
37
+ from models.sit import SiT_models
38
+ from samplers import (
39
+ euler_maruyama_image_noise_before_tc_sampler,
40
+ euler_maruyama_image_noise_sampler,
41
+ euler_maruyama_sampler,
42
+ euler_ode_sampler,
43
+ )
44
+
45
+
46
+ def semantic_dim_from_enc_type(enc_type):
47
+ """与 train.py 一致:按 enc_type 推断语义/class token 维度。"""
48
+ if enc_type is None:
49
+ return 768
50
+ s = str(enc_type).lower()
51
+ if "vit-g" in s or "vitg" in s:
52
+ return 1536
53
+ if "vit-l" in s or "vitl" in s:
54
+ return 1024
55
+ if "vit-s" in s or "vits" in s:
56
+ return 384
57
+ return 768
58
+
59
+
60
+ def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
61
+ a = ckpt.get("args")
62
+ if a is None:
63
+ return None
64
+ if isinstance(a, argparse.Namespace):
65
+ return a
66
+ if isinstance(a, dict):
67
+ return argparse.Namespace(**a)
68
+ if isinstance(a, types.SimpleNamespace):
69
+ return argparse.Namespace(**vars(a))
70
+ return None
71
+
72
+
73
+ def load_vae(device: torch.device):
74
+ """与 train.py 相同策略:优先本地 diffusers 缓存中的 sd-vae-ft-mse。"""
75
+ try:
76
+ from preprocessing import dnnlib
77
+
78
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
79
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
80
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
81
+ os.environ["HF_HOME"] = cache_dir
82
+ try:
83
+ vae = AutoencoderKL.from_pretrained(
84
+ "stabilityai/sd-vae-ft-mse",
85
+ cache_dir=cache_dir,
86
+ local_files_only=True,
87
+ ).to(device)
88
+ vae.eval()
89
+ print(f"Loaded VAE from local cache: {cache_dir}")
90
+ return vae
91
+ except Exception:
92
+ pass
93
+ candidate_dir = None
94
+ for root_dir in [
95
+ cache_dir,
96
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
97
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
98
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
99
+ ]:
100
+ if not os.path.isdir(root_dir):
101
+ continue
102
+ for root, _, files in os.walk(root_dir):
103
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
104
+ candidate_dir = root
105
+ break
106
+ if candidate_dir is not None:
107
+ break
108
+ if candidate_dir is not None:
109
+ vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
110
+ vae.eval()
111
+ print(f"Loaded VAE from {candidate_dir}")
112
+ return vae
113
+ except Exception as e:
114
+ print(f"VAE local cache search failed: {e}", file=sys.stderr)
115
+ try:
116
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
117
+ vae.eval()
118
+ print("Loaded VAE from Hub: stabilityai/sd-vae-ft-mse")
119
+ return vae
120
+ except Exception as e:
121
+ raise RuntimeError(
122
+ "无法加载 VAE stabilityai/sd-vae-ft-mse,请确认已下载或网络可用。"
123
+ ) from e
124
+
125
+
126
+ def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
127
+ res = int(getattr(ta, "resolution", 256))
128
+ latent_size = res // 8
129
+ enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
130
+ z_dims = [semantic_dim_from_enc_type(enc_type)]
131
+ block_kwargs = {
132
+ "fused_attn": getattr(ta, "fused_attn", True),
133
+ "qk_norm": getattr(ta, "qk_norm", False),
134
+ }
135
+ cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
136
+ if ta.model not in SiT_models:
137
+ raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
138
+ model = SiT_models[ta.model](
139
+ input_size=latent_size,
140
+ num_classes=int(getattr(ta, "num_classes", 1000)),
141
+ use_cfg=(cfg_prob > 0),
142
+ z_dims=z_dims,
143
+ encoder_depth=int(getattr(ta, "encoder_depth", 8)),
144
+ **block_kwargs,
145
+ ).to(device)
146
+ return model, z_dims[0]
147
+
148
+
149
+ def resolve_tc_schedule(cli, ta):
150
+ """
151
+ 若同时给出 --steps-before-tc 与 --steps-after-tc:在 t_c 处分段(--t-c 缺省则用检查点 args.t_c)。
152
+ 否则使用均匀 --num-steps(与旧版一致)。
153
+ """
154
+ sb = cli.steps_before_tc
155
+ sa = cli.steps_after_tc
156
+ tc = cli.t_c
157
+ if sb is None and sa is None:
158
+ return None, None, None
159
+ if sb is None or sa is None:
160
+ print(
161
+ "使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。",
162
+ file=sys.stderr,
163
+ )
164
+ sys.exit(1)
165
+ if tc is None:
166
+ tc = getattr(ta, "t_c", None) if ta is not None else None
167
+ if tc is None:
168
+ print(
169
+ "分段采样需要 --t-c,或检查点 args 中含 t_c。",
170
+ file=sys.stderr,
171
+ )
172
+ sys.exit(1)
173
+ return float(tc), int(sb), int(sa)
174
+
175
+
176
+ def parse_cli():
177
+ p = argparse.ArgumentParser(description="REG 检查点采样出图(可选 ODE/EM/EM-图像噪声)")
178
+ p.add_argument("--ckpt", type=str, required=True, help="train.py 保存的 .pt 路径")
179
+ p.add_argument("--out-dir", type=str, required=True, help="输出 PNG 目录(会创建)")
180
+ p.add_argument("--num-images", type=int, required=True, help="生成图片总数")
181
+ p.add_argument("--batch-size", type=int, default=16)
182
+ p.add_argument("--seed", type=int, default=0)
183
+ p.add_argument(
184
+ "--weights",
185
+ type=str,
186
+ choices=("ema", "model"),
187
+ default="ema",
188
+ help="使用检查点中的 ema 或 model 权重",
189
+ )
190
+ p.add_argument("--device", type=str, default="cuda", help="如 cuda 或 cuda:0")
191
+ p.add_argument(
192
+ "--num-steps",
193
+ type=int,
194
+ default=50,
195
+ help="均匀时间网格时的欧拉步数(未使用 --steps-before-tc/--steps-after-tc 时生效)",
196
+ )
197
+ p.add_argument(
198
+ "--t-c",
199
+ type=float,
200
+ default=None,
201
+ help="分段时刻:t∈(t_c,1] 与 t∈[0,t_c] 两段;缺省可用检查点 args.t_c(需配合两段步数)",
202
+ )
203
+ p.add_argument(
204
+ "--steps-before-tc",
205
+ type=int,
206
+ default=None,
207
+ help="从 t=1 积分到 t=t_c 的步数(与 --steps-after-tc 成对使用)",
208
+ )
209
+ p.add_argument(
210
+ "--steps-after-tc",
211
+ type=int,
212
+ default=None,
213
+ help="从 t=t_c 积分到 t=0(经 t_floor=0.04)的步数",
214
+ )
215
+ p.add_argument("--cfg-scale", type=float, default=1.0)
216
+ p.add_argument("--cls-cfg-scale", type=float, default=0.0, help="cls 分支 CFG(>0 时需 cfg-scale>1)")
217
+ p.add_argument("--guidance-low", type=float, default=0.0)
218
+ p.add_argument("--guidance-high", type=float, default=1.0)
219
+ p.add_argument(
220
+ "--path-type",
221
+ type=str,
222
+ default=None,
223
+ choices=["linear", "cosine"],
224
+ help="默认从检查点 args 读取;可覆盖",
225
+ )
226
+ p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
227
+ # 无 args 时的兜底
228
+ p.add_argument("--model", type=str, default=None, help="无检查点 args 时必填;与 SiT_models 键一致,如 SiT-XL/2")
229
+ p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
230
+ p.add_argument("--num-classes", type=int, default=None)
231
+ p.add_argument("--encoder-depth", type=int, default=None)
232
+ p.add_argument("--enc-type", type=str, default=None)
233
+ p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
234
+ p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
235
+ p.add_argument("--cfg-prob", type=float, default=None)
236
+ p.add_argument(
237
+ "--sampler",
238
+ type=str,
239
+ default="em_image_noise",
240
+ choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
241
+ help="采样器:ode=euler_sampler 确定性漂移(linspace 1→0 或 t_c 分段直连 0,无 t_floor;与 EM 网格不同),"
242
+ "em=标准EM(含图像+cls噪声),em_image_noise=仅图像噪声,"
243
+ "em_image_noise_before_tc=t<=t_c时图像去随机+cls全程去随机",
244
+ )
245
+ p.add_argument(
246
+ "--dual-compare-after",
247
+ action="store_true",
248
+ help="需配合分段步数:同批 z/y/cls 连跑两次;after_input 用 --steps-after-tc,"
249
+ "after_equal_before 将 after 步数设为与 --steps-before-tc 相同",
250
+ )
251
+ p.add_argument(
252
+ "--save-fixed-trajectory",
253
+ action="store_true",
254
+ help="保存固定步采样轨迹(npy);仅对非 em 采样器启用,输出在 out-dir/trajectory",
255
+ )
256
+ return p.parse_args()
257
+
258
+
259
+ def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
260
+ imgs = vae.decode((latents - latents_bias) / latents_scale).sample
261
+ imgs = (imgs + 1) / 2.0
262
+ imgs = torch.clamp(imgs, 0, 1)
263
+ return (
264
+ (imgs * 255.0)
265
+ .round()
266
+ .to(torch.uint8)
267
+ .permute(0, 2, 3, 1)
268
+ .cpu()
269
+ .numpy()
270
+ )
271
+
272
+
273
+ def main():
274
+ cli = parse_cli()
275
+ device = torch.device(cli.device if torch.cuda.is_available() else "cpu")
276
+ if device.type == "cuda":
277
+ torch.backends.cuda.matmul.allow_tf32 = True
278
+
279
+ try:
280
+ ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
281
+ except TypeError:
282
+ ckpt = torch.load(cli.ckpt, map_location="cpu")
283
+ ta = load_train_args_from_ckpt(ckpt)
284
+ if ta is None:
285
+ if cli.model is None or cli.resolution is None or cli.enc_type is None:
286
+ print(
287
+ "检查点中无 args,请至少指定:--model --resolution --enc-type "
288
+ "(以及按需 --num-classes --encoder-depth)",
289
+ file=sys.stderr,
290
+ )
291
+ sys.exit(1)
292
+ ta = argparse.Namespace(
293
+ model=cli.model,
294
+ resolution=cli.resolution,
295
+ num_classes=cli.num_classes if cli.num_classes is not None else 1000,
296
+ encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
297
+ enc_type=cli.enc_type,
298
+ fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
299
+ qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
300
+ cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
301
+ )
302
+ else:
303
+ if cli.model is not None:
304
+ ta.model = cli.model
305
+ if cli.resolution is not None:
306
+ ta.resolution = cli.resolution
307
+ if cli.num_classes is not None:
308
+ ta.num_classes = cli.num_classes
309
+ if cli.encoder_depth is not None:
310
+ ta.encoder_depth = cli.encoder_depth
311
+ if cli.enc_type is not None:
312
+ ta.enc_type = cli.enc_type
313
+ if cli.fused_attn is not None:
314
+ ta.fused_attn = cli.fused_attn
315
+ if cli.qk_norm is not None:
316
+ ta.qk_norm = cli.qk_norm
317
+ if cli.cfg_prob is not None:
318
+ ta.cfg_prob = cli.cfg_prob
319
+
320
+ path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
321
+
322
+ tc_split = resolve_tc_schedule(cli, ta)
323
+ if cli.dual_compare_after and tc_split[0] is None:
324
+ print("--dual-compare-after 必须配合 --steps-before-tc 与 --steps-after-tc(分段采样)", file=sys.stderr)
325
+ sys.exit(1)
326
+ if tc_split[0] is not None:
327
+ if cli.dual_compare_after:
328
+ print(
329
+ f"双次对比:t_c={tc_split[0]}, before={tc_split[1]}, "
330
+ f"after_input={tc_split[2]}, after_equal_before={tc_split[1]}"
331
+ )
332
+ else:
333
+ print(
334
+ f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]} "
335
+ f"(总模型前向约 {tc_split[1] + tc_split[2] + 1} 次)"
336
+ )
337
+ else:
338
+ print(f"时间网格:均匀 num_steps={cli.num_steps}")
339
+
340
+ if cli.sampler == "ode":
341
+ sampler_fn = euler_ode_sampler
342
+ elif cli.sampler == "em":
343
+ sampler_fn = euler_maruyama_sampler
344
+ elif cli.sampler == "em_image_noise_before_tc":
345
+ sampler_fn = euler_maruyama_image_noise_before_tc_sampler
346
+ else:
347
+ sampler_fn = euler_maruyama_image_noise_sampler
348
+
349
+ model, cls_dim = build_model_from_train_args(ta, device)
350
+ wkey = cli.weights
351
+ if wkey not in ckpt:
352
+ raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
353
+ state = ckpt[wkey]
354
+ if cli.legacy:
355
+ from utils import load_legacy_checkpoints
356
+
357
+ state = load_legacy_checkpoints(
358
+ state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
359
+ )
360
+ model.load_state_dict(state, strict=True)
361
+ model.eval()
362
+
363
+ vae = load_vae(device)
364
+ latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
365
+ latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
366
+
367
+ sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
368
+
369
+ at_tc_dir = at_tc_a = at_tc_b = None
370
+ pair_dir = None
371
+ traj_dir = traj_a = traj_b = None
372
+ if cli.dual_compare_after:
373
+ out_a = os.path.join(cli.out_dir, "after_input")
374
+ out_b = os.path.join(cli.out_dir, "after_equal_before")
375
+ pair_dir = os.path.join(cli.out_dir, "pair")
376
+ os.makedirs(out_a, exist_ok=True)
377
+ os.makedirs(out_b, exist_ok=True)
378
+ os.makedirs(pair_dir, exist_ok=True)
379
+ if tc_split[0] is not None:
380
+ at_tc_a = os.path.join(cli.out_dir, "at_tc", "after_input")
381
+ at_tc_b = os.path.join(cli.out_dir, "at_tc", "after_equal_before")
382
+ os.makedirs(at_tc_a, exist_ok=True)
383
+ os.makedirs(at_tc_b, exist_ok=True)
384
+ if cli.save_fixed_trajectory and cli.sampler != "em":
385
+ traj_a = os.path.join(cli.out_dir, "trajectory", "after_input")
386
+ traj_b = os.path.join(cli.out_dir, "trajectory", "after_equal_before")
387
+ os.makedirs(traj_a, exist_ok=True)
388
+ os.makedirs(traj_b, exist_ok=True)
389
+ else:
390
+ os.makedirs(cli.out_dir, exist_ok=True)
391
+ if tc_split[0] is not None:
392
+ at_tc_dir = os.path.join(cli.out_dir, "at_tc")
393
+ os.makedirs(at_tc_dir, exist_ok=True)
394
+ if cli.save_fixed_trajectory and cli.sampler != "em":
395
+ traj_dir = os.path.join(cli.out_dir, "trajectory")
396
+ os.makedirs(traj_dir, exist_ok=True)
397
+ latent_size = int(getattr(ta, "resolution", 256)) // 8
398
+ n_total = int(cli.num_images)
399
+ b = max(1, int(cli.batch_size))
400
+
401
+ torch.manual_seed(cli.seed)
402
+ if device.type == "cuda":
403
+ torch.cuda.manual_seed_all(cli.seed)
404
+
405
+ written = 0
406
+ pbar = tqdm(total=n_total, desc="sampling")
407
+ while written < n_total:
408
+ cur = min(b, n_total - written)
409
+ z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
410
+ y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
411
+ cls_z = torch.randn(cur, cls_dim, device=device)
412
+
413
+ with torch.no_grad():
414
+ base_kw = dict(
415
+ num_steps=cli.num_steps,
416
+ cfg_scale=cli.cfg_scale,
417
+ guidance_low=cli.guidance_low,
418
+ guidance_high=cli.guidance_high,
419
+ path_type=path_type,
420
+ cls_latents=cls_z,
421
+ args=sampler_args,
422
+ )
423
+ if cli.dual_compare_after:
424
+ tc_v, sb, sa_in = tc_split
425
+ # 两次完整采样会各自消耗 RNG;不重置则第二条的 1→t_c 噪声与第一条不同,z_tc/at_tc 会对不齐。
426
+ # 在固定 z/y/cls_z 之后打快照,第二条运行前恢复,使 t_c 中间态一致(仅后段步数不同)。
427
+ _rng_cpu_dual = torch.random.get_rng_state()
428
+ _rng_cuda_dual = (
429
+ torch.cuda.get_rng_state_all()
430
+ if device.type == "cuda"
431
+ else None
432
+ )
433
+ batch_imgs = {}
434
+ for _run_i, (subdir, sa, tc_save_dir) in enumerate(
435
+ (
436
+ (out_a, sa_in, at_tc_a),
437
+ (out_b, sb, at_tc_b),
438
+ )
439
+ ):
440
+ if _run_i > 0:
441
+ torch.random.set_rng_state(_rng_cpu_dual)
442
+ if _rng_cuda_dual is not None:
443
+ torch.cuda.set_rng_state_all(_rng_cuda_dual)
444
+ em_kw = dict(base_kw)
445
+ em_kw["t_c"] = tc_v
446
+ em_kw["num_steps_before_tc"] = sb
447
+ em_kw["num_steps_after_tc"] = sa
448
+ if cli.sampler == "em_image_noise_before_tc":
449
+ if cli.save_fixed_trajectory and cli.sampler != "em":
450
+ latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
451
+ model,
452
+ z,
453
+ y,
454
+ **em_kw,
455
+ return_mid_state=True,
456
+ t_mid=float(tc_v),
457
+ return_cls_final=True,
458
+ return_trajectory=True,
459
+ )
460
+ else:
461
+ latents, z_tc, cls_tc, cls_t0 = sampler_fn(
462
+ model,
463
+ z,
464
+ y,
465
+ **em_kw,
466
+ return_mid_state=True,
467
+ t_mid=float(tc_v),
468
+ return_cls_final=True,
469
+ )
470
+ traj = None
471
+ else:
472
+ if cli.save_fixed_trajectory and cli.sampler != "em":
473
+ latents, z_tc, cls_tc, traj = sampler_fn(
474
+ model,
475
+ z,
476
+ y,
477
+ **em_kw,
478
+ return_mid_state=True,
479
+ t_mid=float(tc_v),
480
+ return_trajectory=True,
481
+ )
482
+ else:
483
+ latents, z_tc, cls_tc = sampler_fn(
484
+ model,
485
+ z,
486
+ y,
487
+ **em_kw,
488
+ return_mid_state=True,
489
+ t_mid=float(tc_v),
490
+ )
491
+ traj = None
492
+ cls_t0 = None
493
+ latents = latents.to(torch.float32)
494
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
495
+ batch_imgs[subdir] = imgs
496
+ for i in range(cur):
497
+ Image.fromarray(imgs[i]).save(
498
+ os.path.join(subdir, f"{written + i:06d}.png")
499
+ )
500
+ if tc_save_dir is not None and z_tc is not None:
501
+ imgs_tc = _decode_to_uint8_hwc(
502
+ z_tc.to(torch.float32), latents_bias, latents_scale, vae
503
+ )
504
+ for i in range(cur):
505
+ Image.fromarray(imgs_tc[i]).save(
506
+ os.path.join(tc_save_dir, f"{written + i:06d}.png")
507
+ )
508
+ if traj is not None:
509
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
510
+ save_traj_dir = traj_a if subdir == out_a else traj_b
511
+ np.save(os.path.join(save_traj_dir, f"{written:06d}_traj.npy"), traj_np)
512
+ imgs_a = batch_imgs.get(out_a)
513
+ imgs_b = batch_imgs.get(out_b)
514
+ if pair_dir is not None and imgs_a is not None and imgs_b is not None:
515
+ for i in range(cur):
516
+ pair_img = np.concatenate([imgs_a[i], imgs_b[i]], axis=1)
517
+ Image.fromarray(pair_img).save(
518
+ os.path.join(pair_dir, f"{written + i:06d}.png")
519
+ )
520
+ else:
521
+ em_kw = dict(base_kw)
522
+ if tc_split[0] is not None:
523
+ em_kw["t_c"] = tc_split[0]
524
+ em_kw["num_steps_before_tc"] = tc_split[1]
525
+ em_kw["num_steps_after_tc"] = tc_split[2]
526
+ if cli.sampler == "em_image_noise_before_tc":
527
+ if cli.save_fixed_trajectory and cli.sampler != "em":
528
+ latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
529
+ model,
530
+ z,
531
+ y,
532
+ **em_kw,
533
+ return_mid_state=True,
534
+ t_mid=float(tc_split[0]),
535
+ return_cls_final=True,
536
+ return_trajectory=True,
537
+ )
538
+ else:
539
+ latents, z_tc, cls_tc, cls_t0 = sampler_fn(
540
+ model,
541
+ z,
542
+ y,
543
+ **em_kw,
544
+ return_mid_state=True,
545
+ t_mid=float(tc_split[0]),
546
+ return_cls_final=True,
547
+ )
548
+ traj = None
549
+ else:
550
+ if cli.save_fixed_trajectory and cli.sampler != "em":
551
+ latents, z_tc, cls_tc, traj = sampler_fn(
552
+ model,
553
+ z,
554
+ y,
555
+ **em_kw,
556
+ return_mid_state=True,
557
+ t_mid=float(tc_split[0]),
558
+ return_trajectory=True,
559
+ )
560
+ else:
561
+ latents, z_tc, cls_tc = sampler_fn(
562
+ model,
563
+ z,
564
+ y,
565
+ **em_kw,
566
+ return_mid_state=True,
567
+ t_mid=float(tc_split[0]),
568
+ )
569
+ traj = None
570
+ cls_t0 = None
571
+ latents = latents.to(torch.float32)
572
+ if z_tc is not None and at_tc_dir is not None:
573
+ imgs_tc = _decode_to_uint8_hwc(
574
+ z_tc.to(torch.float32), latents_bias, latents_scale, vae
575
+ )
576
+ for i in range(cur):
577
+ Image.fromarray(imgs_tc[i]).save(
578
+ os.path.join(at_tc_dir, f"{written + i:06d}.png")
579
+ )
580
+ if traj is not None and traj_dir is not None:
581
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
582
+ np.save(os.path.join(traj_dir, f"{written:06d}_traj.npy"), traj_np)
583
+ else:
584
+ latents = sampler_fn(model, z, y, **em_kw).to(torch.float32)
585
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
586
+ for i in range(cur):
587
+ Image.fromarray(imgs[i]).save(
588
+ os.path.join(cli.out_dir, f"{written + i:06d}.png")
589
+ )
590
+ written += cur
591
+ pbar.update(cur)
592
+ pbar.close()
593
+ if cli.dual_compare_after:
594
+ msg = (
595
+ f"Done. Saved {written} images per run under {out_a} and {out_b} "
596
+ f"(parent: {cli.out_dir})"
597
+ )
598
+ if pair_dir is not None:
599
+ msg += f"; paired comparisons under {pair_dir}"
600
+ if tc_split[0] is not None and at_tc_a is not None:
601
+ msg += f"; t≈t_c decoded under {at_tc_a} and {at_tc_b}"
602
+ print(msg)
603
+ else:
604
+ msg = f"Done. Saved {written} images under {cli.out_dir}"
605
+ if tc_split[0] is not None and at_tc_dir is not None:
606
+ msg += f"; t≈t_c decoded under {at_tc_dir}"
607
+ print(msg)
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()
REG/sample_from_checkpoint_ddp.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。
4
+
5
+ 用法(4 卡示例):
6
+ torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \
7
+ --ckpt exps/jsflow-experiment/checkpoints/0290000.pt \
8
+ --out-dir ./my_samples_ddp \
9
+ --num-images 50000 \
10
+ --batch-size 16 \
11
+ --t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \
12
+ --sampler em_image_noise_before_tc
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import math
19
+ import os
20
+ import sys
21
+ import types
22
+ import numpy as np
23
+
24
+ import torch
25
+ import torch.distributed as dist
26
+ from diffusers.models import AutoencoderKL
27
+ from PIL import Image
28
+ from tqdm import tqdm
29
+
30
+ from models.sit import SiT_models
31
+ from samplers import (
32
+ euler_maruyama_image_noise_before_tc_sampler,
33
+ euler_maruyama_image_noise_sampler,
34
+ euler_maruyama_sampler,
35
+ euler_ode_sampler,
36
+ )
37
+
38
+
39
+ def create_npz_from_sample_folder(sample_dir: str, num: int):
40
+ """
41
+ 将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。
42
+ """
43
+ samples = []
44
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
45
+ sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png"))
46
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
47
+ samples.append(sample_np)
48
+ samples = np.stack(samples)
49
+ npz_path = f"{sample_dir}.npz"
50
+ np.savez(npz_path, arr_0=samples)
51
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
52
+ return npz_path
53
+
54
+
55
+ def semantic_dim_from_enc_type(enc_type):
56
+ if enc_type is None:
57
+ return 768
58
+ s = str(enc_type).lower()
59
+ if "vit-g" in s or "vitg" in s:
60
+ return 1536
61
+ if "vit-l" in s or "vitl" in s:
62
+ return 1024
63
+ if "vit-s" in s or "vits" in s:
64
+ return 384
65
+ return 768
66
+
67
+
68
+ def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
69
+ a = ckpt.get("args")
70
+ if a is None:
71
+ return None
72
+ if isinstance(a, argparse.Namespace):
73
+ return a
74
+ if isinstance(a, dict):
75
+ return argparse.Namespace(**a)
76
+ if isinstance(a, types.SimpleNamespace):
77
+ return argparse.Namespace(**vars(a))
78
+ return None
79
+
80
+
81
+ def load_vae(device: torch.device):
82
+ try:
83
+ from preprocessing import dnnlib
84
+
85
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
86
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
87
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
88
+ os.environ["HF_HOME"] = cache_dir
89
+ try:
90
+ vae = AutoencoderKL.from_pretrained(
91
+ "stabilityai/sd-vae-ft-mse",
92
+ cache_dir=cache_dir,
93
+ local_files_only=True,
94
+ ).to(device)
95
+ vae.eval()
96
+ return vae
97
+ except Exception:
98
+ pass
99
+ candidate_dir = None
100
+ for root_dir in [
101
+ cache_dir,
102
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
103
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
104
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
105
+ ]:
106
+ if not os.path.isdir(root_dir):
107
+ continue
108
+ for root, _, files in os.walk(root_dir):
109
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
110
+ candidate_dir = root
111
+ break
112
+ if candidate_dir is not None:
113
+ break
114
+ if candidate_dir is not None:
115
+ vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
116
+ vae.eval()
117
+ return vae
118
+ except Exception:
119
+ pass
120
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
121
+ vae.eval()
122
+ return vae
123
+
124
+
125
+ def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
126
+ res = int(getattr(ta, "resolution", 256))
127
+ latent_size = res // 8
128
+ enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
129
+ z_dims = [semantic_dim_from_enc_type(enc_type)]
130
+ block_kwargs = {
131
+ "fused_attn": getattr(ta, "fused_attn", True),
132
+ "qk_norm": getattr(ta, "qk_norm", False),
133
+ }
134
+ cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
135
+ if ta.model not in SiT_models:
136
+ raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
137
+ model = SiT_models[ta.model](
138
+ input_size=latent_size,
139
+ num_classes=int(getattr(ta, "num_classes", 1000)),
140
+ use_cfg=(cfg_prob > 0),
141
+ z_dims=z_dims,
142
+ encoder_depth=int(getattr(ta, "encoder_depth", 8)),
143
+ **block_kwargs,
144
+ ).to(device)
145
+ return model, z_dims[0]
146
+
147
+
148
+ def resolve_tc_schedule(cli, ta):
149
+ sb = cli.steps_before_tc
150
+ sa = cli.steps_after_tc
151
+ tc = cli.t_c
152
+ if sb is None and sa is None:
153
+ return None, None, None
154
+ if sb is None or sa is None:
155
+ print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr)
156
+ sys.exit(1)
157
+ if tc is None:
158
+ tc = getattr(ta, "t_c", None) if ta is not None else None
159
+ if tc is None:
160
+ print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr)
161
+ sys.exit(1)
162
+ return float(tc), int(sb), int(sa)
163
+
164
+
165
+ def parse_cli():
166
+ p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)")
167
+ p.add_argument("--ckpt", type=str, required=True)
168
+ p.add_argument("--out-dir", type=str, required=True)
169
+ p.add_argument("--num-images", type=int, required=True)
170
+ p.add_argument("--batch-size", type=int, default=16)
171
+ p.add_argument("--seed", type=int, default=0)
172
+ p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema")
173
+ p.add_argument("--device", type=str, default="cuda")
174
+ p.add_argument("--num-steps", type=int, default=50)
175
+ p.add_argument("--t-c", type=float, default=None)
176
+ p.add_argument("--steps-before-tc", type=int, default=None)
177
+ p.add_argument("--steps-after-tc", type=int, default=None)
178
+ p.add_argument("--cfg-scale", type=float, default=1.0)
179
+ p.add_argument("--cls-cfg-scale", type=float, default=0.0)
180
+ p.add_argument("--guidance-low", type=float, default=0.0)
181
+ p.add_argument("--guidance-high", type=float, default=1.0)
182
+ p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"])
183
+ p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
184
+ p.add_argument("--model", type=str, default=None)
185
+ p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
186
+ p.add_argument("--num-classes", type=int, default=1000)
187
+ p.add_argument("--encoder-depth", type=int, default=None)
188
+ p.add_argument("--enc-type", type=str, default=None)
189
+ p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
190
+ p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
191
+ p.add_argument("--cfg-prob", type=float, default=None)
192
+ p.add_argument(
193
+ "--sampler",
194
+ type=str,
195
+ default="em_image_noise_before_tc",
196
+ choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
197
+ )
198
+ p.add_argument(
199
+ "--save-fixed-trajectory",
200
+ action="store_true",
201
+ help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}",
202
+ )
203
+ p.add_argument(
204
+ "--save-npz",
205
+ action=argparse.BooleanOptionalAction,
206
+ default=True,
207
+ help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz",
208
+ )
209
+ return p.parse_args()
210
+
211
+
212
+ def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
213
+ imgs = vae.decode((latents - latents_bias) / latents_scale).sample
214
+ imgs = (imgs + 1) / 2.0
215
+ imgs = torch.clamp(imgs, 0, 1)
216
+ return (
217
+ (imgs * 255.0)
218
+ .round()
219
+ .to(torch.uint8)
220
+ .permute(0, 2, 3, 1)
221
+ .cpu()
222
+ .numpy()
223
+ )
224
+
225
+
226
+ def init_ddp():
227
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
228
+ rank = int(os.environ["RANK"])
229
+ world_size = int(os.environ["WORLD_SIZE"])
230
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
231
+ dist.init_process_group(backend="nccl", init_method="env://")
232
+ torch.cuda.set_device(local_rank)
233
+ return True, rank, world_size, local_rank
234
+ return False, 0, 1, 0
235
+
236
+
237
+ def main():
238
+ cli = parse_cli()
239
+ use_ddp, rank, world_size, local_rank = init_ddp()
240
+
241
+ if torch.cuda.is_available():
242
+ device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device)
243
+ torch.backends.cuda.matmul.allow_tf32 = True
244
+ else:
245
+ device = torch.device("cpu")
246
+
247
+ try:
248
+ ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
249
+ except TypeError:
250
+ ckpt = torch.load(cli.ckpt, map_location="cpu")
251
+ ta = load_train_args_from_ckpt(ckpt)
252
+ if ta is None:
253
+ if cli.model is None or cli.resolution is None or cli.enc_type is None:
254
+ print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr)
255
+ sys.exit(1)
256
+ ta = argparse.Namespace(
257
+ model=cli.model,
258
+ resolution=cli.resolution,
259
+ num_classes=cli.num_classes if cli.num_classes is not None else 1000,
260
+ encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
261
+ enc_type=cli.enc_type,
262
+ fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
263
+ qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
264
+ cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
265
+ )
266
+ else:
267
+ if cli.model is not None:
268
+ ta.model = cli.model
269
+ if cli.resolution is not None:
270
+ ta.resolution = cli.resolution
271
+ if cli.num_classes is not None:
272
+ ta.num_classes = cli.num_classes
273
+ if cli.encoder_depth is not None:
274
+ ta.encoder_depth = cli.encoder_depth
275
+ if cli.enc_type is not None:
276
+ ta.enc_type = cli.enc_type
277
+ if cli.fused_attn is not None:
278
+ ta.fused_attn = cli.fused_attn
279
+ if cli.qk_norm is not None:
280
+ ta.qk_norm = cli.qk_norm
281
+ if cli.cfg_prob is not None:
282
+ ta.cfg_prob = cli.cfg_prob
283
+
284
+ path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
285
+ tc_split = resolve_tc_schedule(cli, ta)
286
+
287
+ if rank == 0:
288
+ if tc_split[0] is not None:
289
+ print(
290
+ f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}"
291
+ )
292
+ else:
293
+ print(f"时间网格:均匀 num_steps={cli.num_steps}")
294
+
295
+ if cli.sampler == "ode":
296
+ sampler_fn = euler_ode_sampler
297
+ elif cli.sampler == "em":
298
+ sampler_fn = euler_maruyama_sampler
299
+ elif cli.sampler == "em_image_noise_before_tc":
300
+ sampler_fn = euler_maruyama_image_noise_before_tc_sampler
301
+ else:
302
+ sampler_fn = euler_maruyama_image_noise_sampler
303
+
304
+ model, cls_dim = build_model_from_train_args(ta, device)
305
+ wkey = cli.weights
306
+ if wkey not in ckpt:
307
+ raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
308
+ state = ckpt[wkey]
309
+ if cli.legacy:
310
+ from utils import load_legacy_checkpoints
311
+
312
+ state = load_legacy_checkpoints(
313
+ state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
314
+ )
315
+ model.load_state_dict(state, strict=True)
316
+ model.eval()
317
+
318
+ vae = load_vae(device)
319
+ latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
320
+ latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
321
+ sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
322
+
323
+ os.makedirs(cli.out_dir, exist_ok=True)
324
+ traj_dir = None
325
+ if cli.save_fixed_trajectory and cli.sampler != "em":
326
+ traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}")
327
+ os.makedirs(traj_dir, exist_ok=True)
328
+
329
+ latent_size = int(getattr(ta, "resolution", 256)) // 8
330
+ n_total = int(cli.num_images)
331
+ b = max(1, int(cli.batch_size))
332
+ global_batch_size = b * world_size
333
+ total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size)
334
+ samples_needed_this_gpu = int(total_samples // world_size)
335
+ if samples_needed_this_gpu % b != 0:
336
+ raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size")
337
+ iterations = int(samples_needed_this_gpu // b)
338
+
339
+ seed_rank = int(cli.seed) + int(rank)
340
+ torch.manual_seed(seed_rank)
341
+ if device.type == "cuda":
342
+ torch.cuda.manual_seed_all(seed_rank)
343
+
344
+ if rank == 0:
345
+ print(f"Total number of images that will be sampled: {total_samples}")
346
+ pbar = range(iterations)
347
+ pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar
348
+ total = 0
349
+ written_local = 0
350
+ for _ in pbar:
351
+ cur = b
352
+ z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
353
+ y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
354
+ cls_z = torch.randn(cur, cls_dim, device=device)
355
+
356
+ with torch.no_grad():
357
+ em_kw = dict(
358
+ num_steps=cli.num_steps,
359
+ cfg_scale=cli.cfg_scale,
360
+ guidance_low=cli.guidance_low,
361
+ guidance_high=cli.guidance_high,
362
+ path_type=path_type,
363
+ cls_latents=cls_z,
364
+ args=sampler_args,
365
+ )
366
+ if tc_split[0] is not None:
367
+ em_kw["t_c"] = tc_split[0]
368
+ em_kw["num_steps_before_tc"] = tc_split[1]
369
+ em_kw["num_steps_after_tc"] = tc_split[2]
370
+
371
+ if cli.save_fixed_trajectory and cli.sampler != "em":
372
+ if cli.sampler == "em_image_noise_before_tc":
373
+ latents, traj = sampler_fn(
374
+ model, z, y, **em_kw, return_trajectory=True
375
+ )
376
+ else:
377
+ latents, traj = sampler_fn(
378
+ model, z, y, **em_kw, return_trajectory=True
379
+ )
380
+ else:
381
+ latents = sampler_fn(model, z, y, **em_kw)
382
+ traj = None
383
+
384
+ latents = latents.to(torch.float32)
385
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
386
+ for i, img in enumerate(imgs):
387
+ gidx = i * world_size + rank + total
388
+ if gidx < n_total:
389
+ Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png"))
390
+ written_local += 1
391
+
392
+ if traj is not None and traj_dir is not None:
393
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
394
+ first_idx = rank + total
395
+ if first_idx < n_total:
396
+ np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np)
397
+
398
+ total += global_batch_size
399
+ if use_ddp:
400
+ dist.barrier()
401
+ if rank == 0 and hasattr(pbar, "close"):
402
+ pbar.close()
403
+
404
+ if use_ddp:
405
+ dist.barrier()
406
+ if rank == 0:
407
+ if cli.save_npz:
408
+ create_npz_from_sample_folder(cli.out_dir, n_total)
409
+ print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).")
410
+ if use_ddp:
411
+ dist.destroy_process_group()
412
+
413
+
414
+ if __name__ == "__main__":
415
+ main()
416
+
REG/samplers.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def expand_t_like_x(t, x_cur):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x_cur.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+ def get_score_from_velocity(vt, xt, t, path_type="linear"):
16
+ """Wrapper function: transfrom velocity prediction model to score
17
+ Args:
18
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
19
+ x: [batch_dim, ...] shaped tensor; x_t data point
20
+ t: [batch_dim,] time tensor
21
+ """
22
+ t = expand_t_like_x(t, xt)
23
+ if path_type == "linear":
24
+ alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
25
+ sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
26
+ elif path_type == "cosine":
27
+ alpha_t = torch.cos(t * np.pi / 2)
28
+ sigma_t = torch.sin(t * np.pi / 2)
29
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
30
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ mean = xt
35
+ reverse_alpha_ratio = alpha_t / d_alpha_t
36
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
37
+ score = (reverse_alpha_ratio * vt - mean) / var
38
+
39
+ return score
40
+
41
+
42
+ def compute_diffusion(t_cur):
43
+ return 2 * t_cur
44
+
45
+
46
+ def build_sampling_time_steps(
47
+ num_steps=50,
48
+ t_c=None,
49
+ num_steps_before_tc=None,
50
+ num_steps_after_tc=None,
51
+ t_floor=0.04,
52
+ ):
53
+ """
54
+ 构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。
55
+
56
+ - 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。
57
+ - 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c);
58
+ t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。
59
+ """
60
+ t_floor = float(t_floor)
61
+ if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
62
+ ns = int(num_steps)
63
+ if ns < 1:
64
+ raise ValueError("num_steps must be >= 1")
65
+ t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64)
66
+ return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
67
+
68
+ t_c = float(t_c)
69
+ nb = int(num_steps_before_tc)
70
+ na = int(num_steps_after_tc)
71
+ if nb < 1 or na < 1:
72
+ raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c")
73
+ if not (0.0 < t_c < 1.0):
74
+ raise ValueError("t_c must be in (0, 1)")
75
+ if t_c <= t_floor:
76
+ raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})")
77
+
78
+ p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64)
79
+ p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64)
80
+ t_steps = torch.cat([p1, p2[1:]])
81
+ return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
82
+
83
+
84
+ def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc):
85
+ """仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。"""
86
+ return (
87
+ t_c is not None
88
+ and num_steps_before_tc is not None
89
+ and num_steps_after_tc is not None
90
+ )
91
+
92
+
93
+ def _cls_effective_and_freeze(
94
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
95
+ ):
96
+ """
97
+ 时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。
98
+ 返回 (用于前向的 cls, 更新后的 cls_frozen)。
99
+ """
100
+ if not freeze_after_tc or t_c_v is None:
101
+ return cls_x_cur, cls_frozen
102
+ if float(t_cur) <= float(t_c_v) + 1e-9:
103
+ if cls_frozen is None:
104
+ cls_frozen = cls_x_cur.clone()
105
+ return cls_frozen, cls_frozen
106
+ return cls_x_cur, cls_frozen
107
+
108
+
109
+ def _build_euler_sampler_time_steps(
110
+ num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
111
+ ):
112
+ """
113
+ euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。
114
+ """
115
+ if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
116
+ ns = int(num_steps)
117
+ if ns < 1:
118
+ raise ValueError("num_steps must be >= 1")
119
+ return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device)
120
+ t_c = float(t_c)
121
+ nb = int(num_steps_before_tc)
122
+ na = int(num_steps_after_tc)
123
+ if nb < 1 or na < 1:
124
+ raise ValueError(
125
+ "num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c"
126
+ )
127
+ if not (0.0 < t_c < 1.0):
128
+ raise ValueError("t_c must be in (0, 1)")
129
+ p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device)
130
+ p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device)
131
+ return torch.cat([p1, p2[1:]])
132
+
133
+
134
+ def euler_maruyama_sampler(
135
+ model,
136
+ latents,
137
+ y,
138
+ num_steps=20,
139
+ heun=False, # not used, just for compatability
140
+ cfg_scale=1.0,
141
+ guidance_low=0.0,
142
+ guidance_high=1.0,
143
+ path_type="linear",
144
+ cls_latents=None,
145
+ args=None,
146
+ return_mid_state=False,
147
+ t_mid=0.5,
148
+ t_c=None,
149
+ num_steps_before_tc=None,
150
+ num_steps_after_tc=None,
151
+ deterministic=False,
152
+ return_trajectory=False,
153
+ ):
154
+ """
155
+ Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致;
156
+ deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor),
157
+ 本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。
158
+ """
159
+ # setup conditioning
160
+ if cfg_scale > 1.0:
161
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
162
+ #[1000, 1000]
163
+ _dtype = latents.dtype
164
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
165
+
166
+ t_steps = build_sampling_time_steps(
167
+ num_steps=num_steps,
168
+ t_c=t_c,
169
+ num_steps_before_tc=num_steps_before_tc,
170
+ num_steps_after_tc=num_steps_after_tc,
171
+ )
172
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
173
+ t_c_v = float(t_c) if freeze_after_tc else None
174
+ x_next = latents.to(torch.float64)
175
+ cls_x_next = cls_latents.to(torch.float64)
176
+ device = x_next.device
177
+ z_mid = cls_mid = None
178
+ t_mid = float(t_mid)
179
+ cls_frozen = None
180
+ traj = [x_next.clone()] if return_trajectory else None
181
+
182
+ with torch.no_grad():
183
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
184
+ dt = t_next - t_cur
185
+ x_cur = x_next
186
+ cls_x_cur = cls_x_next
187
+
188
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
189
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
190
+ )
191
+
192
+ tc, tn = float(t_cur), float(t_next)
193
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
194
+ if abs(tc - t_mid) < abs(tn - t_mid):
195
+ z_mid = x_cur.clone()
196
+ cls_mid = cls_model_input.clone()
197
+
198
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
199
+ model_input = torch.cat([x_cur] * 2, dim=0)
200
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
201
+ y_cur = torch.cat([y, y_null], dim=0)
202
+ else:
203
+ model_input = x_cur
204
+ y_cur = y
205
+
206
+ kwargs = dict(y=y_cur)
207
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
208
+ diffusion = compute_diffusion(t_cur)
209
+
210
+ if deterministic:
211
+ deps = torch.zeros_like(x_cur)
212
+ cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)])
213
+ else:
214
+ eps_i = torch.randn_like(x_cur).to(device)
215
+ cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device)
216
+ deps = eps_i * torch.sqrt(torch.abs(dt))
217
+ cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
218
+
219
+ # compute drift
220
+ v_cur, _, cls_v_cur = model(
221
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
222
+ )
223
+ v_cur = v_cur.to(torch.float64)
224
+ cls_v_cur = cls_v_cur.to(torch.float64)
225
+
226
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
227
+ d_cur = v_cur - 0.5 * diffusion * s_cur
228
+
229
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
230
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
231
+
232
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
233
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
234
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
235
+
236
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
237
+ if cls_cfg > 0:
238
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
239
+ else:
240
+ cls_d_cur = cls_d_cur_cond
241
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
242
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
243
+ cls_x_next = cls_frozen
244
+ else:
245
+ cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
246
+
247
+ if return_trajectory:
248
+ traj.append(x_next.clone())
249
+
250
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
251
+ z_mid = x_next.clone()
252
+ cls_mid = cls_x_next.clone()
253
+
254
+ # last step
255
+ t_cur, t_next = t_steps[-2], t_steps[-1]
256
+ dt = t_next - t_cur
257
+ x_cur = x_next
258
+ cls_x_cur = cls_x_next
259
+
260
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
261
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
262
+ )
263
+
264
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
265
+ model_input = torch.cat([x_cur] * 2, dim=0)
266
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
267
+ y_cur = torch.cat([y, y_null], dim=0)
268
+ else:
269
+ model_input = x_cur
270
+ y_cur = y
271
+ kwargs = dict(y=y_cur)
272
+ time_input = torch.ones(model_input.size(0)).to(
273
+ device=device, dtype=torch.float64
274
+ ) * t_cur
275
+
276
+ # compute drift
277
+ v_cur, _, cls_v_cur = model(
278
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
279
+ )
280
+ v_cur = v_cur.to(torch.float64)
281
+ cls_v_cur = cls_v_cur.to(torch.float64)
282
+
283
+
284
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
285
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
286
+
287
+ diffusion = compute_diffusion(t_cur)
288
+ d_cur = v_cur - 0.5 * diffusion * s_cur
289
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
290
+
291
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
292
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
293
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
294
+
295
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
296
+ if cls_cfg > 0:
297
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
298
+ else:
299
+ cls_d_cur = cls_d_cur_cond
300
+
301
+ mean_x = x_cur + dt * d_cur
302
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
303
+ cls_mean_x = cls_frozen
304
+ else:
305
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
306
+
307
+ if return_trajectory:
308
+ traj.append(mean_x.clone())
309
+
310
+ if return_trajectory and return_mid_state:
311
+ return mean_x, z_mid, cls_mid, traj
312
+ if return_trajectory:
313
+ return mean_x, traj
314
+ if return_mid_state:
315
+ return mean_x, z_mid, cls_mid
316
+ return mean_x
317
+
318
+
319
+ def euler_maruyama_image_noise_sampler(
320
+ model,
321
+ latents,
322
+ y,
323
+ num_steps=20,
324
+ heun=False, # not used, just for compatability
325
+ cfg_scale=1.0,
326
+ guidance_low=0.0,
327
+ guidance_high=1.0,
328
+ path_type="linear",
329
+ cls_latents=None,
330
+ args=None,
331
+ return_mid_state=False,
332
+ t_mid=0.5,
333
+ t_c=None,
334
+ num_steps_before_tc=None,
335
+ num_steps_after_tc=None,
336
+ return_trajectory=False,
337
+ ):
338
+ """
339
+ EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。
340
+ """
341
+ if cfg_scale > 1.0:
342
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
343
+ _dtype = latents.dtype
344
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
345
+
346
+ t_steps = build_sampling_time_steps(
347
+ num_steps=num_steps,
348
+ t_c=t_c,
349
+ num_steps_before_tc=num_steps_before_tc,
350
+ num_steps_after_tc=num_steps_after_tc,
351
+ )
352
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
353
+ t_c_v = float(t_c) if freeze_after_tc else None
354
+ x_next = latents.to(torch.float64)
355
+ cls_x_next = cls_latents.to(torch.float64)
356
+ device = x_next.device
357
+ z_mid = cls_mid = None
358
+ t_mid = float(t_mid)
359
+ cls_frozen = None
360
+ traj = [x_next.clone()] if return_trajectory else None
361
+
362
+ with torch.no_grad():
363
+ for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
364
+ dt = t_next - t_cur
365
+ x_cur = x_next
366
+ cls_x_cur = cls_x_next
367
+
368
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
369
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
370
+ )
371
+
372
+ tc, tn = float(t_cur), float(t_next)
373
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
374
+ if abs(tc - t_mid) < abs(tn - t_mid):
375
+ z_mid = x_cur.clone()
376
+ cls_mid = cls_model_input.clone()
377
+
378
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
379
+ model_input = torch.cat([x_cur] * 2, dim=0)
380
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
381
+ y_cur = torch.cat([y, y_null], dim=0)
382
+ else:
383
+ model_input = x_cur
384
+ y_cur = y
385
+
386
+ kwargs = dict(y=y_cur)
387
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
388
+ diffusion = compute_diffusion(t_cur)
389
+
390
+ eps_i = torch.randn_like(x_cur).to(device)
391
+ deps = eps_i * torch.sqrt(torch.abs(dt))
392
+
393
+ v_cur, _, cls_v_cur = model(
394
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
395
+ )
396
+ v_cur = v_cur.to(torch.float64)
397
+ cls_v_cur = cls_v_cur.to(torch.float64)
398
+
399
+ if add_img_noise:
400
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
401
+ d_cur = v_cur - 0.5 * diffusion * s_cur
402
+
403
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
404
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
405
+ else:
406
+ # t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。
407
+ d_cur = v_cur
408
+ cls_d_cur = cls_v_cur
409
+
410
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
411
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
412
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
413
+
414
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
415
+ if cls_cfg > 0:
416
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
417
+ else:
418
+ cls_d_cur = cls_d_cur_cond
419
+
420
+ # 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项)
421
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
422
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
423
+ cls_x_next = cls_frozen
424
+ else:
425
+ cls_x_next = cls_x_cur + cls_d_cur * dt
426
+ if return_trajectory:
427
+ traj.append(x_next.clone())
428
+
429
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
430
+ z_mid = x_next.clone()
431
+ cls_mid = cls_x_next.clone()
432
+
433
+ t_cur, t_next = t_steps[-2], t_steps[-1]
434
+ dt = t_next - t_cur
435
+ x_cur = x_next
436
+ cls_x_cur = cls_x_next
437
+
438
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
439
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
440
+ )
441
+
442
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
443
+ model_input = torch.cat([x_cur] * 2, dim=0)
444
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
445
+ y_cur = torch.cat([y, y_null], dim=0)
446
+ else:
447
+ model_input = x_cur
448
+ y_cur = y
449
+ kwargs = dict(y=y_cur)
450
+ time_input = torch.ones(model_input.size(0)).to(
451
+ device=device, dtype=torch.float64
452
+ ) * t_cur
453
+
454
+ v_cur, _, cls_v_cur = model(
455
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
456
+ )
457
+ v_cur = v_cur.to(torch.float64)
458
+ cls_v_cur = cls_v_cur.to(torch.float64)
459
+
460
+ # 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。
461
+ d_cur = v_cur
462
+ cls_d_cur = cls_v_cur
463
+
464
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
465
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
466
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
467
+
468
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
469
+ if cls_cfg > 0:
470
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
471
+ else:
472
+ cls_d_cur = cls_d_cur_cond
473
+
474
+ mean_x = x_cur + dt * d_cur
475
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
476
+ cls_mean_x = cls_frozen
477
+ else:
478
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
479
+
480
+ if return_trajectory and return_mid_state:
481
+ return mean_x, z_mid, cls_mid, traj
482
+ if return_trajectory:
483
+ return mean_x, traj
484
+ if return_mid_state:
485
+ return mean_x, z_mid, cls_mid
486
+ return mean_x
487
+
488
+
489
+ def euler_maruyama_image_noise_before_tc_sampler(
490
+ model,
491
+ latents,
492
+ y,
493
+ num_steps=20,
494
+ heun=False, # not used, just for compatability
495
+ cfg_scale=1.0,
496
+ guidance_low=0.0,
497
+ guidance_high=1.0,
498
+ path_type="linear",
499
+ cls_latents=None,
500
+ args=None,
501
+ return_mid_state=False,
502
+ t_mid=0.5,
503
+ t_c=None,
504
+ num_steps_before_tc=None,
505
+ num_steps_after_tc=None,
506
+ return_cls_final=False,
507
+ return_trajectory=False,
508
+ ):
509
+ """
510
+ EM 采样变体:
511
+ - 图像 latent 在 t > t_c 区间引入随机扩散噪声;
512
+ - 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移);
513
+ - cls/token 通道全程不引入随机项。
514
+ """
515
+ if cfg_scale > 1.0:
516
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
517
+ _dtype = latents.dtype
518
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
519
+
520
+ t_steps = build_sampling_time_steps(
521
+ num_steps=num_steps,
522
+ t_c=t_c,
523
+ num_steps_before_tc=num_steps_before_tc,
524
+ num_steps_after_tc=num_steps_after_tc,
525
+ )
526
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
527
+ t_c_freeze = float(t_c) if freeze_after_tc else None
528
+ x_next = latents.to(torch.float64)
529
+ cls_x_next = cls_latents.to(torch.float64)
530
+ device = x_next.device
531
+ z_mid = cls_mid = None
532
+ t_mid = float(t_mid)
533
+ t_c_v = None if t_c is None else float(t_c)
534
+ cls_frozen = None
535
+ traj = [x_next.clone()] if return_trajectory else None
536
+
537
+ with torch.no_grad():
538
+ for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
539
+ dt = t_next - t_cur
540
+ x_cur = x_next
541
+ cls_x_cur = cls_x_next
542
+
543
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
544
+ cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
545
+ )
546
+
547
+ tc, tn = float(t_cur), float(t_next)
548
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
549
+ if abs(tc - t_mid) < abs(tn - t_mid):
550
+ z_mid = x_cur.clone()
551
+ cls_mid = cls_model_input.clone()
552
+
553
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
554
+ model_input = torch.cat([x_cur] * 2, dim=0)
555
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
556
+ y_cur = torch.cat([y, y_null], dim=0)
557
+ else:
558
+ model_input = x_cur
559
+ y_cur = y
560
+
561
+ kwargs = dict(y=y_cur)
562
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
563
+ diffusion = compute_diffusion(t_cur)
564
+
565
+ # 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声
566
+ add_img_noise = True
567
+ if t_c_v is not None and float(t_next) <= t_c_v:
568
+ add_img_noise = False
569
+
570
+ eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur)
571
+ deps = eps_i * torch.sqrt(torch.abs(dt))
572
+
573
+ v_cur, _, cls_v_cur = model(
574
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
575
+ )
576
+ v_cur = v_cur.to(torch.float64)
577
+ cls_v_cur = cls_v_cur.to(torch.float64)
578
+
579
+ if add_img_noise:
580
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
581
+ d_cur = v_cur - 0.5 * diffusion * s_cur
582
+
583
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
584
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
585
+ else:
586
+ # t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项)
587
+ d_cur = v_cur
588
+ cls_d_cur = cls_v_cur
589
+
590
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
591
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
592
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
593
+
594
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
595
+ if cls_cfg > 0:
596
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
597
+ else:
598
+ cls_d_cur = cls_d_cur_cond
599
+
600
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
601
+ if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
602
+ cls_x_next = cls_frozen
603
+ else:
604
+ cls_x_next = cls_x_cur + cls_d_cur * dt
605
+ if return_trajectory:
606
+ traj.append(x_next.clone())
607
+
608
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
609
+ z_mid = x_next.clone()
610
+ cls_mid = cls_x_next.clone()
611
+
612
+ t_cur, t_next = t_steps[-2], t_steps[-1]
613
+ dt = t_next - t_cur
614
+ x_cur = x_next
615
+ cls_x_cur = cls_x_next
616
+
617
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
618
+ cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
619
+ )
620
+
621
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
622
+ model_input = torch.cat([x_cur] * 2, dim=0)
623
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
624
+ y_cur = torch.cat([y, y_null], dim=0)
625
+ else:
626
+ model_input = x_cur
627
+ y_cur = y
628
+ kwargs = dict(y=y_cur)
629
+ time_input = torch.ones(model_input.size(0)).to(
630
+ device=device, dtype=torch.float64
631
+ ) * t_cur
632
+
633
+ v_cur, _, cls_v_cur = model(
634
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
635
+ )
636
+ v_cur = v_cur.to(torch.float64)
637
+ cls_v_cur = cls_v_cur.to(torch.float64)
638
+
639
+ # 最后一步无随机项,保持与 ODE 一致使用 d=v。
640
+ d_cur = v_cur
641
+ cls_d_cur = cls_v_cur
642
+
643
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
644
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
645
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
646
+
647
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
648
+ if cls_cfg > 0:
649
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
650
+ else:
651
+ cls_d_cur = cls_d_cur_cond
652
+
653
+ mean_x = x_cur + dt * d_cur
654
+ if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
655
+ cls_mean_x = cls_frozen
656
+ else:
657
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
658
+
659
+ if return_trajectory and return_mid_state and return_cls_final:
660
+ return mean_x, z_mid, cls_mid, cls_mean_x, traj
661
+ if return_trajectory and return_mid_state:
662
+ return mean_x, z_mid, cls_mid, traj
663
+ if return_trajectory and return_cls_final:
664
+ return mean_x, cls_mean_x, traj
665
+ if return_trajectory:
666
+ return mean_x, traj
667
+ if return_mid_state and return_cls_final:
668
+ return mean_x, z_mid, cls_mid, cls_mean_x
669
+ if return_mid_state:
670
+ return mean_x, z_mid, cls_mid
671
+ if return_cls_final:
672
+ return mean_x, cls_mean_x
673
+ return mean_x
674
+
675
+
676
+ def euler_ode_sampler(
677
+ model,
678
+ latents,
679
+ y,
680
+ num_steps=20,
681
+ cfg_scale=1.0,
682
+ guidance_low=0.0,
683
+ guidance_high=1.0,
684
+ path_type="linear",
685
+ cls_latents=None,
686
+ args=None,
687
+ return_mid_state=False,
688
+ t_mid=0.5,
689
+ t_c=None,
690
+ num_steps_before_tc=None,
691
+ num_steps_after_tc=None,
692
+ return_trajectory=False,
693
+ ):
694
+ """
695
+ REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。
696
+ """
697
+ return euler_sampler(
698
+ model,
699
+ latents,
700
+ y,
701
+ num_steps=num_steps,
702
+ heun=False,
703
+ cfg_scale=cfg_scale,
704
+ guidance_low=guidance_low,
705
+ guidance_high=guidance_high,
706
+ path_type=path_type,
707
+ cls_latents=cls_latents,
708
+ args=args,
709
+ return_mid_state=return_mid_state,
710
+ t_mid=t_mid,
711
+ t_c=t_c,
712
+ num_steps_before_tc=num_steps_before_tc,
713
+ num_steps_after_tc=num_steps_after_tc,
714
+ return_trajectory=return_trajectory,
715
+ )
716
+
717
+
718
+ def euler_sampler(
719
+ model,
720
+ latents,
721
+ y,
722
+ num_steps=20,
723
+ heun=False,
724
+ cfg_scale=1.0,
725
+ guidance_low=0.0,
726
+ guidance_high=1.0,
727
+ path_type="linear",
728
+ cls_latents=None,
729
+ args=None,
730
+ return_mid_state=False,
731
+ t_mid=0.5,
732
+ t_c=None,
733
+ num_steps_before_tc=None,
734
+ num_steps_after_tc=None,
735
+ return_trajectory=False,
736
+ ):
737
+ """
738
+ 轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。
739
+
740
+ - 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。
741
+ - 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。
742
+ - 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。
743
+
744
+ REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。
745
+ """
746
+ if cls_latents is None:
747
+ raise ValueError(
748
+ "euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。"
749
+ )
750
+ if cfg_scale > 1.0:
751
+ y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype)
752
+ else:
753
+ y_null = None
754
+ _dtype = latents.dtype
755
+ cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
756
+ device = latents.device
757
+
758
+ t_steps = _build_euler_sampler_time_steps(
759
+ num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
760
+ )
761
+ freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
762
+ t_c_v = float(t_c) if freeze_after_tc else None
763
+
764
+ x_next = latents.to(torch.float64)
765
+ cls_x_next = cls_latents.to(torch.float64)
766
+ z_mid = cls_mid = None
767
+ t_mid = float(t_mid)
768
+ cls_frozen = None
769
+ traj = [x_next.clone()] if return_trajectory else None
770
+
771
+ with torch.no_grad():
772
+ for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
773
+ dt = t_next - t_cur
774
+ x_cur = x_next
775
+ cls_x_cur = cls_x_next
776
+
777
+ cls_model_input, cls_frozen = _cls_effective_and_freeze(
778
+ cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
779
+ )
780
+
781
+ tc, tn = float(t_cur), float(t_next)
782
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
783
+ if abs(tc - t_mid) < abs(tn - t_mid):
784
+ z_mid = x_cur.clone()
785
+ cls_mid = cls_model_input.clone()
786
+
787
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
788
+ model_input = torch.cat([x_cur] * 2, dim=0)
789
+ cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
790
+ y_cur = torch.cat([y, y_null], dim=0)
791
+ else:
792
+ model_input = x_cur
793
+ y_cur = y
794
+
795
+ time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
796
+
797
+ v_cur, _, cls_v_cur = model(
798
+ model_input.to(dtype=_dtype),
799
+ time_input.to(dtype=_dtype),
800
+ y_cur,
801
+ cls_token=cls_model_input.to(dtype=_dtype),
802
+ )
803
+ v_cur = v_cur.to(torch.float64)
804
+ cls_v_cur = cls_v_cur.to(torch.float64)
805
+
806
+ # ODE: follow velocity parameterization directly (d/dt x_t = v_t).
807
+ # This aligns with velocity training target and avoids extra v->score->drift conversion.
808
+ d_cur = v_cur
809
+ cls_d_cur = cls_v_cur
810
+
811
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
812
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
813
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
814
+
815
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
816
+ if cls_cfg > 0:
817
+ cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
818
+ else:
819
+ cls_d_cur = cls_d_cur_cond
820
+
821
+ x_next = x_cur + dt * d_cur
822
+ if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
823
+ cls_x_next = cls_frozen
824
+ else:
825
+ cls_x_next = cls_x_cur + dt * cls_d_cur
826
+
827
+ if return_trajectory:
828
+ traj.append(x_next.clone())
829
+
830
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
831
+ z_mid = x_next.clone()
832
+ cls_mid = cls_x_next.clone()
833
+
834
+ if return_trajectory and return_mid_state:
835
+ return x_next, z_mid, cls_mid, traj
836
+ if return_trajectory:
837
+ return x_next, traj
838
+ if return_mid_state:
839
+ return x_next, z_mid, cls_mid
840
+ return x_next
REG/samples.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 双次对比步数请用 --dual-compare-after(见 sample_from_checkpoint.py),输出在 out-dir 子目录。
3
+
4
+ CUDA_VISIBLE_DEVICES=1 python sample_from_checkpoint.py \
5
+ --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1970000.pt \
6
+ --out-dir ./my_samples_test \
7
+ --num-images 100 \
8
+ --batch-size 4 \
9
+ --seed 0 \
10
+ --t-c 0.75 \
11
+ --steps-before-tc 100\
12
+ --steps-after-tc 2 \
13
+ --sampler ode \
14
+ --cfg-scale 1.0 \
15
+ --dual-compare-after \
REG/samples_0.25_new.log ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793]
2
+ W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] *****************************************
5
+ 时间网格:t_c=0.25, 步数 (1→t_c)=100, (t_c→0)=2
6
+ Total number of images that will be sampled: 5120
7
+
8
+ [rank0]:[W408 11:06:31.621528015 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
9
+ [rank2]:[W408 11:06:31.627182760 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
10
+ [rank1]:[W408 11:06:32.966218444 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
11
+
12
+ W0408 11:11:04.746000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100158 closing signal SIGTERM
13
+ W0408 11:11:04.748000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100159 closing signal SIGTERM
14
+ W0408 11:11:04.749000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100160 closing signal SIGTERM
15
+ W0408 11:11:04.749000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100161 closing signal SIGTERM
16
+ Traceback (most recent call last):
17
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/torchrun", line 33, in <module>
18
+ sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
19
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
21
+ return f(*args, **kwargs)
22
+ ^^^^^^^^^^^^^^^^^^
23
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
24
+ run(args)
25
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
26
+ elastic_launch(
27
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
28
+ return launch_agent(self._config, self._entrypoint, list(args))
29
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
30
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
31
+ result = agent.run()
32
+ ^^^^^^^^^^^
33
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
34
+ result = f(*args, **kwargs)
35
+ ^^^^^^^^^^^^^^^^^^
36
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
37
+ result = self._invoke_run(role)
38
+ ^^^^^^^^^^^^^^^^^^^^^^
39
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
40
+ time.sleep(monitor_interval)
41
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
42
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
43
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 100112 got signal: 15
REG/samples_0.5.log ADDED
The diff for this file is too large to render. See raw diff
 
REG/samples_0.75.log ADDED
The diff for this file is too large to render. See raw diff
 
REG/samples_0.75_new.log ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793]
2
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
5
+ 时间网格:t_c=0.75, 步数 (1→t_c)=100, (t_c→0)=50
6
+ Total number of images that will be sampled: 40192
7
+
8
+ [rank3]:[W325 16:57:00.799344818 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
9
+ [rank1]:[W325 16:57:00.847229448 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
10
+ [rank0]:[W325 16:57:01.326116049 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
11
+
12
+ W0325 18:03:18.212000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 538677 closing signal SIGTERM
13
+ W0325 18:03:18.212000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 538678 closing signal SIGTERM
14
+ E0325 18:03:18.554000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -9) local_rank: 1 (pid: 538676) of binary: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
15
+ Traceback (most recent call last):
16
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/torchrun", line 33, in <module>
17
+ sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
18
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
19
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
20
+ return f(*args, **kwargs)
21
+ ^^^^^^^^^^^^^^^^^^
22
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
23
+ run(args)
24
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
25
+ elastic_launch(
26
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
27
+ return launch_agent(self._config, self._entrypoint, list(args))
28
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29
+ File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
30
+ raise ChildFailedError(
31
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
32
+ ==========================================================
33
+ sample_from_checkpoint_ddp.py FAILED
34
+ ----------------------------------------------------------
35
+ Failures:
36
+ <NO_OTHER_FAILURES>
37
+ ----------------------------------------------------------
38
+ Root Cause (first observed failure):
39
+ [0]:
40
+ time : 2026-03-25_18:03:18
41
+ host : 24c964746905d416ce09d045f9a06f23-taskrole1-0
42
+ rank : 1 (local_rank: 1)
43
+ exitcode : -9 (pid: 538676)
44
+ error_file: <N/A>
45
+ traceback : Signal 9 (SIGKILL) received by PID 538676
46
+ ==========================================================
REG/samples_ddp.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 4 卡 DDP 单路径采样(不做 dual-compare,不保存 at_tc 中间图)
4
+ CUDA_VISIBLE_DEVICES=0,1,2,3 nohup nohup torchrun \
5
+ --nnodes=1 \
6
+ --nproc_per_node=4 \
7
+ --rdzv_endpoint=localhost:29110 \
8
+ sample_from_checkpoint_ddp.py \
9
+ --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.25/checkpoints/0230000.pt \
10
+ --out-dir ./my_samples_25 \
11
+ --num-images 5000 \
12
+ --batch-size 64 \
13
+ --seed 0 \
14
+ --t-c 0.25 \
15
+ --steps-before-tc 100 \
16
+ --steps-after-tc 2 \
17
+ --sampler em_image_noise_before_tc \
18
+ --cfg-scale 1.0 \
19
+ > samples_0.25_new.log 2>&1 &
20
+
21
+ # nohup python sample_from_checkpoint_ddp.py \
22
+ # --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.5/checkpoints/0250000.pt \
23
+ # --out-dir ./my_samples_5 \
24
+ # --num-images 20000 \
25
+ # --batch-size 16 \
26
+ # --seed 0 \
27
+ # --t-c 0.5 \
28
+ # --steps-before-tc 100 \
29
+ # --steps-after-tc 50 \
30
+ # --sampler em_image_noise_before_tc \
31
+ # --cfg-scale 1.0 \
32
+ # > samples_0.5.log 2>&1 &
REG/train.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ from copy import deepcopy
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from collections import OrderedDict
8
+ import json
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from tqdm.auto import tqdm
15
+ from torch.utils.data import DataLoader
16
+
17
+ from accelerate import Accelerator, DistributedDataParallelKwargs
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+
21
+ from models.sit import SiT_models
22
+ from loss import SILoss
23
+ from utils import load_encoders
24
+
25
+ from dataset import CustomDataset
26
+ from diffusers.models import AutoencoderKL
27
+ # import wandb_utils
28
+ import wandb
29
+ import math
30
+ from torchvision.utils import make_grid
31
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
32
+ from torchvision.transforms import Normalize
33
+ from PIL import Image
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ def semantic_dim_from_enc_type(enc_type):
39
+ """DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。"""
40
+ if enc_type is None:
41
+ return 768
42
+ s = str(enc_type).lower()
43
+ if "vit-g" in s or "vitg" in s:
44
+ return 1536
45
+ if "vit-l" in s or "vitl" in s:
46
+ return 1024
47
+ if "vit-s" in s or "vits" in s:
48
+ return 384
49
+ return 768
50
+
51
+
52
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
53
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
54
+
55
+
56
+
57
+ def preprocess_raw_image(x, enc_type):
58
+ resolution = x.shape[-1]
59
+ if 'clip' in enc_type:
60
+ x = x / 255.
61
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
62
+ x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
63
+ elif 'mocov3' in enc_type or 'mae' in enc_type:
64
+ x = x / 255.
65
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
66
+ elif 'dinov2' in enc_type:
67
+ x = x / 255.
68
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
69
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
70
+ elif 'dinov1' in enc_type:
71
+ x = x / 255.
72
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
73
+ elif 'jepa' in enc_type:
74
+ x = x / 255.
75
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
76
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
77
+
78
+ return x
79
+
80
+
81
+ def array2grid(x):
82
+ nrow = round(math.sqrt(x.size(0)))
83
+ x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
84
+ x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
85
+ return x
86
+
87
+
88
+ @torch.no_grad()
89
+ def sample_posterior(moments, latents_scale=1., latents_bias=0.):
90
+ device = moments.device
91
+
92
+ mean, std = torch.chunk(moments, 2, dim=1)
93
+ z = mean + std * torch.randn_like(mean)
94
+ z = (z * latents_scale + latents_bias)
95
+ return z
96
+
97
+
98
+ @torch.no_grad()
99
+ def update_ema(ema_model, model, decay=0.9999):
100
+ """
101
+ Step the EMA model towards the current model.
102
+ """
103
+ ema_params = OrderedDict(ema_model.named_parameters())
104
+ model_params = OrderedDict(model.named_parameters())
105
+
106
+ for name, param in model_params.items():
107
+ name = name.replace("module.", "")
108
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
109
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
110
+
111
+
112
+ def create_logger(logging_dir):
113
+ """
114
+ Create a logger that writes to a log file and stdout.
115
+ """
116
+ logging.basicConfig(
117
+ level=logging.INFO,
118
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
119
+ datefmt='%Y-%m-%d %H:%M:%S',
120
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
121
+ )
122
+ logger = logging.getLogger(__name__)
123
+ return logger
124
+
125
+
126
+ def requires_grad(model, flag=True):
127
+ """
128
+ Set requires_grad flag for all parameters in a model.
129
+ """
130
+ for p in model.parameters():
131
+ p.requires_grad = flag
132
+
133
+
134
+ #################################################################################
135
+ # Training Loop #
136
+ #################################################################################
137
+
138
+ def main(args):
139
+ # set accelerator
140
+ logging_dir = Path(args.output_dir, args.logging_dir)
141
+ accelerator_project_config = ProjectConfiguration(
142
+ project_dir=args.output_dir, logging_dir=logging_dir
143
+ )
144
+
145
+ accelerator = Accelerator(
146
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
147
+ mixed_precision=args.mixed_precision,
148
+ log_with=args.report_to,
149
+ project_config=accelerator_project_config,
150
+ kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
151
+ )
152
+
153
+ if accelerator.is_main_process:
154
+ os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
155
+ save_dir = os.path.join(args.output_dir, args.exp_name)
156
+ os.makedirs(save_dir, exist_ok=True)
157
+ args_dict = vars(args)
158
+ # Save to a JSON file
159
+ json_dir = os.path.join(save_dir, "args.json")
160
+ with open(json_dir, 'w') as f:
161
+ json.dump(args_dict, f, indent=4)
162
+ checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
163
+ os.makedirs(checkpoint_dir, exist_ok=True)
164
+ logger = create_logger(save_dir)
165
+ logger.info(f"Experiment directory created at {save_dir}")
166
+ device = accelerator.device
167
+ if torch.backends.mps.is_available():
168
+ accelerator.native_amp = False
169
+ if args.seed is not None:
170
+ set_seed(args.seed + accelerator.process_index)
171
+
172
+ # Create model:
173
+ assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
174
+ latent_size = args.resolution // 8
175
+
176
+ train_dataset = CustomDataset(
177
+ args.data_dir, semantic_features_dir=args.semantic_features_dir
178
+ )
179
+ use_preprocessed_semantic = train_dataset.use_preprocessed_semantic
180
+
181
+ if use_preprocessed_semantic:
182
+ encoders, encoder_types, architectures = [], [], []
183
+ z_dims = [semantic_dim_from_enc_type(args.enc_type)]
184
+ if accelerator.is_main_process:
185
+ logger.info(
186
+ f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}"
187
+ )
188
+ elif args.enc_type is not None:
189
+ encoders, encoder_types, architectures = load_encoders(
190
+ args.enc_type, device, args.resolution
191
+ )
192
+ z_dims = [encoder.embed_dim for encoder in encoders]
193
+ else:
194
+ raise NotImplementedError()
195
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
196
+ model = SiT_models[args.model](
197
+ input_size=latent_size,
198
+ num_classes=args.num_classes,
199
+ use_cfg = (args.cfg_prob > 0),
200
+ z_dims = z_dims,
201
+ encoder_depth=args.encoder_depth,
202
+ **block_kwargs
203
+ )
204
+
205
+ model = model.to(device)
206
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
207
+ requires_grad(ema, False)
208
+
209
+ latents_scale = torch.tensor(
210
+ [0.18215, 0.18215, 0.18215, 0.18215]
211
+ ).view(1, 4, 1, 1).to(device)
212
+ latents_bias = torch.tensor(
213
+ [0., 0., 0., 0.]
214
+ ).view(1, 4, 1, 1).to(device)
215
+
216
+ # VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse)
217
+ try:
218
+ from preprocessing import dnnlib
219
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
220
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
221
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
222
+ os.environ["HF_HOME"] = cache_dir
223
+ try:
224
+ vae = AutoencoderKL.from_pretrained(
225
+ "stabilityai/sd-vae-ft-mse",
226
+ cache_dir=cache_dir,
227
+ local_files_only=True,
228
+ ).to(device)
229
+ vae.eval()
230
+ if accelerator.is_main_process:
231
+ logger.info(
232
+ "Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache "
233
+ f"at '{cache_dir}' for intermediate sampling."
234
+ )
235
+ except Exception as e_main:
236
+ vae = None
237
+ candidate_dir = None
238
+ possible_roots = [
239
+ cache_dir,
240
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
241
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
242
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
243
+ ]
244
+ checked_roots = []
245
+ for root_dir in possible_roots:
246
+ if not os.path.isdir(root_dir):
247
+ continue
248
+ checked_roots.append(root_dir)
249
+ for root, dirs, files in os.walk(root_dir):
250
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
251
+ candidate_dir = root
252
+ break
253
+ if candidate_dir is not None:
254
+ break
255
+ if candidate_dir is not None:
256
+ try:
257
+ vae = AutoencoderKL.from_pretrained(
258
+ candidate_dir,
259
+ local_files_only=True,
260
+ ).to(device)
261
+ vae.eval()
262
+ if accelerator.is_main_process:
263
+ logger.info(
264
+ "Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path "
265
+ f"'{candidate_dir}'. Searched roots: {checked_roots}"
266
+ )
267
+ except Exception as e_fallback:
268
+ if accelerator.is_main_process:
269
+ logger.warning(
270
+ "Tried to load VAE from discovered local path "
271
+ f"'{candidate_dir}' but failed: {e_fallback}"
272
+ )
273
+ if vae is None and accelerator.is_main_process:
274
+ logger.warning(
275
+ "Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. "
276
+ f"Last repo-level error: {e_main}"
277
+ )
278
+ except Exception as e:
279
+ vae = None
280
+ if accelerator.is_main_process:
281
+ logger.warning(
282
+ f"Failed to initialize VAE loading logic (will skip image decoding): {e}"
283
+ )
284
+
285
+ # create loss function
286
+ loss_fn = SILoss(
287
+ prediction=args.prediction,
288
+ path_type=args.path_type,
289
+ encoders=encoders,
290
+ accelerator=accelerator,
291
+ latents_scale=latents_scale,
292
+ latents_bias=latents_bias,
293
+ weighting=args.weighting,
294
+ t_c=args.t_c,
295
+ ot_cls=args.ot_cls,
296
+ )
297
+ if accelerator.is_main_process:
298
+ logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
299
+
300
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
301
+ if args.allow_tf32:
302
+ torch.backends.cuda.matmul.allow_tf32 = True
303
+ torch.backends.cudnn.allow_tf32 = True
304
+
305
+ optimizer = torch.optim.AdamW(
306
+ model.parameters(),
307
+ lr=args.learning_rate,
308
+ betas=(args.adam_beta1, args.adam_beta2),
309
+ weight_decay=args.adam_weight_decay,
310
+ eps=args.adam_epsilon,
311
+ )
312
+
313
+ # Setup data(train_dataset 已在上方创建)
314
+ local_batch_size = int(args.batch_size // accelerator.num_processes)
315
+ train_dataloader = DataLoader(
316
+ train_dataset,
317
+ batch_size=local_batch_size,
318
+ shuffle=True,
319
+ num_workers=args.num_workers,
320
+ pin_memory=True,
321
+ drop_last=True
322
+ )
323
+ if accelerator.is_main_process:
324
+ logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
325
+
326
+ # Prepare models for training:
327
+ update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
328
+ model.train() # important! This enables embedding dropout for classifier-free guidance
329
+ ema.eval() # EMA model should always be in eval mode
330
+
331
+ # resume:
332
+ global_step = 0
333
+ if args.resume_from_ckpt is not None:
334
+ ckpt = torch.load(args.resume_from_ckpt, map_location="cpu")
335
+ model.load_state_dict(ckpt["model"])
336
+ ema.load_state_dict(ckpt["ema"])
337
+ if "opt" in ckpt:
338
+ optimizer.load_state_dict(ckpt["opt"])
339
+ global_step = int(ckpt.get("steps", 0))
340
+ if accelerator.is_main_process:
341
+ logger.info(
342
+ f"Resumed from ckpt: {args.resume_from_ckpt} (global_step={global_step})"
343
+ )
344
+ elif args.resume_step > 0:
345
+ ckpt_name = str(args.resume_step).zfill(7) +'.pt'
346
+ ckpt = torch.load(
347
+ f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
348
+ map_location='cpu',
349
+ )
350
+ model.load_state_dict(ckpt['model'])
351
+ ema.load_state_dict(ckpt['ema'])
352
+ optimizer.load_state_dict(ckpt['opt'])
353
+ global_step = ckpt['steps']
354
+
355
+ model, optimizer, train_dataloader = accelerator.prepare(
356
+ model, optimizer, train_dataloader
357
+ )
358
+
359
+ if accelerator.is_main_process:
360
+ tracker_config = vars(copy.deepcopy(args))
361
+ accelerator.init_trackers(
362
+ project_name="REG",
363
+ config=tracker_config,
364
+ init_kwargs={
365
+ "wandb": {"name": f"{args.exp_name}"}
366
+ },
367
+ )
368
+
369
+
370
+ progress_bar = tqdm(
371
+ range(0, args.max_train_steps),
372
+ initial=global_step,
373
+ desc="Steps",
374
+ # Only show the progress bar once on each machine.
375
+ disable=not accelerator.is_local_main_process,
376
+ )
377
+
378
+ # Labels to condition the model with (feel free to change):
379
+ sample_batch_size = 64 // accelerator.num_processes
380
+ first_batch = next(iter(train_dataloader))
381
+ if len(first_batch) == 4:
382
+ gt_raw_images, gt_xs, _, _ = first_batch
383
+ else:
384
+ gt_raw_images, gt_xs, _ = first_batch
385
+ assert gt_raw_images.shape[-1] == args.resolution
386
+ gt_xs = gt_xs[:sample_batch_size]
387
+ gt_xs = sample_posterior(
388
+ gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
389
+ )
390
+ ys = torch.randint(1000, size=(sample_batch_size,), device=device)
391
+ ys = ys.to(device)
392
+ # Create sampling noise:
393
+ n = ys.size(0)
394
+ xT = torch.randn((n, 4, latent_size, latent_size), device=device)
395
+
396
+ for epoch in range(args.epochs):
397
+ model.train()
398
+ for batch in train_dataloader:
399
+ if len(batch) == 4:
400
+ raw_image, x, r_preprocessed, y = batch
401
+ use_sem_file = True
402
+ else:
403
+ raw_image, x, y = batch
404
+ r_preprocessed = None
405
+ use_sem_file = False
406
+
407
+ raw_image = raw_image.to(device)
408
+ x = x.squeeze(dim=1).to(device).float()
409
+ y = y.to(device)
410
+ if args.legacy:
411
+ # In our early experiments, we accidentally apply label dropping twice:
412
+ # once in train.py and once in sit.py.
413
+ # We keep this option for exact reproducibility with previous runs.
414
+ drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
415
+ labels = torch.where(drop_ids, args.num_classes, y)
416
+ else:
417
+ labels = y
418
+ with torch.no_grad():
419
+ x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
420
+ zs = []
421
+ if use_sem_file and r_preprocessed is not None:
422
+ cls_token = r_preprocessed.to(device).float()
423
+ if cls_token.dim() == 1:
424
+ cls_token = cls_token.unsqueeze(0)
425
+ while cls_token.dim() > 2:
426
+ cls_token = cls_token.squeeze(1)
427
+ base_m = model.module if hasattr(model, "module") else model
428
+ n_pad = base_m.x_embedder.num_patches
429
+ zs = [
430
+ torch.cat(
431
+ [
432
+ cls_token.unsqueeze(1),
433
+ cls_token.unsqueeze(1).expand(-1, n_pad, -1),
434
+ ],
435
+ dim=1,
436
+ )
437
+ ]
438
+ else:
439
+ with accelerator.autocast():
440
+ for encoder, encoder_type, arch in zip(
441
+ encoders, encoder_types, architectures
442
+ ):
443
+ raw_image_ = preprocess_raw_image(raw_image, encoder_type)
444
+ z = encoder.forward_features(raw_image_)
445
+ if 'dinov2' in encoder_type:
446
+ dense_z = z['x_norm_patchtokens']
447
+ cls_token = z['x_norm_clstoken']
448
+ dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1)
449
+ else:
450
+ exit()
451
+ zs.append(dense_z)
452
+
453
+ with accelerator.accumulate(model):
454
+ model_kwargs = dict(y=labels)
455
+ loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
456
+ cls_token=cls_token,
457
+ time_input=None, noises=None)
458
+ loss_mean = loss1.mean()
459
+ loss_mean_cls = loss2.mean() * args.cls
460
+ proj_loss_mean = proj_loss1.mean() * args.proj_coeff
461
+ tc_vel_loss = torch.tensor(0.0, device=device)
462
+ if args.tc_velocity_loss_coeff > 0:
463
+ tc_vel_loss = loss_fn.tc_velocity_loss(
464
+ model,
465
+ x,
466
+ model_kwargs=model_kwargs,
467
+ cls_token=cls_token,
468
+ noises=noises,
469
+ ).mean()
470
+ loss = (
471
+ loss_mean
472
+ + proj_loss_mean
473
+ + loss_mean_cls
474
+ + args.tc_velocity_loss_coeff * tc_vel_loss
475
+ )
476
+
477
+
478
+ ## optimization
479
+ accelerator.backward(loss)
480
+ if accelerator.sync_gradients:
481
+ params_to_clip = model.parameters()
482
+ grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
483
+ optimizer.step()
484
+ optimizer.zero_grad(set_to_none=True)
485
+
486
+ if accelerator.sync_gradients:
487
+ update_ema(ema, model) # change ema function
488
+
489
+ ### enter
490
+ if accelerator.sync_gradients:
491
+ progress_bar.update(1)
492
+ global_step += 1
493
+ if global_step % args.checkpointing_steps == 0 and global_step > 0:
494
+ if accelerator.is_main_process:
495
+ checkpoint = {
496
+ "model": model.module.state_dict(),
497
+ "ema": ema.state_dict(),
498
+ "opt": optimizer.state_dict(),
499
+ "args": args,
500
+ "steps": global_step,
501
+ }
502
+ checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
503
+ torch.save(checkpoint, checkpoint_path)
504
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
505
+
506
+ if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
507
+ t_mid_vis = float(args.t_c)
508
+ tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
509
+ logging.info(
510
+ f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..."
511
+ )
512
+ ema.eval()
513
+ with torch.no_grad():
514
+ latent_size = args.resolution // 8
515
+ n_samples = min(16, args.batch_size)
516
+ base_model = model.module if hasattr(model, "module") else model
517
+ cls_dim = base_model.z_dims[0]
518
+ shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
519
+ torch.manual_seed(shared_seed)
520
+ z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
521
+ torch.manual_seed(shared_seed)
522
+ cls_init = torch.randn(n_samples, cls_dim, device=device)
523
+ y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
524
+
525
+ from samplers import euler_maruyama_sampler
526
+ z_0, z_mid, _ = euler_maruyama_sampler(
527
+ ema,
528
+ z_init,
529
+ y_samples,
530
+ num_steps=50,
531
+ cfg_scale=1.0,
532
+ guidance_low=0.0,
533
+ guidance_high=1.0,
534
+ path_type=args.path_type,
535
+ cls_latents=cls_init,
536
+ args=args,
537
+ return_mid_state=True,
538
+ t_mid=t_mid_vis,
539
+ )
540
+
541
+ samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
542
+ t0_dir = os.path.join(samples_root, "t0")
543
+ t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
544
+ os.makedirs(t0_dir, exist_ok=True)
545
+ os.makedirs(t_mid_dir, exist_ok=True)
546
+
547
+ if vae is not None:
548
+ z_f = z_0.to(dtype=torch.float32)
549
+ samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
550
+ samples_final = (samples_final + 1) / 2.0
551
+ samples_final = samples_final.clamp(0, 1)
552
+ grid_final = array2grid(samples_final)
553
+ Image.fromarray(grid_final).save(
554
+ os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
555
+ )
556
+
557
+ if z_mid is not None:
558
+ z_m = z_mid.to(dtype=torch.float32)
559
+ samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
560
+ samples_mid = (samples_mid + 1) / 2.0
561
+ samples_mid = samples_mid.clamp(0, 1)
562
+ grid_mid = array2grid(samples_mid)
563
+ Image.fromarray(grid_mid).save(
564
+ os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
565
+ )
566
+ else:
567
+ logging.warning(
568
+ f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
569
+ f"skip t0_{tc_tag} image this step."
570
+ )
571
+
572
+ del z_init, cls_init, y_samples, z_0
573
+ if z_mid is not None:
574
+ del z_mid
575
+ if vae is not None:
576
+ del samples_final, grid_final
577
+ if "samples_mid" in locals():
578
+ del samples_mid, grid_mid
579
+ torch.cuda.empty_cache()
580
+
581
+
582
+ logs = {
583
+ "loss_final": accelerator.gather(loss).mean().detach().item(),
584
+ "loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
585
+ "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
586
+ "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
587
+ "loss_tc_vel": accelerator.gather(tc_vel_loss).mean().detach().item(),
588
+ "grad_norm": accelerator.gather(grad_norm).mean().detach().item()
589
+ }
590
+
591
+ log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
592
+ logging.info(f"Step: {global_step}, Training Logs: {log_message}")
593
+
594
+ progress_bar.set_postfix(**logs)
595
+ accelerator.log(logs, step=global_step)
596
+
597
+ if global_step >= args.max_train_steps:
598
+ break
599
+ if global_step >= args.max_train_steps:
600
+ break
601
+
602
+ model.eval() # important! This disables randomized embedding dropout
603
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
604
+
605
+ accelerator.wait_for_everyone()
606
+ if accelerator.is_main_process:
607
+ logger.info("Done!")
608
+ accelerator.end_training()
609
+
610
+ def parse_args(input_args=None):
611
+ parser = argparse.ArgumentParser(description="Training")
612
+
613
+ # logging:
614
+ parser.add_argument("--output-dir", type=str, default="exps")
615
+ parser.add_argument("--exp-name", type=str, required=True)
616
+ parser.add_argument("--logging-dir", type=str, default="logs")
617
+ parser.add_argument("--report-to", type=str, default="wandb")
618
+ parser.add_argument("--sampling-steps", type=int, default=2000)
619
+ parser.add_argument("--resume-step", type=int, default=0)
620
+ parser.add_argument(
621
+ "--resume-from-ckpt",
622
+ type=str,
623
+ default=None,
624
+ help="直接从指定 checkpoint 路径续训(优先于 --resume-step)。",
625
+ )
626
+
627
+ # model
628
+ parser.add_argument("--model", type=str)
629
+ parser.add_argument("--num-classes", type=int, default=1000)
630
+ parser.add_argument("--encoder-depth", type=int, default=8)
631
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
632
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
633
+ parser.add_argument("--ops-head", type=int, default=16)
634
+
635
+ # dataset
636
+ parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
637
+ parser.add_argument(
638
+ "--semantic-features-dir",
639
+ type=str,
640
+ default=None,
641
+ help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。"
642
+ "默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。",
643
+ )
644
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
645
+ parser.add_argument("--batch-size", type=int, default=256)#256
646
+
647
+ # precision
648
+ parser.add_argument("--allow-tf32", action="store_true")
649
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
650
+
651
+ # optimization
652
+ parser.add_argument("--epochs", type=int, default=14000)
653
+ parser.add_argument("--max-train-steps", type=int, default=10000000)
654
+ parser.add_argument("--checkpointing-steps", type=int, default=10000)
655
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
656
+ parser.add_argument("--learning-rate", type=float, default=1e-4)
657
+ parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
658
+ parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
659
+ parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
660
+ parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
661
+ parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
662
+
663
+ # seed
664
+ parser.add_argument("--seed", type=int, default=0)
665
+
666
+ # cpu
667
+ parser.add_argument("--num-workers", type=int, default=4)
668
+
669
+ # loss
670
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
671
+ parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
672
+ parser.add_argument("--cfg-prob", type=float, default=0.1)
673
+ parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
674
+ parser.add_argument("--proj-coeff", type=float, default=0.5)
675
+ parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
676
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
677
+ parser.add_argument("--cls", type=float, default=0.03)
678
+ parser.add_argument(
679
+ "--t-c",
680
+ type=float,
681
+ default=0.5,
682
+ help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。"
683
+ "t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);"
684
+ "t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。",
685
+ )
686
+ parser.add_argument(
687
+ "--ot-cls",
688
+ action=argparse.BooleanOptionalAction,
689
+ default=True,
690
+ help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。",
691
+ )
692
+ parser.add_argument(
693
+ "--tc-velocity-loss-coeff",
694
+ type=float,
695
+ default=0.0,
696
+ help="额外 t=t_c 图像速度场监督项权重(>0 启用,用于增强单步性)。",
697
+ )
698
+ if input_args is not None:
699
+ args = parser.parse_args(input_args)
700
+ else:
701
+ args = parser.parse_args()
702
+
703
+ return args
704
+
705
+ if __name__ == "__main__":
706
+ args = parse_args()
707
+
708
+ main(args)
REG/train.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # REG/train.py:与主仓库类似,可单独指定数据根目录与预处理 cls 特征目录。
3
+ # 数据布局:${DATA_DIR}/imagenet_256_vae/ 下 VAE latent;
4
+ # ${SEMANTIC_FEATURES_DIR}/ 下 img-feature-*.npy + dataset.json(与 parallel_encode 一致)。
5
+
6
+ NUM_GPUS=4
7
+
8
+ # ------------ 按本机路径修改 ------------
9
+ DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
10
+ SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
11
+
12
+ # 后台示例(与主实验脚本风格一致):
13
+ # nohup bash train.sh > jsflow-experiment.log 2>&1 &
14
+
15
+ nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
16
+ --report-to wandb \
17
+ --allow-tf32 \
18
+ --mixed-precision bf16 \
19
+ --seed 0 \
20
+ --path-type linear \
21
+ --prediction v \
22
+ --weighting uniform \
23
+ --model SiT-XL/2 \
24
+ --enc-type dinov2-vit-b \
25
+ --encoder-depth 8 \
26
+ --proj-coeff 0.5 \
27
+ --output-dir exps \
28
+ --exp-name jsflow-experiment-0.75-0.01 \
29
+ --batch-size 256 \
30
+ --data-dir "${DATA_DIR}" \
31
+ --semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
32
+ --learning-rate 0.00005 \
33
+ --t-c 0.75 \
34
+ --cls 0.01 \
35
+ --ot-cls \
36
+ > jsflow-experiment.log 2>&1 &
37
+
38
+ # 说明:
39
+ # - 不使用预处理特征、改在线抽 DINO 时:去掉 --semantic-features-dir,并保证 data-dir 为 REG 原布局
40
+ # (imagenet_256_vae + vae-sd)。
41
+ # - 关闭 minibatch OT:追加 --no-ot-cls。
42
+ # - 主仓库 train.py 中的 --weight-ratio / --semantic-reg-coeff / --repa-* 等为本 REG 脚本未实现项;
43
+ # 投影强度请用 --proj-coeff,cls 流损失权重用 --cls。
REG/train_resume_tc_velocity.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 基于指定 checkpoint 续训,并额外加入 t_c 处速度场监督项(用于增强单步性)
3
+ set -euo pipefail
4
+
5
+ NUM_GPUS=4
6
+
7
+ DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
8
+ SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
9
+
10
+ # 用户指定的续训起点 checkpoint
11
+ RESUME_CKPT="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1920000.pt"
12
+
13
+ # 新增的 t_c 速度场损失权重(可按需调大/调小)
14
+ TC_VEL_COEFF=2
15
+
16
+ nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
17
+ --report-to wandb \
18
+ --allow-tf32 \
19
+ --mixed-precision bf16 \
20
+ --seed 0 \
21
+ --path-type linear \
22
+ --prediction v \
23
+ --weighting uniform \
24
+ --model SiT-XL/2 \
25
+ --enc-type dinov2-vit-b \
26
+ --encoder-depth 8 \
27
+ --proj-coeff 0.5 \
28
+ --output-dir exps \
29
+ --exp-name jsflow-experiment-0.75-0.01-one-step \
30
+ --batch-size 256 \
31
+ --data-dir "${DATA_DIR}" \
32
+ --semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
33
+ --learning-rate 0.00005 \
34
+ --t-c 0.75 \
35
+ --cls 0.005 \
36
+ --ot-cls \
37
+ --resume-from-ckpt "${RESUME_CKPT}" \
38
+ --tc-velocity-loss-coeff "${TC_VEL_COEFF}" \
39
+ > jsflow-experiment-0.75-0.01-tcvel.log 2>&1 &
40
+
41
+ echo "Launched resume training with tc velocity loss."
REG/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision.datasets.utils import download_url
3
+ import torch
4
+ import torchvision.models as torchvision_models
5
+ import timm
6
+ from models import mocov3_vit
7
+ import math
8
+ import warnings
9
+
10
+
11
+ # code from SiT repository
12
+ pretrained_models = {'last.pt'}
13
+
14
+ def download_model(model_name):
15
+ """
16
+ Downloads a pre-trained SiT model from the web.
17
+ """
18
+ assert model_name in pretrained_models
19
+ local_path = f'pretrained_models/{model_name}'
20
+ if not os.path.isfile(local_path):
21
+ os.makedirs('pretrained_models', exist_ok=True)
22
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
23
+ download_url(web_path, 'pretrained_models', filename=model_name)
24
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
25
+ return model
26
+
27
+ def fix_mocov3_state_dict(state_dict):
28
+ for k in list(state_dict.keys()):
29
+ # retain only base_encoder up to before the embedding layer
30
+ if k.startswith('module.base_encoder'):
31
+ # fix naming bug in checkpoint
32
+ new_k = k[len("module.base_encoder."):]
33
+ if "blocks.13.norm13" in new_k:
34
+ new_k = new_k.replace("norm13", "norm1")
35
+ if "blocks.13.mlp.fc13" in k:
36
+ new_k = new_k.replace("fc13", "fc1")
37
+ if "blocks.14.norm14" in k:
38
+ new_k = new_k.replace("norm14", "norm2")
39
+ if "blocks.14.mlp.fc14" in k:
40
+ new_k = new_k.replace("fc14", "fc2")
41
+ # remove prefix
42
+ if 'head' not in new_k and new_k.split('.')[0] != 'fc':
43
+ state_dict[new_k] = state_dict[k]
44
+ # delete renamed or unused k
45
+ del state_dict[k]
46
+ if 'pos_embed' in state_dict.keys():
47
+ state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
48
+ state_dict['pos_embed'], [16, 16],
49
+ )
50
+ return state_dict
51
+
52
+ @torch.no_grad()
53
+ def load_encoders(enc_type, device, resolution=256):
54
+ assert (resolution == 256) or (resolution == 512)
55
+
56
+ enc_names = enc_type.split(',')
57
+ encoders, architectures, encoder_types = [], [], []
58
+ for enc_name in enc_names:
59
+ encoder_type, architecture, model_config = enc_name.split('-')
60
+ # Currently, we only support 512x512 experiments with DINOv2 encoders.
61
+ if resolution == 512:
62
+ if encoder_type != 'dinov2':
63
+ raise NotImplementedError(
64
+ "Currently, we only support 512x512 experiments with DINOv2 encoders."
65
+ )
66
+
67
+ architectures.append(architecture)
68
+ encoder_types.append(encoder_type)
69
+ if encoder_type == 'mocov3':
70
+ if architecture == 'vit':
71
+ if model_config == 's':
72
+ encoder = mocov3_vit.vit_small()
73
+ elif model_config == 'b':
74
+ encoder = mocov3_vit.vit_base()
75
+ elif model_config == 'l':
76
+ encoder = mocov3_vit.vit_large()
77
+ ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
78
+ state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
79
+ del encoder.head
80
+ encoder.load_state_dict(state_dict, strict=True)
81
+ encoder.head = torch.nn.Identity()
82
+ elif architecture == 'resnet':
83
+ raise NotImplementedError()
84
+
85
+ encoder = encoder.to(device)
86
+ encoder.eval()
87
+
88
+ elif 'dinov2' in encoder_type:
89
+ import timm
90
+ if 'reg' in encoder_type:
91
+ try:
92
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
93
+ f'dinov2_vit{model_config}14_reg', source='local')
94
+ except:
95
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
96
+ else:
97
+ try:
98
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
99
+ f'dinov2_vit{model_config}14', source='local')
100
+ except:
101
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
102
+
103
+ print(f"Now you are using the {enc_name} as the aligning model")
104
+ del encoder.head
105
+ patch_resolution = 16 * (resolution // 256)
106
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
107
+ encoder.pos_embed.data, [patch_resolution, patch_resolution],
108
+ )
109
+ encoder.head = torch.nn.Identity()
110
+ encoder = encoder.to(device)
111
+ encoder.eval()
112
+
113
+ elif 'dinov1' == encoder_type:
114
+ import timm
115
+ from models import dinov1
116
+ encoder = dinov1.vit_base()
117
+ ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
118
+ if 'pos_embed' in ckpt.keys():
119
+ ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
120
+ ckpt['pos_embed'], [16, 16],
121
+ )
122
+ del encoder.head
123
+ encoder.head = torch.nn.Identity()
124
+ encoder.load_state_dict(ckpt, strict=True)
125
+ encoder = encoder.to(device)
126
+ encoder.forward_features = encoder.forward
127
+ encoder.eval()
128
+
129
+ elif encoder_type == 'clip':
130
+ import clip
131
+ from models.clip_vit import UpdatedVisionTransformer
132
+ encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
133
+ encoder = UpdatedVisionTransformer(encoder_).to(device)
134
+ #.to(device)
135
+ encoder.embed_dim = encoder.model.transformer.width
136
+ encoder.forward_features = encoder.forward
137
+ encoder.eval()
138
+
139
+ elif encoder_type == 'mae':
140
+ from models.mae_vit import vit_large_patch16
141
+ import timm
142
+ kwargs = dict(img_size=256)
143
+ encoder = vit_large_patch16(**kwargs).to(device)
144
+ with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
145
+ state_dict = torch.load(f)
146
+ if 'pos_embed' in state_dict["model"].keys():
147
+ state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
148
+ state_dict["model"]['pos_embed'], [16, 16],
149
+ )
150
+ encoder.load_state_dict(state_dict["model"])
151
+
152
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
153
+ encoder.pos_embed.data, [16, 16],
154
+ )
155
+
156
+ elif encoder_type == 'jepa':
157
+ from models.jepa import vit_huge
158
+ kwargs = dict(img_size=[224, 224], patch_size=14)
159
+ encoder = vit_huge(**kwargs).to(device)
160
+ with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
161
+ state_dict = torch.load(f, map_location=device)
162
+ new_state_dict = dict()
163
+ for key, value in state_dict['encoder'].items():
164
+ new_state_dict[key[7:]] = value
165
+ encoder.load_state_dict(new_state_dict)
166
+ encoder.forward_features = encoder.forward
167
+
168
+ encoders.append(encoder)
169
+
170
+ return encoders, encoder_types, architectures
171
+
172
+
173
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
174
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
175
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
176
+ def norm_cdf(x):
177
+ # Computes standard normal cumulative distribution function
178
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
179
+
180
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
181
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
182
+ "The distribution of values may be incorrect.",
183
+ stacklevel=2)
184
+
185
+ with torch.no_grad():
186
+ # Values are generated by using a truncated uniform distribution and
187
+ # then using the inverse CDF for the normal distribution.
188
+ # Get upper and lower cdf values
189
+ l = norm_cdf((a - mean) / std)
190
+ u = norm_cdf((b - mean) / std)
191
+
192
+ # Uniformly fill tensor with values from [l, u], then translate to
193
+ # [2l-1, 2u-1].
194
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
195
+
196
+ # Use inverse cdf transform for normal distribution to get truncated
197
+ # standard normal
198
+ tensor.erfinv_()
199
+
200
+ # Transform to proper mean, std
201
+ tensor.mul_(std * math.sqrt(2.))
202
+ tensor.add_(mean)
203
+
204
+ # Clamp to ensure it's in the proper range
205
+ tensor.clamp_(min=a, max=b)
206
+ return tensor
207
+
208
+
209
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
210
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
211
+
212
+
213
+ def load_legacy_checkpoints(state_dict, encoder_depth):
214
+ new_state_dict = dict()
215
+ for key, value in state_dict.items():
216
+ if 'decoder_blocks' in key:
217
+ parts =key.split('.')
218
+ new_idx = int(parts[1]) + encoder_depth
219
+ parts[0] = 'blocks'
220
+ parts[1] = str(new_idx)
221
+ new_key = '.'.join(parts)
222
+ new_state_dict[new_key] = value
223
+ else:
224
+ new_state_dict[key] = value
225
+ return new_state_dict
REG/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.25.0
4
+ e:
5
+ ucanic8s891x6sl28vnbha78lzoecw66:
6
+ args:
7
+ - --report-to
8
+ - wandb
9
+ - --allow-tf32
10
+ - --mixed-precision
11
+ - bf16
12
+ - --seed
13
+ - "0"
14
+ - --path-type
15
+ - linear
16
+ - --prediction
17
+ - v
18
+ - --weighting
19
+ - uniform
20
+ - --model
21
+ - SiT-XL/2
22
+ - --enc-type
23
+ - dinov2-vit-b
24
+ - --encoder-depth
25
+ - "8"
26
+ - --proj-coeff
27
+ - "0.5"
28
+ - --output-dir
29
+ - exps
30
+ - --exp-name
31
+ - jsflow-experiment
32
+ - --batch-size
33
+ - "256"
34
+ - --data-dir
35
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
36
+ - --semantic-features-dir
37
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
38
+ - --learning-rate
39
+ - "0.00005"
40
+ - --t-c
41
+ - "0.5"
42
+ - --cls
43
+ - "0.2"
44
+ - --ot-cls
45
+ codePath: train.py
46
+ codePathLocal: train.py
47
+ cpu_count: 96
48
+ cpu_count_logical: 192
49
+ cudaVersion: "13.0"
50
+ disk:
51
+ /:
52
+ total: "3838880616448"
53
+ used: "357557354496"
54
+ email: 2365972933@qq.com
55
+ executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
56
+ git:
57
+ commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
58
+ remote: https://github.com/Martinser/REG.git
59
+ gpu: NVIDIA H100 80GB HBM3
60
+ gpu_count: 4
61
+ gpu_nvidia:
62
+ - architecture: Hopper
63
+ cudaCores: 16896
64
+ memoryTotal: "85520809984"
65
+ name: NVIDIA H100 80GB HBM3
66
+ uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
67
+ - architecture: Hopper
68
+ cudaCores: 16896
69
+ memoryTotal: "85520809984"
70
+ name: NVIDIA H100 80GB HBM3
71
+ uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
72
+ - architecture: Hopper
73
+ cudaCores: 16896
74
+ memoryTotal: "85520809984"
75
+ name: NVIDIA H100 80GB HBM3
76
+ uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
77
+ - architecture: Hopper
78
+ cudaCores: 16896
79
+ memoryTotal: "85520809984"
80
+ name: NVIDIA H100 80GB HBM3
81
+ uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
82
+ host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
83
+ memory:
84
+ total: "2164115296256"
85
+ os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
86
+ program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
87
+ python: CPython 3.12.9
88
+ root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
89
+ startedAt: "2026-03-22T07:00:22.092510Z"
90
+ writerId: ucanic8s891x6sl28vnbha78lzoecw66
91
+ m: []
92
+ python_version: 3.12.9
93
+ t:
94
+ "1":
95
+ - 1
96
+ - 5
97
+ - 11
98
+ - 41
99
+ - 49
100
+ - 53
101
+ - 63
102
+ - 71
103
+ - 83
104
+ - 98
105
+ "2":
106
+ - 1
107
+ - 5
108
+ - 11
109
+ - 41
110
+ - 49
111
+ - 53
112
+ - 63
113
+ - 71
114
+ - 83
115
+ - 98
116
+ "3":
117
+ - 13
118
+ "4": 3.12.9
119
+ "5": 0.25.0
120
+ "6": 4.53.2
121
+ "12": 0.25.0
122
+ "13": linux-x86_64
123
+ adam_beta1:
124
+ value: 0.9
125
+ adam_beta2:
126
+ value: 0.999
127
+ adam_epsilon:
128
+ value: 1e-08
129
+ adam_weight_decay:
130
+ value: 0
131
+ allow_tf32:
132
+ value: true
133
+ batch_size:
134
+ value: 256
135
+ cfg_prob:
136
+ value: 0.1
137
+ checkpointing_steps:
138
+ value: 10000
139
+ cls:
140
+ value: 0.2
141
+ data_dir:
142
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
143
+ enc_type:
144
+ value: dinov2-vit-b
145
+ encoder_depth:
146
+ value: 8
147
+ epochs:
148
+ value: 1400
149
+ exp_name:
150
+ value: jsflow-experiment
151
+ fused_attn:
152
+ value: true
153
+ gradient_accumulation_steps:
154
+ value: 1
155
+ learning_rate:
156
+ value: 5e-05
157
+ legacy:
158
+ value: false
159
+ logging_dir:
160
+ value: logs
161
+ max_grad_norm:
162
+ value: 1
163
+ max_train_steps:
164
+ value: 1000000
165
+ mixed_precision:
166
+ value: bf16
167
+ model:
168
+ value: SiT-XL/2
169
+ num_classes:
170
+ value: 1000
171
+ num_workers:
172
+ value: 4
173
+ ops_head:
174
+ value: 16
175
+ ot_cls:
176
+ value: true
177
+ output_dir:
178
+ value: exps
179
+ path_type:
180
+ value: linear
181
+ prediction:
182
+ value: v
183
+ proj_coeff:
184
+ value: 0.5
185
+ qk_norm:
186
+ value: false
187
+ report_to:
188
+ value: wandb
189
+ resolution:
190
+ value: 256
191
+ resume_step:
192
+ value: 0
193
+ sampling_steps:
194
+ value: 2000
195
+ seed:
196
+ value: 0
197
+ semantic_features_dir:
198
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
199
+ t_c:
200
+ value: 0.5
201
+ weighting:
202
+ value: uniform
REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":2},"_runtime":2}
REG/wandb/run-20260322_150443-e3yw9ii4/files/config.yaml ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.25.0
4
+ e:
5
+ q63x26q8nhayytv8q2rrmj9j9uy9kvub:
6
+ args:
7
+ - --report-to
8
+ - wandb
9
+ - --allow-tf32
10
+ - --mixed-precision
11
+ - bf16
12
+ - --seed
13
+ - "0"
14
+ - --path-type
15
+ - linear
16
+ - --prediction
17
+ - v
18
+ - --weighting
19
+ - uniform
20
+ - --model
21
+ - SiT-XL/2
22
+ - --enc-type
23
+ - dinov2-vit-b
24
+ - --encoder-depth
25
+ - "8"
26
+ - --proj-coeff
27
+ - "0.5"
28
+ - --output-dir
29
+ - exps
30
+ - --exp-name
31
+ - jsflow-experiment
32
+ - --batch-size
33
+ - "256"
34
+ - --data-dir
35
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
36
+ - --semantic-features-dir
37
+ - /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
38
+ - --learning-rate
39
+ - "0.00005"
40
+ - --t-c
41
+ - "0.5"
42
+ - --cls
43
+ - "0.2"
44
+ - --ot-cls
45
+ codePath: train.py
46
+ codePathLocal: train.py
47
+ cpu_count: 96
48
+ cpu_count_logical: 192
49
+ cudaVersion: "13.0"
50
+ disk:
51
+ /:
52
+ total: "3838880616448"
53
+ used: "357557714944"
54
+ email: 2365972933@qq.com
55
+ executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
56
+ git:
57
+ commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
58
+ remote: https://github.com/Martinser/REG.git
59
+ gpu: NVIDIA H100 80GB HBM3
60
+ gpu_count: 4
61
+ gpu_nvidia:
62
+ - architecture: Hopper
63
+ cudaCores: 16896
64
+ memoryTotal: "85520809984"
65
+ name: NVIDIA H100 80GB HBM3
66
+ uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
67
+ - architecture: Hopper
68
+ cudaCores: 16896
69
+ memoryTotal: "85520809984"
70
+ name: NVIDIA H100 80GB HBM3
71
+ uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
72
+ - architecture: Hopper
73
+ cudaCores: 16896
74
+ memoryTotal: "85520809984"
75
+ name: NVIDIA H100 80GB HBM3
76
+ uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
77
+ - architecture: Hopper
78
+ cudaCores: 16896
79
+ memoryTotal: "85520809984"
80
+ name: NVIDIA H100 80GB HBM3
81
+ uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
82
+ host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
83
+ memory:
84
+ total: "2164115296256"
85
+ os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
86
+ program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
87
+ python: CPython 3.12.9
88
+ root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
89
+ startedAt: "2026-03-22T07:04:43.133739Z"
90
+ writerId: q63x26q8nhayytv8q2rrmj9j9uy9kvub
91
+ m: []
92
+ python_version: 3.12.9
93
+ t:
94
+ "1":
95
+ - 1
96
+ - 5
97
+ - 11
98
+ - 41
99
+ - 49
100
+ - 53
101
+ - 63
102
+ - 71
103
+ - 83
104
+ - 98
105
+ "2":
106
+ - 1
107
+ - 5
108
+ - 11
109
+ - 41
110
+ - 49
111
+ - 53
112
+ - 63
113
+ - 71
114
+ - 83
115
+ - 98
116
+ "3":
117
+ - 13
118
+ "4": 3.12.9
119
+ "5": 0.25.0
120
+ "6": 4.53.2
121
+ "12": 0.25.0
122
+ "13": linux-x86_64
123
+ adam_beta1:
124
+ value: 0.9
125
+ adam_beta2:
126
+ value: 0.999
127
+ adam_epsilon:
128
+ value: 1e-08
129
+ adam_weight_decay:
130
+ value: 0
131
+ allow_tf32:
132
+ value: true
133
+ batch_size:
134
+ value: 256
135
+ cfg_prob:
136
+ value: 0.1
137
+ checkpointing_steps:
138
+ value: 10000
139
+ cls:
140
+ value: 0.2
141
+ data_dir:
142
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
143
+ enc_type:
144
+ value: dinov2-vit-b
145
+ encoder_depth:
146
+ value: 8
147
+ epochs:
148
+ value: 1400
149
+ exp_name:
150
+ value: jsflow-experiment
151
+ fused_attn:
152
+ value: true
153
+ gradient_accumulation_steps:
154
+ value: 1
155
+ learning_rate:
156
+ value: 5e-05
157
+ legacy:
158
+ value: false
159
+ logging_dir:
160
+ value: logs
161
+ max_grad_norm:
162
+ value: 1
163
+ max_train_steps:
164
+ value: 1000000
165
+ mixed_precision:
166
+ value: bf16
167
+ model:
168
+ value: SiT-XL/2
169
+ num_classes:
170
+ value: 1000
171
+ num_workers:
172
+ value: 4
173
+ ops_head:
174
+ value: 16
175
+ ot_cls:
176
+ value: true
177
+ output_dir:
178
+ value: exps
179
+ path_type:
180
+ value: linear
181
+ prediction:
182
+ value: v
183
+ proj_coeff:
184
+ value: 0.5
185
+ qk_norm:
186
+ value: false
187
+ report_to:
188
+ value: wandb
189
+ resolution:
190
+ value: 256
191
+ resume_step:
192
+ value: 0
193
+ sampling_steps:
194
+ value: 2000
195
+ seed:
196
+ value: 0
197
+ semantic_features_dir:
198
+ value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
199
+ t_c:
200
+ value: 0.5
201
+ weighting:
202
+ value: uniform
REG/wandb/run-20260322_150443-e3yw9ii4/files/output.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Steps: 0%| | 1/1000000 [00:02<588:29:38, 2.12s/it][2026-03-22 15:04:48] Generating EMA samples for evaluation (SDE → t=0)...
2
+ Traceback (most recent call last):
3
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 572, in <module>
4
+ main(args)
5
+ File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 444, in main
6
+ if vae is not None:
7
+ ^^^
8
+ NameError: name 'vae' is not defined
9
+ [rank0]: Traceback (most recent call last):
10
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 572, in <module>
11
+ [rank0]: main(args)
12
+ [rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 444, in main
13
+ [rank0]: if vae is not None:
14
+ [rank0]: ^^^
15
+ [rank0]: NameError: name 'vae' is not defined
REG/wandb/run-20260322_150443-e3yw9ii4/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-22T07:04:43.133739Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.5",
41
+ "--cls",
42
+ "0.2",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357557714944"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "q63x26q8nhayytv8q2rrmj9j9uy9kvub"
101
+ }
REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":3},"_runtime":3}
REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-22T15:04:43.390486873+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-22T15:04:44.970687851+08:00","level":"INFO","msg":"stream: created new stream","id":"e3yw9ii4"}
3
+ {"time":"2026-03-22T15:04:44.970802178+08:00","level":"INFO","msg":"handler: started","stream_id":"e3yw9ii4"}
4
+ {"time":"2026-03-22T15:04:44.971744065+08:00","level":"INFO","msg":"stream: started","id":"e3yw9ii4"}
5
+ {"time":"2026-03-22T15:04:44.97174913+08:00","level":"INFO","msg":"writer: started","stream_id":"e3yw9ii4"}
6
+ {"time":"2026-03-22T15:04:44.971758857+08:00","level":"INFO","msg":"sender: started","stream_id":"e3yw9ii4"}
7
+ {"time":"2026-03-22T15:04:50.286711145+08:00","level":"INFO","msg":"stream: closing","id":"e3yw9ii4"}
REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-22 15:04:43,155 INFO MainThread:326012 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_setup.py:_flush():81] Configure stats pid to 326012
3
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log
5
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log
6
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():892] starting backend
10
+ 2026-03-22 15:04:43,378 INFO MainThread:326012 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-22 15:04:43,388 INFO MainThread:326012 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-22 15:04:43,389 INFO MainThread:326012 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-22 15:04:43,402 INFO MainThread:326012 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-22 15:04:46,450 INFO MainThread:326012 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-22 15:04:46,545 INFO MainThread:326012 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-22 15:04:46,545 INFO MainThread:326012 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
21
+ 2026-03-22 15:04:50,286 INFO wandb-AsyncioManager-main:326012 [service_client.py:_forward_responses():134] Reached EOF.
22
+ 2026-03-22 15:04:50,286 INFO wandb-AsyncioManager-main:326012 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"time":"2026-03-22T15:06:35.605542246+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-22T15:06:37.585674571+08:00","level":"INFO","msg":"stream: created new stream","id":"o2w3z8rq"}
3
+ {"time":"2026-03-22T15:06:37.585934805+08:00","level":"INFO","msg":"handler: started","stream_id":"o2w3z8rq"}
4
+ {"time":"2026-03-22T15:06:37.586954142+08:00","level":"INFO","msg":"stream: started","id":"o2w3z8rq"}
5
+ {"time":"2026-03-22T15:06:37.587002572+08:00","level":"INFO","msg":"sender: started","stream_id":"o2w3z8rq"}
6
+ {"time":"2026-03-22T15:06:37.58696296+08:00","level":"INFO","msg":"writer: started","stream_id":"o2w3z8rq"}
REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Configure stats pid to 328110
3
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log
5
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log
6
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():892] starting backend
10
+ 2026-03-22 15:06:35,590 INFO MainThread:328110 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-22 15:06:35,601 INFO MainThread:328110 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-22 15:06:35,604 INFO MainThread:328110 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-22 15:06:35,618 INFO MainThread:328110 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-22 15:06:38,166 INFO MainThread:328110 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-22 15:06:38,258 INFO MainThread:328110 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-22 15:06:38,258 INFO MainThread:328110 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-22 15:06:38,259 INFO MainThread:328110 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-22 15:06:38,259 INFO MainThread:328110 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-22 15:06:38,262 INFO MainThread:328110 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-22 15:06:38,263 INFO MainThread:328110 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
REG/wandb/run-20260323_135607-zue1y2ba/files/output.log ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Steps: 0%| | 1/1000000 [00:02<603:07:37, 2.17s/it][2026-03-23 13:56:13] Generating EMA samples (Euler-Maruyama; t≈0.75 → t=0)...
2
+ [2026-03-23 13:56:16] Step: 1, Training Logs: loss_final: 4.886707, loss_mean: 1.706308, proj_loss: 0.001541, loss_mean_cls: 3.178859, grad_norm: 1.484174
3
+ Steps: 0%| | 2/1000000 [00:04<669:29:28, 2.41s/it, grad_norm=1.48, loss_final=4.89, loss_mean=1.71, loss_mean_cls=3.18, proj_loss=0.00154][2026-03-23 13:56:16] Step: 2, Training Logs: loss_final: 4.294325, loss_mean: 1.689698, proj_loss: -0.010266, loss_mean_cls: 2.614893, grad_norm: 1.062257
4
+ Steps: 0%| | 3/1000000 [00:04<394:40:41, 1.42s/it, grad_norm=1.06, loss_final=4.29, loss_mean=1.69, loss_mean_cls=2.61, proj_loss=-0.0103][2026-03-23 13:56:16] Step: 3, Training Logs: loss_final: 4.829489, loss_mean: 1.666703, proj_loss: -0.019215, loss_mean_cls: 3.182001, grad_norm: 1.115940
5
+ Steps: 0%| | 4/1000000 [00:05<265:36:26, 1.05it/s, grad_norm=1.12, loss_final=4.83, loss_mean=1.67, loss_mean_cls=3.18, proj_loss=-0.0192][2026-03-23 13:56:16] Step: 4, Training Logs: loss_final: 4.838729, loss_mean: 1.683388, proj_loss: -0.026348, loss_mean_cls: 3.181689, grad_norm: 0.753645
6
+ Steps: 0%| | 5/1000000 [00:05<194:13:56, 1.43it/s, grad_norm=0.754, loss_final=4.84, loss_mean=1.68, loss_mean_cls=3.18, proj_loss=-0.0263][2026-03-23 13:56:17] Step: 5, Training Logs: loss_final: 4.434019, loss_mean: 1.678989, proj_loss: -0.034618, loss_mean_cls: 2.789647, grad_norm: 0.828005
7
+ Steps: 0%| | 6/1000000 [00:05<151:13:33, 1.84it/s, grad_norm=0.828, loss_final=4.43, loss_mean=1.68, loss_mean_cls=2.79, proj_loss=-0.0346][2026-03-23 13:56:17] Step: 6, Training Logs: loss_final: 4.717834, loss_mean: 1.685299, proj_loss: -0.039478, loss_mean_cls: 3.072013, grad_norm: 0.944329
8
+ Steps: 0%| | 7/1000000 [00:05<123:56:37, 2.24it/s, grad_norm=0.944, loss_final=4.72, loss_mean=1.69, loss_mean_cls=3.07, proj_loss=-0.0395][2026-03-23 13:56:17] Step: 7, Training Logs: loss_final: 4.768410, loss_mean: 1.687585, proj_loss: -0.042706, loss_mean_cls: 3.123530, grad_norm: 0.827338
9
+ Steps: 0%| | 8/1000000 [00:06<106:03:40, 2.62it/s, grad_norm=0.827, loss_final=4.77, loss_mean=1.69, loss_mean_cls=3.12, proj_loss=-0.0427][2026-03-23 13:56:17] Step: 8, Training Logs: loss_final: 4.815378, loss_mean: 1.655840, proj_loss: -0.044987, loss_mean_cls: 3.204525, grad_norm: 0.854108
10
+ Steps: 0%| | 9/1000000 [00:06<94:05:27, 2.95it/s, grad_norm=0.854, loss_final=4.82, loss_mean=1.66, loss_mean_cls=3.2, proj_loss=-0.045] [2026-03-23 13:56:18] Step: 9, Training Logs: loss_final: 4.777181, loss_mean: 1.657609, proj_loss: -0.046041, loss_mean_cls: 3.165613, grad_norm: 0.939216
11
+ Steps: 0%| | 10/1000000 [00:06<86:16:44, 3.22it/s, grad_norm=0.939, loss_final=4.78, loss_mean=1.66, loss_mean_cls=3.17, proj_loss=-0.046][2026-03-23 13:56:18] Step: 10, Training Logs: loss_final: 4.966327, loss_mean: 1.662955, proj_loss: -0.047787, loss_mean_cls: 3.351158, grad_norm: 0.991139
12
+ Steps: 0%| | 11/1000000 [00:06<80:53:45, 3.43it/s, grad_norm=0.991, loss_final=4.97, loss_mean=1.66, loss_mean_cls=3.35, proj_loss=-0.0478][2026-03-23 13:56:18] Step: 11, Training Logs: loss_final: 5.292466, loss_mean: 1.650309, proj_loss: -0.049361, loss_mean_cls: 3.691518, grad_norm: 0.968609
13
+ Steps: 0%| | 12/1000000 [00:07<76:53:52, 3.61it/s, grad_norm=0.969, loss_final=5.29, loss_mean=1.65, loss_mean_cls=3.69, proj_loss=-0.0494][2026-03-23 13:56:18] Step: 12, Training Logs: loss_final: 5.088713, loss_mean: 1.626424, proj_loss: -0.049838, loss_mean_cls: 3.512127, grad_norm: 1.162980
14
+ Steps: 0%| | 13/1000000 [00:07<74:06:27, 3.75it/s, grad_norm=1.16, loss_final=5.09, loss_mean=1.63, loss_mean_cls=3.51, proj_loss=-0.0498][2026-03-23 13:56:19] Step: 13, Training Logs: loss_final: 4.928039, loss_mean: 1.623318, proj_loss: -0.051428, loss_mean_cls: 3.356148, grad_norm: 1.370081
15
+ Steps: 0%| | 14/1000000 [00:07<72:19:32, 3.84it/s, grad_norm=1.37, loss_final=4.93, loss_mean=1.62, loss_mean_cls=3.36, proj_loss=-0.0514][2026-03-23 13:56:19] Step: 14, Training Logs: loss_final: 4.342262, loss_mean: 1.599151, proj_loss: -0.051557, loss_mean_cls: 2.794668, grad_norm: 1.247963
16
+ Steps: 0%| | 15/1000000 [00:07<70:56:55, 3.92it/s, grad_norm=1.25, loss_final=4.34, loss_mean=1.6, loss_mean_cls=2.79, proj_loss=-0.0516][2026-03-23 13:56:19] Step: 15, Training Logs: loss_final: 5.220107, loss_mean: 1.589133, proj_loss: -0.054820, loss_mean_cls: 3.685793, grad_norm: 1.207489
17
+ Steps: 0%| | 16/1000000 [00:08<69:57:27, 3.97it/s, grad_norm=1.21, loss_final=5.22, loss_mean=1.59, loss_mean_cls=3.69, proj_loss=-0.0548][2026-03-23 13:56:19] Step: 16, Training Logs: loss_final: 4.653113, loss_mean: 1.599458, proj_loss: -0.052329, loss_mean_cls: 3.105984, grad_norm: 1.168053
18
+ Steps: 0%| | 17/1000000 [00:08<69:18:40, 4.01it/s, grad_norm=1.17, loss_final=4.65, loss_mean=1.6, loss_mean_cls=3.11, proj_loss=-0.0523][2026-03-23 13:56:20] Step: 17, Training Logs: loss_final: 4.984623, loss_mean: 1.639365, proj_loss: -0.051690, loss_mean_cls: 3.396948, grad_norm: 2.396022
19
+ Steps: 0%| | 18/1000000 [00:08<69:00:10, 4.03it/s, grad_norm=2.4, loss_final=4.98, loss_mean=1.64, loss_mean_cls=3.4, proj_loss=-0.0517][2026-03-23 13:56:20] Step: 18, Training Logs: loss_final: 4.356441, loss_mean: 1.560717, proj_loss: -0.054883, loss_mean_cls: 2.850607, grad_norm: 0.931938
20
+ Steps: 0%| | 19/1000000 [00:08<68:37:45, 4.05it/s, grad_norm=0.932, loss_final=4.36, loss_mean=1.56, loss_mean_cls=2.85, proj_loss=-0.0549][2026-03-23 13:56:20] Step: 19, Training Logs: loss_final: 4.922860, loss_mean: 1.625848, proj_loss: -0.055209, loss_mean_cls: 3.352221, grad_norm: 2.649956
21
+ Steps: 0%| | 20/1000000 [00:09<68:31:08, 4.05it/s, grad_norm=2.65, loss_final=4.92, loss_mean=1.63, loss_mean_cls=3.35, proj_loss=-0.0552][2026-03-23 13:56:20] Step: 20, Training Logs: loss_final: 4.693020, loss_mean: 1.635111, proj_loss: -0.053463, loss_mean_cls: 3.111372, grad_norm: 2.886806
22
+ Steps: 0%| | 21/1000000 [00:09<68:17:30, 4.07it/s, grad_norm=2.89, loss_final=4.69, loss_mean=1.64, loss_mean_cls=3.11, proj_loss=-0.0535][2026-03-23 13:56:21] Step: 21, Training Logs: loss_final: 4.847960, loss_mean: 1.608342, proj_loss: -0.053926, loss_mean_cls: 3.293545, grad_norm: 1.804076
23
+ Steps: 0%| | 22/1000000 [00:09<68:07:59, 4.08it/s, grad_norm=1.8, loss_final=4.85, loss_mean=1.61, loss_mean_cls=3.29, proj_loss=-0.0539][2026-03-23 13:56:21] Step: 22, Training Logs: loss_final: 4.531857, loss_mean: 1.556872, proj_loss: -0.053500, loss_mean_cls: 3.028484, grad_norm: 1.136153
24
+ Steps: 0%| | 23/1000000 [00:09<68:01:00, 4.08it/s, grad_norm=1.14, loss_final=4.53, loss_mean=1.56, loss_mean_cls=3.03, proj_loss=-0.0535][2026-03-23 13:56:21] Step: 23, Training Logs: loss_final: 4.347858, loss_mean: 1.571953, proj_loss: -0.051053, loss_mean_cls: 2.826958, grad_norm: 1.379415
25
+ Steps: 0%| | 24/1000000 [00:10<67:55:02, 4.09it/s, grad_norm=1.38, loss_final=4.35, loss_mean=1.57, loss_mean_cls=2.83, proj_loss=-0.0511][2026-03-23 13:56:21] Step: 24, Training Logs: loss_final: 4.812301, loss_mean: 1.573597, proj_loss: -0.054251, loss_mean_cls: 3.292954, grad_norm: 1.536997
26
+ Steps: 0%| | 25/1000000 [00:10<67:50:56, 4.09it/s, grad_norm=1.54, loss_final=4.81, loss_mean=1.57, loss_mean_cls=3.29, proj_loss=-0.0543][2026-03-23 13:56:22] Step: 25, Training Logs: loss_final: 5.140490, loss_mean: 1.572134, proj_loss: -0.055976, loss_mean_cls: 3.624332, grad_norm: 1.501310
27
+ Steps: 0%| | 26/1000000 [00:10<67:49:32, 4.10it/s, grad_norm=1.5, loss_final=5.14, loss_mean=1.57, loss_mean_cls=3.62, proj_loss=-0.056][2026-03-23 13:56:22] Step: 26, Training Logs: loss_final: 4.614564, loss_mean: 1.575298, proj_loss: -0.055522, loss_mean_cls: 3.094787, grad_norm: 1.374379
28
+ Steps: 0%| | 27/1000000 [00:10<67:48:51, 4.10it/s, grad_norm=1.37, loss_final=4.61, loss_mean=1.58, loss_mean_cls=3.09, proj_loss=-0.0555][2026-03-23 13:56:22] Step: 27, Training Logs: loss_final: 4.428196, loss_mean: 1.559373, proj_loss: -0.054907, loss_mean_cls: 2.923730, grad_norm: 1.247505
29
+ Steps: 0%| | 28/1000000 [00:11<67:51:05, 4.09it/s, grad_norm=1.25, loss_final=4.43, loss_mean=1.56, loss_mean_cls=2.92, proj_loss=-0.0549][2026-03-23 13:56:22] Step: 28, Training Logs: loss_final: 4.514633, loss_mean: 1.561718, proj_loss: -0.054182, loss_mean_cls: 3.007097, grad_norm: 1.272959
30
+ Steps: 0%| | 29/1000000 [00:11<67:50:57, 4.09it/s, grad_norm=1.27, loss_final=4.51, loss_mean=1.56, loss_mean_cls=3.01, proj_loss=-0.0542][2026-03-23 13:56:23] Step: 29, Training Logs: loss_final: 4.064789, loss_mean: 1.520746, proj_loss: -0.055004, loss_mean_cls: 2.599047, grad_norm: 1.213601
31
+ Steps: 0%| | 30/1000000 [00:11<67:49:21, 4.10it/s, grad_norm=1.21, loss_final=4.06, loss_mean=1.52, loss_mean_cls=2.6, proj_loss=-0.055][2026-03-23 13:56:23] Step: 30, Training Logs: loss_final: 4.343926, loss_mean: 1.523063, proj_loss: -0.055674, loss_mean_cls: 2.876538, grad_norm: 1.148791
32
+ Steps: 0%| | 31/1000000 [00:11<67:50:28, 4.09it/s, grad_norm=1.15, loss_final=4.34, loss_mean=1.52, loss_mean_cls=2.88, proj_loss=-0.0557][2026-03-23 13:56:23] Step: 31, Training Logs: loss_final: 4.943371, loss_mean: 1.500860, proj_loss: -0.056089, loss_mean_cls: 3.498600, grad_norm: 1.126629
33
+ Steps: 0%| | 32/1000000 [00:12<67:51:49, 4.09it/s, grad_norm=1.13, loss_final=4.94, loss_mean=1.5, loss_mean_cls=3.5, proj_loss=-0.0561][2026-03-23 13:56:23] Step: 32, Training Logs: loss_final: 4.913333, loss_mean: 1.487980, proj_loss: -0.053908, loss_mean_cls: 3.479261, grad_norm: 1.111562
34
+ Steps: 0%| | 33/1000000 [00:12<67:49:26, 4.10it/s, grad_norm=1.11, loss_final=4.91, loss_mean=1.49, loss_mean_cls=3.48, proj_loss=-0.0539][2026-03-23 13:56:24] Step: 33, Training Logs: loss_final: 5.135122, loss_mean: 1.485391, proj_loss: -0.058625, loss_mean_cls: 3.708356, grad_norm: 1.166992
35
+ Steps: 0%| | 34/1000000 [00:12<67:49:31, 4.10it/s, grad_norm=1.17, loss_final=5.14, loss_mean=1.49, loss_mean_cls=3.71, proj_loss=-0.0586][2026-03-23 13:56:24] Step: 34, Training Logs: loss_final: 4.094241, loss_mean: 1.475285, proj_loss: -0.056888, loss_mean_cls: 2.675844, grad_norm: 1.037110
36
+ Steps: 0%| | 35/1000000 [00:12<67:47:41, 4.10it/s, grad_norm=1.04, loss_final=4.09, loss_mean=1.48, loss_mean_cls=2.68, proj_loss=-0.0569][2026-03-23 13:56:24] Step: 35, Training Logs: loss_final: 4.305937, loss_mean: 1.460218, proj_loss: -0.052705, loss_mean_cls: 2.898424, grad_norm: 1.091828
37
+ Steps: 0%| | 36/1000000 [00:13<67:50:00, 4.09it/s, grad_norm=1.09, loss_final=4.31, loss_mean=1.46, loss_mean_cls=2.9, proj_loss=-0.0527][2026-03-23 13:56:24] Step: 36, Training Logs: loss_final: 4.382219, loss_mean: 1.466251, proj_loss: -0.054344, loss_mean_cls: 2.970313, grad_norm: 1.093016
38
+ Steps: 0%| | 37/1000000 [00:13<68:14:38, 4.07it/s, grad_norm=1.09, loss_final=4.38, loss_mean=1.47, loss_mean_cls=2.97, proj_loss=-0.0543][2026-03-23 13:56:25] Step: 37, Training Logs: loss_final: 4.428336, loss_mean: 1.469206, proj_loss: -0.057364, loss_mean_cls: 3.016495, grad_norm: 1.468433
39
+ Steps: 0%| | 38/1000000 [00:13<68:07:38, 4.08it/s, grad_norm=1.47, loss_final=4.43, loss_mean=1.47, loss_mean_cls=3.02, proj_loss=-0.0574][2026-03-23 13:56:25] Step: 38, Training Logs: loss_final: 4.412411, loss_mean: 1.453988, proj_loss: -0.056726, loss_mean_cls: 3.015149, grad_norm: 0.825492
40
+ Steps: 0%| | 39/1000000 [00:13<68:01:05, 4.08it/s, grad_norm=0.825, loss_final=4.41, loss_mean=1.45, loss_mean_cls=3.02, proj_loss=-0.0567][2026-03-23 13:56:25] Step: 39, Training Logs: loss_final: 4.554020, loss_mean: 1.450588, proj_loss: -0.055717, loss_mean_cls: 3.159149, grad_norm: 1.281297
41
+ Steps: 0%| | 40/1000000 [00:14<68:02:46, 4.08it/s, grad_norm=1.28, loss_final=4.55, loss_mean=1.45, loss_mean_cls=3.16, proj_loss=-0.0557][2026-03-23 13:56:25] Step: 40, Training Logs: loss_final: 4.900630, loss_mean: 1.421538, proj_loss: -0.056336, loss_mean_cls: 3.535427, grad_norm: 1.096009
42
+ Steps: 0%| | 41/1000000 [00:14<67:56:56, 4.09it/s, grad_norm=1.1, loss_final=4.9, loss_mean=1.42, loss_mean_cls=3.54, proj_loss=-0.0563][2026-03-23 13:56:26] Step: 41, Training Logs: loss_final: 5.000130, loss_mean: 1.418576, proj_loss: -0.056068, loss_mean_cls: 3.637622, grad_norm: 1.111790
43
+ Steps: 0%| | 42/1000000 [00:14<67:53:46, 4.09it/s, grad_norm=1.11, loss_final=5, loss_mean=1.42, loss_mean_cls=3.64, proj_loss=-0.0561][2026-03-23 13:56:26] Step: 42, Training Logs: loss_final: 4.223974, loss_mean: 1.442425, proj_loss: -0.054436, loss_mean_cls: 2.835985, grad_norm: 1.223004
44
+ Steps: 0%| | 43/1000000 [00:14<67:51:37, 4.09it/s, grad_norm=1.22, loss_final=4.22, loss_mean=1.44, loss_mean_cls=2.84, proj_loss=-0.0544][2026-03-23 13:56:26] Step: 43, Training Logs: loss_final: 5.023710, loss_mean: 1.441718, proj_loss: -0.055548, loss_mean_cls: 3.637539, grad_norm: 1.391159
45
+ Steps: 0%| | 44/1000000 [00:15<67:53:09, 4.09it/s, grad_norm=1.39, loss_final=5.02, loss_mean=1.44, loss_mean_cls=3.64, proj_loss=-0.0555][2026-03-23 13:56:26] Step: 44, Training Logs: loss_final: 5.000957, loss_mean: 1.414230, proj_loss: -0.055362, loss_mean_cls: 3.642089, grad_norm: 1.162115
46
+ Steps: 0%| | 45/1000000 [00:15<67:51:51, 4.09it/s, grad_norm=1.16, loss_final=5, loss_mean=1.41, loss_mean_cls=3.64, proj_loss=-0.0554][2026-03-23 13:56:27] Step: 45, Training Logs: loss_final: 4.689414, loss_mean: 1.389372, proj_loss: -0.054793, loss_mean_cls: 3.354835, grad_norm: 0.814496
47
+ Steps: 0%| | 46/1000000 [00:15<67:51:06, 4.09it/s, grad_norm=0.814, loss_final=4.69, loss_mean=1.39, loss_mean_cls=3.35, proj_loss=-0.0548][2026-03-23 13:56:27] Step: 46, Training Logs: loss_final: 4.452005, loss_mean: 1.403370, proj_loss: -0.056346, loss_mean_cls: 3.104981, grad_norm: 1.062373
48
+ Steps: 0%| | 47/1000000 [00:15<67:50:59, 4.09it/s, grad_norm=1.06, loss_final=4.45, loss_mean=1.4, loss_mean_cls=3.1, proj_loss=-0.0563][2026-03-23 13:56:27] Step: 47, Training Logs: loss_final: 4.638161, loss_mean: 1.413495, proj_loss: -0.055944, loss_mean_cls: 3.280609, grad_norm: 0.873316
49
+ Steps: 0%| | 48/1000000 [00:15<67:51:08, 4.09it/s, grad_norm=0.873, loss_final=4.64, loss_mean=1.41, loss_mean_cls=3.28, proj_loss=-0.0559][2026-03-23 13:56:27] Step: 48, Training Logs: loss_final: 4.429680, loss_mean: 1.411911, proj_loss: -0.054931, loss_mean_cls: 3.072701, grad_norm: 0.759066
50
+ Steps: 0%| | 49/1000000 [00:16<67:51:48, 4.09it/s, grad_norm=0.759, loss_final=4.43, loss_mean=1.41, loss_mean_cls=3.07, proj_loss=-0.0549][2026-03-23 13:56:27] Step: 49, Training Logs: loss_final: 4.901843, loss_mean: 1.388557, proj_loss: -0.055650, loss_mean_cls: 3.568936, grad_norm: 0.750354
51
+ Steps: 0%| | 50/1000000 [00:16<67:49:39, 4.10it/s, grad_norm=0.75, loss_final=4.9, loss_mean=1.39, loss_mean_cls=3.57, proj_loss=-0.0557][2026-03-23 13:56:28] Step: 50, Training Logs: loss_final: 4.100002, loss_mean: 1.412004, proj_loss: -0.052795, loss_mean_cls: 2.740793, grad_norm: 0.767011
52
+ Steps: 0%| | 51/1000000 [00:16<67:51:25, 4.09it/s, grad_norm=0.767, loss_final=4.1, loss_mean=1.41, loss_mean_cls=2.74, proj_loss=-0.0528][2026-03-23 13:56:28] Step: 51, Training Logs: loss_final: 4.759223, loss_mean: 1.391587, proj_loss: -0.054960, loss_mean_cls: 3.422596, grad_norm: 0.719157
53
+ Steps: 0%| | 52/1000000 [00:16<67:51:11, 4.09it/s, grad_norm=0.719, loss_final=4.76, loss_mean=1.39, loss_mean_cls=3.42, proj_loss=-0.055][2026-03-23 13:56:28] Step: 52, Training Logs: loss_final: 4.482382, loss_mean: 1.395736, proj_loss: -0.056075, loss_mean_cls: 3.142721, grad_norm: 0.726095
54
+ Steps: 0%| | 53/1000000 [00:17<67:50:55, 4.09it/s, grad_norm=0.726, loss_final=4.48, loss_mean=1.4, loss_mean_cls=3.14, proj_loss=-0.0561][2026-03-23 13:56:28] Step: 53, Training Logs: loss_final: 4.466825, loss_mean: 1.373211, proj_loss: -0.055033, loss_mean_cls: 3.148648, grad_norm: 0.663936
55
+ Steps: 0%| | 54/1000000 [00:17<67:49:43, 4.10it/s, grad_norm=0.664, loss_final=4.47, loss_mean=1.37, loss_mean_cls=3.15, proj_loss=-0.055][2026-03-23 13:56:29] Step: 54, Training Logs: loss_final: 4.687413, loss_mean: 1.367031, proj_loss: -0.056856, loss_mean_cls: 3.377239, grad_norm: 0.706146
56
+ Steps: 0%| | 55/1000000 [00:17<67:47:50, 4.10it/s, grad_norm=0.706, loss_final=4.69, loss_mean=1.37, loss_mean_cls=3.38, proj_loss=-0.0569][2026-03-23 13:56:29] Step: 55, Training Logs: loss_final: 4.724741, loss_mean: 1.355043, proj_loss: -0.056556, loss_mean_cls: 3.426254, grad_norm: 0.891681
57
+ Steps: 0%| | 56/1000000 [00:17<67:49:45, 4.10it/s, grad_norm=0.892, loss_final=4.72, loss_mean=1.36, loss_mean_cls=3.43, proj_loss=-0.0566][2026-03-23 13:56:29] Step: 56, Training Logs: loss_final: 4.598913, loss_mean: 1.357820, proj_loss: -0.053707, loss_mean_cls: 3.294799, grad_norm: 0.688579
58
+ Steps: 0%| | 57/1000000 [00:18<67:48:24, 4.10it/s, grad_norm=0.689, loss_final=4.6, loss_mean=1.36, loss_mean_cls=3.29, proj_loss=-0.0537][2026-03-23 13:56:29] Step: 57, Training Logs: loss_final: 4.647234, loss_mean: 1.346388, proj_loss: -0.055346, loss_mean_cls: 3.356192, grad_norm: 0.643795
59
+ Steps: 0%| | 58/1000000 [00:18<67:48:01, 4.10it/s, grad_norm=0.644, loss_final=4.65, loss_mean=1.35, loss_mean_cls=3.36, proj_loss=-0.0553][2026-03-23 13:56:30] Step: 58, Training Logs: loss_final: 4.629024, loss_mean: 1.361648, proj_loss: -0.056354, loss_mean_cls: 3.323730, grad_norm: 0.640788
60
+ Steps: 0%| | 59/1000000 [00:18<67:48:29, 4.10it/s, grad_norm=0.641, loss_final=4.63, loss_mean=1.36, loss_mean_cls=3.32, proj_loss=-0.0564][2026-03-23 13:56:30] Step: 59, Training Logs: loss_final: 4.318052, loss_mean: 1.352944, proj_loss: -0.054974, loss_mean_cls: 3.020082, grad_norm: 1.103365
61
+ Steps: 0%| | 60/1000000 [00:18<67:52:13, 4.09it/s, grad_norm=1.1, loss_final=4.32, loss_mean=1.35, loss_mean_cls=3.02, proj_loss=-0.055][2026-03-23 13:56:30] Step: 60, Training Logs: loss_final: 4.558999, loss_mean: 1.356349, proj_loss: -0.056304, loss_mean_cls: 3.258954, grad_norm: 2.178408
62
+ Steps: 0%| | 61/1000000 [00:19<67:50:06, 4.09it/s, grad_norm=2.18, loss_final=4.56, loss_mean=1.36, loss_mean_cls=3.26, proj_loss=-0.0563][2026-03-23 13:56:30] Step: 61, Training Logs: loss_final: 4.405947, loss_mean: 1.345089, proj_loss: -0.054088, loss_mean_cls: 3.114946, grad_norm: 0.833812
63
+ Steps: 0%| | 62/1000000 [00:19<67:49:48, 4.09it/s, grad_norm=0.834, loss_final=4.41, loss_mean=1.35, loss_mean_cls=3.11, proj_loss=-0.0541][2026-03-23 13:56:31] Step: 62, Training Logs: loss_final: 4.248630, loss_mean: 1.369897, proj_loss: -0.055107, loss_mean_cls: 2.933840, grad_norm: 2.604588
64
+ Steps: 0%| | 63/1000000 [00:19<67:49:04, 4.10it/s, grad_norm=2.6, loss_final=4.25, loss_mean=1.37, loss_mean_cls=2.93, proj_loss=-0.0551][2026-03-23 13:56:31] Step: 63, Training Logs: loss_final: 4.370800, loss_mean: 1.356826, proj_loss: -0.054216, loss_mean_cls: 3.068191, grad_norm: 1.995899
65
+ Steps: 0%| | 64/1000000 [00:19<67:52:12, 4.09it/s, grad_norm=2, loss_final=4.37, loss_mean=1.36, loss_mean_cls=3.07, proj_loss=-0.0542][2026-03-23 13:56:31] Step: 64, Training Logs: loss_final: 3.620857, loss_mean: 1.371637, proj_loss: -0.056188, loss_mean_cls: 2.305408, grad_norm: 1.270011
66
+ Steps: 0%| | 65/1000000 [00:20<67:52:26, 4.09it/s, grad_norm=1.27, loss_final=3.62, loss_mean=1.37, loss_mean_cls=2.31, proj_loss=-0.0562][2026-03-23 13:56:31] Step: 65, Training Logs: loss_final: 3.841930, loss_mean: 1.345907, proj_loss: -0.057290, loss_mean_cls: 2.553313, grad_norm: 1.103519
67
+ Steps: 0%| | 66/1000000 [00:20<67:50:32, 4.09it/s, grad_norm=1.1, loss_final=3.84, loss_mean=1.35, loss_mean_cls=2.55, proj_loss=-0.0573][2026-03-23 13:56:32] Step: 66, Training Logs: loss_final: 4.823415, loss_mean: 1.309009, proj_loss: -0.057261, loss_mean_cls: 3.571667, grad_norm: 1.146726
68
+ Steps: 0%| | 67/1000000 [00:20<67:51:11, 4.09it/s, grad_norm=1.15, loss_final=4.82, loss_mean=1.31, loss_mean_cls=3.57, proj_loss=-0.0573][2026-03-23 13:56:32] Step: 67, Training Logs: loss_final: 4.654735, loss_mean: 1.315030, proj_loss: -0.055812, loss_mean_cls: 3.395517, grad_norm: 0.900543
69
+ Steps: 0%| | 68/1000000 [00:20<67:53:43, 4.09it/s, grad_norm=0.901, loss_final=4.65, loss_mean=1.32, loss_mean_cls=3.4, proj_loss=-0.0558][2026-03-23 13:56:32] Step: 68, Training Logs: loss_final: 3.865313, loss_mean: 1.311920, proj_loss: -0.055171, loss_mean_cls: 2.608564, grad_norm: 0.968363
70
+ Steps: 0%| | 69/1000000 [00:21<67:51:01, 4.09it/s, grad_norm=0.968, loss_final=3.87, loss_mean=1.31, loss_mean_cls=2.61, proj_loss=-0.0552][2026-03-23 13:56:32] Step: 69, Training Logs: loss_final: 4.332185, loss_mean: 1.287292, proj_loss: -0.055524, loss_mean_cls: 3.100417, grad_norm: 0.992981
71
+ Steps: 0%| | 70/1000000 [00:21<67:50:34, 4.09it/s, grad_norm=0.993, loss_final=4.33, loss_mean=1.29, loss_mean_cls=3.1, proj_loss=-0.0555][2026-03-23 13:56:33] Step: 70, Training Logs: loss_final: 5.214293, loss_mean: 1.271327, proj_loss: -0.056000, loss_mean_cls: 3.998966, grad_norm: 0.950980
72
+ Steps: 0%| | 71/1000000 [00:21<67:51:17, 4.09it/s, grad_norm=0.951, loss_final=5.21, loss_mean=1.27, loss_mean_cls=4, proj_loss=-0.056][2026-03-23 13:56:33] Step: 71, Training Logs: loss_final: 4.271416, loss_mean: 1.298077, proj_loss: -0.058767, loss_mean_cls: 3.032106, grad_norm: 0.897326
73
+ Steps: 0%| | 72/1000000 [00:21<67:53:03, 4.09it/s, grad_norm=0.897, loss_final=4.27, loss_mean=1.3, loss_mean_cls=3.03, proj_loss=-0.0588][2026-03-23 13:56:33] Step: 72, Training Logs: loss_final: 4.944882, loss_mean: 1.275258, proj_loss: -0.054986, loss_mean_cls: 3.724611, grad_norm: 1.406698
74
+ Steps: 0%| | 73/1000000 [00:22<67:51:07, 4.09it/s, grad_norm=1.41, loss_final=4.94, loss_mean=1.28, loss_mean_cls=3.72, proj_loss=-0.055][2026-03-23 13:56:33] Step: 73, Training Logs: loss_final: 4.262454, loss_mean: 1.391114, proj_loss: -0.054645, loss_mean_cls: 2.925985, grad_norm: 4.082608
75
+ Steps: 0%| | 74/1000000 [00:22<67:49:45, 4.09it/s, grad_norm=4.08, loss_final=4.26, loss_mean=1.39, loss_mean_cls=2.93, proj_loss=-0.0546][2026-03-23 13:56:34] Step: 74, Training Logs: loss_final: 4.441927, loss_mean: 1.368322, proj_loss: -0.056059, loss_mean_cls: 3.129665, grad_norm: 3.740271
76
+ Steps: 0%| | 75/1000000 [00:22<67:48:41, 4.10it/s, grad_norm=3.74, loss_final=4.44, loss_mean=1.37, loss_mean_cls=3.13, proj_loss=-0.0561][2026-03-23 13:56:34] Step: 75, Training Logs: loss_final: 4.786633, loss_mean: 1.323001, proj_loss: -0.054923, loss_mean_cls: 3.518555, grad_norm: 1.360202
77
+ Steps: 0%| | 76/1000000 [00:22<67:47:44, 4.10it/s, grad_norm=1.36, loss_final=4.79, loss_mean=1.32, loss_mean_cls=3.52, proj_loss=-0.0549][2026-03-23 13:56:34] Step: 76, Training Logs: loss_final: 4.880804, loss_mean: 1.312154, proj_loss: -0.058812, loss_mean_cls: 3.627462, grad_norm: 2.462873
78
+ Steps: 0%| | 77/1000000 [00:23<67:48:38, 4.10it/s, grad_norm=2.46, loss_final=4.88, loss_mean=1.31, loss_mean_cls=3.63, proj_loss=-0.0588][2026-03-23 13:56:34] Step: 77, Training Logs: loss_final: 4.347323, loss_mean: 1.277282, proj_loss: -0.056339, loss_mean_cls: 3.126381, grad_norm: 1.089468
79
+ Steps: 0%| | 78/1000000 [00:23<67:47:33, 4.10it/s, grad_norm=1.09, loss_final=4.35, loss_mean=1.28, loss_mean_cls=3.13, proj_loss=-0.0563][2026-03-23 13:56:35] Step: 78, Training Logs: loss_final: 4.452549, loss_mean: 1.308045, proj_loss: -0.057751, loss_mean_cls: 3.202255, grad_norm: 1.409237
80
+ Steps: 0%| | 79/1000000 [00:23<67:48:01, 4.10it/s, grad_norm=1.41, loss_final=4.45, loss_mean=1.31, loss_mean_cls=3.2, proj_loss=-0.0578][2026-03-23 13:56:35] Step: 79, Training Logs: loss_final: 4.300065, loss_mean: 1.256680, proj_loss: -0.057729, loss_mean_cls: 3.101114, grad_norm: 1.063736
81
+ Steps: 0%| | 80/1000000 [00:23<67:46:24, 4.10it/s, grad_norm=1.06, loss_final=4.3, loss_mean=1.26, loss_mean_cls=3.1, proj_loss=-0.0577][2026-03-23 13:56:35] Step: 80, Training Logs: loss_final: 5.255699, loss_mean: 1.230699, proj_loss: -0.055231, loss_mean_cls: 4.080230, grad_norm: 1.159766
82
+ Steps: 0%| | 81/1000000 [00:24<67:45:22, 4.10it/s, grad_norm=1.16, loss_final=5.26, loss_mean=1.23, loss_mean_cls=4.08, proj_loss=-0.0552][2026-03-23 13:56:35] Step: 81, Training Logs: loss_final: 4.184104, loss_mean: 1.256977, proj_loss: -0.057558, loss_mean_cls: 2.984685, grad_norm: 1.074197
83
+ Steps: 0%| | 82/1000000 [00:24<67:46:48, 4.10it/s, grad_norm=1.07, loss_final=4.18, loss_mean=1.26, loss_mean_cls=2.98, proj_loss=-0.0576][2026-03-23 13:56:36] Step: 82, Training Logs: loss_final: 4.368578, loss_mean: 1.246449, proj_loss: -0.055964, loss_mean_cls: 3.178093, grad_norm: 1.212082
84
+ Steps: 0%| | 83/1000000 [00:24<67:46:42, 4.10it/s, grad_norm=1.21, loss_final=4.37, loss_mean=1.25, loss_mean_cls=3.18, proj_loss=-0.056][2026-03-23 13:56:36] Step: 83, Training Logs: loss_final: 4.066692, loss_mean: 1.279943, proj_loss: -0.057100, loss_mean_cls: 2.843850, grad_norm: 1.081535
85
+ Steps: 0%| | 84/1000000 [00:24<67:45:25, 4.10it/s, grad_norm=1.08, loss_final=4.07, loss_mean=1.28, loss_mean_cls=2.84, proj_loss=-0.0571][2026-03-23 13:56:36] Step: 84, Training Logs: loss_final: 4.419612, loss_mean: 1.222443, proj_loss: -0.056123, loss_mean_cls: 3.253291, grad_norm: 1.017672
86
+ Steps: 0%| | 85/1000000 [00:25<67:46:48, 4.10it/s, grad_norm=1.02, loss_final=4.42, loss_mean=1.22, loss_mean_cls=3.25, proj_loss=-0.0561][2026-03-23 13:56:36] Step: 85, Training Logs: loss_final: 4.330603, loss_mean: 1.217940, proj_loss: -0.055811, loss_mean_cls: 3.168474, grad_norm: 1.048292
87
+ Steps: 0%| | 86/1000000 [00:25<67:46:18, 4.10it/s, grad_norm=1.05, loss_final=4.33, loss_mean=1.22, loss_mean_cls=3.17, proj_loss=-0.0558][2026-03-23 13:56:37] Step: 86, Training Logs: loss_final: 4.073160, loss_mean: 1.261520, proj_loss: -0.057447, loss_mean_cls: 2.869087, grad_norm: 2.760909
88
+ Steps: 0%| | 87/1000000 [00:25<67:46:56, 4.10it/s, grad_norm=2.76, loss_final=4.07, loss_mean=1.26, loss_mean_cls=2.87, proj_loss=-0.0574][2026-03-23 13:56:37] Step: 87, Training Logs: loss_final: 4.349278, loss_mean: 1.225316, proj_loss: -0.055978, loss_mean_cls: 3.179941, grad_norm: 1.415633
89
+ Steps: 0%| | 88/1000000 [00:25<67:48:19, 4.10it/s, grad_norm=1.42, loss_final=4.35, loss_mean=1.23, loss_mean_cls=3.18, proj_loss=-0.056][2026-03-23 13:56:37] Step: 88, Training Logs: loss_final: 4.191838, loss_mean: 1.244159, proj_loss: -0.053158, loss_mean_cls: 3.000837, grad_norm: 2.142123
90
+ Steps: 0%| | 89/1000000 [00:26<67:48:47, 4.10it/s, grad_norm=2.14, loss_final=4.19, loss_mean=1.24, loss_mean_cls=3, proj_loss=-0.0532][2026-03-23 13:56:37] Step: 89, Training Logs: loss_final: 3.810838, loss_mean: 1.241034, proj_loss: -0.056567, loss_mean_cls: 2.626369, grad_norm: 1.977536
91
+ Steps: 0%| | 90/1000000 [00:26<67:49:26, 4.10it/s, grad_norm=1.98, loss_final=3.81, loss_mean=1.24, loss_mean_cls=2.63, proj_loss=-0.0566][2026-03-23 13:56:37] Step: 90, Training Logs: loss_final: 4.302361, loss_mean: 1.235736, proj_loss: -0.055757, loss_mean_cls: 3.122383, grad_norm: 1.465364
92
+ Steps: 0%| | 91/1000000 [00:26<67:49:17, 4.10it/s, grad_norm=1.47, loss_final=4.3, loss_mean=1.24, loss_mean_cls=3.12, proj_loss=-0.0558][2026-03-23 13:56:38] Step: 91, Training Logs: loss_final: 4.793566, loss_mean: 1.295077, proj_loss: -0.054678, loss_mean_cls: 3.553167, grad_norm: 4.777657
93
+ Steps: 0%| | 92/1000000 [00:26<67:50:23, 4.09it/s, grad_norm=4.78, loss_final=4.79, loss_mean=1.3, loss_mean_cls=3.55, proj_loss=-0.0547][2026-03-23 13:56:38] Step: 92, Training Logs: loss_final: 4.318158, loss_mean: 1.268673, proj_loss: -0.056761, loss_mean_cls: 3.106246, grad_norm: 3.593853
94
+ Steps: 0%| | 93/1000000 [00:26<67:49:16, 4.10it/s, grad_norm=3.59, loss_final=4.32, loss_mean=1.27, loss_mean_cls=3.11, proj_loss=-0.0568][2026-03-23 13:56:38] Step: 93, Training Logs: loss_final: 4.038731, loss_mean: 1.237636, proj_loss: -0.058346, loss_mean_cls: 2.859441, grad_norm: 2.444315
95
+ Steps: 0%| | 94/1000000 [00:27<67:49:30, 4.10it/s, grad_norm=2.44, loss_final=4.04, loss_mean=1.24, loss_mean_cls=2.86, proj_loss=-0.0583][2026-03-23 13:56:38] Step: 94, Training Logs: loss_final: 5.108729, loss_mean: 1.212993, proj_loss: -0.058394, loss_mean_cls: 3.954130, grad_norm: 1.895388
96
+ Steps: 0%| | 95/1000000 [00:27<67:48:34, 4.10it/s, grad_norm=1.9, loss_final=5.11, loss_mean=1.21, loss_mean_cls=3.95, proj_loss=-0.0584][2026-03-23 13:56:39] Step: 95, Training Logs: loss_final: 4.209163, loss_mean: 1.207963, proj_loss: -0.055074, loss_mean_cls: 3.056275, grad_norm: 1.442985
97
+ Steps: 0%| | 96/1000000 [00:27<67:50:56, 4.09it/s, grad_norm=1.44, loss_final=4.21, loss_mean=1.21, loss_mean_cls=3.06, proj_loss=-0.0551][2026-03-23 13:56:39] Step: 96, Training Logs: loss_final: 4.601864, loss_mean: 1.191701, proj_loss: -0.057576, loss_mean_cls: 3.467739, grad_norm: 1.299760
98
+ Steps: 0%| | 97/1000000 [00:27<67:51:29, 4.09it/s, grad_norm=1.3, loss_final=4.6, loss_mean=1.19, loss_mean_cls=3.47, proj_loss=-0.0576][2026-03-23 13:56:39] Step: 97, Training Logs: loss_final: 4.614489, loss_mean: 1.185281, proj_loss: -0.056098, loss_mean_cls: 3.485306, grad_norm: 1.128512
99
+ Steps: 0%| | 98/1000000 [00:28<67:50:38, 4.09it/s, grad_norm=1.13, loss_final=4.61, loss_mean=1.19, loss_mean_cls=3.49, proj_loss=-0.0561][2026-03-23 13:56:39] Step: 98, Training Logs: loss_final: 4.294365, loss_mean: 1.196555, proj_loss: -0.058301, loss_mean_cls: 3.156111, grad_norm: 1.399289
100
+ Steps: 0%| | 99/1000000 [00:28<67:49:41, 4.09it/s, grad_norm=1.4, loss_final=4.29, loss_mean=1.2, loss_mean_cls=3.16, proj_loss=-0.0583][2026-03-23 13:56:40] Step: 99, Training Logs: loss_final: 4.215539, loss_mean: 1.199367, proj_loss: -0.057493, loss_mean_cls: 3.073665, grad_norm: 1.343760
101
+ Steps: 0%| | 100/1000000 [00:28<67:51:40, 4.09it/s, grad_norm=1.34, loss_final=4.22, loss_mean=1.2, loss_mean_cls=3.07, proj_loss=-0.0575][2026-03-23 13:56:40] Step: 100, Training Logs: loss_final: 4.488984, loss_mean: 1.149765, proj_loss: -0.056815, loss_mean_cls: 3.396034, grad_norm: 1.045582
102
+ Steps: 0%| | 101/1000000 [00:28<67:50:50, 4.09it/s, grad_norm=1.05, loss_final=4.49, loss_mean=1.15, loss_mean_cls=3.4, proj_loss=-0.0568][2026-03-23 13:56:40] Step: 101, Training Logs: loss_final: 4.368436, loss_mean: 1.178788, proj_loss: -0.058116, loss_mean_cls: 3.247765, grad_norm: 1.342256
103
+ Steps: 0%| | 102/1000000 [00:29<67:50:14, 4.09it/s, grad_norm=1.34, loss_final=4.37, loss_mean=1.18, loss_mean_cls=3.25, proj_loss=-0.0581][2026-03-23 13:56:40] Step: 102, Training Logs: loss_final: 4.096940, loss_mean: 1.195665, proj_loss: -0.053437, loss_mean_cls: 2.954712, grad_norm: 1.167338
104
+ Steps: 0%| | 103/1000000 [00:29<67:52:23, 4.09it/s, grad_norm=1.17, loss_final=4.1, loss_mean=1.2, loss_mean_cls=2.95, proj_loss=-0.0534][2026-03-23 13:56:41] Step: 103, Training Logs: loss_final: 4.406911, loss_mean: 1.123473, proj_loss: -0.055880, loss_mean_cls: 3.339318, grad_norm: 0.934016
105
+ Steps: 0%| | 104/1000000 [00:29<67:52:00, 4.09it/s, grad_norm=0.934, loss_final=4.41, loss_mean=1.12, loss_mean_cls=3.34, proj_loss=-0.0559][2026-03-23 13:56:41] Step: 104, Training Logs: loss_final: 3.831389, loss_mean: 1.127624, proj_loss: -0.056442, loss_mean_cls: 2.760207, grad_norm: 1.079511
106
+ Steps: 0%| | 105/1000000 [00:29<67:50:29, 4.09it/s, grad_norm=1.08, loss_final=3.83, loss_mean=1.13, loss_mean_cls=2.76, proj_loss=-0.0564][2026-03-23 13:56:41] Step: 105, Training Logs: loss_final: 3.949926, loss_mean: 1.169529, proj_loss: -0.056276, loss_mean_cls: 2.836674, grad_norm: 0.823243
107
+ Steps: 0%| | 106/1000000 [00:30<67:49:55, 4.09it/s, grad_norm=0.823, loss_final=3.95, loss_mean=1.17, loss_mean_cls=2.84, proj_loss=-0.0563][2026-03-23 13:56:41] Step: 106, Training Logs: loss_final: 4.447761, loss_mean: 1.131930, proj_loss: -0.057379, loss_mean_cls: 3.373209, grad_norm: 0.780265
108
+ Steps: 0%| | 107/1000000 [00:30<67:50:26, 4.09it/s, grad_norm=0.78, loss_final=4.45, loss_mean=1.13, loss_mean_cls=3.37, proj_loss=-0.0574][2026-03-23 13:56:42] Step: 107, Training Logs: loss_final: 4.185682, loss_mean: 1.137062, proj_loss: -0.056792, loss_mean_cls: 3.105412, grad_norm: 0.896073
109
+ Steps: 0%| | 108/1000000 [00:30<67:51:47, 4.09it/s, grad_norm=0.896, loss_final=4.19, loss_mean=1.14, loss_mean_cls=3.11, proj_loss=-0.0568][2026-03-23 13:56:42] Step: 108, Training Logs: loss_final: 4.381330, loss_mean: 1.127375, proj_loss: -0.057231, loss_mean_cls: 3.311185, grad_norm: 0.791624
110
+ Steps: 0%| | 109/1000000 [00:30<67:50:27, 4.09it/s, grad_norm=0.792, loss_final=4.38, loss_mean=1.13, loss_mean_cls=3.31, proj_loss=-0.0572][2026-03-23 13:56:42] Step: 109, Training Logs: loss_final: 4.101658, loss_mean: 1.133106, proj_loss: -0.058366, loss_mean_cls: 3.026918, grad_norm: 1.389539
111
+ Steps: 0%| | 110/1000000 [00:31<67:50:56, 4.09it/s, grad_norm=1.39, loss_final=4.1, loss_mean=1.13, loss_mean_cls=3.03, proj_loss=-0.0584][2026-03-23 13:56:42] Step: 110, Training Logs: loss_final: 4.222694, loss_mean: 1.125139, proj_loss: -0.056927, loss_mean_cls: 3.154482, grad_norm: 1.616820
112
+ Steps: 0%| | 111/1000000 [00:31<67:52:18, 4.09it/s, grad_norm=1.62, loss_final=4.22, loss_mean=1.13, loss_mean_cls=3.15, proj_loss=-0.0569][2026-03-23 13:56:43] Step: 111, Training Logs: loss_final: 3.836849, loss_mean: 1.123032, proj_loss: -0.057576, loss_mean_cls: 2.771394, grad_norm: 1.516640
113
+ Steps: 0%| | 112/1000000 [00:31<67:50:27, 4.09it/s, grad_norm=1.52, loss_final=3.84, loss_mean=1.12, loss_mean_cls=2.77, proj_loss=-0.0576][2026-03-23 13:56:43] Step: 112, Training Logs: loss_final: 4.385750, loss_mean: 1.170239, proj_loss: -0.057827, loss_mean_cls: 3.273338, grad_norm: 3.234420
114
+ Steps: 0%| | 113/1000000 [00:31<67:50:21, 4.09it/s, grad_norm=3.23, loss_final=4.39, loss_mean=1.17, loss_mean_cls=3.27, proj_loss=-0.0578][2026-03-23 13:56:43] Step: 113, Training Logs: loss_final: 4.791656, loss_mean: 1.139730, proj_loss: -0.057919, loss_mean_cls: 3.709845, grad_norm: 2.920728
115
+ Steps: 0%| | 114/1000000 [00:32<67:52:04, 4.09it/s, grad_norm=2.92, loss_final=4.79, loss_mean=1.14, loss_mean_cls=3.71, proj_loss=-0.0579][2026-03-23 13:56:43] Step: 114, Training Logs: loss_final: 4.198164, loss_mean: 1.112274, proj_loss: -0.054186, loss_mean_cls: 3.140076, grad_norm: 1.197350
116
+ Steps: 0%| | 115/1000000 [00:32<67:52:07, 4.09it/s, grad_norm=1.2, loss_final=4.2, loss_mean=1.11, loss_mean_cls=3.14, proj_loss=-0.0542][2026-03-23 13:56:44] Step: 115, Training Logs: loss_final: 4.495673, loss_mean: 1.149152, proj_loss: -0.056316, loss_mean_cls: 3.402837, grad_norm: 1.762717
117
+ Steps: 0%| | 116/1000000 [00:32<67:52:00, 4.09it/s, grad_norm=1.76, loss_final=4.5, loss_mean=1.15, loss_mean_cls=3.4, proj_loss=-0.0563][2026-03-23 13:56:44] Step: 116, Training Logs: loss_final: 3.863493, loss_mean: 1.173006, proj_loss: -0.057397, loss_mean_cls: 2.747884, grad_norm: 2.976036
118
+ Steps: 0%| | 117/1000000 [00:32<67:52:53, 4.09it/s, grad_norm=2.98, loss_final=3.86, loss_mean=1.17, loss_mean_cls=2.75, proj_loss=-0.0574][2026-03-23 13:56:44] Step: 117, Training Logs: loss_final: 4.637027, loss_mean: 1.119735, proj_loss: -0.056042, loss_mean_cls: 3.573334, grad_norm: 2.094532
119
+ Steps: 0%| | 118/1000000 [00:33<67:52:15, 4.09it/s, grad_norm=2.09, loss_final=4.64, loss_mean=1.12, loss_mean_cls=3.57, proj_loss=-0.056][2026-03-23 13:56:44] Step: 118, Training Logs: loss_final: 3.896062, loss_mean: 1.133374, proj_loss: -0.057744, loss_mean_cls: 2.820432, grad_norm: 2.091462
120
+ Steps: 0%| | 119/1000000 [00:33<67:52:19, 4.09it/s, grad_norm=2.09, loss_final=3.9, loss_mean=1.13, loss_mean_cls=2.82, proj_loss=-0.0577][2026-03-23 13:56:45] Step: 119, Training Logs: loss_final: 4.494811, loss_mean: 1.127229, proj_loss: -0.056165, loss_mean_cls: 3.423747, grad_norm: 2.573430
121
+ Steps: 0%| | 120/1000000 [00:33<67:52:07, 4.09it/s, grad_norm=2.57, loss_final=4.49, loss_mean=1.13, loss_mean_cls=3.42, proj_loss=-0.0562][2026-03-23 13:56:45] Step: 120, Training Logs: loss_final: 4.395417, loss_mean: 1.094057, proj_loss: -0.057758, loss_mean_cls: 3.359118, grad_norm: 1.652974
122
+ Steps: 0%| | 121/1000000 [00:33<67:49:02, 4.10it/s, grad_norm=1.65, loss_final=4.4, loss_mean=1.09, loss_mean_cls=3.36, proj_loss=-0.0578][2026-03-23 13:56:45] Step: 121, Training Logs: loss_final: 3.780295, loss_mean: 1.143866, proj_loss: -0.057413, loss_mean_cls: 2.693842, grad_norm: 2.330424
123
+ Steps: 0%| | 122/1000000 [00:34<67:51:29, 4.09it/s, grad_norm=2.33, loss_final=3.78, loss_mean=1.14, loss_mean_cls=2.69, proj_loss=-0.0574][2026-03-23 13:56:45] Step: 122, Training Logs: loss_final: 5.011144, loss_mean: 1.122681, proj_loss: -0.058837, loss_mean_cls: 3.947300, grad_norm: 1.509129
124
+ Steps: 0%| | 123/1000000 [00:34<67:51:03, 4.09it/s, grad_norm=1.51, loss_final=5.01, loss_mean=1.12, loss_mean_cls=3.95, proj_loss=-0.0588][2026-03-23 13:56:46] Step: 123, Training Logs: loss_final: 4.172338, loss_mean: 1.115492, proj_loss: -0.056144, loss_mean_cls: 3.112991, grad_norm: 1.528705
125
+ Steps: 0%| | 124/1000000 [00:34<67:50:03, 4.09it/s, grad_norm=1.53, loss_final=4.17, loss_mean=1.12, loss_mean_cls=3.11, proj_loss=-0.0561][2026-03-23 13:56:46] Step: 124, Training Logs: loss_final: 4.010628, loss_mean: 1.117229, proj_loss: -0.054621, loss_mean_cls: 2.948020, grad_norm: 1.329769
126
+ Steps: 0%| | 125/1000000 [00:34<67:48:41, 4.10it/s, grad_norm=1.33, loss_final=4.01, loss_mean=1.12, loss_mean_cls=2.95, proj_loss=-0.0546][2026-03-23 13:56:46] Step: 125, Training Logs: loss_final: 3.764182, loss_mean: 1.127925, proj_loss: -0.057976, loss_mean_cls: 2.694233, grad_norm: 1.674507
127
+ Steps: 0%| | 126/1000000 [00:35<67:47:01, 4.10it/s, grad_norm=1.67, loss_final=3.76, loss_mean=1.13, loss_mean_cls=2.69, proj_loss=-0.058][2026-03-23 13:56:46] Step: 126, Training Logs: loss_final: 4.371668, loss_mean: 1.097472, proj_loss: -0.055551, loss_mean_cls: 3.329747, grad_norm: 1.962917
128
+ Steps: 0%| | 127/1000000 [00:35<67:47:01, 4.10it/s, grad_norm=1.96, loss_final=4.37, loss_mean=1.1, loss_mean_cls=3.33, proj_loss=-0.0556][2026-03-23 13:56:47] Step: 127, Training Logs: loss_final: 4.417885, loss_mean: 1.097842, proj_loss: -0.057305, loss_mean_cls: 3.377348, grad_norm: 1.860904
129
+ Steps: 0%| | 128/1000000 [00:35<67:46:39, 4.10it/s, grad_norm=1.86, loss_final=4.42, loss_mean=1.1, loss_mean_cls=3.38, proj_loss=-0.0573][2026-03-23 13:56:47] Step: 128, Training Logs: loss_final: 4.017467, loss_mean: 1.103394, proj_loss: -0.057005, loss_mean_cls: 2.971077, grad_norm: 1.122871
130
+ Steps: 0%| | 129/1000000 [00:35<67:55:50, 4.09it/s, grad_norm=1.12, loss_final=4.02, loss_mean=1.1, loss_mean_cls=2.97, proj_loss=-0.057][2026-03-23 13:56:47] Step: 129, Training Logs: loss_final: 4.208836, loss_mean: 1.102054, proj_loss: -0.056320, loss_mean_cls: 3.163102, grad_norm: 1.851275
131
+ Steps: 0%| | 130/1000000 [00:36<67:52:57, 4.09it/s, grad_norm=1.85, loss_final=4.21, loss_mean=1.1, loss_mean_cls=3.16, proj_loss=-0.0563][2026-03-23 13:56:47] Step: 130, Training Logs: loss_final: 4.207572, loss_mean: 1.127353, proj_loss: -0.057998, loss_mean_cls: 3.138218, grad_norm: 1.432804
132
+ Steps: 0%| | 131/1000000 [00:36<67:51:54, 4.09it/s, grad_norm=1.43, loss_final=4.21, loss_mean=1.13, loss_mean_cls=3.14, proj_loss=-0.058][2026-03-23 13:56:48] Step: 131, Training Logs: loss_final: 4.999227, loss_mean: 1.065102, proj_loss: -0.057811, loss_mean_cls: 3.991935, grad_norm: 1.260865
133
+ Steps: 0%| | 132/1000000 [00:36<67:50:15, 4.09it/s, grad_norm=1.26, loss_final=5, loss_mean=1.07, loss_mean_cls=3.99, proj_loss=-0.0578][2026-03-23 13:56:48] Step: 132, Training Logs: loss_final: 4.958624, loss_mean: 1.080354, proj_loss: -0.055375, loss_mean_cls: 3.933645, grad_norm: 1.087305
134
+ Steps: 0%| | 133/1000000 [00:36<67:49:24, 4.10it/s, grad_norm=1.09, loss_final=4.96, loss_mean=1.08, loss_mean_cls=3.93, proj_loss=-0.0554][2026-03-23 13:56:48] Step: 133, Training Logs: loss_final: 4.336567, loss_mean: 1.074137, proj_loss: -0.057906, loss_mean_cls: 3.320336, grad_norm: 0.845532
135
+ Steps: 0%| | 134/1000000 [00:36<67:49:27, 4.09it/s, grad_norm=0.846, loss_final=4.34, loss_mean=1.07, loss_mean_cls=3.32, proj_loss=-0.0579][2026-03-23 13:56:48] Step: 134, Training Logs: loss_final: 4.113417, loss_mean: 1.101781, proj_loss: -0.055828, loss_mean_cls: 3.067464, grad_norm: 0.853130
136
+ Steps: 0%| | 135/1000000 [00:37<67:48:26, 4.10it/s, grad_norm=0.853, loss_final=4.11, loss_mean=1.1, loss_mean_cls=3.07, proj_loss=-0.0558][2026-03-23 13:56:48] Step: 135, Training Logs: loss_final: 4.451187, loss_mean: 1.071875, proj_loss: -0.057522, loss_mean_cls: 3.436834, grad_norm: 1.211643
137
+ Steps: 0%| | 136/1000000 [00:37<67:48:00, 4.10it/s, grad_norm=1.21, loss_final=4.45, loss_mean=1.07, loss_mean_cls=3.44, proj_loss=-0.0575][2026-03-23 13:56:49] Step: 136, Training Logs: loss_final: 3.470972, loss_mean: 1.105708, proj_loss: -0.057636, loss_mean_cls: 2.422900, grad_norm: 1.028398
138
+ Steps: 0%| | 137/1000000 [00:37<67:48:46, 4.10it/s, grad_norm=1.03, loss_final=3.47, loss_mean=1.11, loss_mean_cls=2.42, proj_loss=-0.0576][2026-03-23 13:56:49] Step: 137, Training Logs: loss_final: 4.167720, loss_mean: 1.116012, proj_loss: -0.056297, loss_mean_cls: 3.108005, grad_norm: 1.331365
139
+ Steps: 0%| | 138/1000000 [00:37<67:49:54, 4.09it/s, grad_norm=1.33, loss_final=4.17, loss_mean=1.12, loss_mean_cls=3.11, proj_loss=-0.0563][2026-03-23 13:56:49] Step: 138, Training Logs: loss_final: 3.969088, loss_mean: 1.080545, proj_loss: -0.058032, loss_mean_cls: 2.946575, grad_norm: 1.159432
140
+ Steps: 0%| | 139/1000000 [00:38<67:49:15, 4.10it/s, grad_norm=1.16, loss_final=3.97, loss_mean=1.08, loss_mean_cls=2.95, proj_loss=-0.058][2026-03-23 13:56:49] Step: 139, Training Logs: loss_final: 4.553603, loss_mean: 1.100756, proj_loss: -0.057699, loss_mean_cls: 3.510547, grad_norm: 1.379237
141
+ Steps: 0%| | 140/1000000 [00:38<67:48:29, 4.10it/s, grad_norm=1.38, loss_final=4.55, loss_mean=1.1, loss_mean_cls=3.51, proj_loss=-0.0577][2026-03-23 13:56:50] Step: 140, Training Logs: loss_final: 3.823996, loss_mean: 1.091382, proj_loss: -0.055826, loss_mean_cls: 2.788440, grad_norm: 1.269147
142
+ Steps: 0%| | 141/1000000 [00:38<67:49:20, 4.10it/s, grad_norm=1.27, loss_final=3.82, loss_mean=1.09, loss_mean_cls=2.79, proj_loss=-0.0558][2026-03-23 13:56:50] Step: 141, Training Logs: loss_final: 4.244388, loss_mean: 1.086618, proj_loss: -0.060178, loss_mean_cls: 3.217947, grad_norm: 1.224737
143
+ Steps: 0%| | 142/1000000 [00:38<67:48:47, 4.10it/s, grad_norm=1.22, loss_final=4.24, loss_mean=1.09, loss_mean_cls=3.22, proj_loss=-0.0602][2026-03-23 13:56:50] Step: 142, Training Logs: loss_final: 4.555122, loss_mean: 1.074355, proj_loss: -0.056447, loss_mean_cls: 3.537215, grad_norm: 1.288044
144
+ Steps: 0%| | 143/1000000 [00:39<67:49:15, 4.10it/s, grad_norm=1.29, loss_final=4.56, loss_mean=1.07, loss_mean_cls=3.54, proj_loss=-0.0564][2026-03-23 13:56:50] Step: 143, Training Logs: loss_final: 4.631051, loss_mean: 1.031162, proj_loss: -0.056307, loss_mean_cls: 3.656196, grad_norm: 0.875154
145
+ Steps: 0%| | 144/1000000 [00:39<67:47:56, 4.10it/s, grad_norm=0.875, loss_final=4.63, loss_mean=1.03, loss_mean_cls=3.66, proj_loss=-0.0563][2026-03-23 13:56:51] Step: 144, Training Logs: loss_final: 4.039268, loss_mean: 1.089356, proj_loss: -0.056316, loss_mean_cls: 3.006228, grad_norm: 1.155283
146
+ Steps: 0%| | 145/1000000 [00:39<67:55:59, 4.09it/s, grad_norm=1.16, loss_final=4.04, loss_mean=1.09, loss_mean_cls=3.01, proj_loss=-0.0563][2026-03-23 13:56:51] Step: 145, Training Logs: loss_final: 3.760608, loss_mean: 1.070431, proj_loss: -0.059299, loss_mean_cls: 2.749476, grad_norm: 1.032837
147
+ Steps: 0%| | 146/1000000 [00:39<67:53:40, 4.09it/s, grad_norm=1.03, loss_final=3.76, loss_mean=1.07, loss_mean_cls=2.75, proj_loss=-0.0593][2026-03-23 13:56:51] Step: 146, Training Logs: loss_final: 4.181468, loss_mean: 1.069318, proj_loss: -0.057231, loss_mean_cls: 3.169381, grad_norm: 1.739599
148
+ Steps: 0%| | 147/1000000 [00:40<87:04:35, 3.19it/s, grad_norm=1.74, loss_final=4.18, loss_mean=1.07, loss_mean_cls=3.17, proj_loss=-0.0572][2026-03-23 13:56:52] Step: 147, Training Logs: loss_final: 4.397898, loss_mean: 1.077711, proj_loss: -0.057317, loss_mean_cls: 3.377505, grad_norm: 1.324051
149
+ Steps: 0%| | 148/1000000 [00:40<81:18:29, 3.42it/s, grad_norm=1.32, loss_final=4.4, loss_mean=1.08, loss_mean_cls=3.38, proj_loss=-0.0573][2026-03-23 13:56:52] Step: 148, Training Logs: loss_final: 3.678877, loss_mean: 1.095331, proj_loss: -0.056643, loss_mean_cls: 2.640189, grad_norm: 0.957576
150
+ Steps: 0%| | 149/1000000 [00:40<77:21:07, 3.59it/s, grad_norm=0.958, loss_final=3.68, loss_mean=1.1, loss_mean_cls=2.64, proj_loss=-0.0566][2026-03-23 13:56:52] Step: 149, Training Logs: loss_final: 3.970488, loss_mean: 1.096237, proj_loss: -0.057982, loss_mean_cls: 2.932232, grad_norm: 0.751983
151
+ Steps: 0%| | 150/1000000 [00:41<74:28:13, 3.73it/s, grad_norm=0.752, loss_final=3.97, loss_mean=1.1, loss_mean_cls=2.93, proj_loss=-0.058][2026-03-23 13:56:52] Step: 150, Training Logs: loss_final: 3.589296, loss_mean: 1.085876, proj_loss: -0.059221, loss_mean_cls: 2.562641, grad_norm: 1.001571
152
+ Steps: 0%| | 151/1000000 [00:41<72:29:17, 3.83it/s, grad_norm=1, loss_final=3.59, loss_mean=1.09, loss_mean_cls=2.56, proj_loss=-0.0592][2026-03-23 13:56:53] Step: 151, Training Logs: loss_final: 3.809782, loss_mean: 1.041376, proj_loss: -0.057593, loss_mean_cls: 2.825998, grad_norm: 0.772958
153
+ Steps: 0%| | 152/1000000 [00:41<71:06:45, 3.91it/s, grad_norm=0.773, loss_final=3.81, loss_mean=1.04, loss_mean_cls=2.83, proj_loss=-0.0576][2026-03-23 13:56:53] Step: 152, Training Logs: loss_final: 3.756772, loss_mean: 1.077754, proj_loss: -0.055612, loss_mean_cls: 2.734629, grad_norm: 0.942414
154
+ Steps: 0%| | 153/1000000 [00:41<70:15:13, 3.95it/s, grad_norm=0.942, loss_final=3.76, loss_mean=1.08, loss_mean_cls=2.73, proj_loss=-0.0556][2026-03-23 13:56:53] Step: 153, Training Logs: loss_final: 3.815321, loss_mean: 1.091242, proj_loss: -0.055186, loss_mean_cls: 2.779264, grad_norm: 1.427716
155
+ Steps: 0%| | 154/1000000 [00:42<69:30:56, 4.00it/s, grad_norm=1.43, loss_final=3.82, loss_mean=1.09, loss_mean_cls=2.78, proj_loss=-0.0552][2026-03-23 13:56:53] Step: 154, Training Logs: loss_final: 4.484177, loss_mean: 1.070384, proj_loss: -0.056569, loss_mean_cls: 3.470362, grad_norm: 0.890748
156
+ Steps: 0%| | 155/1000000 [00:42<68:58:59, 4.03it/s, grad_norm=0.891, loss_final=4.48, loss_mean=1.07, loss_mean_cls=3.47, proj_loss=-0.0566][2026-03-23 13:56:54] Step: 155, Training Logs: loss_final: 3.992099, loss_mean: 1.094668, proj_loss: -0.056406, loss_mean_cls: 2.953837, grad_norm: 1.103738
157
+ Steps: 0%| | 156/1000000 [00:42<68:38:04, 4.05it/s, grad_norm=1.1, loss_final=3.99, loss_mean=1.09, loss_mean_cls=2.95, proj_loss=-0.0564][2026-03-23 13:56:54] Step: 156, Training Logs: loss_final: 4.835541, loss_mean: 1.066255, proj_loss: -0.057052, loss_mean_cls: 3.826337, grad_norm: 1.572935
158
+ Steps: 0%| | 157/1000000 [00:42<68:32:39, 4.05it/s, grad_norm=1.57, loss_final=4.84, loss_mean=1.07, loss_mean_cls=3.83, proj_loss=-0.0571][2026-03-23 13:56:54] Step: 157, Training Logs: loss_final: 4.422526, loss_mean: 1.037669, proj_loss: -0.057450, loss_mean_cls: 3.442307, grad_norm: 0.881763
159
+ Steps: 0%| | 158/1000000 [00:43<68:18:41, 4.07it/s, grad_norm=0.882, loss_final=4.42, loss_mean=1.04, loss_mean_cls=3.44, proj_loss=-0.0574][2026-03-23 13:56:54] Step: 158, Training Logs: loss_final: 3.933815, loss_mean: 1.094053, proj_loss: -0.058011, loss_mean_cls: 2.897773, grad_norm: 1.524001
160
+ Steps: 0%| | 159/1000000 [00:43<68:08:39, 4.08it/s, grad_norm=1.52, loss_final=3.93, loss_mean=1.09, loss_mean_cls=2.9, proj_loss=-0.058][2026-03-23 13:56:55] Step: 159, Training Logs: loss_final: 3.294493, loss_mean: 1.101539, proj_loss: -0.057312, loss_mean_cls: 2.250265, grad_norm: 0.862417
161
+ Steps: 0%| | 160/1000000 [00:43<68:00:56, 4.08it/s, grad_norm=0.862, loss_final=3.29, loss_mean=1.1, loss_mean_cls=2.25, proj_loss=-0.0573][2026-03-23 13:56:55] Step: 160, Training Logs: loss_final: 4.003653, loss_mean: 1.049939, proj_loss: -0.056768, loss_mean_cls: 3.010482, grad_norm: 1.406133
162
+ Steps: 0%| | 161/1000000 [00:43<68:05:23, 4.08it/s, grad_norm=1.41, loss_final=4, loss_mean=1.05, loss_mean_cls=3.01, proj_loss=-0.0568][2026-03-23 13:56:55] Step: 161, Training Logs: loss_final: 3.735500, loss_mean: 1.059312, proj_loss: -0.055975, loss_mean_cls: 2.732163, grad_norm: 1.257030
163
+ Steps: 0%| | 162/1000000 [00:44<67:59:39, 4.08it/s, grad_norm=1.26, loss_final=3.74, loss_mean=1.06, loss_mean_cls=2.73, proj_loss=-0.056][2026-03-23 13:56:55] Step: 162, Training Logs: loss_final: 4.134730, loss_mean: 1.051557, proj_loss: -0.055925, loss_mean_cls: 3.139099, grad_norm: 1.383831
164
+ Steps: 0%| | 163/1000000 [00:44<67:54:57, 4.09it/s, grad_norm=1.38, loss_final=4.13, loss_mean=1.05, loss_mean_cls=3.14, proj_loss=-0.0559][2026-03-23 13:56:56] Step: 163, Training Logs: loss_final: 3.976990, loss_mean: 1.080228, proj_loss: -0.055570, loss_mean_cls: 2.952332, grad_norm: 2.321836
165
+ Steps: 0%| | 164/1000000 [00:44<67:52:28, 4.09it/s, grad_norm=2.32, loss_final=3.98, loss_mean=1.08, loss_mean_cls=2.95, proj_loss=-0.0556][2026-03-23 13:56:56] Step: 164, Training Logs: loss_final: 4.920438, loss_mean: 1.053734, proj_loss: -0.057831, loss_mean_cls: 3.924535, grad_norm: 1.427572
166
+ Steps: 0%| | 165/1000000 [00:44<67:51:20, 4.09it/s, grad_norm=1.43, loss_final=4.92, loss_mean=1.05, loss_mean_cls=3.92, proj_loss=-0.0578][2026-03-23 13:56:56] Step: 165, Training Logs: loss_final: 4.075354, loss_mean: 1.102409, proj_loss: -0.055738, loss_mean_cls: 3.028682, grad_norm: 1.554521
167
+ Steps: 0%| | 166/1000000 [00:45<67:50:22, 4.09it/s, grad_norm=1.55, loss_final=4.08, loss_mean=1.1, loss_mean_cls=3.03, proj_loss=-0.0557][2026-03-23 13:56:56] Step: 166, Training Logs: loss_final: 3.970260, loss_mean: 1.052696, proj_loss: -0.054281, loss_mean_cls: 2.971844, grad_norm: 1.398542
168
+ Steps: 0%| | 167/1000000 [00:45<67:49:03, 4.10it/s, grad_norm=1.4, loss_final=3.97, loss_mean=1.05, loss_mean_cls=2.97, proj_loss=-0.0543][2026-03-23 13:56:57] Step: 167, Training Logs: loss_final: 4.506516, loss_mean: 1.058694, proj_loss: -0.055691, loss_mean_cls: 3.503513, grad_norm: 1.465548
169
+ Steps: 0%| | 168/1000000 [00:45<67:48:15, 4.10it/s, grad_norm=1.47, loss_final=4.51, loss_mean=1.06, loss_mean_cls=3.5, proj_loss=-0.0557][2026-03-23 13:56:57] Step: 168, Training Logs: loss_final: 3.774779, loss_mean: 1.070169, proj_loss: -0.056798, loss_mean_cls: 2.761408, grad_norm: 1.063631
170
+ Steps: 0%| | 169/1000000 [00:45<67:55:41, 4.09it/s, grad_norm=1.06, loss_final=3.77, loss_mean=1.07, loss_mean_cls=2.76, proj_loss=-0.0568][2026-03-23 13:56:57] Step: 169, Training Logs: loss_final: 3.799341, loss_mean: 1.072764, proj_loss: -0.057054, loss_mean_cls: 2.783631, grad_norm: 1.432508
171
+ Steps: 0%| | 170/1000000 [00:46<67:53:00, 4.09it/s, grad_norm=1.43, loss_final=3.8, loss_mean=1.07, loss_mean_cls=2.78, proj_loss=-0.0571][2026-03-23 13:56:57] Step: 170, Training Logs: loss_final: 3.950215, loss_mean: 1.050979, proj_loss: -0.058352, loss_mean_cls: 2.957588, grad_norm: 1.403183
172
+ Steps: 0%| | 171/1000000 [00:46<67:51:12, 4.09it/s, grad_norm=1.4, loss_final=3.95, loss_mean=1.05, loss_mean_cls=2.96, proj_loss=-0.0584][2026-03-23 13:56:58] Step: 171, Training Logs: loss_final: 4.198855, loss_mean: 1.049812, proj_loss: -0.056677, loss_mean_cls: 3.205721, grad_norm: 0.921758
173
+ Steps: 0%| | 172/1000000 [00:46<67:49:22, 4.09it/s, grad_norm=0.922, loss_final=4.2, loss_mean=1.05, loss_mean_cls=3.21, proj_loss=-0.0567][2026-03-23 13:56:58] Step: 172, Training Logs: loss_final: 3.441348, loss_mean: 1.076280, proj_loss: -0.057867, loss_mean_cls: 2.422936, grad_norm: 0.854744
174
+ Steps: 0%| | 173/1000000 [00:46<67:50:30, 4.09it/s, grad_norm=0.855, loss_final=3.44, loss_mean=1.08, loss_mean_cls=2.42, proj_loss=-0.0579][2026-03-23 13:56:58] Step: 173, Training Logs: loss_final: 3.798677, loss_mean: 1.085783, proj_loss: -0.057882, loss_mean_cls: 2.770777, grad_norm: 1.453923
175
+ Steps: 0%| | 174/1000000 [00:46<67:50:14, 4.09it/s, grad_norm=1.45, loss_final=3.8, loss_mean=1.09, loss_mean_cls=2.77, proj_loss=-0.0579][2026-03-23 13:56:58] Step: 174, Training Logs: loss_final: 3.112827, loss_mean: 1.064432, proj_loss: -0.058472, loss_mean_cls: 2.106868, grad_norm: 1.250231
176
+ Steps: 0%| | 175/1000000 [00:47<67:49:59, 4.09it/s, grad_norm=1.25, loss_final=3.11, loss_mean=1.06, loss_mean_cls=2.11, proj_loss=-0.0585][2026-03-23 13:56:58] Step: 175, Training Logs: loss_final: 3.708531, loss_mean: 1.080565, proj_loss: -0.055104, loss_mean_cls: 2.683070, grad_norm: 1.818153
177
+ Steps: 0%| | 176/1000000 [00:47<67:49:26, 4.09it/s, grad_norm=1.82, loss_final=3.71, loss_mean=1.08, loss_mean_cls=2.68, proj_loss=-0.0551][2026-03-23 13:56:59] Step: 176, Training Logs: loss_final: 4.057882, loss_mean: 1.041325, proj_loss: -0.057044, loss_mean_cls: 3.073601, grad_norm: 1.056386
178
+ Steps: 0%| | 177/1000000 [00:47<67:50:14, 4.09it/s, grad_norm=1.06, loss_final=4.06, loss_mean=1.04, loss_mean_cls=3.07, proj_loss=-0.057][2026-03-23 13:56:59] Step: 177, Training Logs: loss_final: 3.843677, loss_mean: 1.082743, proj_loss: -0.056343, loss_mean_cls: 2.817278, grad_norm: 1.454900
179
+ Steps: 0%| | 178/1000000 [00:47<67:50:45, 4.09it/s, grad_norm=1.45, loss_final=3.84, loss_mean=1.08, loss_mean_cls=2.82, proj_loss=-0.0563][2026-03-23 13:56:59] Step: 178, Training Logs: loss_final: 4.592722, loss_mean: 1.051988, proj_loss: -0.056379, loss_mean_cls: 3.597114, grad_norm: 1.553295
180
+ Steps: 0%| | 179/1000000 [00:48<67:51:26, 4.09it/s, grad_norm=1.55, loss_final=4.59, loss_mean=1.05, loss_mean_cls=3.6, proj_loss=-0.0564][2026-03-23 13:56:59] Step: 179, Training Logs: loss_final: 4.083874, loss_mean: 1.044562, proj_loss: -0.057760, loss_mean_cls: 3.097073, grad_norm: 1.207282
181
+ Steps: 0%| | 180/1000000 [00:48<67:50:22, 4.09it/s, grad_norm=1.21, loss_final=4.08, loss_mean=1.04, loss_mean_cls=3.1, proj_loss=-0.0578][2026-03-23 13:57:00] Step: 180, Training Logs: loss_final: 3.533881, loss_mean: 1.084617, proj_loss: -0.055471, loss_mean_cls: 2.504735, grad_norm: 1.384633
182
+ Steps: 0%| | 181/1000000 [00:48<67:49:31, 4.09it/s, grad_norm=1.38, loss_final=3.53, loss_mean=1.08, loss_mean_cls=2.5, proj_loss=-0.0555][2026-03-23 13:57:00] Step: 181, Training Logs: loss_final: 4.409495, loss_mean: 1.036472, proj_loss: -0.057444, loss_mean_cls: 3.430467, grad_norm: 1.297928
183
+ Steps: 0%| | 182/1000000 [00:48<67:48:57, 4.10it/s, grad_norm=1.3, loss_final=4.41, loss_mean=1.04, loss_mean_cls=3.43, proj_loss=-0.0574][2026-03-23 13:57:00] Step: 182, Training Logs: loss_final: 3.176256, loss_mean: 1.091783, proj_loss: -0.057242, loss_mean_cls: 2.141715, grad_norm: 1.965931
184
+ Steps: 0%| | 183/1000000 [00:49<67:49:28, 4.09it/s, grad_norm=1.97, loss_final=3.18, loss_mean=1.09, loss_mean_cls=2.14, proj_loss=-0.0572][2026-03-23 13:57:00] Step: 183, Training Logs: loss_final: 3.651900, loss_mean: 1.063701, proj_loss: -0.057272, loss_mean_cls: 2.645471, grad_norm: 1.260781
185
+ Steps: 0%| | 184/1000000 [00:49<67:49:34, 4.09it/s, grad_norm=1.26, loss_final=3.65, loss_mean=1.06, loss_mean_cls=2.65, proj_loss=-0.0573][2026-03-23 13:57:01] Step: 184, Training Logs: loss_final: 4.134811, loss_mean: 1.052191, proj_loss: -0.056823, loss_mean_cls: 3.139443, grad_norm: 1.083165
186
+ Steps: 0%| | 185/1000000 [00:49<67:54:27, 4.09it/s, grad_norm=1.08, loss_final=4.13, loss_mean=1.05, loss_mean_cls=3.14, proj_loss=-0.0568][2026-03-23 13:57:01] Step: 185, Training Logs: loss_final: 3.952762, loss_mean: 1.034234, proj_loss: -0.056918, loss_mean_cls: 2.975446, grad_norm: 1.370537
187
+ Steps: 0%| | 186/1000000 [00:49<67:53:04, 4.09it/s, grad_norm=1.37, loss_final=3.95, loss_mean=1.03, loss_mean_cls=2.98, proj_loss=-0.0569][2026-03-23 13:57:01] Step: 186, Training Logs: loss_final: 4.377052, loss_mean: 1.068180, proj_loss: -0.055129, loss_mean_cls: 3.364000, grad_norm: 2.047776
188
+ Steps: 0%| | 187/1000000 [00:50<67:51:19, 4.09it/s, grad_norm=2.05, loss_final=4.38, loss_mean=1.07, loss_mean_cls=3.36, proj_loss=-0.0551][2026-03-23 13:57:01] Step: 187, Training Logs: loss_final: 3.465180, loss_mean: 1.067420, proj_loss: -0.055022, loss_mean_cls: 2.452782, grad_norm: 1.094218
189
+ Steps: 0%| | 188/1000000 [00:50<67:50:38, 4.09it/s, grad_norm=1.09, loss_final=3.47, loss_mean=1.07, loss_mean_cls=2.45, proj_loss=-0.055][2026-03-23 13:57:02] Step: 188, Training Logs: loss_final: 4.428637, loss_mean: 1.075879, proj_loss: -0.056132, loss_mean_cls: 3.408890, grad_norm: 2.243081
190
+ Steps: 0%| | 189/1000000 [00:50<67:50:39, 4.09it/s, grad_norm=2.24, loss_final=4.43, loss_mean=1.08, loss_mean_cls=3.41, proj_loss=-0.0561][2026-03-23 13:57:02] Step: 189, Training Logs: loss_final: 3.874612, loss_mean: 1.059214, proj_loss: -0.055212, loss_mean_cls: 2.870610, grad_norm: 1.509933
191
+ Steps: 0%| | 190/1000000 [00:50<67:53:49, 4.09it/s, grad_norm=1.51, loss_final=3.87, loss_mean=1.06, loss_mean_cls=2.87, proj_loss=-0.0552][2026-03-23 13:57:02] Step: 190, Training Logs: loss_final: 3.696738, loss_mean: 1.066510, proj_loss: -0.058307, loss_mean_cls: 2.688535, grad_norm: 1.849224
192
+ Steps: 0%| | 191/1000000 [00:51<67:54:49, 4.09it/s, grad_norm=1.85, loss_final=3.7, loss_mean=1.07, loss_mean_cls=2.69, proj_loss=-0.0583][2026-03-23 13:57:02] Step: 191, Training Logs: loss_final: 4.354342, loss_mean: 1.020227, proj_loss: -0.055745, loss_mean_cls: 3.389860, grad_norm: 1.847939
193
+ Steps: 0%| | 192/1000000 [00:51<67:54:25, 4.09it/s, grad_norm=1.85, loss_final=4.35, loss_mean=1.02, loss_mean_cls=3.39, proj_loss=-0.0557][2026-03-23 13:57:03] Step: 192, Training Logs: loss_final: 4.409212, loss_mean: 1.048026, proj_loss: -0.056594, loss_mean_cls: 3.417780, grad_norm: 1.524102
194
+ Steps: 0%| | 193/1000000 [00:51<67:53:07, 4.09it/s, grad_norm=1.52, loss_final=4.41, loss_mean=1.05, loss_mean_cls=3.42, proj_loss=-0.0566][2026-03-23 13:57:03] Step: 193, Training Logs: loss_final: 3.909914, loss_mean: 1.034270, proj_loss: -0.058143, loss_mean_cls: 2.933788, grad_norm: 1.197792
195
+ Steps: 0%| | 194/1000000 [00:51<67:53:28, 4.09it/s, grad_norm=1.2, loss_final=3.91, loss_mean=1.03, loss_mean_cls=2.93, proj_loss=-0.0581][2026-03-23 13:57:03] Step: 194, Training Logs: loss_final: 3.518339, loss_mean: 1.061185, proj_loss: -0.053464, loss_mean_cls: 2.510617, grad_norm: 1.685875
196
+ Steps: 0%| | 195/1000000 [00:52<67:52:30, 4.09it/s, grad_norm=1.69, loss_final=3.52, loss_mean=1.06, loss_mean_cls=2.51, proj_loss=-0.0535][2026-03-23 13:57:03] Step: 195, Training Logs: loss_final: 4.144682, loss_mean: 1.069985, proj_loss: -0.056024, loss_mean_cls: 3.130722, grad_norm: 1.617648
197
+ Steps: 0%| | 196/1000000 [00:52<67:53:25, 4.09it/s, grad_norm=1.62, loss_final=4.14, loss_mean=1.07, loss_mean_cls=3.13, proj_loss=-0.056][2026-03-23 13:57:04] Step: 196, Training Logs: loss_final: 4.209203, loss_mean: 1.043093, proj_loss: -0.056968, loss_mean_cls: 3.223077, grad_norm: 1.561610
198
+ Steps: 0%| | 197/1000000 [00:52<67:51:41, 4.09it/s, grad_norm=1.56, loss_final=4.21, loss_mean=1.04, loss_mean_cls=3.22, proj_loss=-0.057][2026-03-23 13:57:04] Step: 197, Training Logs: loss_final: 3.986165, loss_mean: 1.048607, proj_loss: -0.058094, loss_mean_cls: 2.995652, grad_norm: 1.465502
199
+ Steps: 0%| | 198/1000000 [00:52<67:51:10, 4.09it/s, grad_norm=1.47, loss_final=3.99, loss_mean=1.05, loss_mean_cls=3, proj_loss=-0.0581][2026-03-23 13:57:04] Step: 198, Training Logs: loss_final: 4.109353, loss_mean: 1.057425, proj_loss: -0.057203, loss_mean_cls: 3.109131, grad_norm: 1.602210
200
+ Steps: 0%| | 199/1000000 [00:53<67:51:00, 4.09it/s, grad_norm=1.6, loss_final=4.11, loss_mean=1.06, loss_mean_cls=3.11, proj_loss=-0.0572][2026-03-23 13:57:04] Step: 199, Training Logs: loss_final: 3.827809, loss_mean: 1.060561, proj_loss: -0.056529, loss_mean_cls: 2.823776, grad_norm: 1.281930
201
+ Steps: 0%| | 200/1000000 [00:53<67:52:45, 4.09it/s, grad_norm=1.28, loss_final=3.83, loss_mean=1.06, loss_mean_cls=2.82, proj_loss=-0.0565][2026-03-23 13:57:05] Step: 200, Training Logs: loss_final: 4.859810, loss_mean: 1.018786, proj_loss: -0.056772, loss_mean_cls: 3.897796, grad_norm: 1.346983
202
+ Steps: 0%| | 201/1000000 [00:53<67:50:08, 4.09it/s, grad_norm=1.35, loss_final=4.86, loss_mean=1.02, loss_mean_cls=3.9, proj_loss=-0.0568][2026-03-23 13:57:05] Step: 201, Training Logs: loss_final: 3.945022, loss_mean: 1.039917, proj_loss: -0.058203, loss_mean_cls: 2.963308, grad_norm: 1.138654
203
+ Steps: 0%| | 202/1000000 [00:53<67:49:17, 4.09it/s, grad_norm=1.14, loss_final=3.95, loss_mean=1.04, loss_mean_cls=2.96, proj_loss=-0.0582][2026-03-23 13:57:05] Step: 202, Training Logs: loss_final: 3.795771, loss_mean: 1.045783, proj_loss: -0.058287, loss_mean_cls: 2.808275, grad_norm: 1.378415
204
+ Steps: 0%| | 203/1000000 [00:54<67:47:45, 4.10it/s, grad_norm=1.38, loss_final=3.8, loss_mean=1.05, loss_mean_cls=2.81, proj_loss=-0.0583][2026-03-23 13:57:05] Step: 203, Training Logs: loss_final: 3.873578, loss_mean: 1.057585, proj_loss: -0.057143, loss_mean_cls: 2.873136, grad_norm: 1.039413
205
+ Steps: 0%| | 204/1000000 [00:54<67:47:23, 4.10it/s, grad_norm=1.04, loss_final=3.87, loss_mean=1.06, loss_mean_cls=2.87, proj_loss=-0.0571][2026-03-23 13:57:06] Step: 204, Training Logs: loss_final: 3.869634, loss_mean: 1.025371, proj_loss: -0.059635, loss_mean_cls: 2.903898, grad_norm: 1.113446
206
+ Steps: 0%| | 205/1000000 [00:54<67:46:28, 4.10it/s, grad_norm=1.11, loss_final=3.87, loss_mean=1.03, loss_mean_cls=2.9, proj_loss=-0.0596][2026-03-23 13:57:06] Step: 205, Training Logs: loss_final: 4.516129, loss_mean: 1.026917, proj_loss: -0.059113, loss_mean_cls: 3.548324, grad_norm: 1.246919
207
+ Steps: 0%| | 206/1000000 [00:54<67:46:58, 4.10it/s, grad_norm=1.25, loss_final=4.52, loss_mean=1.03, loss_mean_cls=3.55, proj_loss=-0.0591][2026-03-23 13:57:06] Step: 206, Training Logs: loss_final: 3.870474, loss_mean: 1.033348, proj_loss: -0.057558, loss_mean_cls: 2.894683, grad_norm: 1.104687
208
+ Steps: 0%| | 207/1000000 [00:55<67:49:28, 4.09it/s, grad_norm=1.1, loss_final=3.87, loss_mean=1.03, loss_mean_cls=2.89, proj_loss=-0.0576][2026-03-23 13:57:06] Step: 207, Training Logs: loss_final: 3.657545, loss_mean: 1.036661, proj_loss: -0.054751, loss_mean_cls: 2.675634, grad_norm: 1.384344
209
+ Steps: 0%| | 208/1000000 [00:55<67:48:34, 4.10it/s, grad_norm=1.38, loss_final=3.66, loss_mean=1.04, loss_mean_cls=2.68, proj_loss=-0.0548][2026-03-23 13:57:07] Step: 208, Training Logs: loss_final: 3.776420, loss_mean: 1.025812, proj_loss: -0.057821, loss_mean_cls: 2.808429, grad_norm: 1.096370
210
+ Steps: 0%| | 209/1000000 [00:55<67:48:19, 4.10it/s, grad_norm=1.1, loss_final=3.78, loss_mean=1.03, loss_mean_cls=2.81, proj_loss=-0.0578][2026-03-23 13:57:07] Step: 209, Training Logs: loss_final: 3.761090, loss_mean: 1.062359, proj_loss: -0.056922, loss_mean_cls: 2.755653, grad_norm: 1.903223
211
+ Steps: 0%| | 210/1000000 [00:55<67:48:31, 4.10it/s, grad_norm=1.9, loss_final=3.76, loss_mean=1.06, loss_mean_cls=2.76, proj_loss=-0.0569][2026-03-23 13:57:07] Step: 210, Training Logs: loss_final: 3.906418, loss_mean: 1.041829, proj_loss: -0.057181, loss_mean_cls: 2.921770, grad_norm: 1.735096
212
+ Steps: 0%| | 211/1000000 [00:56<67:48:32, 4.10it/s, grad_norm=1.74, loss_final=3.91, loss_mean=1.04, loss_mean_cls=2.92, proj_loss=-0.0572][2026-03-23 13:57:07] Step: 211, Training Logs: loss_final: 3.858996, loss_mean: 1.023147, proj_loss: -0.059874, loss_mean_cls: 2.895722, grad_norm: 1.423884
213
+ Steps: 0%| | 212/1000000 [00:56<67:46:51, 4.10it/s, grad_norm=1.42, loss_final=3.86, loss_mean=1.02, loss_mean_cls=2.9, proj_loss=-0.0599][2026-03-23 13:57:08] Step: 212, Training Logs: loss_final: 4.049764, loss_mean: 1.020845, proj_loss: -0.057857, loss_mean_cls: 3.086776, grad_norm: 1.539841
214
+ Steps: 0%| | 213/1000000 [00:56<67:48:03, 4.10it/s, grad_norm=1.54, loss_final=4.05, loss_mean=1.02, loss_mean_cls=3.09, proj_loss=-0.0579][2026-03-23 13:57:08] Step: 213, Training Logs: loss_final: 3.666535, loss_mean: 1.064638, proj_loss: -0.058414, loss_mean_cls: 2.660310, grad_norm: 1.364981
215
+ Steps: 0%| | 214/1000000 [00:56<67:48:11, 4.10it/s, grad_norm=1.36, loss_final=3.67, loss_mean=1.06, loss_mean_cls=2.66, proj_loss=-0.0584][2026-03-23 13:57:08] Step: 214, Training Logs: loss_final: 4.115675, loss_mean: 1.022230, proj_loss: -0.055882, loss_mean_cls: 3.149326, grad_norm: 1.587428
216
+ Steps: 0%| | 215/1000000 [00:57<67:46:32, 4.10it/s, grad_norm=1.59, loss_final=4.12, loss_mean=1.02, loss_mean_cls=3.15, proj_loss=-0.0559][2026-03-23 13:57:08] Step: 215, Training Logs: loss_final: 4.290797, loss_mean: 1.006664, proj_loss: -0.057540, loss_mean_cls: 3.341673, grad_norm: 1.364294
217
+ Steps: 0%| | 216/1000000 [00:57<67:48:39, 4.10it/s, grad_norm=1.36, loss_final=4.29, loss_mean=1.01, loss_mean_cls=3.34, proj_loss=-0.0575][2026-03-23 13:57:09] Step: 216, Training Logs: loss_final: 4.299802, loss_mean: 1.033919, proj_loss: -0.057627, loss_mean_cls: 3.323509, grad_norm: 2.076446
218
+ Steps: 0%| | 217/1000000 [00:57<67:48:50, 4.10it/s, grad_norm=2.08, loss_final=4.3, loss_mean=1.03, loss_mean_cls=3.32, proj_loss=-0.0576][2026-03-23 13:57:09] Step: 217, Training Logs: loss_final: 4.032390, loss_mean: 1.038509, proj_loss: -0.059780, loss_mean_cls: 3.053661, grad_norm: 2.095842
219
+ Steps: 0%| | 218/1000000 [00:57<67:48:13, 4.10it/s, grad_norm=2.1, loss_final=4.03, loss_mean=1.04, loss_mean_cls=3.05, proj_loss=-0.0598][2026-03-23 13:57:09] Step: 218, Training Logs: loss_final: 4.852066, loss_mean: 0.991051, proj_loss: -0.055897, loss_mean_cls: 3.916911, grad_norm: 1.589446
220
+ Steps: 0%| | 219/1000000 [00:57<67:54:34, 4.09it/s, grad_norm=1.59, loss_final=4.85, loss_mean=0.991, loss_mean_cls=3.92, proj_loss=-0.0559][2026-03-23 13:57:09] Step: 219, Training Logs: loss_final: 3.452114, loss_mean: 1.032779, proj_loss: -0.057448, loss_mean_cls: 2.476782, grad_norm: 2.759307
221
+ Steps: 0%| | 220/1000000 [00:58<67:53:33, 4.09it/s, grad_norm=2.76, loss_final=3.45, loss_mean=1.03, loss_mean_cls=2.48, proj_loss=-0.0574][2026-03-23 13:57:09] Step: 220, Training Logs: loss_final: 4.474224, loss_mean: 1.012484, proj_loss: -0.056469, loss_mean_cls: 3.518209, grad_norm: 2.037073
222
+ Steps: 0%| | 221/1000000 [00:58<67:58:08, 4.09it/s, grad_norm=2.04, loss_final=4.47, loss_mean=1.01, loss_mean_cls=3.52, proj_loss=-0.0565][2026-03-23 13:57:10] Step: 221, Training Logs: loss_final: 3.619599, loss_mean: 1.027089, proj_loss: -0.055216, loss_mean_cls: 2.647726, grad_norm: 1.876838
223
+ Steps: 0%| | 222/1000000 [00:58<67:55:14, 4.09it/s, grad_norm=1.88, loss_final=3.62, loss_mean=1.03, loss_mean_cls=2.65, proj_loss=-0.0552][2026-03-23 13:57:10] Step: 222, Training Logs: loss_final: 3.956278, loss_mean: 0.994196, proj_loss: -0.058793, loss_mean_cls: 3.020875, grad_norm: 1.756858
224
+ Steps: 0%| | 223/1000000 [00:58<67:57:51, 4.09it/s, grad_norm=1.76, loss_final=3.96, loss_mean=0.994, loss_mean_cls=3.02, proj_loss=-0.0588][2026-03-23 13:57:10] Step: 223, Training Logs: loss_final: 3.396563, loss_mean: 1.048192, proj_loss: -0.056162, loss_mean_cls: 2.404532, grad_norm: 1.566004
225
+ Steps: 0%| | 224/1000000 [00:59<67:54:18, 4.09it/s, grad_norm=1.57, loss_final=3.4, loss_mean=1.05, loss_mean_cls=2.4, proj_loss=-0.0562][2026-03-23 13:57:10] Step: 224, Training Logs: loss_final: 4.041858, loss_mean: 1.026601, proj_loss: -0.059143, loss_mean_cls: 3.074400, grad_norm: 1.464198
226
+ Steps: 0%| | 225/1000000 [00:59<67:52:38, 4.09it/s, grad_norm=1.46, loss_final=4.04, loss_mean=1.03, loss_mean_cls=3.07, proj_loss=-0.0591][2026-03-23 13:57:11] Step: 225, Training Logs: loss_final: 3.787225, loss_mean: 1.008417, proj_loss: -0.057760, loss_mean_cls: 2.836567, grad_norm: 1.825020
227
+ Steps: 0%| | 226/1000000 [00:59<67:52:45, 4.09it/s, grad_norm=1.83, loss_final=3.79, loss_mean=1.01, loss_mean_cls=2.84, proj_loss=-0.0578][2026-03-23 13:57:11] Step: 226, Training Logs: loss_final: 3.892933, loss_mean: 1.035562, proj_loss: -0.058638, loss_mean_cls: 2.916008, grad_norm: 1.450337
228
+ Steps: 0%| | 227/1000000 [00:59<67:53:30, 4.09it/s, grad_norm=1.45, loss_final=3.89, loss_mean=1.04, loss_mean_cls=2.92, proj_loss=-0.0586][2026-03-23 13:57:11] Step: 227, Training Logs: loss_final: 4.605734, loss_mean: 1.015767, proj_loss: -0.058328, loss_mean_cls: 3.648294, grad_norm: 1.576961
229
+ Steps: 0%| | 228/1000000 [01:00<67:52:30, 4.09it/s, grad_norm=1.58, loss_final=4.61, loss_mean=1.02, loss_mean_cls=3.65, proj_loss=-0.0583][2026-03-23 13:57:11] Step: 228, Training Logs: loss_final: 3.988146, loss_mean: 1.020183, proj_loss: -0.057191, loss_mean_cls: 3.025154, grad_norm: 1.494521
230
+ Steps: 0%| | 229/1000000 [01:00<67:49:43, 4.09it/s, grad_norm=1.49, loss_final=3.99, loss_mean=1.02, loss_mean_cls=3.03, proj_loss=-0.0572][2026-03-23 13:57:12] Step: 229, Training Logs: loss_final: 4.138639, loss_mean: 1.001511, proj_loss: -0.058552, loss_mean_cls: 3.195680, grad_norm: 1.869578
231
+ Steps: 0%| | 230/1000000 [01:00<67:48:23, 4.10it/s, grad_norm=1.87, loss_final=4.14, loss_mean=1, loss_mean_cls=3.2, proj_loss=-0.0586][2026-03-23 13:57:12] Step: 230, Training Logs: loss_final: 3.587549, loss_mean: 1.026086, proj_loss: -0.057275, loss_mean_cls: 2.618739, grad_norm: 1.361713
232
+ Steps: 0%| | 231/1000000 [01:00<67:50:24, 4.09it/s, grad_norm=1.36, loss_final=3.59, loss_mean=1.03, loss_mean_cls=2.62, proj_loss=-0.0573][2026-03-23 13:57:12] Step: 231, Training Logs: loss_final: 4.299013, loss_mean: 1.019667, proj_loss: -0.057679, loss_mean_cls: 3.337025, grad_norm: 2.507548
233
+ Steps: 0%| | 232/1000000 [01:01<67:49:48, 4.09it/s, grad_norm=2.51, loss_final=4.3, loss_mean=1.02, loss_mean_cls=3.34, proj_loss=-0.0577][2026-03-23 13:57:12] Step: 232, Training Logs: loss_final: 3.918818, loss_mean: 1.041992, proj_loss: -0.057748, loss_mean_cls: 2.934574, grad_norm: 1.768478
234
+ Steps: 0%| | 233/1000000 [01:01<67:48:26, 4.10it/s, grad_norm=1.77, loss_final=3.92, loss_mean=1.04, loss_mean_cls=2.93, proj_loss=-0.0577][2026-03-23 13:57:13] Step: 233, Training Logs: loss_final: 4.135336, loss_mean: 1.012130, proj_loss: -0.055806, loss_mean_cls: 3.179013, grad_norm: 2.010638
235
+ Steps: 0%| | 234/1000000 [01:01<67:48:16, 4.10it/s, grad_norm=2.01, loss_final=4.14, loss_mean=1.01, loss_mean_cls=3.18, proj_loss=-0.0558][2026-03-23 13:57:13] Step: 234, Training Logs: loss_final: 4.622794, loss_mean: 1.008856, proj_loss: -0.059102, loss_mean_cls: 3.673040, grad_norm: 1.824655
236
+ Steps: 0%| | 235/1000000 [01:01<67:48:42, 4.10it/s, grad_norm=1.82, loss_final=4.62, loss_mean=1.01, loss_mean_cls=3.67, proj_loss=-0.0591][2026-03-23 13:57:13] Step: 235, Training Logs: loss_final: 4.330180, loss_mean: 1.019101, proj_loss: -0.054455, loss_mean_cls: 3.365534, grad_norm: 1.865192
237
+ Steps: 0%| | 236/1000000 [01:02<67:47:59, 4.10it/s, grad_norm=1.87, loss_final=4.33, loss_mean=1.02, loss_mean_cls=3.37, proj_loss=-0.0545][2026-03-23 13:57:13] Step: 236, Training Logs: loss_final: 4.471676, loss_mean: 1.005463, proj_loss: -0.057791, loss_mean_cls: 3.524004, grad_norm: 1.992857
238
+ Steps: 0%| | 237/1000000 [01:02<67:48:33, 4.10it/s, grad_norm=1.99, loss_final=4.47, loss_mean=1.01, loss_mean_cls=3.52, proj_loss=-0.0578][2026-03-23 13:57:14] Step: 237, Training Logs: loss_final: 4.708742, loss_mean: 1.007163, proj_loss: -0.057294, loss_mean_cls: 3.758873, grad_norm: 1.696959
239
+ Steps: 0%| | 238/1000000 [01:02<67:48:46, 4.10it/s, grad_norm=1.7, loss_final=4.71, loss_mean=1.01, loss_mean_cls=3.76, proj_loss=-0.0573][2026-03-23 13:57:14] Step: 238, Training Logs: loss_final: 4.700453, loss_mean: 1.011082, proj_loss: -0.057899, loss_mean_cls: 3.747270, grad_norm: 2.577658
240
+ Steps: 0%| | 239/1000000 [01:02<67:48:18, 4.10it/s, grad_norm=2.58, loss_final=4.7, loss_mean=1.01, loss_mean_cls=3.75, proj_loss=-0.0579][2026-03-23 13:57:14] Step: 239, Training Logs: loss_final: 4.519000, loss_mean: 0.999633, proj_loss: -0.058458, loss_mean_cls: 3.577825, grad_norm: 2.147466
241
+ Steps: 0%| | 240/1000000 [01:03<67:48:55, 4.10it/s, grad_norm=2.15, loss_final=4.52, loss_mean=1, loss_mean_cls=3.58, proj_loss=-0.0585][2026-03-23 13:57:14] Step: 240, Training Logs: loss_final: 3.874477, loss_mean: 1.005100, proj_loss: -0.059289, loss_mean_cls: 2.928666, grad_norm: 2.779819
242
+ Steps: 0%| | 241/1000000 [01:03<67:50:12, 4.09it/s, grad_norm=2.78, loss_final=3.87, loss_mean=1.01, loss_mean_cls=2.93, proj_loss=-0.0593][2026-03-23 13:57:15] Step: 241, Training Logs: loss_final: 4.244301, loss_mean: 0.978814, proj_loss: -0.057319, loss_mean_cls: 3.322806, grad_norm: 1.934856
243
+ Steps: 0%| | 242/1000000 [01:03<67:50:20, 4.09it/s, grad_norm=1.93, loss_final=4.24, loss_mean=0.979, loss_mean_cls=3.32, proj_loss=-0.0573][2026-03-23 13:57:15] Step: 242, Training Logs: loss_final: 4.099132, loss_mean: 1.017777, proj_loss: -0.056814, loss_mean_cls: 3.138168, grad_norm: 2.223382
244
+ Steps: 0%| | 243/1000000 [01:03<67:49:22, 4.09it/s, grad_norm=2.22, loss_final=4.1, loss_mean=1.02, loss_mean_cls=3.14, proj_loss=-0.0568][2026-03-23 13:57:15] Step: 243, Training Logs: loss_final: 3.643645, loss_mean: 1.043657, proj_loss: -0.057582, loss_mean_cls: 2.657569, grad_norm: 1.876094
245
+ Steps: 0%| | 244/1000000 [01:04<67:48:22, 4.10it/s, grad_norm=1.88, loss_final=3.64, loss_mean=1.04, loss_mean_cls=2.66, proj_loss=-0.0576][2026-03-23 13:57:15] Step: 244, Training Logs: loss_final: 3.791383, loss_mean: 1.039013, proj_loss: -0.057821, loss_mean_cls: 2.810191, grad_norm: 2.101933
246
+ Steps: 0%| | 245/1000000 [01:04<67:48:03, 4.10it/s, grad_norm=2.1, loss_final=3.79, loss_mean=1.04, loss_mean_cls=2.81, proj_loss=-0.0578][2026-03-23 13:57:16] Step: 245, Training Logs: loss_final: 4.122998, loss_mean: 1.003735, proj_loss: -0.056978, loss_mean_cls: 3.176242, grad_norm: 1.748520
247
+ Steps: 0%| | 246/1000000 [01:04<67:47:44, 4.10it/s, grad_norm=1.75, loss_final=4.12, loss_mean=1, loss_mean_cls=3.18, proj_loss=-0.057][2026-03-23 13:57:16] Step: 246, Training Logs: loss_final: 3.975110, loss_mean: 1.031366, proj_loss: -0.058318, loss_mean_cls: 3.002061, grad_norm: 2.088799
248
+ Steps: 0%| | 247/1000000 [01:04<67:49:45, 4.09it/s, grad_norm=2.09, loss_final=3.98, loss_mean=1.03, loss_mean_cls=3, proj_loss=-0.0583][2026-03-23 13:57:16] Step: 247, Training Logs: loss_final: 3.795434, loss_mean: 1.016577, proj_loss: -0.059947, loss_mean_cls: 2.838804, grad_norm: 1.548398
249
+ Steps: 0%| | 248/1000000 [01:05<67:48:52, 4.10it/s, grad_norm=1.55, loss_final=3.8, loss_mean=1.02, loss_mean_cls=2.84, proj_loss=-0.0599][2026-03-23 13:57:16] Step: 248, Training Logs: loss_final: 3.547227, loss_mean: 1.007156, proj_loss: -0.057843, loss_mean_cls: 2.597914, grad_norm: 2.097709
250
+ Steps: 0%| | 249/1000000 [01:05<67:47:51, 4.10it/s, grad_norm=2.1, loss_final=3.55, loss_mean=1.01, loss_mean_cls=2.6, proj_loss=-0.0578][2026-03-23 13:57:17] Step: 249, Training Logs: loss_final: 3.670953, loss_mean: 1.005734, proj_loss: -0.056730, loss_mean_cls: 2.721948, grad_norm: 1.993682
251
+ Steps: 0%| | 250/1000000 [01:05<67:48:32, 4.10it/s, grad_norm=1.99, loss_final=3.67, loss_mean=1.01, loss_mean_cls=2.72, proj_loss=-0.0567][2026-03-23 13:57:17] Step: 250, Training Logs: loss_final: 3.783008, loss_mean: 1.009454, proj_loss: -0.057429, loss_mean_cls: 2.830983, grad_norm: 2.015732
252
+ Steps: 0%| | 251/1000000 [01:05<67:49:52, 4.09it/s, grad_norm=2.02, loss_final=3.78, loss_mean=1.01, loss_mean_cls=2.83, proj_loss=-0.0574][2026-03-23 13:57:17] Step: 251, Training Logs: loss_final: 3.917508, loss_mean: 1.015374, proj_loss: -0.057006, loss_mean_cls: 2.959140, grad_norm: 1.911892
253
+ Steps: 0%| | 252/1000000 [01:06<67:48:52, 4.10it/s, grad_norm=1.91, loss_final=3.92, loss_mean=1.02, loss_mean_cls=2.96, proj_loss=-0.057][2026-03-23 13:57:17] Step: 252, Training Logs: loss_final: 4.345321, loss_mean: 0.999146, proj_loss: -0.059442, loss_mean_cls: 3.405616, grad_norm: 2.544859
254
+ Steps: 0%| | 253/1000000 [01:06<67:48:39, 4.10it/s, grad_norm=2.54, loss_final=4.35, loss_mean=0.999, loss_mean_cls=3.41, proj_loss=-0.0594][2026-03-23 13:57:18] Step: 253, Training Logs: loss_final: 3.264407, loss_mean: 1.012033, proj_loss: -0.060338, loss_mean_cls: 2.312712, grad_norm: 2.002759
255
+ Steps: 0%| | 254/1000000 [01:06<67:47:32, 4.10it/s, grad_norm=2, loss_final=3.26, loss_mean=1.01, loss_mean_cls=2.31, proj_loss=-0.0603][2026-03-23 13:57:18] Step: 254, Training Logs: loss_final: 3.680009, loss_mean: 1.009429, proj_loss: -0.057669, loss_mean_cls: 2.728250, grad_norm: 2.470939
256
+ Steps: 0%| | 255/1000000 [01:06<67:46:55, 4.10it/s, grad_norm=2.47, loss_final=3.68, loss_mean=1.01, loss_mean_cls=2.73, proj_loss=-0.0577][2026-03-23 13:57:18] Step: 255, Training Logs: loss_final: 3.617963, loss_mean: 1.040573, proj_loss: -0.058244, loss_mean_cls: 2.635633, grad_norm: 2.220286
257
+ Steps: 0%| | 256/1000000 [01:07<67:48:18, 4.10it/s, grad_norm=2.22, loss_final=3.62, loss_mean=1.04, loss_mean_cls=2.64, proj_loss=-0.0582][2026-03-23 13:57:18] Step: 256, Training Logs: loss_final: 4.080931, loss_mean: 0.997066, proj_loss: -0.058532, loss_mean_cls: 3.142397, grad_norm: 2.140220
258
+ Steps: 0%| | 257/1000000 [01:07<67:47:49, 4.10it/s, grad_norm=2.14, loss_final=4.08, loss_mean=0.997, loss_mean_cls=3.14, proj_loss=-0.0585][2026-03-23 13:57:19] Step: 257, Training Logs: loss_final: 3.699068, loss_mean: 1.000059, proj_loss: -0.058102, loss_mean_cls: 2.757111, grad_norm: 1.981827
259
+ Steps: 0%| | 258/1000000 [01:07<67:46:24, 4.10it/s, grad_norm=1.98, loss_final=3.7, loss_mean=1, loss_mean_cls=2.76, proj_loss=-0.0581][2026-03-23 13:57:19] Step: 258, Training Logs: loss_final: 3.992317, loss_mean: 1.029925, proj_loss: -0.058207, loss_mean_cls: 3.020598, grad_norm: 2.269195
260
+ Steps: 0%| | 259/1000000 [01:07<67:46:06, 4.10it/s, grad_norm=2.27, loss_final=3.99, loss_mean=1.03, loss_mean_cls=3.02, proj_loss=-0.0582][2026-03-23 13:57:19] Step: 259, Training Logs: loss_final: 3.969361, loss_mean: 1.005176, proj_loss: -0.055455, loss_mean_cls: 3.019640, grad_norm: 2.430135
261
+ Steps: 0%| | 260/1000000 [01:08<67:47:12, 4.10it/s, grad_norm=2.43, loss_final=3.97, loss_mean=1.01, loss_mean_cls=3.02, proj_loss=-0.0555][2026-03-23 13:57:19] Step: 260, Training Logs: loss_final: 3.946629, loss_mean: 0.986893, proj_loss: -0.055241, loss_mean_cls: 3.014977, grad_norm: 2.162997
262
+ Steps: 0%| | 261/1000000 [01:08<79:04:37, 3.51it/s, grad_norm=2.16, loss_final=3.95, loss_mean=0.987, loss_mean_cls=3.01, proj_loss=-0.0552][2026-03-23 13:57:20] Step: 261, Training Logs: loss_final: 3.920058, loss_mean: 0.984585, proj_loss: -0.057787, loss_mean_cls: 2.993260, grad_norm: 2.114045
263
+ Steps: 0%| | 262/1000000 [01:08<78:26:04, 3.54it/s, grad_norm=2.11, loss_final=3.92, loss_mean=0.985, loss_mean_cls=2.99, proj_loss=-0.0578][2026-03-23 13:57:20] Step: 262, Training Logs: loss_final: 4.358081, loss_mean: 0.977391, proj_loss: -0.057046, loss_mean_cls: 3.437736, grad_norm: 2.065785
264
+ Steps: 0%| | 263/1000000 [01:08<75:13:43, 3.69it/s, grad_norm=2.07, loss_final=4.36, loss_mean=0.977, loss_mean_cls=3.44, proj_loss=-0.057][2026-03-23 13:57:20] Step: 263, Training Logs: loss_final: 3.840959, loss_mean: 0.999257, proj_loss: -0.056460, loss_mean_cls: 2.898163, grad_norm: 2.413785
265
+ Steps: 0%| | 264/1000000 [01:09<72:59:25, 3.80it/s, grad_norm=2.41, loss_final=3.84, loss_mean=0.999, loss_mean_cls=2.9, proj_loss=-0.0565][2026-03-23 13:57:20] Step: 264, Training Logs: loss_final: 3.940903, loss_mean: 1.023376, proj_loss: -0.057920, loss_mean_cls: 2.975446, grad_norm: 1.844397
266
+ Steps: 0%| | 265/1000000 [01:09<71:26:14, 3.89it/s, grad_norm=1.84, loss_final=3.94, loss_mean=1.02, loss_mean_cls=2.98, proj_loss=-0.0579][2026-03-23 13:57:21] Step: 265, Training Logs: loss_final: 4.206155, loss_mean: 0.981618, proj_loss: -0.059432, loss_mean_cls: 3.283970, grad_norm: 2.069561
267
+ Steps: 0%| | 266/1000000 [01:09<70:20:03, 3.95it/s, grad_norm=2.07, loss_final=4.21, loss_mean=0.982, loss_mean_cls=3.28, proj_loss=-0.0594][2026-03-23 13:57:21] Step: 266, Training Logs: loss_final: 3.804409, loss_mean: 0.990602, proj_loss: -0.060707, loss_mean_cls: 2.874513, grad_norm: 1.578623
268
+ Steps: 0%| | 267/1000000 [01:09<69:36:48, 3.99it/s, grad_norm=1.58, loss_final=3.8, loss_mean=0.991, loss_mean_cls=2.87, proj_loss=-0.0607][2026-03-23 13:57:21] Step: 267, Training Logs: loss_final: 4.086649, loss_mean: 1.005726, proj_loss: -0.058538, loss_mean_cls: 3.139461, grad_norm: 2.692237
269
+ Steps: 0%| | 268/1000000 [01:10<69:04:32, 4.02it/s, grad_norm=2.69, loss_final=4.09, loss_mean=1.01, loss_mean_cls=3.14, proj_loss=-0.0585][2026-03-23 13:57:21] Step: 268, Training Logs: loss_final: 3.484965, loss_mean: 1.013161, proj_loss: -0.056986, loss_mean_cls: 2.528790, grad_norm: 2.118618
270
+ Steps: 0%| | 269/1000000 [01:10<69:49:28, 3.98it/s, grad_norm=2.12, loss_final=3.48, loss_mean=1.01, loss_mean_cls=2.53, proj_loss=-0.057][2026-03-23 13:57:22] Step: 269, Training Logs: loss_final: 3.312028, loss_mean: 1.008822, proj_loss: -0.058315, loss_mean_cls: 2.361521, grad_norm: 1.997808
271
+ Steps: 0%| | 270/1000000 [01:10<68:55:10, 4.03it/s, grad_norm=2, loss_final=3.31, loss_mean=1.01, loss_mean_cls=2.36, proj_loss=-0.0583][2026-03-23 13:57:22] Step: 270, Training Logs: loss_final: 3.996657, loss_mean: 1.009214, proj_loss: -0.058962, loss_mean_cls: 3.046405, grad_norm: 2.755926
272
+ Steps: 0%| | 271/1000000 [01:10<68:39:18, 4.04it/s, grad_norm=2.76, loss_final=4, loss_mean=1.01, loss_mean_cls=3.05, proj_loss=-0.059][2026-03-23 13:57:22] Step: 271, Training Logs: loss_final: 4.391388, loss_mean: 1.024230, proj_loss: -0.058988, loss_mean_cls: 3.426146, grad_norm: 2.096098
273
+ Steps: 0%| | 272/1000000 [01:11<68:23:53, 4.06it/s, grad_norm=2.1, loss_final=4.39, loss_mean=1.02, loss_mean_cls=3.43, proj_loss=-0.059][2026-03-23 13:57:22] Step: 272, Training Logs: loss_final: 3.920087, loss_mean: 1.024717, proj_loss: -0.059831, loss_mean_cls: 2.955201, grad_norm: 2.465345
274
+ Steps: 0%| | 273/1000000 [01:11<68:12:32, 4.07it/s, grad_norm=2.47, loss_final=3.92, loss_mean=1.02, loss_mean_cls=2.96, proj_loss=-0.0598][2026-03-23 13:57:23] Step: 273, Training Logs: loss_final: 3.677792, loss_mean: 1.019457, proj_loss: -0.056399, loss_mean_cls: 2.714733, grad_norm: 2.047685
275
+ Steps: 0%| | 274/1000000 [01:11<68:06:48, 4.08it/s, grad_norm=2.05, loss_final=3.68, loss_mean=1.02, loss_mean_cls=2.71, proj_loss=-0.0564][2026-03-23 13:57:23] Step: 274, Training Logs: loss_final: 3.611300, loss_mean: 1.009368, proj_loss: -0.057628, loss_mean_cls: 2.659561, grad_norm: 1.780220
276
+ Steps: 0%| | 275/1000000 [01:11<68:01:32, 4.08it/s, grad_norm=1.78, loss_final=3.61, loss_mean=1.01, loss_mean_cls=2.66, proj_loss=-0.0576][2026-03-23 13:57:23] Step: 275, Training Logs: loss_final: 3.762297, loss_mean: 1.000715, proj_loss: -0.057146, loss_mean_cls: 2.818728, grad_norm: 2.212818
277
+ Steps: 0%| | 276/1000000 [01:12<67:56:22, 4.09it/s, grad_norm=2.21, loss_final=3.76, loss_mean=1, loss_mean_cls=2.82, proj_loss=-0.0571][2026-03-23 13:57:23] Step: 276, Training Logs: loss_final: 4.110180, loss_mean: 1.001855, proj_loss: -0.058339, loss_mean_cls: 3.166663, grad_norm: 3.446807
278
+ Steps: 0%| | 277/1000000 [01:12<67:54:20, 4.09it/s, grad_norm=3.45, loss_final=4.11, loss_mean=1, loss_mean_cls=3.17, proj_loss=-0.0583][2026-03-23 13:57:24] Step: 277, Training Logs: loss_final: 3.589436, loss_mean: 1.014081, proj_loss: -0.058753, loss_mean_cls: 2.634108, grad_norm: 1.866073
279
+ Steps: 0%| | 278/1000000 [01:12<67:51:55, 4.09it/s, grad_norm=1.87, loss_final=3.59, loss_mean=1.01, loss_mean_cls=2.63, proj_loss=-0.0588][2026-03-23 13:57:24] Step: 278, Training Logs: loss_final: 3.909306, loss_mean: 1.036322, proj_loss: -0.057491, loss_mean_cls: 2.930474, grad_norm: 3.416674
280
+ Steps: 0%| | 279/1000000 [01:12<67:51:48, 4.09it/s, grad_norm=3.42, loss_final=3.91, loss_mean=1.04, loss_mean_cls=2.93, proj_loss=-0.0575][2026-03-23 13:57:24] Step: 279, Training Logs: loss_final: 3.635005, loss_mean: 1.021109, proj_loss: -0.058334, loss_mean_cls: 2.672229, grad_norm: 2.773069
281
+ Steps: 0%| | 280/1000000 [01:13<67:50:10, 4.09it/s, grad_norm=2.77, loss_final=3.64, loss_mean=1.02, loss_mean_cls=2.67, proj_loss=-0.0583][2026-03-23 13:57:24] Step: 280, Training Logs: loss_final: 4.731202, loss_mean: 0.976675, proj_loss: -0.056143, loss_mean_cls: 3.810670, grad_norm: 2.045803
282
+ Steps: 0%| | 281/1000000 [01:13<67:48:49, 4.10it/s, grad_norm=2.05, loss_final=4.73, loss_mean=0.977, loss_mean_cls=3.81, proj_loss=-0.0561][2026-03-23 13:57:25] Step: 281, Training Logs: loss_final: 4.352121, loss_mean: 1.020213, proj_loss: -0.057039, loss_mean_cls: 3.388947, grad_norm: 3.085879
283
+ Steps: 0%| | 282/1000000 [01:13<67:48:21, 4.10it/s, grad_norm=3.09, loss_final=4.35, loss_mean=1.02, loss_mean_cls=3.39, proj_loss=-0.057][2026-03-23 13:57:25] Step: 282, Training Logs: loss_final: 3.645599, loss_mean: 1.009737, proj_loss: -0.056890, loss_mean_cls: 2.692752, grad_norm: 2.743040
284
+ Steps: 0%| | 283/1000000 [01:13<67:47:40, 4.10it/s, grad_norm=2.74, loss_final=3.65, loss_mean=1.01, loss_mean_cls=2.69, proj_loss=-0.0569][2026-03-23 13:57:25] Step: 283, Training Logs: loss_final: 3.621672, loss_mean: 1.014742, proj_loss: -0.058281, loss_mean_cls: 2.665210, grad_norm: 2.332768
285
+ Steps: 0%| | 284/1000000 [01:14<67:48:17, 4.10it/s, grad_norm=2.33, loss_final=3.62, loss_mean=1.01, loss_mean_cls=2.67, proj_loss=-0.0583][2026-03-23 13:57:25] Step: 284, Training Logs: loss_final: 3.634438, loss_mean: 1.014846, proj_loss: -0.058438, loss_mean_cls: 2.678030, grad_norm: 2.844196
286
+ Steps: 0%| | 285/1000000 [01:14<67:48:15, 4.10it/s, grad_norm=2.84, loss_final=3.63, loss_mean=1.01, loss_mean_cls=2.68, proj_loss=-0.0584][2026-03-23 13:57:26] Step: 285, Training Logs: loss_final: 3.942388, loss_mean: 0.986733, proj_loss: -0.056891, loss_mean_cls: 3.012546, grad_norm: 2.134713
287
+ Steps: 0%| | 286/1000000 [01:14<67:47:40, 4.10it/s, grad_norm=2.13, loss_final=3.94, loss_mean=0.987, loss_mean_cls=3.01, proj_loss=-0.0569][2026-03-23 13:57:26] Step: 286, Training Logs: loss_final: 3.926682, loss_mean: 1.011870, proj_loss: -0.057176, loss_mean_cls: 2.971987, grad_norm: 2.768237
288
+ Steps: 0%| | 287/1000000 [01:14<67:46:57, 4.10it/s, grad_norm=2.77, loss_final=3.93, loss_mean=1.01, loss_mean_cls=2.97, proj_loss=-0.0572][2026-03-23 13:57:26] Step: 287, Training Logs: loss_final: 4.047902, loss_mean: 1.002864, proj_loss: -0.059520, loss_mean_cls: 3.104558, grad_norm: 2.354553
289
+ Steps: 0%| | 287/1000000 [01:14<67:46:57, 4.10it/s, grad_norm=2.35, loss_final=4.05, loss_mean=1, loss_mean_cls=3.1, proj_loss=-0.0595]
REG/wandb/run-20260323_135607-zue1y2ba/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260323_135607-zue1y2ba/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-23T05:56:07.858187Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment-0.75",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.75",
41
+ "--cls",
42
+ "0.2",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357568126976"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "l4vui4vfnl881ctol25fj9y70t6im9l9"
101
+ }
REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"time":"2026-03-23T13:56:08.183465712+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-23T13:56:10.661719633+08:00","level":"INFO","msg":"stream: created new stream","id":"zue1y2ba"}
3
+ {"time":"2026-03-23T13:56:10.661931952+08:00","level":"INFO","msg":"handler: started","stream_id":"zue1y2ba"}
4
+ {"time":"2026-03-23T13:56:10.662874633+08:00","level":"INFO","msg":"stream: started","id":"zue1y2ba"}
5
+ {"time":"2026-03-23T13:56:10.662895027+08:00","level":"INFO","msg":"writer: started","stream_id":"zue1y2ba"}
6
+ {"time":"2026-03-23T13:56:10.662918583+08:00","level":"INFO","msg":"sender: started","stream_id":"zue1y2ba"}
REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Configure stats pid to 397944
3
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log
5
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log
6
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():892] starting backend
10
+ 2026-03-23 13:56:08,167 INFO MainThread:397944 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-23 13:56:08,180 INFO MainThread:397944 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-23 13:56:08,181 INFO MainThread:397944 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-23 13:56:08,194 INFO MainThread:397944 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-23 13:56:11,614 INFO MainThread:397944 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-23 13:56:11,706 INFO MainThread:397944 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-23 13:56:11,712 INFO MainThread:397944 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-23 13:56:11,713 INFO MainThread:397944 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.75, 'ot_cls': True}
REG/wandb/run-20260323_135841-w9holkos/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
REG/wandb/run-20260323_135841-w9holkos/files/wandb-metadata.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-23T05:58:41.322248Z",
5
+ "args": [
6
+ "--report-to",
7
+ "wandb",
8
+ "--allow-tf32",
9
+ "--mixed-precision",
10
+ "bf16",
11
+ "--seed",
12
+ "0",
13
+ "--path-type",
14
+ "linear",
15
+ "--prediction",
16
+ "v",
17
+ "--weighting",
18
+ "uniform",
19
+ "--model",
20
+ "SiT-XL/2",
21
+ "--enc-type",
22
+ "dinov2-vit-b",
23
+ "--encoder-depth",
24
+ "8",
25
+ "--proj-coeff",
26
+ "0.5",
27
+ "--output-dir",
28
+ "exps",
29
+ "--exp-name",
30
+ "jsflow-experiment-0.75",
31
+ "--batch-size",
32
+ "256",
33
+ "--data-dir",
34
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
35
+ "--semantic-features-dir",
36
+ "/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
37
+ "--learning-rate",
38
+ "0.00005",
39
+ "--t-c",
40
+ "0.75",
41
+ "--cls",
42
+ "0.05",
43
+ "--ot-cls"
44
+ ],
45
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
46
+ "codePath": "train.py",
47
+ "codePathLocal": "train.py",
48
+ "git": {
49
+ "remote": "https://github.com/Martinser/REG.git",
50
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
51
+ },
52
+ "email": "2365972933@qq.com",
53
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
54
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
55
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
56
+ "cpu_count": 96,
57
+ "cpu_count_logical": 192,
58
+ "gpu": "NVIDIA H100 80GB HBM3",
59
+ "gpu_count": 4,
60
+ "disk": {
61
+ "/": {
62
+ "total": "3838880616448",
63
+ "used": "357568360448"
64
+ }
65
+ },
66
+ "memory": {
67
+ "total": "2164115296256"
68
+ },
69
+ "gpu_nvidia": [
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
76
+ },
77
+ {
78
+ "name": "NVIDIA H100 80GB HBM3",
79
+ "memoryTotal": "85520809984",
80
+ "cudaCores": 16896,
81
+ "architecture": "Hopper",
82
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
83
+ },
84
+ {
85
+ "name": "NVIDIA H100 80GB HBM3",
86
+ "memoryTotal": "85520809984",
87
+ "cudaCores": 16896,
88
+ "architecture": "Hopper",
89
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
90
+ },
91
+ {
92
+ "name": "NVIDIA H100 80GB HBM3",
93
+ "memoryTotal": "85520809984",
94
+ "cudaCores": 16896,
95
+ "architecture": "Hopper",
96
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
97
+ }
98
+ ],
99
+ "cudaVersion": "13.0",
100
+ "writerId": "nlbia82zbry6kpqgagoidmc6x8szwd5d"
101
+ }
REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-23T13:58:41.647788404+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-23T13:58:42.578470875+08:00","level":"INFO","msg":"stream: created new stream","id":"w9holkos"}
3
+ {"time":"2026-03-23T13:58:42.578676113+08:00","level":"INFO","msg":"handler: started","stream_id":"w9holkos"}
4
+ {"time":"2026-03-23T13:58:42.579473589+08:00","level":"INFO","msg":"stream: started","id":"w9holkos"}
5
+ {"time":"2026-03-23T13:58:42.57951741+08:00","level":"INFO","msg":"sender: started","stream_id":"w9holkos"}
6
+ {"time":"2026-03-23T13:58:42.579478227+08:00","level":"INFO","msg":"writer: started","stream_id":"w9holkos"}
7
+ {"time":"2026-03-23T14:49:13.568442881+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
8
+ {"time":"2026-03-23T14:52:15.597652411+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
9
+ {"time":"2026-03-23T14:52:26.072213509+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": write tcp 172.20.98.27:52324->35.186.228.49:443: write: broken pipe"}
10
+ {"time":"2026-03-23T17:02:52.905542765+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
11
+ {"time":"2026-03-23T17:05:55.176103762+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
12
+ {"time":"2026-03-23T17:06:10.164453104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
13
+ {"time":"2026-03-23T22:05:06.25355716+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:44154->35.186.228.49:443: read: connection reset by peer"}
14
+ {"time":"2026-03-23T22:05:20.791067182+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:40392->35.186.228.49:443: read: connection reset by peer"}
15
+ {"time":"2026-03-24T02:18:38.770696332+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
16
+ {"time":"2026-03-24T06:25:41.879737278+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
17
+ {"time":"2026-03-24T06:30:14.989373032+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
18
+ {"time":"2026-03-24T09:05:02.85908394+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:46722->35.186.228.49:443: read: connection reset by peer"}
19
+ {"time":"2026-03-25T04:41:04.741907157+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
REG/wandb/run-20260323_135841-w9holkos/logs/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Configure stats pid to 400275
3
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug.log
5
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log
6
+ 2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():892] starting backend
10
+ 2026-03-23 13:58:41,630 INFO MainThread:400275 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-23 13:58:41,643 INFO MainThread:400275 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-23 13:58:41,646 INFO MainThread:400275 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-23 13:58:41,659 INFO MainThread:400275 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-23 13:58:43,108 INFO MainThread:400275 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-23 13:58:43,202 INFO MainThread:400275 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-23 13:58:43,209 INFO MainThread:400275 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-23 13:58:43,210 INFO MainThread:400275 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.05, 't_c': 0.75, 'ot_cls': True}