Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- REG copy/LICENSE +21 -0
- REG copy/README.md +156 -0
- REG copy/dataset.py +80 -0
- REG copy/eval.sh +52 -0
- REG copy/generate.py +243 -0
- REG copy/loss.py +102 -0
- REG copy/requirements.txt +97 -0
- REG copy/samplers.py +169 -0
- REG copy/utils.py +225 -0
- REG/LICENSE +21 -0
- REG/README.md +156 -0
- REG/dataset.py +149 -0
- REG/eval.sh +52 -0
- REG/eval_custom_0.25.log +1 -0
- REG/generate.py +227 -0
- REG/loss.py +193 -0
- REG/requirements.txt +97 -0
- REG/sample_from_checkpoint.py +611 -0
- REG/sample_from_checkpoint_ddp.py +416 -0
- REG/samplers.py +840 -0
- REG/samples.sh +15 -0
- REG/samples_0.25_new.log +43 -0
- REG/samples_0.5.log +0 -0
- REG/samples_0.75.log +0 -0
- REG/samples_0.75_new.log +46 -0
- REG/samples_ddp.sh +32 -0
- REG/train.py +708 -0
- REG/train.sh +43 -0
- REG/train_resume_tc_velocity.sh +41 -0
- REG/utils.py +225 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml +202 -0
- REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-summary.json +1 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/files/config.yaml +202 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/files/output.log +15 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/files/requirements.txt +168 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-summary.json +1 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log +7 -0
- REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log +22 -0
- REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log +6 -0
- REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log +20 -0
- REG/wandb/run-20260323_135607-zue1y2ba/files/output.log +289 -0
- REG/wandb/run-20260323_135607-zue1y2ba/files/requirements.txt +168 -0
- REG/wandb/run-20260323_135607-zue1y2ba/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log +6 -0
- REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log +20 -0
- REG/wandb/run-20260323_135841-w9holkos/files/requirements.txt +168 -0
- REG/wandb/run-20260323_135841-w9holkos/files/wandb-metadata.json +101 -0
- REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log +19 -0
- REG/wandb/run-20260323_135841-w9holkos/logs/debug.log +20 -0
REG copy/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Sihyun Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
REG copy/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
|
| 3 |
+
</h1>
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup> 
|
| 6 |
+
<a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup> 
|
| 7 |
+
<a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup> 
|
| 8 |
+
<a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup> 
|
| 9 |
+
<a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup> 
|
| 10 |
+
<a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup> 
|
| 11 |
+
<a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup> 
|
| 12 |
+
<a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup> 
|
| 13 |
+
<a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup> 
|
| 14 |
+
<a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup> 
|
| 15 |
+
<a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup> 
|
| 16 |
+
<a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup> 
|
| 17 |
+
<p align="center">
|
| 18 |
+
$^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
|
| 19 |
+
$^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
|
| 20 |
+
<p align='center'>
|
| 21 |
+
<div align="center">
|
| 22 |
+
<a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
|
| 23 |
+
<a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
|
| 24 |
+
<a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
|
| 25 |
+
</div>
|
| 26 |
+
<p align='center'>
|
| 27 |
+
</p>
|
| 28 |
+
</p>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## 🚩 Overview
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
|
| 37 |
+
We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
|
| 38 |
+
|
| 39 |
+
In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
|
| 40 |
+
REG acquires the capability to produce coherent image-class pairs directly from pure noise,
|
| 41 |
+
substantially improving both generation quality and training efficiency.
|
| 42 |
+
This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
|
| 43 |
+
The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
|
| 44 |
+
|
| 45 |
+
On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
|
| 46 |
+
More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## 📰 News
|
| 51 |
+
|
| 52 |
+
- **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## 📝 Results
|
| 56 |
+
|
| 57 |
+
- Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
|
| 58 |
+
- $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
|
| 59 |
+
|
| 60 |
+
<div align="center">
|
| 61 |
+
<img src="fig/img.png" alt="Results">
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 📋 Plan
|
| 66 |
+
- More training steps on ImageNet 256&512 and T2I.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 👊 Usage
|
| 70 |
+
|
| 71 |
+
### 1. Environment setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
conda create -n reg python=3.10.16 -y
|
| 75 |
+
conda activate reg
|
| 76 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
|
| 77 |
+
pip install -r requirements.txt
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 2. Dataset
|
| 81 |
+
|
| 82 |
+
#### Dataset download
|
| 83 |
+
|
| 84 |
+
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
|
| 85 |
+
|
| 86 |
+
#### Preprocessing data
|
| 87 |
+
Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
|
| 88 |
+
|
| 89 |
+
### 3. Training
|
| 90 |
+
Run train.sh
|
| 91 |
+
```bash
|
| 92 |
+
bash train.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
train.sh contains the following content.
|
| 96 |
+
```bash
|
| 97 |
+
accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
|
| 98 |
+
--report-to="wandb" \
|
| 99 |
+
--allow-tf32 \
|
| 100 |
+
--mixed-precision="fp16" \
|
| 101 |
+
--seed=0 \
|
| 102 |
+
--path-type="linear" \
|
| 103 |
+
--prediction="v" \
|
| 104 |
+
--weighting="uniform" \
|
| 105 |
+
--model="SiT-B/2" \
|
| 106 |
+
--enc-type="dinov2-vit-b" \
|
| 107 |
+
--proj-coeff=0.5 \
|
| 108 |
+
--encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
|
| 109 |
+
--output-dir="your_path" \
|
| 110 |
+
--exp-name="linear-dinov2-b-enc4" \
|
| 111 |
+
--batch-size=256 \
|
| 112 |
+
--data-dir="data_path/imagenet_vae" \
|
| 113 |
+
--cls=0.03
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
|
| 117 |
+
|
| 118 |
+
- `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
|
| 119 |
+
- `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
|
| 120 |
+
- `--proj-coeff`: Any values larger than 0
|
| 121 |
+
- `--encoder-depth`: Any values between 1 to the depth of the model
|
| 122 |
+
- `--output-dir`: Any directory that you want to save checkpoints and logs
|
| 123 |
+
- `--exp-name`: Any string name (the folder will be created under `output-dir`)
|
| 124 |
+
- `--cls`: Weight coefficients of REG loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
### 4. Generate images and evaluation
|
| 128 |
+
You can generate images and get the final results through the following script.
|
| 129 |
+
The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
bash eval.sh
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Citation
|
| 137 |
+
If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
|
| 138 |
+
```
|
| 139 |
+
@article{wu2025representation,
|
| 140 |
+
title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
|
| 141 |
+
author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
|
| 142 |
+
journal={arXiv preprint arXiv:2507.01467},
|
| 143 |
+
year={2025}
|
| 144 |
+
}
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Contact
|
| 148 |
+
If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Acknowledgements
|
| 152 |
+
|
| 153 |
+
Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
REG copy/dataset.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import PIL.Image
|
| 10 |
+
try:
|
| 11 |
+
import pyspng
|
| 12 |
+
except ImportError:
|
| 13 |
+
pyspng = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CustomDataset(Dataset):
|
| 17 |
+
def __init__(self, data_dir):
|
| 18 |
+
PIL.Image.init()
|
| 19 |
+
supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
|
| 20 |
+
|
| 21 |
+
self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
|
| 22 |
+
self.features_dir = os.path.join(data_dir, 'vae-sd')
|
| 23 |
+
|
| 24 |
+
# images
|
| 25 |
+
self._image_fnames = {
|
| 26 |
+
os.path.relpath(os.path.join(root, fname), start=self.images_dir)
|
| 27 |
+
for root, _dirs, files in os.walk(self.images_dir) for fname in files
|
| 28 |
+
}
|
| 29 |
+
self.image_fnames = sorted(
|
| 30 |
+
fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
|
| 31 |
+
)
|
| 32 |
+
# features
|
| 33 |
+
self._feature_fnames = {
|
| 34 |
+
os.path.relpath(os.path.join(root, fname), start=self.features_dir)
|
| 35 |
+
for root, _dirs, files in os.walk(self.features_dir) for fname in files
|
| 36 |
+
}
|
| 37 |
+
self.feature_fnames = sorted(
|
| 38 |
+
fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# labels
|
| 42 |
+
fname = os.path.join(self.features_dir, 'dataset.json')
|
| 43 |
+
if os.path.exists(fname):
|
| 44 |
+
print(f"Using {fname}.")
|
| 45 |
+
else:
|
| 46 |
+
raise FileNotFoundError("Neither of the specified files exists.")
|
| 47 |
+
|
| 48 |
+
with open(fname, 'rb') as f:
|
| 49 |
+
labels = json.load(f)['labels']
|
| 50 |
+
labels = dict(labels)
|
| 51 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
|
| 52 |
+
labels = np.array(labels)
|
| 53 |
+
self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _file_ext(self, fname):
|
| 57 |
+
return os.path.splitext(fname)[1].lower()
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
assert len(self.image_fnames) == len(self.feature_fnames), \
|
| 61 |
+
"Number of feature files and label files should be same"
|
| 62 |
+
return len(self.feature_fnames)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, idx):
|
| 65 |
+
image_fname = self.image_fnames[idx]
|
| 66 |
+
feature_fname = self.feature_fnames[idx]
|
| 67 |
+
image_ext = self._file_ext(image_fname)
|
| 68 |
+
with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
|
| 69 |
+
if image_ext == '.npy':
|
| 70 |
+
image = np.load(f)
|
| 71 |
+
image = image.reshape(-1, *image.shape[-2:])
|
| 72 |
+
elif image_ext == '.png' and pyspng is not None:
|
| 73 |
+
image = pyspng.load(f.read())
|
| 74 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 75 |
+
else:
|
| 76 |
+
image = np.array(PIL.Image.open(f))
|
| 77 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 78 |
+
|
| 79 |
+
features = np.load(os.path.join(self.features_dir, feature_fname))
|
| 80 |
+
return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
|
REG copy/eval.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
random_number=$((RANDOM % 100 + 1200))
|
| 3 |
+
NUM_GPUS=8
|
| 4 |
+
STEP="4000000"
|
| 5 |
+
SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
|
| 6 |
+
VAE_PATH="your_vae_path/"
|
| 7 |
+
NUM_STEP=250
|
| 8 |
+
MODEL_SIZE='XL'
|
| 9 |
+
CFG_SCALE=2.3
|
| 10 |
+
CLS_CFG_SCALE=2.3
|
| 11 |
+
GH=0.85
|
| 12 |
+
|
| 13 |
+
export NCCL_P2P_DISABLE=1
|
| 14 |
+
|
| 15 |
+
python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
|
| 16 |
+
--model SiT-XL/2 \
|
| 17 |
+
--num-fid-samples 50000 \
|
| 18 |
+
--ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
|
| 19 |
+
--path-type=linear \
|
| 20 |
+
--encoder-depth=8 \
|
| 21 |
+
--projector-embed-dims=768 \
|
| 22 |
+
--per-proc-batch-size=64 \
|
| 23 |
+
--mode=sde \
|
| 24 |
+
--num-steps=${NUM_STEP} \
|
| 25 |
+
--cfg-scale=${CFG_SCALE} \
|
| 26 |
+
--cls-cfg-scale=${CLS_CFG_SCALE} \
|
| 27 |
+
--guidance-high=${GH} \
|
| 28 |
+
--sample-dir ${SAVE_PATH}/checkpoints \
|
| 29 |
+
--cls=768
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
python ./evaluations/evaluator.py \
|
| 33 |
+
--ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
|
| 34 |
+
--sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
|
| 35 |
+
--save_path ${SAVE_PATH}/checkpoints \
|
| 36 |
+
--cfg_cond 1 \
|
| 37 |
+
--step ${STEP} \
|
| 38 |
+
--num_steps ${NUM_STEP} \
|
| 39 |
+
--cfg ${CFG_SCALE} \
|
| 40 |
+
--cls_cfg ${CLS_CFG_SCALE} \
|
| 41 |
+
--gh ${GH}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
REG copy/generate.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 9 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 10 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 11 |
+
|
| 12 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from models.sit import SiT_models
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
import socket
|
| 25 |
+
from samplers import euler_maruyama_sampler
|
| 26 |
+
from utils import load_legacy_checkpoints, download_model
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def setup_distributed():
|
| 30 |
+
"""Init NCCL process group. If not launched via torchrun, use a single local process."""
|
| 31 |
+
if "RANK" not in os.environ:
|
| 32 |
+
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
| 33 |
+
if "MASTER_PORT" not in os.environ:
|
| 34 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 35 |
+
sock.bind(("", 0))
|
| 36 |
+
os.environ["MASTER_PORT"] = str(sock.getsockname()[1])
|
| 37 |
+
sock.close()
|
| 38 |
+
os.environ["RANK"] = "0"
|
| 39 |
+
os.environ["WORLD_SIZE"] = "1"
|
| 40 |
+
os.environ["LOCAL_RANK"] = "0"
|
| 41 |
+
dist.init_process_group("nccl")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 45 |
+
"""
|
| 46 |
+
Builds a single .npz file from a folder of .png samples.
|
| 47 |
+
"""
|
| 48 |
+
samples = []
|
| 49 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 50 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 51 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 52 |
+
samples.append(sample_np)
|
| 53 |
+
samples = np.stack(samples)
|
| 54 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 55 |
+
npz_path = f"{sample_dir}.npz"
|
| 56 |
+
np.savez(npz_path, arr_0=samples)
|
| 57 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 58 |
+
return npz_path
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main(args):
|
| 62 |
+
"""
|
| 63 |
+
Run sampling.
|
| 64 |
+
"""
|
| 65 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 66 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 67 |
+
torch.set_grad_enabled(False)
|
| 68 |
+
|
| 69 |
+
# Setup DDP (works with plain `python` or torchrun)
|
| 70 |
+
setup_distributed()
|
| 71 |
+
rank = dist.get_rank()
|
| 72 |
+
device = rank % torch.cuda.device_count()
|
| 73 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 74 |
+
torch.manual_seed(seed)
|
| 75 |
+
torch.cuda.set_device(device)
|
| 76 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 77 |
+
|
| 78 |
+
# Load model:
|
| 79 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 80 |
+
latent_size = args.resolution // 8
|
| 81 |
+
model = SiT_models[args.model](
|
| 82 |
+
input_size=latent_size,
|
| 83 |
+
num_classes=args.num_classes,
|
| 84 |
+
use_cfg = True,
|
| 85 |
+
z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
|
| 86 |
+
encoder_depth=args.encoder_depth,
|
| 87 |
+
**block_kwargs,
|
| 88 |
+
).to(device)
|
| 89 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 90 |
+
ckpt_path = args.ckpt
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 94 |
+
if ckpt_path is None:
|
| 95 |
+
args.ckpt = 'SiT-XL-2-256x256.pt'
|
| 96 |
+
assert args.model == 'SiT-XL/2'
|
| 97 |
+
assert len(args.projector_embed_dims.split(',')) == 1
|
| 98 |
+
assert int(args.projector_embed_dims.split(',')[0]) == 768
|
| 99 |
+
state_dict = download_model('last.pt')
|
| 100 |
+
else:
|
| 101 |
+
state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
|
| 102 |
+
|
| 103 |
+
if args.legacy:
|
| 104 |
+
state_dict = load_legacy_checkpoints(
|
| 105 |
+
state_dict=state_dict, encoder_depth=args.encoder_depth
|
| 106 |
+
)
|
| 107 |
+
model.load_state_dict(state_dict)
|
| 108 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
model.eval() # important!
|
| 112 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 113 |
+
#vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Create folder to save samples:
|
| 117 |
+
model_string_name = args.model.replace("/", "-")
|
| 118 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 119 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
|
| 120 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
|
| 121 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 122 |
+
if rank == 0:
|
| 123 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 124 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 125 |
+
dist.barrier()
|
| 126 |
+
|
| 127 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 128 |
+
n = args.per_proc_batch_size
|
| 129 |
+
global_batch_size = n * dist.get_world_size()
|
| 130 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 131 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 132 |
+
if rank == 0:
|
| 133 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 134 |
+
print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 135 |
+
print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
|
| 136 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 137 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 138 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 139 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 140 |
+
pbar = range(iterations)
|
| 141 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 142 |
+
total = 0
|
| 143 |
+
for _ in pbar:
|
| 144 |
+
# Sample inputs:
|
| 145 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 146 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 147 |
+
cls_z = torch.randn(n, args.cls, device=device)
|
| 148 |
+
|
| 149 |
+
# Sample images:
|
| 150 |
+
sampling_kwargs = dict(
|
| 151 |
+
model=model,
|
| 152 |
+
latents=z,
|
| 153 |
+
y=y,
|
| 154 |
+
num_steps=args.num_steps,
|
| 155 |
+
heun=args.heun,
|
| 156 |
+
cfg_scale=args.cfg_scale,
|
| 157 |
+
guidance_low=args.guidance_low,
|
| 158 |
+
guidance_high=args.guidance_high,
|
| 159 |
+
path_type=args.path_type,
|
| 160 |
+
cls_latents=cls_z,
|
| 161 |
+
args=args
|
| 162 |
+
)
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
if args.mode == "sde":
|
| 165 |
+
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
|
| 166 |
+
elif args.mode == "ode":# will support
|
| 167 |
+
exit()
|
| 168 |
+
#samples = euler_sampler(**sampling_kwargs).to(torch.float32)
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError()
|
| 171 |
+
|
| 172 |
+
latents_scale = torch.tensor(
|
| 173 |
+
[0.18215, 0.18215, 0.18215, 0.18215, ]
|
| 174 |
+
).view(1, 4, 1, 1).to(device)
|
| 175 |
+
latents_bias = -torch.tensor(
|
| 176 |
+
[0., 0., 0., 0.,]
|
| 177 |
+
).view(1, 4, 1, 1).to(device)
|
| 178 |
+
samples = vae.decode((samples - latents_bias) / latents_scale).sample
|
| 179 |
+
samples = (samples + 1) / 2.
|
| 180 |
+
samples = torch.clamp(
|
| 181 |
+
255. * samples, 0, 255
|
| 182 |
+
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 183 |
+
|
| 184 |
+
# Save samples to disk as individual .png files
|
| 185 |
+
for i, sample in enumerate(samples):
|
| 186 |
+
index = i * dist.get_world_size() + rank + total
|
| 187 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 188 |
+
total += global_batch_size
|
| 189 |
+
|
| 190 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 191 |
+
dist.barrier()
|
| 192 |
+
if rank == 0:
|
| 193 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 194 |
+
print("Done.")
|
| 195 |
+
dist.barrier()
|
| 196 |
+
dist.destroy_process_group()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
parser = argparse.ArgumentParser()
|
| 201 |
+
# seed
|
| 202 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 203 |
+
|
| 204 |
+
# precision
|
| 205 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 206 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 207 |
+
|
| 208 |
+
# logging/saving:
|
| 209 |
+
parser.add_argument("--ckpt", type=str, default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/test.pt", help="Optional path to a SiT checkpoint.")
|
| 210 |
+
parser.add_argument("--sample-dir", type=str, default="samples_50")
|
| 211 |
+
|
| 212 |
+
# model
|
| 213 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 214 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 215 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 216 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 217 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
|
| 218 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 219 |
+
# vae
|
| 220 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 221 |
+
|
| 222 |
+
# number of samples
|
| 223 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
| 224 |
+
parser.add_argument("--num-fid-samples", type=int, default=100)
|
| 225 |
+
|
| 226 |
+
# sampling related hyperparameters
|
| 227 |
+
parser.add_argument("--mode", type=str, default="ode")
|
| 228 |
+
parser.add_argument("--cfg-scale", type=float, default=1)
|
| 229 |
+
parser.add_argument("--cls-cfg-scale", type=float, default=1)
|
| 230 |
+
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
|
| 231 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 232 |
+
parser.add_argument("--num-steps", type=int, default=50)
|
| 233 |
+
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 234 |
+
parser.add_argument("--guidance-low", type=float, default=0.)
|
| 235 |
+
parser.add_argument("--guidance-high", type=float, default=1.)
|
| 236 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 237 |
+
parser.add_argument('--cls', default=768, type=int)
|
| 238 |
+
# will be deprecated
|
| 239 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
args = parser.parse_args()
|
| 243 |
+
main(args)
|
REG copy/loss.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def mean_flat(x):
|
| 6 |
+
"""
|
| 7 |
+
Take the mean over all non-batch dimensions.
|
| 8 |
+
"""
|
| 9 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
| 10 |
+
|
| 11 |
+
def sum_flat(x):
|
| 12 |
+
"""
|
| 13 |
+
Take the mean over all non-batch dimensions.
|
| 14 |
+
"""
|
| 15 |
+
return torch.sum(x, dim=list(range(1, len(x.size()))))
|
| 16 |
+
|
| 17 |
+
class SILoss:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
prediction='v',
|
| 21 |
+
path_type="linear",
|
| 22 |
+
weighting="uniform",
|
| 23 |
+
encoders=[],
|
| 24 |
+
accelerator=None,
|
| 25 |
+
latents_scale=None,
|
| 26 |
+
latents_bias=None,
|
| 27 |
+
):
|
| 28 |
+
self.prediction = prediction
|
| 29 |
+
self.weighting = weighting
|
| 30 |
+
self.path_type = path_type
|
| 31 |
+
self.encoders = encoders
|
| 32 |
+
self.accelerator = accelerator
|
| 33 |
+
self.latents_scale = latents_scale
|
| 34 |
+
self.latents_bias = latents_bias
|
| 35 |
+
|
| 36 |
+
def interpolant(self, t):
|
| 37 |
+
if self.path_type == "linear":
|
| 38 |
+
alpha_t = 1 - t
|
| 39 |
+
sigma_t = t
|
| 40 |
+
d_alpha_t = -1
|
| 41 |
+
d_sigma_t = 1
|
| 42 |
+
elif self.path_type == "cosine":
|
| 43 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 44 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 45 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 46 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError()
|
| 49 |
+
|
| 50 |
+
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
|
| 51 |
+
|
| 52 |
+
def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
|
| 53 |
+
time_input=None, noises=None,):
|
| 54 |
+
if model_kwargs == None:
|
| 55 |
+
model_kwargs = {}
|
| 56 |
+
# sample timesteps
|
| 57 |
+
if time_input is None:
|
| 58 |
+
if self.weighting == "uniform":
|
| 59 |
+
time_input = torch.rand((images.shape[0], 1, 1, 1))
|
| 60 |
+
elif self.weighting == "lognormal":
|
| 61 |
+
# sample timestep according to log-normal distribution of sigmas following EDM
|
| 62 |
+
rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
|
| 63 |
+
sigma = rnd_normal.exp()
|
| 64 |
+
if self.path_type == "linear":
|
| 65 |
+
time_input = sigma / (1 + sigma)
|
| 66 |
+
elif self.path_type == "cosine":
|
| 67 |
+
time_input = 2 / np.pi * torch.atan(sigma)
|
| 68 |
+
|
| 69 |
+
time_input = time_input.to(device=images.device, dtype=images.dtype)
|
| 70 |
+
|
| 71 |
+
if noises is None:
|
| 72 |
+
noises = torch.randn_like(images)
|
| 73 |
+
noises_cls = torch.randn_like(cls_token)
|
| 74 |
+
|
| 75 |
+
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
|
| 76 |
+
|
| 77 |
+
model_input = alpha_t * images + sigma_t * noises
|
| 78 |
+
cls_input = alpha_t.squeeze(-1).squeeze(-1) * cls_token + sigma_t.squeeze(-1).squeeze(-1) * noises_cls
|
| 79 |
+
if self.prediction == 'v':
|
| 80 |
+
model_target = d_alpha_t * images + d_sigma_t * noises
|
| 81 |
+
cls_target = d_alpha_t * cls_token + d_sigma_t * noises_cls
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
model_output, zs_tilde, cls_output = model(model_input, time_input.flatten(), **model_kwargs,
|
| 86 |
+
cls_token=cls_input)
|
| 87 |
+
|
| 88 |
+
#denoising_loss
|
| 89 |
+
denoising_loss = mean_flat((model_output - model_target) ** 2)
|
| 90 |
+
denoising_loss_cls = mean_flat((cls_output - cls_target) ** 2)
|
| 91 |
+
|
| 92 |
+
# projection loss
|
| 93 |
+
proj_loss = 0.
|
| 94 |
+
bsz = zs[0].shape[0]
|
| 95 |
+
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
|
| 96 |
+
for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
|
| 97 |
+
z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
|
| 98 |
+
z_j = torch.nn.functional.normalize(z_j, dim=-1)
|
| 99 |
+
proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
|
| 100 |
+
proj_loss /= (len(zs) * bsz)
|
| 101 |
+
|
| 102 |
+
return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
|
REG copy/requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- pip:
|
| 2 |
+
absl-py==2.2.2
|
| 3 |
+
accelerate==1.2.1
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.11.16
|
| 6 |
+
aiosignal==1.3.2
|
| 7 |
+
astunparse==1.6.3
|
| 8 |
+
async-timeout==5.0.1
|
| 9 |
+
attrs==25.3.0
|
| 10 |
+
certifi==2022.12.7
|
| 11 |
+
charset-normalizer==2.1.1
|
| 12 |
+
click==8.1.8
|
| 13 |
+
datasets==2.20.0
|
| 14 |
+
diffusers==0.32.1
|
| 15 |
+
dill==0.3.8
|
| 16 |
+
docker-pycreds==0.4.0
|
| 17 |
+
einops==0.8.1
|
| 18 |
+
filelock==3.13.1
|
| 19 |
+
flatbuffers==25.2.10
|
| 20 |
+
frozenlist==1.5.0
|
| 21 |
+
fsspec==2024.5.0
|
| 22 |
+
ftfy==6.3.1
|
| 23 |
+
gast==0.6.0
|
| 24 |
+
gitdb==4.0.12
|
| 25 |
+
gitpython==3.1.44
|
| 26 |
+
google-pasta==0.2.0
|
| 27 |
+
grpcio==1.71.0
|
| 28 |
+
h5py==3.13.0
|
| 29 |
+
huggingface-hub==0.27.1
|
| 30 |
+
idna==3.4
|
| 31 |
+
importlib-metadata==8.6.1
|
| 32 |
+
jinja2==3.1.4
|
| 33 |
+
joblib==1.4.2
|
| 34 |
+
keras==3.9.2
|
| 35 |
+
libclang==18.1.1
|
| 36 |
+
markdown==3.8
|
| 37 |
+
markdown-it-py==3.0.0
|
| 38 |
+
markupsafe==2.1.5
|
| 39 |
+
mdurl==0.1.2
|
| 40 |
+
ml-dtypes==0.3.2
|
| 41 |
+
mpmath==1.3.0
|
| 42 |
+
multidict==6.4.3
|
| 43 |
+
multiprocess==0.70.16
|
| 44 |
+
namex==0.0.8
|
| 45 |
+
networkx==3.3
|
| 46 |
+
numpy==1.26.4
|
| 47 |
+
opt-einsum==3.4.0
|
| 48 |
+
optree==0.15.0
|
| 49 |
+
packaging==24.2
|
| 50 |
+
pandas==2.2.3
|
| 51 |
+
pillow==11.0.0
|
| 52 |
+
platformdirs==4.3.7
|
| 53 |
+
propcache==0.3.1
|
| 54 |
+
protobuf==4.25.6
|
| 55 |
+
psutil==7.0.0
|
| 56 |
+
pyarrow==19.0.1
|
| 57 |
+
pyarrow-hotfix==0.6
|
| 58 |
+
pygments==2.19.1
|
| 59 |
+
python-dateutil==2.9.0.post0
|
| 60 |
+
pytz==2025.2
|
| 61 |
+
pyyaml==6.0.2
|
| 62 |
+
regex==2024.11.6
|
| 63 |
+
requests==2.32.3
|
| 64 |
+
rich==14.0.0
|
| 65 |
+
safetensors==0.5.3
|
| 66 |
+
scikit-learn==1.5.1
|
| 67 |
+
scipy==1.15.2
|
| 68 |
+
sentry-sdk==2.26.1
|
| 69 |
+
setproctitle==1.3.5
|
| 70 |
+
six==1.17.0
|
| 71 |
+
smmap==5.0.2
|
| 72 |
+
sympy==1.13.1
|
| 73 |
+
tensorboard==2.16.1
|
| 74 |
+
tensorboard-data-server==0.7.2
|
| 75 |
+
tensorflow==2.16.1
|
| 76 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
| 77 |
+
termcolor==3.0.1
|
| 78 |
+
tf-keras==2.16.0
|
| 79 |
+
threadpoolctl==3.6.0
|
| 80 |
+
timm==1.0.12
|
| 81 |
+
tokenizers==0.21.0
|
| 82 |
+
tqdm==4.67.1
|
| 83 |
+
transformers==4.47.0
|
| 84 |
+
triton==2.1.0
|
| 85 |
+
typing-extensions==4.12.2
|
| 86 |
+
tzdata==2025.2
|
| 87 |
+
urllib3==1.26.13
|
| 88 |
+
wandb==0.17.6
|
| 89 |
+
wcwidth==0.2.13
|
| 90 |
+
werkzeug==3.1.3
|
| 91 |
+
wrapt==1.17.2
|
| 92 |
+
xformer==1.0.1
|
| 93 |
+
xformers==0.0.23
|
| 94 |
+
xxhash==3.5.0
|
| 95 |
+
yarl==1.20.0
|
| 96 |
+
zipp==3.21.0
|
| 97 |
+
|
REG copy/samplers.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def expand_t_like_x(t, x_cur):
|
| 6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 7 |
+
Args:
|
| 8 |
+
t: [batch_dim,], time vector
|
| 9 |
+
x: [batch_dim,...], data point
|
| 10 |
+
"""
|
| 11 |
+
dims = [1] * (len(x_cur.size()) - 1)
|
| 12 |
+
t = t.view(t.size(0), *dims)
|
| 13 |
+
return t
|
| 14 |
+
|
| 15 |
+
def get_score_from_velocity(vt, xt, t, path_type="linear"):
|
| 16 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 17 |
+
Args:
|
| 18 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 19 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 20 |
+
t: [batch_dim,] time tensor
|
| 21 |
+
"""
|
| 22 |
+
t = expand_t_like_x(t, xt)
|
| 23 |
+
if path_type == "linear":
|
| 24 |
+
alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
|
| 25 |
+
sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
|
| 26 |
+
elif path_type == "cosine":
|
| 27 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 28 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 29 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 30 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
mean = xt
|
| 35 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 36 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 37 |
+
score = (reverse_alpha_ratio * vt - mean) / var
|
| 38 |
+
|
| 39 |
+
return score
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_diffusion(t_cur):
|
| 43 |
+
return 2 * t_cur
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def euler_maruyama_sampler(
|
| 47 |
+
model,
|
| 48 |
+
latents,
|
| 49 |
+
y,
|
| 50 |
+
num_steps=20,
|
| 51 |
+
heun=False, # not used, just for compatability
|
| 52 |
+
cfg_scale=1.0,
|
| 53 |
+
guidance_low=0.0,
|
| 54 |
+
guidance_high=1.0,
|
| 55 |
+
path_type="linear",
|
| 56 |
+
cls_latents=None,
|
| 57 |
+
args=None
|
| 58 |
+
):
|
| 59 |
+
# setup conditioning
|
| 60 |
+
if cfg_scale > 1.0:
|
| 61 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 62 |
+
#[1000, 1000]
|
| 63 |
+
_dtype = latents.dtype
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64)
|
| 67 |
+
t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)])
|
| 68 |
+
x_next = latents.to(torch.float64)
|
| 69 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 70 |
+
device = x_next.device
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
|
| 75 |
+
dt = t_next - t_cur
|
| 76 |
+
x_cur = x_next
|
| 77 |
+
cls_x_cur = cls_x_next
|
| 78 |
+
|
| 79 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 80 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 81 |
+
cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
|
| 82 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 83 |
+
else:
|
| 84 |
+
model_input = x_cur
|
| 85 |
+
cls_model_input = cls_x_cur
|
| 86 |
+
y_cur = y
|
| 87 |
+
|
| 88 |
+
kwargs = dict(y=y_cur)
|
| 89 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 90 |
+
diffusion = compute_diffusion(t_cur)
|
| 91 |
+
|
| 92 |
+
eps_i = torch.randn_like(x_cur).to(device)
|
| 93 |
+
cls_eps_i = torch.randn_like(cls_x_cur).to(device)
|
| 94 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 95 |
+
cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
|
| 96 |
+
|
| 97 |
+
# compute drift
|
| 98 |
+
v_cur, _, cls_v_cur = model(
|
| 99 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 100 |
+
)
|
| 101 |
+
v_cur = v_cur.to(torch.float64)
|
| 102 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 103 |
+
|
| 104 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 105 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 106 |
+
|
| 107 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 108 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 109 |
+
|
| 110 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 111 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 112 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 113 |
+
|
| 114 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 115 |
+
if args.cls_cfg_scale >0:
|
| 116 |
+
cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 117 |
+
else:
|
| 118 |
+
cls_d_cur = cls_d_cur_cond
|
| 119 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 120 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
|
| 121 |
+
|
| 122 |
+
# last step
|
| 123 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 124 |
+
dt = t_next - t_cur
|
| 125 |
+
x_cur = x_next
|
| 126 |
+
cls_x_cur = cls_x_next
|
| 127 |
+
|
| 128 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 129 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 130 |
+
cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
|
| 131 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 132 |
+
else:
|
| 133 |
+
model_input = x_cur
|
| 134 |
+
cls_model_input = cls_x_cur
|
| 135 |
+
y_cur = y
|
| 136 |
+
kwargs = dict(y=y_cur)
|
| 137 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 138 |
+
device=device, dtype=torch.float64
|
| 139 |
+
) * t_cur
|
| 140 |
+
|
| 141 |
+
# compute drift
|
| 142 |
+
v_cur, _, cls_v_cur = model(
|
| 143 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 144 |
+
)
|
| 145 |
+
v_cur = v_cur.to(torch.float64)
|
| 146 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 150 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 151 |
+
|
| 152 |
+
diffusion = compute_diffusion(t_cur)
|
| 153 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 154 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
|
| 155 |
+
|
| 156 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 157 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 158 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 159 |
+
|
| 160 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 161 |
+
if args.cls_cfg_scale > 0:
|
| 162 |
+
cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 163 |
+
else:
|
| 164 |
+
cls_d_cur = cls_d_cur_cond
|
| 165 |
+
|
| 166 |
+
mean_x = x_cur + dt * d_cur
|
| 167 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 168 |
+
|
| 169 |
+
return mean_x
|
REG copy/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torchvision.datasets.utils import download_url
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.models as torchvision_models
|
| 5 |
+
import timm
|
| 6 |
+
from models import mocov3_vit
|
| 7 |
+
import math
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# code from SiT repository
|
| 12 |
+
pretrained_models = {'last.pt'}
|
| 13 |
+
|
| 14 |
+
def download_model(model_name):
|
| 15 |
+
"""
|
| 16 |
+
Downloads a pre-trained SiT model from the web.
|
| 17 |
+
"""
|
| 18 |
+
assert model_name in pretrained_models
|
| 19 |
+
local_path = f'pretrained_models/{model_name}'
|
| 20 |
+
if not os.path.isfile(local_path):
|
| 21 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 22 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
|
| 23 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 24 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def fix_mocov3_state_dict(state_dict):
|
| 28 |
+
for k in list(state_dict.keys()):
|
| 29 |
+
# retain only base_encoder up to before the embedding layer
|
| 30 |
+
if k.startswith('module.base_encoder'):
|
| 31 |
+
# fix naming bug in checkpoint
|
| 32 |
+
new_k = k[len("module.base_encoder."):]
|
| 33 |
+
if "blocks.13.norm13" in new_k:
|
| 34 |
+
new_k = new_k.replace("norm13", "norm1")
|
| 35 |
+
if "blocks.13.mlp.fc13" in k:
|
| 36 |
+
new_k = new_k.replace("fc13", "fc1")
|
| 37 |
+
if "blocks.14.norm14" in k:
|
| 38 |
+
new_k = new_k.replace("norm14", "norm2")
|
| 39 |
+
if "blocks.14.mlp.fc14" in k:
|
| 40 |
+
new_k = new_k.replace("fc14", "fc2")
|
| 41 |
+
# remove prefix
|
| 42 |
+
if 'head' not in new_k and new_k.split('.')[0] != 'fc':
|
| 43 |
+
state_dict[new_k] = state_dict[k]
|
| 44 |
+
# delete renamed or unused k
|
| 45 |
+
del state_dict[k]
|
| 46 |
+
if 'pos_embed' in state_dict.keys():
|
| 47 |
+
state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 48 |
+
state_dict['pos_embed'], [16, 16],
|
| 49 |
+
)
|
| 50 |
+
return state_dict
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def load_encoders(enc_type, device, resolution=256):
|
| 54 |
+
assert (resolution == 256) or (resolution == 512)
|
| 55 |
+
|
| 56 |
+
enc_names = enc_type.split(',')
|
| 57 |
+
encoders, architectures, encoder_types = [], [], []
|
| 58 |
+
for enc_name in enc_names:
|
| 59 |
+
encoder_type, architecture, model_config = enc_name.split('-')
|
| 60 |
+
# Currently, we only support 512x512 experiments with DINOv2 encoders.
|
| 61 |
+
if resolution == 512:
|
| 62 |
+
if encoder_type != 'dinov2':
|
| 63 |
+
raise NotImplementedError(
|
| 64 |
+
"Currently, we only support 512x512 experiments with DINOv2 encoders."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
architectures.append(architecture)
|
| 68 |
+
encoder_types.append(encoder_type)
|
| 69 |
+
if encoder_type == 'mocov3':
|
| 70 |
+
if architecture == 'vit':
|
| 71 |
+
if model_config == 's':
|
| 72 |
+
encoder = mocov3_vit.vit_small()
|
| 73 |
+
elif model_config == 'b':
|
| 74 |
+
encoder = mocov3_vit.vit_base()
|
| 75 |
+
elif model_config == 'l':
|
| 76 |
+
encoder = mocov3_vit.vit_large()
|
| 77 |
+
ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
|
| 78 |
+
state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
|
| 79 |
+
del encoder.head
|
| 80 |
+
encoder.load_state_dict(state_dict, strict=True)
|
| 81 |
+
encoder.head = torch.nn.Identity()
|
| 82 |
+
elif architecture == 'resnet':
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
encoder = encoder.to(device)
|
| 86 |
+
encoder.eval()
|
| 87 |
+
|
| 88 |
+
elif 'dinov2' in encoder_type:
|
| 89 |
+
import timm
|
| 90 |
+
if 'reg' in encoder_type:
|
| 91 |
+
try:
|
| 92 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 93 |
+
f'dinov2_vit{model_config}14_reg', source='local')
|
| 94 |
+
except:
|
| 95 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
|
| 96 |
+
else:
|
| 97 |
+
try:
|
| 98 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 99 |
+
f'dinov2_vit{model_config}14', source='local')
|
| 100 |
+
except:
|
| 101 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
|
| 102 |
+
|
| 103 |
+
print(f"Now you are using the {enc_name} as the aligning model")
|
| 104 |
+
del encoder.head
|
| 105 |
+
patch_resolution = 16 * (resolution // 256)
|
| 106 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 107 |
+
encoder.pos_embed.data, [patch_resolution, patch_resolution],
|
| 108 |
+
)
|
| 109 |
+
encoder.head = torch.nn.Identity()
|
| 110 |
+
encoder = encoder.to(device)
|
| 111 |
+
encoder.eval()
|
| 112 |
+
|
| 113 |
+
elif 'dinov1' == encoder_type:
|
| 114 |
+
import timm
|
| 115 |
+
from models import dinov1
|
| 116 |
+
encoder = dinov1.vit_base()
|
| 117 |
+
ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
|
| 118 |
+
if 'pos_embed' in ckpt.keys():
|
| 119 |
+
ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 120 |
+
ckpt['pos_embed'], [16, 16],
|
| 121 |
+
)
|
| 122 |
+
del encoder.head
|
| 123 |
+
encoder.head = torch.nn.Identity()
|
| 124 |
+
encoder.load_state_dict(ckpt, strict=True)
|
| 125 |
+
encoder = encoder.to(device)
|
| 126 |
+
encoder.forward_features = encoder.forward
|
| 127 |
+
encoder.eval()
|
| 128 |
+
|
| 129 |
+
elif encoder_type == 'clip':
|
| 130 |
+
import clip
|
| 131 |
+
from models.clip_vit import UpdatedVisionTransformer
|
| 132 |
+
encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
|
| 133 |
+
encoder = UpdatedVisionTransformer(encoder_).to(device)
|
| 134 |
+
#.to(device)
|
| 135 |
+
encoder.embed_dim = encoder.model.transformer.width
|
| 136 |
+
encoder.forward_features = encoder.forward
|
| 137 |
+
encoder.eval()
|
| 138 |
+
|
| 139 |
+
elif encoder_type == 'mae':
|
| 140 |
+
from models.mae_vit import vit_large_patch16
|
| 141 |
+
import timm
|
| 142 |
+
kwargs = dict(img_size=256)
|
| 143 |
+
encoder = vit_large_patch16(**kwargs).to(device)
|
| 144 |
+
with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
|
| 145 |
+
state_dict = torch.load(f)
|
| 146 |
+
if 'pos_embed' in state_dict["model"].keys():
|
| 147 |
+
state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 148 |
+
state_dict["model"]['pos_embed'], [16, 16],
|
| 149 |
+
)
|
| 150 |
+
encoder.load_state_dict(state_dict["model"])
|
| 151 |
+
|
| 152 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 153 |
+
encoder.pos_embed.data, [16, 16],
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
elif encoder_type == 'jepa':
|
| 157 |
+
from models.jepa import vit_huge
|
| 158 |
+
kwargs = dict(img_size=[224, 224], patch_size=14)
|
| 159 |
+
encoder = vit_huge(**kwargs).to(device)
|
| 160 |
+
with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
|
| 161 |
+
state_dict = torch.load(f, map_location=device)
|
| 162 |
+
new_state_dict = dict()
|
| 163 |
+
for key, value in state_dict['encoder'].items():
|
| 164 |
+
new_state_dict[key[7:]] = value
|
| 165 |
+
encoder.load_state_dict(new_state_dict)
|
| 166 |
+
encoder.forward_features = encoder.forward
|
| 167 |
+
|
| 168 |
+
encoders.append(encoder)
|
| 169 |
+
|
| 170 |
+
return encoders, encoder_types, architectures
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 174 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 175 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 176 |
+
def norm_cdf(x):
|
| 177 |
+
# Computes standard normal cumulative distribution function
|
| 178 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 179 |
+
|
| 180 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 181 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 182 |
+
"The distribution of values may be incorrect.",
|
| 183 |
+
stacklevel=2)
|
| 184 |
+
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
# Values are generated by using a truncated uniform distribution and
|
| 187 |
+
# then using the inverse CDF for the normal distribution.
|
| 188 |
+
# Get upper and lower cdf values
|
| 189 |
+
l = norm_cdf((a - mean) / std)
|
| 190 |
+
u = norm_cdf((b - mean) / std)
|
| 191 |
+
|
| 192 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 193 |
+
# [2l-1, 2u-1].
|
| 194 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 195 |
+
|
| 196 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 197 |
+
# standard normal
|
| 198 |
+
tensor.erfinv_()
|
| 199 |
+
|
| 200 |
+
# Transform to proper mean, std
|
| 201 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 202 |
+
tensor.add_(mean)
|
| 203 |
+
|
| 204 |
+
# Clamp to ensure it's in the proper range
|
| 205 |
+
tensor.clamp_(min=a, max=b)
|
| 206 |
+
return tensor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 210 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_legacy_checkpoints(state_dict, encoder_depth):
|
| 214 |
+
new_state_dict = dict()
|
| 215 |
+
for key, value in state_dict.items():
|
| 216 |
+
if 'decoder_blocks' in key:
|
| 217 |
+
parts =key.split('.')
|
| 218 |
+
new_idx = int(parts[1]) + encoder_depth
|
| 219 |
+
parts[0] = 'blocks'
|
| 220 |
+
parts[1] = str(new_idx)
|
| 221 |
+
new_key = '.'.join(parts)
|
| 222 |
+
new_state_dict[new_key] = value
|
| 223 |
+
else:
|
| 224 |
+
new_state_dict[key] = value
|
| 225 |
+
return new_state_dict
|
REG/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Sihyun Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
REG/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
|
| 3 |
+
</h1>
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup> 
|
| 6 |
+
<a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup> 
|
| 7 |
+
<a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup> 
|
| 8 |
+
<a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup> 
|
| 9 |
+
<a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup> 
|
| 10 |
+
<a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup> 
|
| 11 |
+
<a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup> 
|
| 12 |
+
<a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup> 
|
| 13 |
+
<a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup> 
|
| 14 |
+
<a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup> 
|
| 15 |
+
<a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup> 
|
| 16 |
+
<a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup> 
|
| 17 |
+
<p align="center">
|
| 18 |
+
$^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
|
| 19 |
+
$^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
|
| 20 |
+
<p align='center'>
|
| 21 |
+
<div align="center">
|
| 22 |
+
<a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
|
| 23 |
+
<a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
|
| 24 |
+
<a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
|
| 25 |
+
</div>
|
| 26 |
+
<p align='center'>
|
| 27 |
+
</p>
|
| 28 |
+
</p>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## 🚩 Overview
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
|
| 37 |
+
We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
|
| 38 |
+
|
| 39 |
+
In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
|
| 40 |
+
REG acquires the capability to produce coherent image-class pairs directly from pure noise,
|
| 41 |
+
substantially improving both generation quality and training efficiency.
|
| 42 |
+
This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
|
| 43 |
+
The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
|
| 44 |
+
|
| 45 |
+
On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
|
| 46 |
+
More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## 📰 News
|
| 51 |
+
|
| 52 |
+
- **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## 📝 Results
|
| 56 |
+
|
| 57 |
+
- Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
|
| 58 |
+
- $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
|
| 59 |
+
|
| 60 |
+
<div align="center">
|
| 61 |
+
<img src="fig/img.png" alt="Results">
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 📋 Plan
|
| 66 |
+
- More training steps on ImageNet 256&512 and T2I.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 👊 Usage
|
| 70 |
+
|
| 71 |
+
### 1. Environment setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
conda create -n reg python=3.10.16 -y
|
| 75 |
+
conda activate reg
|
| 76 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
|
| 77 |
+
pip install -r requirements.txt
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 2. Dataset
|
| 81 |
+
|
| 82 |
+
#### Dataset download
|
| 83 |
+
|
| 84 |
+
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
|
| 85 |
+
|
| 86 |
+
#### Preprocessing data
|
| 87 |
+
Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
|
| 88 |
+
|
| 89 |
+
### 3. Training
|
| 90 |
+
Run train.sh
|
| 91 |
+
```bash
|
| 92 |
+
bash train.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
train.sh contains the following content.
|
| 96 |
+
```bash
|
| 97 |
+
accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
|
| 98 |
+
--report-to="wandb" \
|
| 99 |
+
--allow-tf32 \
|
| 100 |
+
--mixed-precision="fp16" \
|
| 101 |
+
--seed=0 \
|
| 102 |
+
--path-type="linear" \
|
| 103 |
+
--prediction="v" \
|
| 104 |
+
--weighting="uniform" \
|
| 105 |
+
--model="SiT-B/2" \
|
| 106 |
+
--enc-type="dinov2-vit-b" \
|
| 107 |
+
--proj-coeff=0.5 \
|
| 108 |
+
--encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
|
| 109 |
+
--output-dir="your_path" \
|
| 110 |
+
--exp-name="linear-dinov2-b-enc4" \
|
| 111 |
+
--batch-size=256 \
|
| 112 |
+
--data-dir="data_path/imagenet_vae" \
|
| 113 |
+
--cls=0.03
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
|
| 117 |
+
|
| 118 |
+
- `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
|
| 119 |
+
- `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
|
| 120 |
+
- `--proj-coeff`: Any values larger than 0
|
| 121 |
+
- `--encoder-depth`: Any values between 1 to the depth of the model
|
| 122 |
+
- `--output-dir`: Any directory that you want to save checkpoints and logs
|
| 123 |
+
- `--exp-name`: Any string name (the folder will be created under `output-dir`)
|
| 124 |
+
- `--cls`: Weight coefficients of REG loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
### 4. Generate images and evaluation
|
| 128 |
+
You can generate images and get the final results through the following script.
|
| 129 |
+
The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
bash eval.sh
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Citation
|
| 137 |
+
If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
|
| 138 |
+
```
|
| 139 |
+
@article{wu2025representation,
|
| 140 |
+
title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
|
| 141 |
+
author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
|
| 142 |
+
journal={arXiv preprint arXiv:2507.01467},
|
| 143 |
+
year={2025}
|
| 144 |
+
}
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Contact
|
| 148 |
+
If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Acknowledgements
|
| 152 |
+
|
| 153 |
+
Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
REG/dataset.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import PIL.Image
|
| 10 |
+
try:
|
| 11 |
+
import pyspng
|
| 12 |
+
except ImportError:
|
| 13 |
+
pyspng = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CustomDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
data_dir 下 VAE latent:imagenet_256_vae/
|
| 19 |
+
无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。
|
| 20 |
+
有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引,
|
| 21 |
+
按特征文件名推断 imagenet_256_vae 中对应 npy。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, data_dir, semantic_features_dir=None):
|
| 25 |
+
PIL.Image.init()
|
| 26 |
+
supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
|
| 27 |
+
|
| 28 |
+
self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
|
| 29 |
+
|
| 30 |
+
if semantic_features_dir is None:
|
| 31 |
+
potential_semantic_dir = os.path.join(
|
| 32 |
+
data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0'
|
| 33 |
+
)
|
| 34 |
+
if os.path.exists(potential_semantic_dir):
|
| 35 |
+
self.semantic_features_dir = potential_semantic_dir
|
| 36 |
+
self.use_preprocessed_semantic = True
|
| 37 |
+
print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
|
| 38 |
+
else:
|
| 39 |
+
self.semantic_features_dir = None
|
| 40 |
+
self.use_preprocessed_semantic = False
|
| 41 |
+
else:
|
| 42 |
+
self.semantic_features_dir = semantic_features_dir
|
| 43 |
+
self.use_preprocessed_semantic = True
|
| 44 |
+
print(f"Using preprocessed semantic features from: {self.semantic_features_dir}")
|
| 45 |
+
|
| 46 |
+
if self.use_preprocessed_semantic:
|
| 47 |
+
label_fname = os.path.join(self.semantic_features_dir, 'dataset.json')
|
| 48 |
+
if not os.path.exists(label_fname):
|
| 49 |
+
raise FileNotFoundError(f"Label file not found: {label_fname}")
|
| 50 |
+
|
| 51 |
+
print(f"Using {label_fname}.")
|
| 52 |
+
with open(label_fname, 'rb') as f:
|
| 53 |
+
data = json.load(f)
|
| 54 |
+
labels_list = data.get('labels', None)
|
| 55 |
+
if labels_list is None:
|
| 56 |
+
raise ValueError(f"'labels' field is missing in {label_fname}")
|
| 57 |
+
|
| 58 |
+
semantic_fnames = []
|
| 59 |
+
labels = []
|
| 60 |
+
for entry in labels_list:
|
| 61 |
+
if entry is None:
|
| 62 |
+
continue
|
| 63 |
+
fname, lab = entry
|
| 64 |
+
semantic_fnames.append(fname)
|
| 65 |
+
labels.append(0 if lab is None else lab)
|
| 66 |
+
|
| 67 |
+
self.semantic_fnames = semantic_fnames
|
| 68 |
+
self.labels = np.array(labels, dtype=np.int64)
|
| 69 |
+
self.num_samples = len(self.semantic_fnames)
|
| 70 |
+
print(f"Loaded {self.num_samples} semantic entries from dataset.json")
|
| 71 |
+
else:
|
| 72 |
+
self.features_dir = os.path.join(data_dir, 'vae-sd')
|
| 73 |
+
|
| 74 |
+
self._image_fnames = {
|
| 75 |
+
os.path.relpath(os.path.join(root, fname), start=self.images_dir)
|
| 76 |
+
for root, _dirs, files in os.walk(self.images_dir) for fname in files
|
| 77 |
+
}
|
| 78 |
+
self.image_fnames = sorted(
|
| 79 |
+
fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
|
| 80 |
+
)
|
| 81 |
+
self._feature_fnames = {
|
| 82 |
+
os.path.relpath(os.path.join(root, fname), start=self.features_dir)
|
| 83 |
+
for root, _dirs, files in os.walk(self.features_dir) for fname in files
|
| 84 |
+
}
|
| 85 |
+
self.feature_fnames = sorted(
|
| 86 |
+
fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
fname = os.path.join(self.features_dir, 'dataset.json')
|
| 90 |
+
if os.path.exists(fname):
|
| 91 |
+
print(f"Using {fname}.")
|
| 92 |
+
else:
|
| 93 |
+
raise FileNotFoundError("Neither of the specified files exists.")
|
| 94 |
+
|
| 95 |
+
with open(fname, 'rb') as f:
|
| 96 |
+
labels = json.load(f)['labels']
|
| 97 |
+
labels = dict(labels)
|
| 98 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
|
| 99 |
+
labels = np.array(labels)
|
| 100 |
+
self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 101 |
+
|
| 102 |
+
def _file_ext(self, fname):
|
| 103 |
+
return os.path.splitext(fname)[1].lower()
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
if self.use_preprocessed_semantic:
|
| 107 |
+
return self.num_samples
|
| 108 |
+
assert len(self.image_fnames) == len(self.feature_fnames), \
|
| 109 |
+
"Number of feature files and label files should be same"
|
| 110 |
+
return len(self.feature_fnames)
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, idx):
|
| 113 |
+
if self.use_preprocessed_semantic:
|
| 114 |
+
semantic_fname = self.semantic_fnames[idx]
|
| 115 |
+
basename = os.path.basename(semantic_fname)
|
| 116 |
+
idx_str = basename.split('-')[-1].split('.')[0]
|
| 117 |
+
subdir = idx_str[:5]
|
| 118 |
+
vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
|
| 119 |
+
vae_path = os.path.join(self.images_dir, vae_relpath)
|
| 120 |
+
|
| 121 |
+
with open(vae_path, 'rb') as f:
|
| 122 |
+
image = np.load(f)
|
| 123 |
+
|
| 124 |
+
semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
|
| 125 |
+
semantic_features = np.load(semantic_path)
|
| 126 |
+
|
| 127 |
+
return (
|
| 128 |
+
torch.from_numpy(image).float(),
|
| 129 |
+
torch.from_numpy(image).float(),
|
| 130 |
+
torch.from_numpy(semantic_features).float(),
|
| 131 |
+
torch.tensor(self.labels[idx]),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
image_fname = self.image_fnames[idx]
|
| 135 |
+
feature_fname = self.feature_fnames[idx]
|
| 136 |
+
image_ext = self._file_ext(image_fname)
|
| 137 |
+
with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
|
| 138 |
+
if image_ext == '.npy':
|
| 139 |
+
image = np.load(f)
|
| 140 |
+
image = image.reshape(-1, *image.shape[-2:])
|
| 141 |
+
elif image_ext == '.png' and pyspng is not None:
|
| 142 |
+
image = pyspng.load(f.read())
|
| 143 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 144 |
+
else:
|
| 145 |
+
image = np.array(PIL.Image.open(f))
|
| 146 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
| 147 |
+
|
| 148 |
+
features = np.load(os.path.join(self.features_dir, feature_fname))
|
| 149 |
+
return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
|
REG/eval.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
random_number=$((RANDOM % 100 + 1200))
|
| 3 |
+
NUM_GPUS=8
|
| 4 |
+
STEP="4000000"
|
| 5 |
+
SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
|
| 6 |
+
VAE_PATH="your_vae_path/"
|
| 7 |
+
NUM_STEP=250
|
| 8 |
+
MODEL_SIZE='XL'
|
| 9 |
+
CFG_SCALE=2.3
|
| 10 |
+
CLS_CFG_SCALE=2.3
|
| 11 |
+
GH=0.85
|
| 12 |
+
|
| 13 |
+
export NCCL_P2P_DISABLE=1
|
| 14 |
+
|
| 15 |
+
python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
|
| 16 |
+
--model SiT-XL/2 \
|
| 17 |
+
--num-fid-samples 50000 \
|
| 18 |
+
--ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
|
| 19 |
+
--path-type=linear \
|
| 20 |
+
--encoder-depth=8 \
|
| 21 |
+
--projector-embed-dims=768 \
|
| 22 |
+
--per-proc-batch-size=64 \
|
| 23 |
+
--mode=sde \
|
| 24 |
+
--num-steps=${NUM_STEP} \
|
| 25 |
+
--cfg-scale=${CFG_SCALE} \
|
| 26 |
+
--cls-cfg-scale=${CLS_CFG_SCALE} \
|
| 27 |
+
--guidance-high=${GH} \
|
| 28 |
+
--sample-dir ${SAVE_PATH}/checkpoints \
|
| 29 |
+
--cls=768
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
python ./evaluations/evaluator.py \
|
| 33 |
+
--ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
|
| 34 |
+
--sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
|
| 35 |
+
--save_path ${SAVE_PATH}/checkpoints \
|
| 36 |
+
--cfg_cond 1 \
|
| 37 |
+
--step ${STEP} \
|
| 38 |
+
--num_steps ${NUM_STEP} \
|
| 39 |
+
--cfg ${CFG_SCALE} \
|
| 40 |
+
--cls_cfg ${CLS_CFG_SCALE} \
|
| 41 |
+
--gh ${GH}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
REG/eval_custom_0.25.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python: can't open file 'fid_custom.py': [Errno 2] No such file or directory
|
REG/generate.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Samples a large number of images from a pre-trained SiT model using DDP.
|
| 9 |
+
Subsequently saves a .npz file that can be used to compute FID and other
|
| 10 |
+
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
|
| 11 |
+
|
| 12 |
+
For a simple single-GPU/CPU sampling script, see sample.py.
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from models.sit import SiT_models
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import os
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import math
|
| 23 |
+
import argparse
|
| 24 |
+
from samplers import euler_maruyama_sampler
|
| 25 |
+
from utils import load_legacy_checkpoints, download_model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_npz_from_sample_folder(sample_dir, num=50_000):
|
| 29 |
+
"""
|
| 30 |
+
Builds a single .npz file from a folder of .png samples.
|
| 31 |
+
"""
|
| 32 |
+
samples = []
|
| 33 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 34 |
+
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
|
| 35 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 36 |
+
samples.append(sample_np)
|
| 37 |
+
samples = np.stack(samples)
|
| 38 |
+
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
|
| 39 |
+
npz_path = f"{sample_dir}.npz"
|
| 40 |
+
np.savez(npz_path, arr_0=samples)
|
| 41 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 42 |
+
return npz_path
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(args):
|
| 46 |
+
"""
|
| 47 |
+
Run sampling.
|
| 48 |
+
"""
|
| 49 |
+
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
|
| 50 |
+
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
|
| 51 |
+
torch.set_grad_enabled(False)
|
| 52 |
+
|
| 53 |
+
# Setup DDP:cd
|
| 54 |
+
dist.init_process_group("nccl")
|
| 55 |
+
rank = dist.get_rank()
|
| 56 |
+
device = rank % torch.cuda.device_count()
|
| 57 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 58 |
+
torch.manual_seed(seed)
|
| 59 |
+
torch.cuda.set_device(device)
|
| 60 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 61 |
+
|
| 62 |
+
# Load model:
|
| 63 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 64 |
+
latent_size = args.resolution // 8
|
| 65 |
+
model = SiT_models[args.model](
|
| 66 |
+
input_size=latent_size,
|
| 67 |
+
num_classes=args.num_classes,
|
| 68 |
+
use_cfg = True,
|
| 69 |
+
z_dims = [int(z_dim) for z_dim in args.projector_embed_dims.split(',')],
|
| 70 |
+
encoder_depth=args.encoder_depth,
|
| 71 |
+
**block_kwargs,
|
| 72 |
+
).to(device)
|
| 73 |
+
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
|
| 74 |
+
ckpt_path = args.ckpt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 78 |
+
if ckpt_path is None:
|
| 79 |
+
args.ckpt = 'SiT-XL-2-256x256.pt'
|
| 80 |
+
assert args.model == 'SiT-XL/2'
|
| 81 |
+
assert len(args.projector_embed_dims.split(',')) == 1
|
| 82 |
+
assert int(args.projector_embed_dims.split(',')[0]) == 768
|
| 83 |
+
state_dict = download_model('last.pt')
|
| 84 |
+
else:
|
| 85 |
+
state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
|
| 86 |
+
|
| 87 |
+
if args.legacy:
|
| 88 |
+
state_dict = load_legacy_checkpoints(
|
| 89 |
+
state_dict=state_dict, encoder_depth=args.encoder_depth
|
| 90 |
+
)
|
| 91 |
+
model.load_state_dict(state_dict)
|
| 92 |
+
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
model.eval() # important!
|
| 96 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
|
| 97 |
+
#vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path="your_local_path/weight/").to(device)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Create folder to save samples:
|
| 101 |
+
model_string_name = args.model.replace("/", "-")
|
| 102 |
+
ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
|
| 103 |
+
folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.resolution}-vae-{args.vae}-" \
|
| 104 |
+
f"cfg-{args.cfg_scale}-seed-{args.global_seed}-{args.mode}-{args.guidance_high}-{args.cls_cfg_scale}"
|
| 105 |
+
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
|
| 106 |
+
if rank == 0:
|
| 107 |
+
os.makedirs(sample_folder_dir, exist_ok=True)
|
| 108 |
+
print(f"Saving .png samples at {sample_folder_dir}")
|
| 109 |
+
dist.barrier()
|
| 110 |
+
|
| 111 |
+
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
|
| 112 |
+
n = args.per_proc_batch_size
|
| 113 |
+
global_batch_size = n * dist.get_world_size()
|
| 114 |
+
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
|
| 115 |
+
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
|
| 116 |
+
if rank == 0:
|
| 117 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 118 |
+
print(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 119 |
+
print(f"projector Parameters: {sum(p.numel() for p in model.projectors.parameters()):,}")
|
| 120 |
+
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
|
| 121 |
+
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
|
| 122 |
+
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
|
| 123 |
+
iterations = int(samples_needed_this_gpu // n)
|
| 124 |
+
pbar = range(iterations)
|
| 125 |
+
pbar = tqdm(pbar) if rank == 0 else pbar
|
| 126 |
+
total = 0
|
| 127 |
+
for _ in pbar:
|
| 128 |
+
# Sample inputs:
|
| 129 |
+
z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
|
| 130 |
+
y = torch.randint(0, args.num_classes, (n,), device=device)
|
| 131 |
+
cls_z = torch.randn(n, args.cls, device=device)
|
| 132 |
+
|
| 133 |
+
# Sample images:
|
| 134 |
+
sampling_kwargs = dict(
|
| 135 |
+
model=model,
|
| 136 |
+
latents=z,
|
| 137 |
+
y=y,
|
| 138 |
+
num_steps=args.num_steps,
|
| 139 |
+
heun=args.heun,
|
| 140 |
+
cfg_scale=args.cfg_scale,
|
| 141 |
+
guidance_low=args.guidance_low,
|
| 142 |
+
guidance_high=args.guidance_high,
|
| 143 |
+
path_type=args.path_type,
|
| 144 |
+
cls_latents=cls_z,
|
| 145 |
+
args=args
|
| 146 |
+
)
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
if args.mode == "sde":
|
| 149 |
+
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
|
| 150 |
+
elif args.mode == "ode":# will support
|
| 151 |
+
exit()
|
| 152 |
+
#samples = euler_sampler(**sampling_kwargs).to(torch.float32)
|
| 153 |
+
else:
|
| 154 |
+
raise NotImplementedError()
|
| 155 |
+
|
| 156 |
+
latents_scale = torch.tensor(
|
| 157 |
+
[0.18215, 0.18215, 0.18215, 0.18215, ]
|
| 158 |
+
).view(1, 4, 1, 1).to(device)
|
| 159 |
+
latents_bias = -torch.tensor(
|
| 160 |
+
[0., 0., 0., 0.,]
|
| 161 |
+
).view(1, 4, 1, 1).to(device)
|
| 162 |
+
samples = vae.decode((samples - latents_bias) / latents_scale).sample
|
| 163 |
+
samples = (samples + 1) / 2.
|
| 164 |
+
samples = torch.clamp(
|
| 165 |
+
255. * samples, 0, 255
|
| 166 |
+
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
| 167 |
+
|
| 168 |
+
# Save samples to disk as individual .png files
|
| 169 |
+
for i, sample in enumerate(samples):
|
| 170 |
+
index = i * dist.get_world_size() + rank + total
|
| 171 |
+
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
|
| 172 |
+
total += global_batch_size
|
| 173 |
+
|
| 174 |
+
# Make sure all processes have finished saving their samples before attempting to convert to .npz
|
| 175 |
+
dist.barrier()
|
| 176 |
+
if rank == 0:
|
| 177 |
+
create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
|
| 178 |
+
print("Done.")
|
| 179 |
+
dist.barrier()
|
| 180 |
+
dist.destroy_process_group()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
parser = argparse.ArgumentParser()
|
| 185 |
+
# seed
|
| 186 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 187 |
+
|
| 188 |
+
# precision
|
| 189 |
+
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
|
| 190 |
+
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
|
| 191 |
+
|
| 192 |
+
# logging/saving:
|
| 193 |
+
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
|
| 194 |
+
parser.add_argument("--sample-dir", type=str, default="samples")
|
| 195 |
+
|
| 196 |
+
# model
|
| 197 |
+
parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-XL/2")
|
| 198 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 199 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 200 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 201 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=False)
|
| 202 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 203 |
+
# vae
|
| 204 |
+
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
|
| 205 |
+
|
| 206 |
+
# number of samples
|
| 207 |
+
parser.add_argument("--per-proc-batch-size", type=int, default=32)
|
| 208 |
+
parser.add_argument("--num-fid-samples", type=int, default=50_000)
|
| 209 |
+
|
| 210 |
+
# sampling related hyperparameters
|
| 211 |
+
parser.add_argument("--mode", type=str, default="ode")
|
| 212 |
+
parser.add_argument("--cfg-scale", type=float, default=1.5)
|
| 213 |
+
parser.add_argument("--cls-cfg-scale", type=float, default=1.5)
|
| 214 |
+
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
|
| 215 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 216 |
+
parser.add_argument("--num-steps", type=int, default=50)
|
| 217 |
+
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 218 |
+
parser.add_argument("--guidance-low", type=float, default=0.)
|
| 219 |
+
parser.add_argument("--guidance-high", type=float, default=1.)
|
| 220 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 221 |
+
parser.add_argument('--cls', default=768, type=int)
|
| 222 |
+
# will be deprecated
|
| 223 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
args = parser.parse_args()
|
| 227 |
+
main(args)
|
REG/loss.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
except ImportError:
|
| 8 |
+
linear_sum_assignment = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def ot_pair_noise_to_cls(noise_cls, cls_gt):
|
| 12 |
+
"""
|
| 13 |
+
Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致):
|
| 14 |
+
在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。
|
| 15 |
+
noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。
|
| 16 |
+
"""
|
| 17 |
+
n = noise_cls.shape[0]
|
| 18 |
+
if n <= 1:
|
| 19 |
+
return noise_cls, cls_gt
|
| 20 |
+
if linear_sum_assignment is None:
|
| 21 |
+
return noise_cls, cls_gt
|
| 22 |
+
x0 = noise_cls.detach().float().reshape(n, -1)
|
| 23 |
+
x1 = cls_gt.detach().float().reshape(n, -1)
|
| 24 |
+
M = torch.cdist(x0, x1) ** 2
|
| 25 |
+
_, j = linear_sum_assignment(M.cpu().numpy())
|
| 26 |
+
j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long)
|
| 27 |
+
return noise_cls[j], cls_gt
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def mean_flat(x):
|
| 31 |
+
"""
|
| 32 |
+
Take the mean over all non-batch dimensions.
|
| 33 |
+
"""
|
| 34 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
| 35 |
+
|
| 36 |
+
def sum_flat(x):
|
| 37 |
+
"""
|
| 38 |
+
Take the mean over all non-batch dimensions.
|
| 39 |
+
"""
|
| 40 |
+
return torch.sum(x, dim=list(range(1, len(x.size()))))
|
| 41 |
+
|
| 42 |
+
class SILoss:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
prediction='v',
|
| 46 |
+
path_type="linear",
|
| 47 |
+
weighting="uniform",
|
| 48 |
+
encoders=[],
|
| 49 |
+
accelerator=None,
|
| 50 |
+
latents_scale=None,
|
| 51 |
+
latents_bias=None,
|
| 52 |
+
t_c=0.5,
|
| 53 |
+
ot_cls=True,
|
| 54 |
+
):
|
| 55 |
+
self.prediction = prediction
|
| 56 |
+
self.weighting = weighting
|
| 57 |
+
self.path_type = path_type
|
| 58 |
+
self.encoders = encoders
|
| 59 |
+
self.accelerator = accelerator
|
| 60 |
+
self.latents_scale = latents_scale
|
| 61 |
+
self.latents_bias = latents_bias
|
| 62 |
+
# t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。
|
| 63 |
+
# t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道);
|
| 64 |
+
# t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。
|
| 65 |
+
tc = float(t_c)
|
| 66 |
+
self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4)
|
| 67 |
+
self.ot_cls = bool(ot_cls)
|
| 68 |
+
|
| 69 |
+
def interpolant(self, t):
|
| 70 |
+
if self.path_type == "linear":
|
| 71 |
+
alpha_t = 1 - t
|
| 72 |
+
sigma_t = t
|
| 73 |
+
d_alpha_t = -1
|
| 74 |
+
d_sigma_t = 1
|
| 75 |
+
elif self.path_type == "cosine":
|
| 76 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 77 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 78 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 79 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 80 |
+
else:
|
| 81 |
+
raise NotImplementedError()
|
| 82 |
+
|
| 83 |
+
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
|
| 84 |
+
|
| 85 |
+
def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
|
| 86 |
+
time_input=None, noises=None,):
|
| 87 |
+
if model_kwargs == None:
|
| 88 |
+
model_kwargs = {}
|
| 89 |
+
# sample timesteps
|
| 90 |
+
if time_input is None:
|
| 91 |
+
if self.weighting == "uniform":
|
| 92 |
+
time_input = torch.rand((images.shape[0], 1, 1, 1))
|
| 93 |
+
elif self.weighting == "lognormal":
|
| 94 |
+
# sample timestep according to log-normal distribution of sigmas following EDM
|
| 95 |
+
rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
|
| 96 |
+
sigma = rnd_normal.exp()
|
| 97 |
+
if self.path_type == "linear":
|
| 98 |
+
time_input = sigma / (1 + sigma)
|
| 99 |
+
elif self.path_type == "cosine":
|
| 100 |
+
time_input = 2 / np.pi * torch.atan(sigma)
|
| 101 |
+
|
| 102 |
+
time_input = time_input.to(device=images.device, dtype=torch.float32)
|
| 103 |
+
cls_token = cls_token.to(device=images.device, dtype=torch.float32)
|
| 104 |
+
|
| 105 |
+
if noises is None:
|
| 106 |
+
noises = torch.randn_like(images)
|
| 107 |
+
noises_cls = torch.randn_like(cls_token)
|
| 108 |
+
else:
|
| 109 |
+
if isinstance(noises, (tuple, list)) and len(noises) == 2:
|
| 110 |
+
noises, noises_cls = noises
|
| 111 |
+
else:
|
| 112 |
+
noises_cls = torch.randn_like(cls_token)
|
| 113 |
+
|
| 114 |
+
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
|
| 115 |
+
|
| 116 |
+
model_input = alpha_t * images + sigma_t * noises
|
| 117 |
+
if self.prediction == 'v':
|
| 118 |
+
model_target = d_alpha_t * images + d_sigma_t * noises
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError()
|
| 121 |
+
|
| 122 |
+
N = images.shape[0]
|
| 123 |
+
t_flat = time_input.view(-1).float()
|
| 124 |
+
high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1)))
|
| 125 |
+
low_noise_mask = 1.0 - high_noise_mask
|
| 126 |
+
|
| 127 |
+
noise_cls_raw = noises_cls
|
| 128 |
+
if self.ot_cls:
|
| 129 |
+
noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token)
|
| 130 |
+
else:
|
| 131 |
+
noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token
|
| 132 |
+
|
| 133 |
+
tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1)
|
| 134 |
+
tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8)
|
| 135 |
+
tau = torch.clamp(tau, 0.0, 1.0)
|
| 136 |
+
alpha_sem = 1.0 - tau
|
| 137 |
+
sigma_sem = tau
|
| 138 |
+
|
| 139 |
+
cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired
|
| 140 |
+
cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token
|
| 141 |
+
cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4)
|
| 142 |
+
cls_t = torch.clamp(cls_t, -1e4, 1e4)
|
| 143 |
+
|
| 144 |
+
cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask
|
| 145 |
+
|
| 146 |
+
inv_scale = 1.0 / (1.0 - self.t_c + 1e-8)
|
| 147 |
+
v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale
|
| 148 |
+
v_cls_target = high_noise_mask * v_cls_high
|
| 149 |
+
|
| 150 |
+
model_output, zs_tilde, cls_output = model(
|
| 151 |
+
model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
#denoising_loss
|
| 155 |
+
denoising_loss = mean_flat((model_output - model_target) ** 2)
|
| 156 |
+
denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2)
|
| 157 |
+
|
| 158 |
+
# projection loss
|
| 159 |
+
proj_loss = 0.
|
| 160 |
+
bsz = zs[0].shape[0]
|
| 161 |
+
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
|
| 162 |
+
for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
|
| 163 |
+
z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
|
| 164 |
+
z_j = torch.nn.functional.normalize(z_j, dim=-1)
|
| 165 |
+
proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
|
| 166 |
+
proj_loss /= (len(zs) * bsz)
|
| 167 |
+
|
| 168 |
+
return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
|
| 169 |
+
|
| 170 |
+
def tc_velocity_loss(self, model, images, model_kwargs=None, cls_token=None, noises=None):
|
| 171 |
+
"""
|
| 172 |
+
额外约束:在 t=t_c 处直接监督图像 velocity 场,增强单步(t_c -> 0)稳定性。
|
| 173 |
+
仅作用于图像分支,不改变原有 cls/projection 主损失定义。
|
| 174 |
+
"""
|
| 175 |
+
if model_kwargs is None:
|
| 176 |
+
model_kwargs = {}
|
| 177 |
+
if cls_token is None:
|
| 178 |
+
raise ValueError("tc_velocity_loss requires cls_token")
|
| 179 |
+
if noises is None:
|
| 180 |
+
noises = torch.randn_like(images)
|
| 181 |
+
|
| 182 |
+
bsz = images.shape[0]
|
| 183 |
+
time_input = torch.full(
|
| 184 |
+
(bsz, 1, 1, 1), float(self.t_c), device=images.device, dtype=torch.float32
|
| 185 |
+
)
|
| 186 |
+
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
|
| 187 |
+
model_input = alpha_t * images + sigma_t * noises
|
| 188 |
+
model_target = d_alpha_t * images + d_sigma_t * noises
|
| 189 |
+
|
| 190 |
+
model_output, _, _ = model(
|
| 191 |
+
model_input, time_input.flatten(), **model_kwargs, cls_token=cls_token
|
| 192 |
+
)
|
| 193 |
+
return mean_flat((model_output - model_target) ** 2)
|
REG/requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- pip:
|
| 2 |
+
absl-py==2.2.2
|
| 3 |
+
accelerate==1.2.1
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.11.16
|
| 6 |
+
aiosignal==1.3.2
|
| 7 |
+
astunparse==1.6.3
|
| 8 |
+
async-timeout==5.0.1
|
| 9 |
+
attrs==25.3.0
|
| 10 |
+
certifi==2022.12.7
|
| 11 |
+
charset-normalizer==2.1.1
|
| 12 |
+
click==8.1.8
|
| 13 |
+
datasets==2.20.0
|
| 14 |
+
diffusers==0.32.1
|
| 15 |
+
dill==0.3.8
|
| 16 |
+
docker-pycreds==0.4.0
|
| 17 |
+
einops==0.8.1
|
| 18 |
+
filelock==3.13.1
|
| 19 |
+
flatbuffers==25.2.10
|
| 20 |
+
frozenlist==1.5.0
|
| 21 |
+
fsspec==2024.5.0
|
| 22 |
+
ftfy==6.3.1
|
| 23 |
+
gast==0.6.0
|
| 24 |
+
gitdb==4.0.12
|
| 25 |
+
gitpython==3.1.44
|
| 26 |
+
google-pasta==0.2.0
|
| 27 |
+
grpcio==1.71.0
|
| 28 |
+
h5py==3.13.0
|
| 29 |
+
huggingface-hub==0.27.1
|
| 30 |
+
idna==3.4
|
| 31 |
+
importlib-metadata==8.6.1
|
| 32 |
+
jinja2==3.1.4
|
| 33 |
+
joblib==1.4.2
|
| 34 |
+
keras==3.9.2
|
| 35 |
+
libclang==18.1.1
|
| 36 |
+
markdown==3.8
|
| 37 |
+
markdown-it-py==3.0.0
|
| 38 |
+
markupsafe==2.1.5
|
| 39 |
+
mdurl==0.1.2
|
| 40 |
+
ml-dtypes==0.3.2
|
| 41 |
+
mpmath==1.3.0
|
| 42 |
+
multidict==6.4.3
|
| 43 |
+
multiprocess==0.70.16
|
| 44 |
+
namex==0.0.8
|
| 45 |
+
networkx==3.3
|
| 46 |
+
numpy==1.26.4
|
| 47 |
+
opt-einsum==3.4.0
|
| 48 |
+
optree==0.15.0
|
| 49 |
+
packaging==24.2
|
| 50 |
+
pandas==2.2.3
|
| 51 |
+
pillow==11.0.0
|
| 52 |
+
platformdirs==4.3.7
|
| 53 |
+
propcache==0.3.1
|
| 54 |
+
protobuf==4.25.6
|
| 55 |
+
psutil==7.0.0
|
| 56 |
+
pyarrow==19.0.1
|
| 57 |
+
pyarrow-hotfix==0.6
|
| 58 |
+
pygments==2.19.1
|
| 59 |
+
python-dateutil==2.9.0.post0
|
| 60 |
+
pytz==2025.2
|
| 61 |
+
pyyaml==6.0.2
|
| 62 |
+
regex==2024.11.6
|
| 63 |
+
requests==2.32.3
|
| 64 |
+
rich==14.0.0
|
| 65 |
+
safetensors==0.5.3
|
| 66 |
+
scikit-learn==1.5.1
|
| 67 |
+
scipy==1.15.2
|
| 68 |
+
sentry-sdk==2.26.1
|
| 69 |
+
setproctitle==1.3.5
|
| 70 |
+
six==1.17.0
|
| 71 |
+
smmap==5.0.2
|
| 72 |
+
sympy==1.13.1
|
| 73 |
+
tensorboard==2.16.1
|
| 74 |
+
tensorboard-data-server==0.7.2
|
| 75 |
+
tensorflow==2.16.1
|
| 76 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
| 77 |
+
termcolor==3.0.1
|
| 78 |
+
tf-keras==2.16.0
|
| 79 |
+
threadpoolctl==3.6.0
|
| 80 |
+
timm==1.0.12
|
| 81 |
+
tokenizers==0.21.0
|
| 82 |
+
tqdm==4.67.1
|
| 83 |
+
transformers==4.47.0
|
| 84 |
+
triton==2.1.0
|
| 85 |
+
typing-extensions==4.12.2
|
| 86 |
+
tzdata==2025.2
|
| 87 |
+
urllib3==1.26.13
|
| 88 |
+
wandb==0.17.6
|
| 89 |
+
wcwidth==0.2.13
|
| 90 |
+
werkzeug==3.1.3
|
| 91 |
+
wrapt==1.17.2
|
| 92 |
+
xformer==1.0.1
|
| 93 |
+
xformers==0.0.23
|
| 94 |
+
xxhash==3.5.0
|
| 95 |
+
yarl==1.20.0
|
| 96 |
+
zipp==3.21.0
|
| 97 |
+
|
REG/sample_from_checkpoint.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
从 REG/train.py 保存的检查点加载权重,在指定目录生成若干 PNG。
|
| 4 |
+
|
| 5 |
+
示例:
|
| 6 |
+
python sample_from_checkpoint.py \\
|
| 7 |
+
--ckpt exps/jsflow-experiment/checkpoints/0050000.pt \\
|
| 8 |
+
--out-dir ./samples_gen \\
|
| 9 |
+
--num-images 64 \\
|
| 10 |
+
--batch-size 8
|
| 11 |
+
|
| 12 |
+
# 按训练 t_c 分段分配步数(t=1→t_c 与 t_c→0;--t-c 可省略若检查点含 t_c):
|
| 13 |
+
python sample_from_checkpoint.py ... \\
|
| 14 |
+
--steps-before-tc 150 --steps-after-tc 100 --t-c 0.5
|
| 15 |
+
|
| 16 |
+
# 同一批初始噪声连跑两种 t_c 后段步数(输出到 out-dir 下子目录):
|
| 17 |
+
python sample_from_checkpoint.py ... \\
|
| 18 |
+
--steps-before-tc 150 --steps-after-tc 5 --dual-compare-after
|
| 19 |
+
# 分段时会在 at_tc/(或 at_tc/after_input、at_tc/after_equal_before)额外保存 t≈t_c 的解码图。
|
| 20 |
+
|
| 21 |
+
检查点需包含 train.py 写入的键:ema(或 model)、args(推荐,用于自动还原结构)。
|
| 22 |
+
若缺少 args,需通过命令行显式传入 --model、--resolution、--enc-type 等。
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import os
|
| 29 |
+
import sys
|
| 30 |
+
import types
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from diffusers.models import AutoencoderKL
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
from models.sit import SiT_models
|
| 38 |
+
from samplers import (
|
| 39 |
+
euler_maruyama_image_noise_before_tc_sampler,
|
| 40 |
+
euler_maruyama_image_noise_sampler,
|
| 41 |
+
euler_maruyama_sampler,
|
| 42 |
+
euler_ode_sampler,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 47 |
+
"""与 train.py 一致:按 enc_type 推断语义/class token 维度。"""
|
| 48 |
+
if enc_type is None:
|
| 49 |
+
return 768
|
| 50 |
+
s = str(enc_type).lower()
|
| 51 |
+
if "vit-g" in s or "vitg" in s:
|
| 52 |
+
return 1536
|
| 53 |
+
if "vit-l" in s or "vitl" in s:
|
| 54 |
+
return 1024
|
| 55 |
+
if "vit-s" in s or "vits" in s:
|
| 56 |
+
return 384
|
| 57 |
+
return 768
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
|
| 61 |
+
a = ckpt.get("args")
|
| 62 |
+
if a is None:
|
| 63 |
+
return None
|
| 64 |
+
if isinstance(a, argparse.Namespace):
|
| 65 |
+
return a
|
| 66 |
+
if isinstance(a, dict):
|
| 67 |
+
return argparse.Namespace(**a)
|
| 68 |
+
if isinstance(a, types.SimpleNamespace):
|
| 69 |
+
return argparse.Namespace(**vars(a))
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_vae(device: torch.device):
|
| 74 |
+
"""与 train.py 相同策略:优先本地 diffusers 缓存中的 sd-vae-ft-mse。"""
|
| 75 |
+
try:
|
| 76 |
+
from preprocessing import dnnlib
|
| 77 |
+
|
| 78 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 79 |
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
| 80 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 81 |
+
os.environ["HF_HOME"] = cache_dir
|
| 82 |
+
try:
|
| 83 |
+
vae = AutoencoderKL.from_pretrained(
|
| 84 |
+
"stabilityai/sd-vae-ft-mse",
|
| 85 |
+
cache_dir=cache_dir,
|
| 86 |
+
local_files_only=True,
|
| 87 |
+
).to(device)
|
| 88 |
+
vae.eval()
|
| 89 |
+
print(f"Loaded VAE from local cache: {cache_dir}")
|
| 90 |
+
return vae
|
| 91 |
+
except Exception:
|
| 92 |
+
pass
|
| 93 |
+
candidate_dir = None
|
| 94 |
+
for root_dir in [
|
| 95 |
+
cache_dir,
|
| 96 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 97 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 98 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 99 |
+
]:
|
| 100 |
+
if not os.path.isdir(root_dir):
|
| 101 |
+
continue
|
| 102 |
+
for root, _, files in os.walk(root_dir):
|
| 103 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 104 |
+
candidate_dir = root
|
| 105 |
+
break
|
| 106 |
+
if candidate_dir is not None:
|
| 107 |
+
break
|
| 108 |
+
if candidate_dir is not None:
|
| 109 |
+
vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
|
| 110 |
+
vae.eval()
|
| 111 |
+
print(f"Loaded VAE from {candidate_dir}")
|
| 112 |
+
return vae
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"VAE local cache search failed: {e}", file=sys.stderr)
|
| 115 |
+
try:
|
| 116 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
|
| 117 |
+
vae.eval()
|
| 118 |
+
print("Loaded VAE from Hub: stabilityai/sd-vae-ft-mse")
|
| 119 |
+
return vae
|
| 120 |
+
except Exception as e:
|
| 121 |
+
raise RuntimeError(
|
| 122 |
+
"无法加载 VAE stabilityai/sd-vae-ft-mse,请确认已下载或网络可用。"
|
| 123 |
+
) from e
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
|
| 127 |
+
res = int(getattr(ta, "resolution", 256))
|
| 128 |
+
latent_size = res // 8
|
| 129 |
+
enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
|
| 130 |
+
z_dims = [semantic_dim_from_enc_type(enc_type)]
|
| 131 |
+
block_kwargs = {
|
| 132 |
+
"fused_attn": getattr(ta, "fused_attn", True),
|
| 133 |
+
"qk_norm": getattr(ta, "qk_norm", False),
|
| 134 |
+
}
|
| 135 |
+
cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
|
| 136 |
+
if ta.model not in SiT_models:
|
| 137 |
+
raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
|
| 138 |
+
model = SiT_models[ta.model](
|
| 139 |
+
input_size=latent_size,
|
| 140 |
+
num_classes=int(getattr(ta, "num_classes", 1000)),
|
| 141 |
+
use_cfg=(cfg_prob > 0),
|
| 142 |
+
z_dims=z_dims,
|
| 143 |
+
encoder_depth=int(getattr(ta, "encoder_depth", 8)),
|
| 144 |
+
**block_kwargs,
|
| 145 |
+
).to(device)
|
| 146 |
+
return model, z_dims[0]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def resolve_tc_schedule(cli, ta):
|
| 150 |
+
"""
|
| 151 |
+
若同时给出 --steps-before-tc 与 --steps-after-tc:在 t_c 处分段(--t-c 缺省则用检查点 args.t_c)。
|
| 152 |
+
否则使用均匀 --num-steps(与旧版一致)。
|
| 153 |
+
"""
|
| 154 |
+
sb = cli.steps_before_tc
|
| 155 |
+
sa = cli.steps_after_tc
|
| 156 |
+
tc = cli.t_c
|
| 157 |
+
if sb is None and sa is None:
|
| 158 |
+
return None, None, None
|
| 159 |
+
if sb is None or sa is None:
|
| 160 |
+
print(
|
| 161 |
+
"使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。",
|
| 162 |
+
file=sys.stderr,
|
| 163 |
+
)
|
| 164 |
+
sys.exit(1)
|
| 165 |
+
if tc is None:
|
| 166 |
+
tc = getattr(ta, "t_c", None) if ta is not None else None
|
| 167 |
+
if tc is None:
|
| 168 |
+
print(
|
| 169 |
+
"分段采样需要 --t-c,或检查点 args 中含 t_c。",
|
| 170 |
+
file=sys.stderr,
|
| 171 |
+
)
|
| 172 |
+
sys.exit(1)
|
| 173 |
+
return float(tc), int(sb), int(sa)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def parse_cli():
|
| 177 |
+
p = argparse.ArgumentParser(description="REG 检查点采样出图(可选 ODE/EM/EM-图像噪声)")
|
| 178 |
+
p.add_argument("--ckpt", type=str, required=True, help="train.py 保存的 .pt 路径")
|
| 179 |
+
p.add_argument("--out-dir", type=str, required=True, help="输出 PNG 目录(会创建)")
|
| 180 |
+
p.add_argument("--num-images", type=int, required=True, help="生成图片总数")
|
| 181 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 182 |
+
p.add_argument("--seed", type=int, default=0)
|
| 183 |
+
p.add_argument(
|
| 184 |
+
"--weights",
|
| 185 |
+
type=str,
|
| 186 |
+
choices=("ema", "model"),
|
| 187 |
+
default="ema",
|
| 188 |
+
help="使用检查点中的 ema 或 model 权重",
|
| 189 |
+
)
|
| 190 |
+
p.add_argument("--device", type=str, default="cuda", help="如 cuda 或 cuda:0")
|
| 191 |
+
p.add_argument(
|
| 192 |
+
"--num-steps",
|
| 193 |
+
type=int,
|
| 194 |
+
default=50,
|
| 195 |
+
help="均匀时间网格时的欧拉步数(未使用 --steps-before-tc/--steps-after-tc 时生效)",
|
| 196 |
+
)
|
| 197 |
+
p.add_argument(
|
| 198 |
+
"--t-c",
|
| 199 |
+
type=float,
|
| 200 |
+
default=None,
|
| 201 |
+
help="分段时刻:t∈(t_c,1] 与 t∈[0,t_c] 两段;缺省可用检查点 args.t_c(需配合两段步数)",
|
| 202 |
+
)
|
| 203 |
+
p.add_argument(
|
| 204 |
+
"--steps-before-tc",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help="从 t=1 积分到 t=t_c 的步数(与 --steps-after-tc 成对使用)",
|
| 208 |
+
)
|
| 209 |
+
p.add_argument(
|
| 210 |
+
"--steps-after-tc",
|
| 211 |
+
type=int,
|
| 212 |
+
default=None,
|
| 213 |
+
help="从 t=t_c 积分到 t=0(经 t_floor=0.04)的步数",
|
| 214 |
+
)
|
| 215 |
+
p.add_argument("--cfg-scale", type=float, default=1.0)
|
| 216 |
+
p.add_argument("--cls-cfg-scale", type=float, default=0.0, help="cls 分支 CFG(>0 时需 cfg-scale>1)")
|
| 217 |
+
p.add_argument("--guidance-low", type=float, default=0.0)
|
| 218 |
+
p.add_argument("--guidance-high", type=float, default=1.0)
|
| 219 |
+
p.add_argument(
|
| 220 |
+
"--path-type",
|
| 221 |
+
type=str,
|
| 222 |
+
default=None,
|
| 223 |
+
choices=["linear", "cosine"],
|
| 224 |
+
help="默认从检查点 args 读取;可覆盖",
|
| 225 |
+
)
|
| 226 |
+
p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 227 |
+
# 无 args 时的兜底
|
| 228 |
+
p.add_argument("--model", type=str, default=None, help="无检查点 args 时必填;与 SiT_models 键一致,如 SiT-XL/2")
|
| 229 |
+
p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
|
| 230 |
+
p.add_argument("--num-classes", type=int, default=None)
|
| 231 |
+
p.add_argument("--encoder-depth", type=int, default=None)
|
| 232 |
+
p.add_argument("--enc-type", type=str, default=None)
|
| 233 |
+
p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
|
| 234 |
+
p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
|
| 235 |
+
p.add_argument("--cfg-prob", type=float, default=None)
|
| 236 |
+
p.add_argument(
|
| 237 |
+
"--sampler",
|
| 238 |
+
type=str,
|
| 239 |
+
default="em_image_noise",
|
| 240 |
+
choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
|
| 241 |
+
help="采样器:ode=euler_sampler 确定性漂移(linspace 1→0 或 t_c 分段直连 0,无 t_floor;与 EM 网格不同),"
|
| 242 |
+
"em=标准EM(含图像+cls噪声),em_image_noise=仅图像噪声,"
|
| 243 |
+
"em_image_noise_before_tc=t<=t_c时图像去随机+cls全程去随机",
|
| 244 |
+
)
|
| 245 |
+
p.add_argument(
|
| 246 |
+
"--dual-compare-after",
|
| 247 |
+
action="store_true",
|
| 248 |
+
help="需配合分段步数:同批 z/y/cls 连跑两次;after_input 用 --steps-after-tc,"
|
| 249 |
+
"after_equal_before 将 after 步数设为与 --steps-before-tc 相同",
|
| 250 |
+
)
|
| 251 |
+
p.add_argument(
|
| 252 |
+
"--save-fixed-trajectory",
|
| 253 |
+
action="store_true",
|
| 254 |
+
help="保存固定步采样轨迹(npy);仅对非 em 采样器启用,输出在 out-dir/trajectory",
|
| 255 |
+
)
|
| 256 |
+
return p.parse_args()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
|
| 260 |
+
imgs = vae.decode((latents - latents_bias) / latents_scale).sample
|
| 261 |
+
imgs = (imgs + 1) / 2.0
|
| 262 |
+
imgs = torch.clamp(imgs, 0, 1)
|
| 263 |
+
return (
|
| 264 |
+
(imgs * 255.0)
|
| 265 |
+
.round()
|
| 266 |
+
.to(torch.uint8)
|
| 267 |
+
.permute(0, 2, 3, 1)
|
| 268 |
+
.cpu()
|
| 269 |
+
.numpy()
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def main():
|
| 274 |
+
cli = parse_cli()
|
| 275 |
+
device = torch.device(cli.device if torch.cuda.is_available() else "cpu")
|
| 276 |
+
if device.type == "cuda":
|
| 277 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
|
| 281 |
+
except TypeError:
|
| 282 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu")
|
| 283 |
+
ta = load_train_args_from_ckpt(ckpt)
|
| 284 |
+
if ta is None:
|
| 285 |
+
if cli.model is None or cli.resolution is None or cli.enc_type is None:
|
| 286 |
+
print(
|
| 287 |
+
"检查点中无 args,请至少指定:--model --resolution --enc-type "
|
| 288 |
+
"(以及按需 --num-classes --encoder-depth)",
|
| 289 |
+
file=sys.stderr,
|
| 290 |
+
)
|
| 291 |
+
sys.exit(1)
|
| 292 |
+
ta = argparse.Namespace(
|
| 293 |
+
model=cli.model,
|
| 294 |
+
resolution=cli.resolution,
|
| 295 |
+
num_classes=cli.num_classes if cli.num_classes is not None else 1000,
|
| 296 |
+
encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
|
| 297 |
+
enc_type=cli.enc_type,
|
| 298 |
+
fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
|
| 299 |
+
qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
|
| 300 |
+
cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
if cli.model is not None:
|
| 304 |
+
ta.model = cli.model
|
| 305 |
+
if cli.resolution is not None:
|
| 306 |
+
ta.resolution = cli.resolution
|
| 307 |
+
if cli.num_classes is not None:
|
| 308 |
+
ta.num_classes = cli.num_classes
|
| 309 |
+
if cli.encoder_depth is not None:
|
| 310 |
+
ta.encoder_depth = cli.encoder_depth
|
| 311 |
+
if cli.enc_type is not None:
|
| 312 |
+
ta.enc_type = cli.enc_type
|
| 313 |
+
if cli.fused_attn is not None:
|
| 314 |
+
ta.fused_attn = cli.fused_attn
|
| 315 |
+
if cli.qk_norm is not None:
|
| 316 |
+
ta.qk_norm = cli.qk_norm
|
| 317 |
+
if cli.cfg_prob is not None:
|
| 318 |
+
ta.cfg_prob = cli.cfg_prob
|
| 319 |
+
|
| 320 |
+
path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
|
| 321 |
+
|
| 322 |
+
tc_split = resolve_tc_schedule(cli, ta)
|
| 323 |
+
if cli.dual_compare_after and tc_split[0] is None:
|
| 324 |
+
print("--dual-compare-after 必须配合 --steps-before-tc 与 --steps-after-tc(分段采样)", file=sys.stderr)
|
| 325 |
+
sys.exit(1)
|
| 326 |
+
if tc_split[0] is not None:
|
| 327 |
+
if cli.dual_compare_after:
|
| 328 |
+
print(
|
| 329 |
+
f"双次对比:t_c={tc_split[0]}, before={tc_split[1]}, "
|
| 330 |
+
f"after_input={tc_split[2]}, after_equal_before={tc_split[1]}"
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
print(
|
| 334 |
+
f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]} "
|
| 335 |
+
f"(总模型前向约 {tc_split[1] + tc_split[2] + 1} 次)"
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
print(f"时间网格:均匀 num_steps={cli.num_steps}")
|
| 339 |
+
|
| 340 |
+
if cli.sampler == "ode":
|
| 341 |
+
sampler_fn = euler_ode_sampler
|
| 342 |
+
elif cli.sampler == "em":
|
| 343 |
+
sampler_fn = euler_maruyama_sampler
|
| 344 |
+
elif cli.sampler == "em_image_noise_before_tc":
|
| 345 |
+
sampler_fn = euler_maruyama_image_noise_before_tc_sampler
|
| 346 |
+
else:
|
| 347 |
+
sampler_fn = euler_maruyama_image_noise_sampler
|
| 348 |
+
|
| 349 |
+
model, cls_dim = build_model_from_train_args(ta, device)
|
| 350 |
+
wkey = cli.weights
|
| 351 |
+
if wkey not in ckpt:
|
| 352 |
+
raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
|
| 353 |
+
state = ckpt[wkey]
|
| 354 |
+
if cli.legacy:
|
| 355 |
+
from utils import load_legacy_checkpoints
|
| 356 |
+
|
| 357 |
+
state = load_legacy_checkpoints(
|
| 358 |
+
state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
|
| 359 |
+
)
|
| 360 |
+
model.load_state_dict(state, strict=True)
|
| 361 |
+
model.eval()
|
| 362 |
+
|
| 363 |
+
vae = load_vae(device)
|
| 364 |
+
latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
|
| 365 |
+
latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
|
| 366 |
+
|
| 367 |
+
sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
|
| 368 |
+
|
| 369 |
+
at_tc_dir = at_tc_a = at_tc_b = None
|
| 370 |
+
pair_dir = None
|
| 371 |
+
traj_dir = traj_a = traj_b = None
|
| 372 |
+
if cli.dual_compare_after:
|
| 373 |
+
out_a = os.path.join(cli.out_dir, "after_input")
|
| 374 |
+
out_b = os.path.join(cli.out_dir, "after_equal_before")
|
| 375 |
+
pair_dir = os.path.join(cli.out_dir, "pair")
|
| 376 |
+
os.makedirs(out_a, exist_ok=True)
|
| 377 |
+
os.makedirs(out_b, exist_ok=True)
|
| 378 |
+
os.makedirs(pair_dir, exist_ok=True)
|
| 379 |
+
if tc_split[0] is not None:
|
| 380 |
+
at_tc_a = os.path.join(cli.out_dir, "at_tc", "after_input")
|
| 381 |
+
at_tc_b = os.path.join(cli.out_dir, "at_tc", "after_equal_before")
|
| 382 |
+
os.makedirs(at_tc_a, exist_ok=True)
|
| 383 |
+
os.makedirs(at_tc_b, exist_ok=True)
|
| 384 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 385 |
+
traj_a = os.path.join(cli.out_dir, "trajectory", "after_input")
|
| 386 |
+
traj_b = os.path.join(cli.out_dir, "trajectory", "after_equal_before")
|
| 387 |
+
os.makedirs(traj_a, exist_ok=True)
|
| 388 |
+
os.makedirs(traj_b, exist_ok=True)
|
| 389 |
+
else:
|
| 390 |
+
os.makedirs(cli.out_dir, exist_ok=True)
|
| 391 |
+
if tc_split[0] is not None:
|
| 392 |
+
at_tc_dir = os.path.join(cli.out_dir, "at_tc")
|
| 393 |
+
os.makedirs(at_tc_dir, exist_ok=True)
|
| 394 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 395 |
+
traj_dir = os.path.join(cli.out_dir, "trajectory")
|
| 396 |
+
os.makedirs(traj_dir, exist_ok=True)
|
| 397 |
+
latent_size = int(getattr(ta, "resolution", 256)) // 8
|
| 398 |
+
n_total = int(cli.num_images)
|
| 399 |
+
b = max(1, int(cli.batch_size))
|
| 400 |
+
|
| 401 |
+
torch.manual_seed(cli.seed)
|
| 402 |
+
if device.type == "cuda":
|
| 403 |
+
torch.cuda.manual_seed_all(cli.seed)
|
| 404 |
+
|
| 405 |
+
written = 0
|
| 406 |
+
pbar = tqdm(total=n_total, desc="sampling")
|
| 407 |
+
while written < n_total:
|
| 408 |
+
cur = min(b, n_total - written)
|
| 409 |
+
z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
|
| 410 |
+
y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
|
| 411 |
+
cls_z = torch.randn(cur, cls_dim, device=device)
|
| 412 |
+
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
base_kw = dict(
|
| 415 |
+
num_steps=cli.num_steps,
|
| 416 |
+
cfg_scale=cli.cfg_scale,
|
| 417 |
+
guidance_low=cli.guidance_low,
|
| 418 |
+
guidance_high=cli.guidance_high,
|
| 419 |
+
path_type=path_type,
|
| 420 |
+
cls_latents=cls_z,
|
| 421 |
+
args=sampler_args,
|
| 422 |
+
)
|
| 423 |
+
if cli.dual_compare_after:
|
| 424 |
+
tc_v, sb, sa_in = tc_split
|
| 425 |
+
# 两次完整采样会各自消耗 RNG;不重置则第二条的 1→t_c 噪声与第一条不同,z_tc/at_tc 会对不齐。
|
| 426 |
+
# 在固定 z/y/cls_z 之后打快照,第二条运行前恢复,使 t_c 中间态一致(仅后段步数不同)。
|
| 427 |
+
_rng_cpu_dual = torch.random.get_rng_state()
|
| 428 |
+
_rng_cuda_dual = (
|
| 429 |
+
torch.cuda.get_rng_state_all()
|
| 430 |
+
if device.type == "cuda"
|
| 431 |
+
else None
|
| 432 |
+
)
|
| 433 |
+
batch_imgs = {}
|
| 434 |
+
for _run_i, (subdir, sa, tc_save_dir) in enumerate(
|
| 435 |
+
(
|
| 436 |
+
(out_a, sa_in, at_tc_a),
|
| 437 |
+
(out_b, sb, at_tc_b),
|
| 438 |
+
)
|
| 439 |
+
):
|
| 440 |
+
if _run_i > 0:
|
| 441 |
+
torch.random.set_rng_state(_rng_cpu_dual)
|
| 442 |
+
if _rng_cuda_dual is not None:
|
| 443 |
+
torch.cuda.set_rng_state_all(_rng_cuda_dual)
|
| 444 |
+
em_kw = dict(base_kw)
|
| 445 |
+
em_kw["t_c"] = tc_v
|
| 446 |
+
em_kw["num_steps_before_tc"] = sb
|
| 447 |
+
em_kw["num_steps_after_tc"] = sa
|
| 448 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 449 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 450 |
+
latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
|
| 451 |
+
model,
|
| 452 |
+
z,
|
| 453 |
+
y,
|
| 454 |
+
**em_kw,
|
| 455 |
+
return_mid_state=True,
|
| 456 |
+
t_mid=float(tc_v),
|
| 457 |
+
return_cls_final=True,
|
| 458 |
+
return_trajectory=True,
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
latents, z_tc, cls_tc, cls_t0 = sampler_fn(
|
| 462 |
+
model,
|
| 463 |
+
z,
|
| 464 |
+
y,
|
| 465 |
+
**em_kw,
|
| 466 |
+
return_mid_state=True,
|
| 467 |
+
t_mid=float(tc_v),
|
| 468 |
+
return_cls_final=True,
|
| 469 |
+
)
|
| 470 |
+
traj = None
|
| 471 |
+
else:
|
| 472 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 473 |
+
latents, z_tc, cls_tc, traj = sampler_fn(
|
| 474 |
+
model,
|
| 475 |
+
z,
|
| 476 |
+
y,
|
| 477 |
+
**em_kw,
|
| 478 |
+
return_mid_state=True,
|
| 479 |
+
t_mid=float(tc_v),
|
| 480 |
+
return_trajectory=True,
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
latents, z_tc, cls_tc = sampler_fn(
|
| 484 |
+
model,
|
| 485 |
+
z,
|
| 486 |
+
y,
|
| 487 |
+
**em_kw,
|
| 488 |
+
return_mid_state=True,
|
| 489 |
+
t_mid=float(tc_v),
|
| 490 |
+
)
|
| 491 |
+
traj = None
|
| 492 |
+
cls_t0 = None
|
| 493 |
+
latents = latents.to(torch.float32)
|
| 494 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 495 |
+
batch_imgs[subdir] = imgs
|
| 496 |
+
for i in range(cur):
|
| 497 |
+
Image.fromarray(imgs[i]).save(
|
| 498 |
+
os.path.join(subdir, f"{written + i:06d}.png")
|
| 499 |
+
)
|
| 500 |
+
if tc_save_dir is not None and z_tc is not None:
|
| 501 |
+
imgs_tc = _decode_to_uint8_hwc(
|
| 502 |
+
z_tc.to(torch.float32), latents_bias, latents_scale, vae
|
| 503 |
+
)
|
| 504 |
+
for i in range(cur):
|
| 505 |
+
Image.fromarray(imgs_tc[i]).save(
|
| 506 |
+
os.path.join(tc_save_dir, f"{written + i:06d}.png")
|
| 507 |
+
)
|
| 508 |
+
if traj is not None:
|
| 509 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 510 |
+
save_traj_dir = traj_a if subdir == out_a else traj_b
|
| 511 |
+
np.save(os.path.join(save_traj_dir, f"{written:06d}_traj.npy"), traj_np)
|
| 512 |
+
imgs_a = batch_imgs.get(out_a)
|
| 513 |
+
imgs_b = batch_imgs.get(out_b)
|
| 514 |
+
if pair_dir is not None and imgs_a is not None and imgs_b is not None:
|
| 515 |
+
for i in range(cur):
|
| 516 |
+
pair_img = np.concatenate([imgs_a[i], imgs_b[i]], axis=1)
|
| 517 |
+
Image.fromarray(pair_img).save(
|
| 518 |
+
os.path.join(pair_dir, f"{written + i:06d}.png")
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
em_kw = dict(base_kw)
|
| 522 |
+
if tc_split[0] is not None:
|
| 523 |
+
em_kw["t_c"] = tc_split[0]
|
| 524 |
+
em_kw["num_steps_before_tc"] = tc_split[1]
|
| 525 |
+
em_kw["num_steps_after_tc"] = tc_split[2]
|
| 526 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 527 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 528 |
+
latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
|
| 529 |
+
model,
|
| 530 |
+
z,
|
| 531 |
+
y,
|
| 532 |
+
**em_kw,
|
| 533 |
+
return_mid_state=True,
|
| 534 |
+
t_mid=float(tc_split[0]),
|
| 535 |
+
return_cls_final=True,
|
| 536 |
+
return_trajectory=True,
|
| 537 |
+
)
|
| 538 |
+
else:
|
| 539 |
+
latents, z_tc, cls_tc, cls_t0 = sampler_fn(
|
| 540 |
+
model,
|
| 541 |
+
z,
|
| 542 |
+
y,
|
| 543 |
+
**em_kw,
|
| 544 |
+
return_mid_state=True,
|
| 545 |
+
t_mid=float(tc_split[0]),
|
| 546 |
+
return_cls_final=True,
|
| 547 |
+
)
|
| 548 |
+
traj = None
|
| 549 |
+
else:
|
| 550 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 551 |
+
latents, z_tc, cls_tc, traj = sampler_fn(
|
| 552 |
+
model,
|
| 553 |
+
z,
|
| 554 |
+
y,
|
| 555 |
+
**em_kw,
|
| 556 |
+
return_mid_state=True,
|
| 557 |
+
t_mid=float(tc_split[0]),
|
| 558 |
+
return_trajectory=True,
|
| 559 |
+
)
|
| 560 |
+
else:
|
| 561 |
+
latents, z_tc, cls_tc = sampler_fn(
|
| 562 |
+
model,
|
| 563 |
+
z,
|
| 564 |
+
y,
|
| 565 |
+
**em_kw,
|
| 566 |
+
return_mid_state=True,
|
| 567 |
+
t_mid=float(tc_split[0]),
|
| 568 |
+
)
|
| 569 |
+
traj = None
|
| 570 |
+
cls_t0 = None
|
| 571 |
+
latents = latents.to(torch.float32)
|
| 572 |
+
if z_tc is not None and at_tc_dir is not None:
|
| 573 |
+
imgs_tc = _decode_to_uint8_hwc(
|
| 574 |
+
z_tc.to(torch.float32), latents_bias, latents_scale, vae
|
| 575 |
+
)
|
| 576 |
+
for i in range(cur):
|
| 577 |
+
Image.fromarray(imgs_tc[i]).save(
|
| 578 |
+
os.path.join(at_tc_dir, f"{written + i:06d}.png")
|
| 579 |
+
)
|
| 580 |
+
if traj is not None and traj_dir is not None:
|
| 581 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 582 |
+
np.save(os.path.join(traj_dir, f"{written:06d}_traj.npy"), traj_np)
|
| 583 |
+
else:
|
| 584 |
+
latents = sampler_fn(model, z, y, **em_kw).to(torch.float32)
|
| 585 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 586 |
+
for i in range(cur):
|
| 587 |
+
Image.fromarray(imgs[i]).save(
|
| 588 |
+
os.path.join(cli.out_dir, f"{written + i:06d}.png")
|
| 589 |
+
)
|
| 590 |
+
written += cur
|
| 591 |
+
pbar.update(cur)
|
| 592 |
+
pbar.close()
|
| 593 |
+
if cli.dual_compare_after:
|
| 594 |
+
msg = (
|
| 595 |
+
f"Done. Saved {written} images per run under {out_a} and {out_b} "
|
| 596 |
+
f"(parent: {cli.out_dir})"
|
| 597 |
+
)
|
| 598 |
+
if pair_dir is not None:
|
| 599 |
+
msg += f"; paired comparisons under {pair_dir}"
|
| 600 |
+
if tc_split[0] is not None and at_tc_a is not None:
|
| 601 |
+
msg += f"; t≈t_c decoded under {at_tc_a} and {at_tc_b}"
|
| 602 |
+
print(msg)
|
| 603 |
+
else:
|
| 604 |
+
msg = f"Done. Saved {written} images under {cli.out_dir}"
|
| 605 |
+
if tc_split[0] is not None and at_tc_dir is not None:
|
| 606 |
+
msg += f"; t≈t_c decoded under {at_tc_dir}"
|
| 607 |
+
print(msg)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
main()
|
REG/sample_from_checkpoint_ddp.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。
|
| 4 |
+
|
| 5 |
+
用法(4 卡示例):
|
| 6 |
+
torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \
|
| 7 |
+
--ckpt exps/jsflow-experiment/checkpoints/0290000.pt \
|
| 8 |
+
--out-dir ./my_samples_ddp \
|
| 9 |
+
--num-images 50000 \
|
| 10 |
+
--batch-size 16 \
|
| 11 |
+
--t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \
|
| 12 |
+
--sampler em_image_noise_before_tc
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import types
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
from diffusers.models import AutoencoderKL
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
from models.sit import SiT_models
|
| 31 |
+
from samplers import (
|
| 32 |
+
euler_maruyama_image_noise_before_tc_sampler,
|
| 33 |
+
euler_maruyama_image_noise_sampler,
|
| 34 |
+
euler_maruyama_sampler,
|
| 35 |
+
euler_ode_sampler,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_npz_from_sample_folder(sample_dir: str, num: int):
|
| 40 |
+
"""
|
| 41 |
+
将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。
|
| 42 |
+
"""
|
| 43 |
+
samples = []
|
| 44 |
+
for i in tqdm(range(num), desc="Building .npz file from samples"):
|
| 45 |
+
sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png"))
|
| 46 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
| 47 |
+
samples.append(sample_np)
|
| 48 |
+
samples = np.stack(samples)
|
| 49 |
+
npz_path = f"{sample_dir}.npz"
|
| 50 |
+
np.savez(npz_path, arr_0=samples)
|
| 51 |
+
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
|
| 52 |
+
return npz_path
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 56 |
+
if enc_type is None:
|
| 57 |
+
return 768
|
| 58 |
+
s = str(enc_type).lower()
|
| 59 |
+
if "vit-g" in s or "vitg" in s:
|
| 60 |
+
return 1536
|
| 61 |
+
if "vit-l" in s or "vitl" in s:
|
| 62 |
+
return 1024
|
| 63 |
+
if "vit-s" in s or "vits" in s:
|
| 64 |
+
return 384
|
| 65 |
+
return 768
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
|
| 69 |
+
a = ckpt.get("args")
|
| 70 |
+
if a is None:
|
| 71 |
+
return None
|
| 72 |
+
if isinstance(a, argparse.Namespace):
|
| 73 |
+
return a
|
| 74 |
+
if isinstance(a, dict):
|
| 75 |
+
return argparse.Namespace(**a)
|
| 76 |
+
if isinstance(a, types.SimpleNamespace):
|
| 77 |
+
return argparse.Namespace(**vars(a))
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_vae(device: torch.device):
|
| 82 |
+
try:
|
| 83 |
+
from preprocessing import dnnlib
|
| 84 |
+
|
| 85 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 86 |
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
| 87 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 88 |
+
os.environ["HF_HOME"] = cache_dir
|
| 89 |
+
try:
|
| 90 |
+
vae = AutoencoderKL.from_pretrained(
|
| 91 |
+
"stabilityai/sd-vae-ft-mse",
|
| 92 |
+
cache_dir=cache_dir,
|
| 93 |
+
local_files_only=True,
|
| 94 |
+
).to(device)
|
| 95 |
+
vae.eval()
|
| 96 |
+
return vae
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
candidate_dir = None
|
| 100 |
+
for root_dir in [
|
| 101 |
+
cache_dir,
|
| 102 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 103 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 104 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 105 |
+
]:
|
| 106 |
+
if not os.path.isdir(root_dir):
|
| 107 |
+
continue
|
| 108 |
+
for root, _, files in os.walk(root_dir):
|
| 109 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 110 |
+
candidate_dir = root
|
| 111 |
+
break
|
| 112 |
+
if candidate_dir is not None:
|
| 113 |
+
break
|
| 114 |
+
if candidate_dir is not None:
|
| 115 |
+
vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
|
| 116 |
+
vae.eval()
|
| 117 |
+
return vae
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
|
| 121 |
+
vae.eval()
|
| 122 |
+
return vae
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
|
| 126 |
+
res = int(getattr(ta, "resolution", 256))
|
| 127 |
+
latent_size = res // 8
|
| 128 |
+
enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
|
| 129 |
+
z_dims = [semantic_dim_from_enc_type(enc_type)]
|
| 130 |
+
block_kwargs = {
|
| 131 |
+
"fused_attn": getattr(ta, "fused_attn", True),
|
| 132 |
+
"qk_norm": getattr(ta, "qk_norm", False),
|
| 133 |
+
}
|
| 134 |
+
cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
|
| 135 |
+
if ta.model not in SiT_models:
|
| 136 |
+
raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
|
| 137 |
+
model = SiT_models[ta.model](
|
| 138 |
+
input_size=latent_size,
|
| 139 |
+
num_classes=int(getattr(ta, "num_classes", 1000)),
|
| 140 |
+
use_cfg=(cfg_prob > 0),
|
| 141 |
+
z_dims=z_dims,
|
| 142 |
+
encoder_depth=int(getattr(ta, "encoder_depth", 8)),
|
| 143 |
+
**block_kwargs,
|
| 144 |
+
).to(device)
|
| 145 |
+
return model, z_dims[0]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def resolve_tc_schedule(cli, ta):
|
| 149 |
+
sb = cli.steps_before_tc
|
| 150 |
+
sa = cli.steps_after_tc
|
| 151 |
+
tc = cli.t_c
|
| 152 |
+
if sb is None and sa is None:
|
| 153 |
+
return None, None, None
|
| 154 |
+
if sb is None or sa is None:
|
| 155 |
+
print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr)
|
| 156 |
+
sys.exit(1)
|
| 157 |
+
if tc is None:
|
| 158 |
+
tc = getattr(ta, "t_c", None) if ta is not None else None
|
| 159 |
+
if tc is None:
|
| 160 |
+
print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr)
|
| 161 |
+
sys.exit(1)
|
| 162 |
+
return float(tc), int(sb), int(sa)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def parse_cli():
|
| 166 |
+
p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)")
|
| 167 |
+
p.add_argument("--ckpt", type=str, required=True)
|
| 168 |
+
p.add_argument("--out-dir", type=str, required=True)
|
| 169 |
+
p.add_argument("--num-images", type=int, required=True)
|
| 170 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 171 |
+
p.add_argument("--seed", type=int, default=0)
|
| 172 |
+
p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema")
|
| 173 |
+
p.add_argument("--device", type=str, default="cuda")
|
| 174 |
+
p.add_argument("--num-steps", type=int, default=50)
|
| 175 |
+
p.add_argument("--t-c", type=float, default=None)
|
| 176 |
+
p.add_argument("--steps-before-tc", type=int, default=None)
|
| 177 |
+
p.add_argument("--steps-after-tc", type=int, default=None)
|
| 178 |
+
p.add_argument("--cfg-scale", type=float, default=1.0)
|
| 179 |
+
p.add_argument("--cls-cfg-scale", type=float, default=0.0)
|
| 180 |
+
p.add_argument("--guidance-low", type=float, default=0.0)
|
| 181 |
+
p.add_argument("--guidance-high", type=float, default=1.0)
|
| 182 |
+
p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"])
|
| 183 |
+
p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 184 |
+
p.add_argument("--model", type=str, default=None)
|
| 185 |
+
p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
|
| 186 |
+
p.add_argument("--num-classes", type=int, default=1000)
|
| 187 |
+
p.add_argument("--encoder-depth", type=int, default=None)
|
| 188 |
+
p.add_argument("--enc-type", type=str, default=None)
|
| 189 |
+
p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
|
| 190 |
+
p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
|
| 191 |
+
p.add_argument("--cfg-prob", type=float, default=None)
|
| 192 |
+
p.add_argument(
|
| 193 |
+
"--sampler",
|
| 194 |
+
type=str,
|
| 195 |
+
default="em_image_noise_before_tc",
|
| 196 |
+
choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
|
| 197 |
+
)
|
| 198 |
+
p.add_argument(
|
| 199 |
+
"--save-fixed-trajectory",
|
| 200 |
+
action="store_true",
|
| 201 |
+
help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}",
|
| 202 |
+
)
|
| 203 |
+
p.add_argument(
|
| 204 |
+
"--save-npz",
|
| 205 |
+
action=argparse.BooleanOptionalAction,
|
| 206 |
+
default=True,
|
| 207 |
+
help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz",
|
| 208 |
+
)
|
| 209 |
+
return p.parse_args()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
|
| 213 |
+
imgs = vae.decode((latents - latents_bias) / latents_scale).sample
|
| 214 |
+
imgs = (imgs + 1) / 2.0
|
| 215 |
+
imgs = torch.clamp(imgs, 0, 1)
|
| 216 |
+
return (
|
| 217 |
+
(imgs * 255.0)
|
| 218 |
+
.round()
|
| 219 |
+
.to(torch.uint8)
|
| 220 |
+
.permute(0, 2, 3, 1)
|
| 221 |
+
.cpu()
|
| 222 |
+
.numpy()
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def init_ddp():
|
| 227 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 228 |
+
rank = int(os.environ["RANK"])
|
| 229 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 230 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 231 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 232 |
+
torch.cuda.set_device(local_rank)
|
| 233 |
+
return True, rank, world_size, local_rank
|
| 234 |
+
return False, 0, 1, 0
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
cli = parse_cli()
|
| 239 |
+
use_ddp, rank, world_size, local_rank = init_ddp()
|
| 240 |
+
|
| 241 |
+
if torch.cuda.is_available():
|
| 242 |
+
device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device)
|
| 243 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 244 |
+
else:
|
| 245 |
+
device = torch.device("cpu")
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
|
| 249 |
+
except TypeError:
|
| 250 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu")
|
| 251 |
+
ta = load_train_args_from_ckpt(ckpt)
|
| 252 |
+
if ta is None:
|
| 253 |
+
if cli.model is None or cli.resolution is None or cli.enc_type is None:
|
| 254 |
+
print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr)
|
| 255 |
+
sys.exit(1)
|
| 256 |
+
ta = argparse.Namespace(
|
| 257 |
+
model=cli.model,
|
| 258 |
+
resolution=cli.resolution,
|
| 259 |
+
num_classes=cli.num_classes if cli.num_classes is not None else 1000,
|
| 260 |
+
encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
|
| 261 |
+
enc_type=cli.enc_type,
|
| 262 |
+
fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
|
| 263 |
+
qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
|
| 264 |
+
cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
if cli.model is not None:
|
| 268 |
+
ta.model = cli.model
|
| 269 |
+
if cli.resolution is not None:
|
| 270 |
+
ta.resolution = cli.resolution
|
| 271 |
+
if cli.num_classes is not None:
|
| 272 |
+
ta.num_classes = cli.num_classes
|
| 273 |
+
if cli.encoder_depth is not None:
|
| 274 |
+
ta.encoder_depth = cli.encoder_depth
|
| 275 |
+
if cli.enc_type is not None:
|
| 276 |
+
ta.enc_type = cli.enc_type
|
| 277 |
+
if cli.fused_attn is not None:
|
| 278 |
+
ta.fused_attn = cli.fused_attn
|
| 279 |
+
if cli.qk_norm is not None:
|
| 280 |
+
ta.qk_norm = cli.qk_norm
|
| 281 |
+
if cli.cfg_prob is not None:
|
| 282 |
+
ta.cfg_prob = cli.cfg_prob
|
| 283 |
+
|
| 284 |
+
path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
|
| 285 |
+
tc_split = resolve_tc_schedule(cli, ta)
|
| 286 |
+
|
| 287 |
+
if rank == 0:
|
| 288 |
+
if tc_split[0] is not None:
|
| 289 |
+
print(
|
| 290 |
+
f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}"
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
print(f"时间网格:均匀 num_steps={cli.num_steps}")
|
| 294 |
+
|
| 295 |
+
if cli.sampler == "ode":
|
| 296 |
+
sampler_fn = euler_ode_sampler
|
| 297 |
+
elif cli.sampler == "em":
|
| 298 |
+
sampler_fn = euler_maruyama_sampler
|
| 299 |
+
elif cli.sampler == "em_image_noise_before_tc":
|
| 300 |
+
sampler_fn = euler_maruyama_image_noise_before_tc_sampler
|
| 301 |
+
else:
|
| 302 |
+
sampler_fn = euler_maruyama_image_noise_sampler
|
| 303 |
+
|
| 304 |
+
model, cls_dim = build_model_from_train_args(ta, device)
|
| 305 |
+
wkey = cli.weights
|
| 306 |
+
if wkey not in ckpt:
|
| 307 |
+
raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
|
| 308 |
+
state = ckpt[wkey]
|
| 309 |
+
if cli.legacy:
|
| 310 |
+
from utils import load_legacy_checkpoints
|
| 311 |
+
|
| 312 |
+
state = load_legacy_checkpoints(
|
| 313 |
+
state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
|
| 314 |
+
)
|
| 315 |
+
model.load_state_dict(state, strict=True)
|
| 316 |
+
model.eval()
|
| 317 |
+
|
| 318 |
+
vae = load_vae(device)
|
| 319 |
+
latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
|
| 320 |
+
latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
|
| 321 |
+
sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
|
| 322 |
+
|
| 323 |
+
os.makedirs(cli.out_dir, exist_ok=True)
|
| 324 |
+
traj_dir = None
|
| 325 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 326 |
+
traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}")
|
| 327 |
+
os.makedirs(traj_dir, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
latent_size = int(getattr(ta, "resolution", 256)) // 8
|
| 330 |
+
n_total = int(cli.num_images)
|
| 331 |
+
b = max(1, int(cli.batch_size))
|
| 332 |
+
global_batch_size = b * world_size
|
| 333 |
+
total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size)
|
| 334 |
+
samples_needed_this_gpu = int(total_samples // world_size)
|
| 335 |
+
if samples_needed_this_gpu % b != 0:
|
| 336 |
+
raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size")
|
| 337 |
+
iterations = int(samples_needed_this_gpu // b)
|
| 338 |
+
|
| 339 |
+
seed_rank = int(cli.seed) + int(rank)
|
| 340 |
+
torch.manual_seed(seed_rank)
|
| 341 |
+
if device.type == "cuda":
|
| 342 |
+
torch.cuda.manual_seed_all(seed_rank)
|
| 343 |
+
|
| 344 |
+
if rank == 0:
|
| 345 |
+
print(f"Total number of images that will be sampled: {total_samples}")
|
| 346 |
+
pbar = range(iterations)
|
| 347 |
+
pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar
|
| 348 |
+
total = 0
|
| 349 |
+
written_local = 0
|
| 350 |
+
for _ in pbar:
|
| 351 |
+
cur = b
|
| 352 |
+
z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
|
| 353 |
+
y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
|
| 354 |
+
cls_z = torch.randn(cur, cls_dim, device=device)
|
| 355 |
+
|
| 356 |
+
with torch.no_grad():
|
| 357 |
+
em_kw = dict(
|
| 358 |
+
num_steps=cli.num_steps,
|
| 359 |
+
cfg_scale=cli.cfg_scale,
|
| 360 |
+
guidance_low=cli.guidance_low,
|
| 361 |
+
guidance_high=cli.guidance_high,
|
| 362 |
+
path_type=path_type,
|
| 363 |
+
cls_latents=cls_z,
|
| 364 |
+
args=sampler_args,
|
| 365 |
+
)
|
| 366 |
+
if tc_split[0] is not None:
|
| 367 |
+
em_kw["t_c"] = tc_split[0]
|
| 368 |
+
em_kw["num_steps_before_tc"] = tc_split[1]
|
| 369 |
+
em_kw["num_steps_after_tc"] = tc_split[2]
|
| 370 |
+
|
| 371 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 372 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 373 |
+
latents, traj = sampler_fn(
|
| 374 |
+
model, z, y, **em_kw, return_trajectory=True
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
latents, traj = sampler_fn(
|
| 378 |
+
model, z, y, **em_kw, return_trajectory=True
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
latents = sampler_fn(model, z, y, **em_kw)
|
| 382 |
+
traj = None
|
| 383 |
+
|
| 384 |
+
latents = latents.to(torch.float32)
|
| 385 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 386 |
+
for i, img in enumerate(imgs):
|
| 387 |
+
gidx = i * world_size + rank + total
|
| 388 |
+
if gidx < n_total:
|
| 389 |
+
Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png"))
|
| 390 |
+
written_local += 1
|
| 391 |
+
|
| 392 |
+
if traj is not None and traj_dir is not None:
|
| 393 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 394 |
+
first_idx = rank + total
|
| 395 |
+
if first_idx < n_total:
|
| 396 |
+
np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np)
|
| 397 |
+
|
| 398 |
+
total += global_batch_size
|
| 399 |
+
if use_ddp:
|
| 400 |
+
dist.barrier()
|
| 401 |
+
if rank == 0 and hasattr(pbar, "close"):
|
| 402 |
+
pbar.close()
|
| 403 |
+
|
| 404 |
+
if use_ddp:
|
| 405 |
+
dist.barrier()
|
| 406 |
+
if rank == 0:
|
| 407 |
+
if cli.save_npz:
|
| 408 |
+
create_npz_from_sample_folder(cli.out_dir, n_total)
|
| 409 |
+
print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).")
|
| 410 |
+
if use_ddp:
|
| 411 |
+
dist.destroy_process_group()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
main()
|
| 416 |
+
|
REG/samplers.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def expand_t_like_x(t, x_cur):
|
| 6 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 7 |
+
Args:
|
| 8 |
+
t: [batch_dim,], time vector
|
| 9 |
+
x: [batch_dim,...], data point
|
| 10 |
+
"""
|
| 11 |
+
dims = [1] * (len(x_cur.size()) - 1)
|
| 12 |
+
t = t.view(t.size(0), *dims)
|
| 13 |
+
return t
|
| 14 |
+
|
| 15 |
+
def get_score_from_velocity(vt, xt, t, path_type="linear"):
|
| 16 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 17 |
+
Args:
|
| 18 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 19 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 20 |
+
t: [batch_dim,] time tensor
|
| 21 |
+
"""
|
| 22 |
+
t = expand_t_like_x(t, xt)
|
| 23 |
+
if path_type == "linear":
|
| 24 |
+
alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
|
| 25 |
+
sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
|
| 26 |
+
elif path_type == "cosine":
|
| 27 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 28 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 29 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 30 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
mean = xt
|
| 35 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 36 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 37 |
+
score = (reverse_alpha_ratio * vt - mean) / var
|
| 38 |
+
|
| 39 |
+
return score
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_diffusion(t_cur):
|
| 43 |
+
return 2 * t_cur
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_sampling_time_steps(
|
| 47 |
+
num_steps=50,
|
| 48 |
+
t_c=None,
|
| 49 |
+
num_steps_before_tc=None,
|
| 50 |
+
num_steps_after_tc=None,
|
| 51 |
+
t_floor=0.04,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。
|
| 55 |
+
|
| 56 |
+
- 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。
|
| 57 |
+
- 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c);
|
| 58 |
+
t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。
|
| 59 |
+
"""
|
| 60 |
+
t_floor = float(t_floor)
|
| 61 |
+
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
|
| 62 |
+
ns = int(num_steps)
|
| 63 |
+
if ns < 1:
|
| 64 |
+
raise ValueError("num_steps must be >= 1")
|
| 65 |
+
t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64)
|
| 66 |
+
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
|
| 67 |
+
|
| 68 |
+
t_c = float(t_c)
|
| 69 |
+
nb = int(num_steps_before_tc)
|
| 70 |
+
na = int(num_steps_after_tc)
|
| 71 |
+
if nb < 1 or na < 1:
|
| 72 |
+
raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c")
|
| 73 |
+
if not (0.0 < t_c < 1.0):
|
| 74 |
+
raise ValueError("t_c must be in (0, 1)")
|
| 75 |
+
if t_c <= t_floor:
|
| 76 |
+
raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})")
|
| 77 |
+
|
| 78 |
+
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64)
|
| 79 |
+
p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64)
|
| 80 |
+
t_steps = torch.cat([p1, p2[1:]])
|
| 81 |
+
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc):
|
| 85 |
+
"""仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。"""
|
| 86 |
+
return (
|
| 87 |
+
t_c is not None
|
| 88 |
+
and num_steps_before_tc is not None
|
| 89 |
+
and num_steps_after_tc is not None
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _cls_effective_and_freeze(
|
| 94 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。
|
| 98 |
+
返回 (用于前向的 cls, 更新后的 cls_frozen)。
|
| 99 |
+
"""
|
| 100 |
+
if not freeze_after_tc or t_c_v is None:
|
| 101 |
+
return cls_x_cur, cls_frozen
|
| 102 |
+
if float(t_cur) <= float(t_c_v) + 1e-9:
|
| 103 |
+
if cls_frozen is None:
|
| 104 |
+
cls_frozen = cls_x_cur.clone()
|
| 105 |
+
return cls_frozen, cls_frozen
|
| 106 |
+
return cls_x_cur, cls_frozen
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _build_euler_sampler_time_steps(
|
| 110 |
+
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
|
| 111 |
+
):
|
| 112 |
+
"""
|
| 113 |
+
euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。
|
| 114 |
+
"""
|
| 115 |
+
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
|
| 116 |
+
ns = int(num_steps)
|
| 117 |
+
if ns < 1:
|
| 118 |
+
raise ValueError("num_steps must be >= 1")
|
| 119 |
+
return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device)
|
| 120 |
+
t_c = float(t_c)
|
| 121 |
+
nb = int(num_steps_before_tc)
|
| 122 |
+
na = int(num_steps_after_tc)
|
| 123 |
+
if nb < 1 or na < 1:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c"
|
| 126 |
+
)
|
| 127 |
+
if not (0.0 < t_c < 1.0):
|
| 128 |
+
raise ValueError("t_c must be in (0, 1)")
|
| 129 |
+
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device)
|
| 130 |
+
p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device)
|
| 131 |
+
return torch.cat([p1, p2[1:]])
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def euler_maruyama_sampler(
|
| 135 |
+
model,
|
| 136 |
+
latents,
|
| 137 |
+
y,
|
| 138 |
+
num_steps=20,
|
| 139 |
+
heun=False, # not used, just for compatability
|
| 140 |
+
cfg_scale=1.0,
|
| 141 |
+
guidance_low=0.0,
|
| 142 |
+
guidance_high=1.0,
|
| 143 |
+
path_type="linear",
|
| 144 |
+
cls_latents=None,
|
| 145 |
+
args=None,
|
| 146 |
+
return_mid_state=False,
|
| 147 |
+
t_mid=0.5,
|
| 148 |
+
t_c=None,
|
| 149 |
+
num_steps_before_tc=None,
|
| 150 |
+
num_steps_after_tc=None,
|
| 151 |
+
deterministic=False,
|
| 152 |
+
return_trajectory=False,
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致;
|
| 156 |
+
deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor),
|
| 157 |
+
本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。
|
| 158 |
+
"""
|
| 159 |
+
# setup conditioning
|
| 160 |
+
if cfg_scale > 1.0:
|
| 161 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 162 |
+
#[1000, 1000]
|
| 163 |
+
_dtype = latents.dtype
|
| 164 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 165 |
+
|
| 166 |
+
t_steps = build_sampling_time_steps(
|
| 167 |
+
num_steps=num_steps,
|
| 168 |
+
t_c=t_c,
|
| 169 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 170 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 171 |
+
)
|
| 172 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 173 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 174 |
+
x_next = latents.to(torch.float64)
|
| 175 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 176 |
+
device = x_next.device
|
| 177 |
+
z_mid = cls_mid = None
|
| 178 |
+
t_mid = float(t_mid)
|
| 179 |
+
cls_frozen = None
|
| 180 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
|
| 184 |
+
dt = t_next - t_cur
|
| 185 |
+
x_cur = x_next
|
| 186 |
+
cls_x_cur = cls_x_next
|
| 187 |
+
|
| 188 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 189 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
tc, tn = float(t_cur), float(t_next)
|
| 193 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 194 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 195 |
+
z_mid = x_cur.clone()
|
| 196 |
+
cls_mid = cls_model_input.clone()
|
| 197 |
+
|
| 198 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 199 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 200 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 201 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 202 |
+
else:
|
| 203 |
+
model_input = x_cur
|
| 204 |
+
y_cur = y
|
| 205 |
+
|
| 206 |
+
kwargs = dict(y=y_cur)
|
| 207 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 208 |
+
diffusion = compute_diffusion(t_cur)
|
| 209 |
+
|
| 210 |
+
if deterministic:
|
| 211 |
+
deps = torch.zeros_like(x_cur)
|
| 212 |
+
cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)])
|
| 213 |
+
else:
|
| 214 |
+
eps_i = torch.randn_like(x_cur).to(device)
|
| 215 |
+
cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device)
|
| 216 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 217 |
+
cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
|
| 218 |
+
|
| 219 |
+
# compute drift
|
| 220 |
+
v_cur, _, cls_v_cur = model(
|
| 221 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 222 |
+
)
|
| 223 |
+
v_cur = v_cur.to(torch.float64)
|
| 224 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 225 |
+
|
| 226 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 227 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 228 |
+
|
| 229 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 230 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 231 |
+
|
| 232 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 233 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 234 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 235 |
+
|
| 236 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 237 |
+
if cls_cfg > 0:
|
| 238 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 239 |
+
else:
|
| 240 |
+
cls_d_cur = cls_d_cur_cond
|
| 241 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 242 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 243 |
+
cls_x_next = cls_frozen
|
| 244 |
+
else:
|
| 245 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
|
| 246 |
+
|
| 247 |
+
if return_trajectory:
|
| 248 |
+
traj.append(x_next.clone())
|
| 249 |
+
|
| 250 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 251 |
+
z_mid = x_next.clone()
|
| 252 |
+
cls_mid = cls_x_next.clone()
|
| 253 |
+
|
| 254 |
+
# last step
|
| 255 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 256 |
+
dt = t_next - t_cur
|
| 257 |
+
x_cur = x_next
|
| 258 |
+
cls_x_cur = cls_x_next
|
| 259 |
+
|
| 260 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 261 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 265 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 266 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 267 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 268 |
+
else:
|
| 269 |
+
model_input = x_cur
|
| 270 |
+
y_cur = y
|
| 271 |
+
kwargs = dict(y=y_cur)
|
| 272 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 273 |
+
device=device, dtype=torch.float64
|
| 274 |
+
) * t_cur
|
| 275 |
+
|
| 276 |
+
# compute drift
|
| 277 |
+
v_cur, _, cls_v_cur = model(
|
| 278 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 279 |
+
)
|
| 280 |
+
v_cur = v_cur.to(torch.float64)
|
| 281 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 285 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 286 |
+
|
| 287 |
+
diffusion = compute_diffusion(t_cur)
|
| 288 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 289 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
|
| 290 |
+
|
| 291 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 292 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 293 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 294 |
+
|
| 295 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 296 |
+
if cls_cfg > 0:
|
| 297 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 298 |
+
else:
|
| 299 |
+
cls_d_cur = cls_d_cur_cond
|
| 300 |
+
|
| 301 |
+
mean_x = x_cur + dt * d_cur
|
| 302 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 303 |
+
cls_mean_x = cls_frozen
|
| 304 |
+
else:
|
| 305 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 306 |
+
|
| 307 |
+
if return_trajectory:
|
| 308 |
+
traj.append(mean_x.clone())
|
| 309 |
+
|
| 310 |
+
if return_trajectory and return_mid_state:
|
| 311 |
+
return mean_x, z_mid, cls_mid, traj
|
| 312 |
+
if return_trajectory:
|
| 313 |
+
return mean_x, traj
|
| 314 |
+
if return_mid_state:
|
| 315 |
+
return mean_x, z_mid, cls_mid
|
| 316 |
+
return mean_x
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def euler_maruyama_image_noise_sampler(
|
| 320 |
+
model,
|
| 321 |
+
latents,
|
| 322 |
+
y,
|
| 323 |
+
num_steps=20,
|
| 324 |
+
heun=False, # not used, just for compatability
|
| 325 |
+
cfg_scale=1.0,
|
| 326 |
+
guidance_low=0.0,
|
| 327 |
+
guidance_high=1.0,
|
| 328 |
+
path_type="linear",
|
| 329 |
+
cls_latents=None,
|
| 330 |
+
args=None,
|
| 331 |
+
return_mid_state=False,
|
| 332 |
+
t_mid=0.5,
|
| 333 |
+
t_c=None,
|
| 334 |
+
num_steps_before_tc=None,
|
| 335 |
+
num_steps_after_tc=None,
|
| 336 |
+
return_trajectory=False,
|
| 337 |
+
):
|
| 338 |
+
"""
|
| 339 |
+
EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。
|
| 340 |
+
"""
|
| 341 |
+
if cfg_scale > 1.0:
|
| 342 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 343 |
+
_dtype = latents.dtype
|
| 344 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 345 |
+
|
| 346 |
+
t_steps = build_sampling_time_steps(
|
| 347 |
+
num_steps=num_steps,
|
| 348 |
+
t_c=t_c,
|
| 349 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 350 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 351 |
+
)
|
| 352 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 353 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 354 |
+
x_next = latents.to(torch.float64)
|
| 355 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 356 |
+
device = x_next.device
|
| 357 |
+
z_mid = cls_mid = None
|
| 358 |
+
t_mid = float(t_mid)
|
| 359 |
+
cls_frozen = None
|
| 360 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 361 |
+
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
|
| 364 |
+
dt = t_next - t_cur
|
| 365 |
+
x_cur = x_next
|
| 366 |
+
cls_x_cur = cls_x_next
|
| 367 |
+
|
| 368 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 369 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
tc, tn = float(t_cur), float(t_next)
|
| 373 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 374 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 375 |
+
z_mid = x_cur.clone()
|
| 376 |
+
cls_mid = cls_model_input.clone()
|
| 377 |
+
|
| 378 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 379 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 380 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 381 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 382 |
+
else:
|
| 383 |
+
model_input = x_cur
|
| 384 |
+
y_cur = y
|
| 385 |
+
|
| 386 |
+
kwargs = dict(y=y_cur)
|
| 387 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 388 |
+
diffusion = compute_diffusion(t_cur)
|
| 389 |
+
|
| 390 |
+
eps_i = torch.randn_like(x_cur).to(device)
|
| 391 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 392 |
+
|
| 393 |
+
v_cur, _, cls_v_cur = model(
|
| 394 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 395 |
+
)
|
| 396 |
+
v_cur = v_cur.to(torch.float64)
|
| 397 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 398 |
+
|
| 399 |
+
if add_img_noise:
|
| 400 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 401 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 402 |
+
|
| 403 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 404 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 405 |
+
else:
|
| 406 |
+
# t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。
|
| 407 |
+
d_cur = v_cur
|
| 408 |
+
cls_d_cur = cls_v_cur
|
| 409 |
+
|
| 410 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 411 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 412 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 413 |
+
|
| 414 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 415 |
+
if cls_cfg > 0:
|
| 416 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 417 |
+
else:
|
| 418 |
+
cls_d_cur = cls_d_cur_cond
|
| 419 |
+
|
| 420 |
+
# 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项)
|
| 421 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 422 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 423 |
+
cls_x_next = cls_frozen
|
| 424 |
+
else:
|
| 425 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt
|
| 426 |
+
if return_trajectory:
|
| 427 |
+
traj.append(x_next.clone())
|
| 428 |
+
|
| 429 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 430 |
+
z_mid = x_next.clone()
|
| 431 |
+
cls_mid = cls_x_next.clone()
|
| 432 |
+
|
| 433 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 434 |
+
dt = t_next - t_cur
|
| 435 |
+
x_cur = x_next
|
| 436 |
+
cls_x_cur = cls_x_next
|
| 437 |
+
|
| 438 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 439 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 443 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 444 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 445 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 446 |
+
else:
|
| 447 |
+
model_input = x_cur
|
| 448 |
+
y_cur = y
|
| 449 |
+
kwargs = dict(y=y_cur)
|
| 450 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 451 |
+
device=device, dtype=torch.float64
|
| 452 |
+
) * t_cur
|
| 453 |
+
|
| 454 |
+
v_cur, _, cls_v_cur = model(
|
| 455 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 456 |
+
)
|
| 457 |
+
v_cur = v_cur.to(torch.float64)
|
| 458 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 459 |
+
|
| 460 |
+
# 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。
|
| 461 |
+
d_cur = v_cur
|
| 462 |
+
cls_d_cur = cls_v_cur
|
| 463 |
+
|
| 464 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 465 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 466 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 467 |
+
|
| 468 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 469 |
+
if cls_cfg > 0:
|
| 470 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 471 |
+
else:
|
| 472 |
+
cls_d_cur = cls_d_cur_cond
|
| 473 |
+
|
| 474 |
+
mean_x = x_cur + dt * d_cur
|
| 475 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 476 |
+
cls_mean_x = cls_frozen
|
| 477 |
+
else:
|
| 478 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 479 |
+
|
| 480 |
+
if return_trajectory and return_mid_state:
|
| 481 |
+
return mean_x, z_mid, cls_mid, traj
|
| 482 |
+
if return_trajectory:
|
| 483 |
+
return mean_x, traj
|
| 484 |
+
if return_mid_state:
|
| 485 |
+
return mean_x, z_mid, cls_mid
|
| 486 |
+
return mean_x
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def euler_maruyama_image_noise_before_tc_sampler(
|
| 490 |
+
model,
|
| 491 |
+
latents,
|
| 492 |
+
y,
|
| 493 |
+
num_steps=20,
|
| 494 |
+
heun=False, # not used, just for compatability
|
| 495 |
+
cfg_scale=1.0,
|
| 496 |
+
guidance_low=0.0,
|
| 497 |
+
guidance_high=1.0,
|
| 498 |
+
path_type="linear",
|
| 499 |
+
cls_latents=None,
|
| 500 |
+
args=None,
|
| 501 |
+
return_mid_state=False,
|
| 502 |
+
t_mid=0.5,
|
| 503 |
+
t_c=None,
|
| 504 |
+
num_steps_before_tc=None,
|
| 505 |
+
num_steps_after_tc=None,
|
| 506 |
+
return_cls_final=False,
|
| 507 |
+
return_trajectory=False,
|
| 508 |
+
):
|
| 509 |
+
"""
|
| 510 |
+
EM 采样变体:
|
| 511 |
+
- 图像 latent 在 t > t_c 区间引入随机扩散噪声;
|
| 512 |
+
- 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移);
|
| 513 |
+
- cls/token 通道全程不引入随机项。
|
| 514 |
+
"""
|
| 515 |
+
if cfg_scale > 1.0:
|
| 516 |
+
y_null = torch.tensor([1000] * y.size(0), device=y.device)
|
| 517 |
+
_dtype = latents.dtype
|
| 518 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 519 |
+
|
| 520 |
+
t_steps = build_sampling_time_steps(
|
| 521 |
+
num_steps=num_steps,
|
| 522 |
+
t_c=t_c,
|
| 523 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 524 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 525 |
+
)
|
| 526 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 527 |
+
t_c_freeze = float(t_c) if freeze_after_tc else None
|
| 528 |
+
x_next = latents.to(torch.float64)
|
| 529 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 530 |
+
device = x_next.device
|
| 531 |
+
z_mid = cls_mid = None
|
| 532 |
+
t_mid = float(t_mid)
|
| 533 |
+
t_c_v = None if t_c is None else float(t_c)
|
| 534 |
+
cls_frozen = None
|
| 535 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 536 |
+
|
| 537 |
+
with torch.no_grad():
|
| 538 |
+
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
|
| 539 |
+
dt = t_next - t_cur
|
| 540 |
+
x_cur = x_next
|
| 541 |
+
cls_x_cur = cls_x_next
|
| 542 |
+
|
| 543 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 544 |
+
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
tc, tn = float(t_cur), float(t_next)
|
| 548 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 549 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 550 |
+
z_mid = x_cur.clone()
|
| 551 |
+
cls_mid = cls_model_input.clone()
|
| 552 |
+
|
| 553 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 554 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 555 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 556 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 557 |
+
else:
|
| 558 |
+
model_input = x_cur
|
| 559 |
+
y_cur = y
|
| 560 |
+
|
| 561 |
+
kwargs = dict(y=y_cur)
|
| 562 |
+
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
|
| 563 |
+
diffusion = compute_diffusion(t_cur)
|
| 564 |
+
|
| 565 |
+
# 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声
|
| 566 |
+
add_img_noise = True
|
| 567 |
+
if t_c_v is not None and float(t_next) <= t_c_v:
|
| 568 |
+
add_img_noise = False
|
| 569 |
+
|
| 570 |
+
eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur)
|
| 571 |
+
deps = eps_i * torch.sqrt(torch.abs(dt))
|
| 572 |
+
|
| 573 |
+
v_cur, _, cls_v_cur = model(
|
| 574 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 575 |
+
)
|
| 576 |
+
v_cur = v_cur.to(torch.float64)
|
| 577 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 578 |
+
|
| 579 |
+
if add_img_noise:
|
| 580 |
+
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
|
| 581 |
+
d_cur = v_cur - 0.5 * diffusion * s_cur
|
| 582 |
+
|
| 583 |
+
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
|
| 584 |
+
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
|
| 585 |
+
else:
|
| 586 |
+
# t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项)
|
| 587 |
+
d_cur = v_cur
|
| 588 |
+
cls_d_cur = cls_v_cur
|
| 589 |
+
|
| 590 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 591 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 592 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 593 |
+
|
| 594 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 595 |
+
if cls_cfg > 0:
|
| 596 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 597 |
+
else:
|
| 598 |
+
cls_d_cur = cls_d_cur_cond
|
| 599 |
+
|
| 600 |
+
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
|
| 601 |
+
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
|
| 602 |
+
cls_x_next = cls_frozen
|
| 603 |
+
else:
|
| 604 |
+
cls_x_next = cls_x_cur + cls_d_cur * dt
|
| 605 |
+
if return_trajectory:
|
| 606 |
+
traj.append(x_next.clone())
|
| 607 |
+
|
| 608 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 609 |
+
z_mid = x_next.clone()
|
| 610 |
+
cls_mid = cls_x_next.clone()
|
| 611 |
+
|
| 612 |
+
t_cur, t_next = t_steps[-2], t_steps[-1]
|
| 613 |
+
dt = t_next - t_cur
|
| 614 |
+
x_cur = x_next
|
| 615 |
+
cls_x_cur = cls_x_next
|
| 616 |
+
|
| 617 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 618 |
+
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 622 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 623 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 624 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 625 |
+
else:
|
| 626 |
+
model_input = x_cur
|
| 627 |
+
y_cur = y
|
| 628 |
+
kwargs = dict(y=y_cur)
|
| 629 |
+
time_input = torch.ones(model_input.size(0)).to(
|
| 630 |
+
device=device, dtype=torch.float64
|
| 631 |
+
) * t_cur
|
| 632 |
+
|
| 633 |
+
v_cur, _, cls_v_cur = model(
|
| 634 |
+
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
|
| 635 |
+
)
|
| 636 |
+
v_cur = v_cur.to(torch.float64)
|
| 637 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 638 |
+
|
| 639 |
+
# 最后一步无随机项,保持与 ODE 一致使用 d=v。
|
| 640 |
+
d_cur = v_cur
|
| 641 |
+
cls_d_cur = cls_v_cur
|
| 642 |
+
|
| 643 |
+
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 644 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 645 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 646 |
+
|
| 647 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 648 |
+
if cls_cfg > 0:
|
| 649 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 650 |
+
else:
|
| 651 |
+
cls_d_cur = cls_d_cur_cond
|
| 652 |
+
|
| 653 |
+
mean_x = x_cur + dt * d_cur
|
| 654 |
+
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
|
| 655 |
+
cls_mean_x = cls_frozen
|
| 656 |
+
else:
|
| 657 |
+
cls_mean_x = cls_x_cur + dt * cls_d_cur
|
| 658 |
+
|
| 659 |
+
if return_trajectory and return_mid_state and return_cls_final:
|
| 660 |
+
return mean_x, z_mid, cls_mid, cls_mean_x, traj
|
| 661 |
+
if return_trajectory and return_mid_state:
|
| 662 |
+
return mean_x, z_mid, cls_mid, traj
|
| 663 |
+
if return_trajectory and return_cls_final:
|
| 664 |
+
return mean_x, cls_mean_x, traj
|
| 665 |
+
if return_trajectory:
|
| 666 |
+
return mean_x, traj
|
| 667 |
+
if return_mid_state and return_cls_final:
|
| 668 |
+
return mean_x, z_mid, cls_mid, cls_mean_x
|
| 669 |
+
if return_mid_state:
|
| 670 |
+
return mean_x, z_mid, cls_mid
|
| 671 |
+
if return_cls_final:
|
| 672 |
+
return mean_x, cls_mean_x
|
| 673 |
+
return mean_x
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def euler_ode_sampler(
|
| 677 |
+
model,
|
| 678 |
+
latents,
|
| 679 |
+
y,
|
| 680 |
+
num_steps=20,
|
| 681 |
+
cfg_scale=1.0,
|
| 682 |
+
guidance_low=0.0,
|
| 683 |
+
guidance_high=1.0,
|
| 684 |
+
path_type="linear",
|
| 685 |
+
cls_latents=None,
|
| 686 |
+
args=None,
|
| 687 |
+
return_mid_state=False,
|
| 688 |
+
t_mid=0.5,
|
| 689 |
+
t_c=None,
|
| 690 |
+
num_steps_before_tc=None,
|
| 691 |
+
num_steps_after_tc=None,
|
| 692 |
+
return_trajectory=False,
|
| 693 |
+
):
|
| 694 |
+
"""
|
| 695 |
+
REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。
|
| 696 |
+
"""
|
| 697 |
+
return euler_sampler(
|
| 698 |
+
model,
|
| 699 |
+
latents,
|
| 700 |
+
y,
|
| 701 |
+
num_steps=num_steps,
|
| 702 |
+
heun=False,
|
| 703 |
+
cfg_scale=cfg_scale,
|
| 704 |
+
guidance_low=guidance_low,
|
| 705 |
+
guidance_high=guidance_high,
|
| 706 |
+
path_type=path_type,
|
| 707 |
+
cls_latents=cls_latents,
|
| 708 |
+
args=args,
|
| 709 |
+
return_mid_state=return_mid_state,
|
| 710 |
+
t_mid=t_mid,
|
| 711 |
+
t_c=t_c,
|
| 712 |
+
num_steps_before_tc=num_steps_before_tc,
|
| 713 |
+
num_steps_after_tc=num_steps_after_tc,
|
| 714 |
+
return_trajectory=return_trajectory,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def euler_sampler(
|
| 719 |
+
model,
|
| 720 |
+
latents,
|
| 721 |
+
y,
|
| 722 |
+
num_steps=20,
|
| 723 |
+
heun=False,
|
| 724 |
+
cfg_scale=1.0,
|
| 725 |
+
guidance_low=0.0,
|
| 726 |
+
guidance_high=1.0,
|
| 727 |
+
path_type="linear",
|
| 728 |
+
cls_latents=None,
|
| 729 |
+
args=None,
|
| 730 |
+
return_mid_state=False,
|
| 731 |
+
t_mid=0.5,
|
| 732 |
+
t_c=None,
|
| 733 |
+
num_steps_before_tc=None,
|
| 734 |
+
num_steps_after_tc=None,
|
| 735 |
+
return_trajectory=False,
|
| 736 |
+
):
|
| 737 |
+
"""
|
| 738 |
+
轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。
|
| 739 |
+
|
| 740 |
+
- 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。
|
| 741 |
+
- 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。
|
| 742 |
+
- 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。
|
| 743 |
+
|
| 744 |
+
REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。
|
| 745 |
+
"""
|
| 746 |
+
if cls_latents is None:
|
| 747 |
+
raise ValueError(
|
| 748 |
+
"euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。"
|
| 749 |
+
)
|
| 750 |
+
if cfg_scale > 1.0:
|
| 751 |
+
y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype)
|
| 752 |
+
else:
|
| 753 |
+
y_null = None
|
| 754 |
+
_dtype = latents.dtype
|
| 755 |
+
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
|
| 756 |
+
device = latents.device
|
| 757 |
+
|
| 758 |
+
t_steps = _build_euler_sampler_time_steps(
|
| 759 |
+
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
|
| 760 |
+
)
|
| 761 |
+
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
|
| 762 |
+
t_c_v = float(t_c) if freeze_after_tc else None
|
| 763 |
+
|
| 764 |
+
x_next = latents.to(torch.float64)
|
| 765 |
+
cls_x_next = cls_latents.to(torch.float64)
|
| 766 |
+
z_mid = cls_mid = None
|
| 767 |
+
t_mid = float(t_mid)
|
| 768 |
+
cls_frozen = None
|
| 769 |
+
traj = [x_next.clone()] if return_trajectory else None
|
| 770 |
+
|
| 771 |
+
with torch.no_grad():
|
| 772 |
+
for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
|
| 773 |
+
dt = t_next - t_cur
|
| 774 |
+
x_cur = x_next
|
| 775 |
+
cls_x_cur = cls_x_next
|
| 776 |
+
|
| 777 |
+
cls_model_input, cls_frozen = _cls_effective_and_freeze(
|
| 778 |
+
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
tc, tn = float(t_cur), float(t_next)
|
| 782 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 783 |
+
if abs(tc - t_mid) < abs(tn - t_mid):
|
| 784 |
+
z_mid = x_cur.clone()
|
| 785 |
+
cls_mid = cls_model_input.clone()
|
| 786 |
+
|
| 787 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 788 |
+
model_input = torch.cat([x_cur] * 2, dim=0)
|
| 789 |
+
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
|
| 790 |
+
y_cur = torch.cat([y, y_null], dim=0)
|
| 791 |
+
else:
|
| 792 |
+
model_input = x_cur
|
| 793 |
+
y_cur = y
|
| 794 |
+
|
| 795 |
+
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
|
| 796 |
+
|
| 797 |
+
v_cur, _, cls_v_cur = model(
|
| 798 |
+
model_input.to(dtype=_dtype),
|
| 799 |
+
time_input.to(dtype=_dtype),
|
| 800 |
+
y_cur,
|
| 801 |
+
cls_token=cls_model_input.to(dtype=_dtype),
|
| 802 |
+
)
|
| 803 |
+
v_cur = v_cur.to(torch.float64)
|
| 804 |
+
cls_v_cur = cls_v_cur.to(torch.float64)
|
| 805 |
+
|
| 806 |
+
# ODE: follow velocity parameterization directly (d/dt x_t = v_t).
|
| 807 |
+
# This aligns with velocity training target and avoids extra v->score->drift conversion.
|
| 808 |
+
d_cur = v_cur
|
| 809 |
+
cls_d_cur = cls_v_cur
|
| 810 |
+
|
| 811 |
+
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
|
| 812 |
+
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
|
| 813 |
+
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
|
| 814 |
+
|
| 815 |
+
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
|
| 816 |
+
if cls_cfg > 0:
|
| 817 |
+
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
|
| 818 |
+
else:
|
| 819 |
+
cls_d_cur = cls_d_cur_cond
|
| 820 |
+
|
| 821 |
+
x_next = x_cur + dt * d_cur
|
| 822 |
+
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
|
| 823 |
+
cls_x_next = cls_frozen
|
| 824 |
+
else:
|
| 825 |
+
cls_x_next = cls_x_cur + dt * cls_d_cur
|
| 826 |
+
|
| 827 |
+
if return_trajectory:
|
| 828 |
+
traj.append(x_next.clone())
|
| 829 |
+
|
| 830 |
+
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
|
| 831 |
+
z_mid = x_next.clone()
|
| 832 |
+
cls_mid = cls_x_next.clone()
|
| 833 |
+
|
| 834 |
+
if return_trajectory and return_mid_state:
|
| 835 |
+
return x_next, z_mid, cls_mid, traj
|
| 836 |
+
if return_trajectory:
|
| 837 |
+
return x_next, traj
|
| 838 |
+
if return_mid_state:
|
| 839 |
+
return x_next, z_mid, cls_mid
|
| 840 |
+
return x_next
|
REG/samples.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 双次对比步数请用 --dual-compare-after(见 sample_from_checkpoint.py),输出在 out-dir 子目录。
|
| 3 |
+
|
| 4 |
+
CUDA_VISIBLE_DEVICES=1 python sample_from_checkpoint.py \
|
| 5 |
+
--ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1970000.pt \
|
| 6 |
+
--out-dir ./my_samples_test \
|
| 7 |
+
--num-images 100 \
|
| 8 |
+
--batch-size 4 \
|
| 9 |
+
--seed 0 \
|
| 10 |
+
--t-c 0.75 \
|
| 11 |
+
--steps-before-tc 100\
|
| 12 |
+
--steps-after-tc 2 \
|
| 13 |
+
--sampler ode \
|
| 14 |
+
--cfg-scale 1.0 \
|
| 15 |
+
--dual-compare-after \
|
REG/samples_0.25_new.log
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W0408 11:05:05.680000 100112 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
时间网格:t_c=0.25, 步数 (1→t_c)=100, (t_c→0)=2
|
| 6 |
+
Total number of images that will be sampled: 5120
|
| 7 |
+
|
| 8 |
+
[rank0]:[W408 11:06:31.621528015 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 9 |
+
[rank2]:[W408 11:06:31.627182760 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2] using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 10 |
+
[rank1]:[W408 11:06:32.966218444 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 11 |
+
|
| 12 |
+
W0408 11:11:04.746000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100158 closing signal SIGTERM
|
| 13 |
+
W0408 11:11:04.748000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100159 closing signal SIGTERM
|
| 14 |
+
W0408 11:11:04.749000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100160 closing signal SIGTERM
|
| 15 |
+
W0408 11:11:04.749000 100112 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 100161 closing signal SIGTERM
|
| 16 |
+
Traceback (most recent call last):
|
| 17 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/torchrun", line 33, in <module>
|
| 18 |
+
sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
|
| 19 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 20 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 21 |
+
return f(*args, **kwargs)
|
| 22 |
+
^^^^^^^^^^^^^^^^^^
|
| 23 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
|
| 24 |
+
run(args)
|
| 25 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
|
| 26 |
+
elastic_launch(
|
| 27 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 28 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 29 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 30 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 260, in launch_agent
|
| 31 |
+
result = agent.run()
|
| 32 |
+
^^^^^^^^^^^
|
| 33 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/metrics/api.py", line 137, in wrapper
|
| 34 |
+
result = f(*args, **kwargs)
|
| 35 |
+
^^^^^^^^^^^^^^^^^^
|
| 36 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 696, in run
|
| 37 |
+
result = self._invoke_run(role)
|
| 38 |
+
^^^^^^^^^^^^^^^^^^^^^^
|
| 39 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/agent/server/api.py", line 855, in _invoke_run
|
| 40 |
+
time.sleep(monitor_interval)
|
| 41 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 84, in _terminate_process_handler
|
| 42 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 43 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 100112 got signal: 15
|
REG/samples_0.5.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
REG/samples_0.75.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
REG/samples_0.75_new.log
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793]
|
| 2 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
|
| 3 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 4 |
+
W0325 16:55:16.395000 538513 site-packages/torch/distributed/run.py:793] *****************************************
|
| 5 |
+
时间网格:t_c=0.75, 步数 (1→t_c)=100, (t_c→0)=50
|
| 6 |
+
Total number of images that will be sampled: 40192
|
| 7 |
+
|
| 8 |
+
[rank3]:[W325 16:57:00.799344818 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3] using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 9 |
+
[rank1]:[W325 16:57:00.847229448 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 10 |
+
[rank0]:[W325 16:57:01.326116049 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
|
| 11 |
+
|
| 12 |
+
W0325 18:03:18.212000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 538677 closing signal SIGTERM
|
| 13 |
+
W0325 18:03:18.212000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 538678 closing signal SIGTERM
|
| 14 |
+
E0325 18:03:18.554000 538513 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -9) local_rank: 1 (pid: 538676) of binary: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 15 |
+
Traceback (most recent call last):
|
| 16 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/torchrun", line 33, in <module>
|
| 17 |
+
sys.exit(load_entry_point('torch==2.5.1', 'console_scripts', 'torchrun')())
|
| 18 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 19 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
|
| 20 |
+
return f(*args, **kwargs)
|
| 21 |
+
^^^^^^^^^^^^^^^^^^
|
| 22 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 919, in main
|
| 23 |
+
run(args)
|
| 24 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/run.py", line 910, in run
|
| 25 |
+
elastic_launch(
|
| 26 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
|
| 27 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 28 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 29 |
+
File "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
|
| 30 |
+
raise ChildFailedError(
|
| 31 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 32 |
+
==========================================================
|
| 33 |
+
sample_from_checkpoint_ddp.py FAILED
|
| 34 |
+
----------------------------------------------------------
|
| 35 |
+
Failures:
|
| 36 |
+
<NO_OTHER_FAILURES>
|
| 37 |
+
----------------------------------------------------------
|
| 38 |
+
Root Cause (first observed failure):
|
| 39 |
+
[0]:
|
| 40 |
+
time : 2026-03-25_18:03:18
|
| 41 |
+
host : 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 42 |
+
rank : 1 (local_rank: 1)
|
| 43 |
+
exitcode : -9 (pid: 538676)
|
| 44 |
+
error_file: <N/A>
|
| 45 |
+
traceback : Signal 9 (SIGKILL) received by PID 538676
|
| 46 |
+
==========================================================
|
REG/samples_ddp.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# 4 卡 DDP 单路径采样(不做 dual-compare,不保存 at_tc 中间图)
|
| 4 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup nohup torchrun \
|
| 5 |
+
--nnodes=1 \
|
| 6 |
+
--nproc_per_node=4 \
|
| 7 |
+
--rdzv_endpoint=localhost:29110 \
|
| 8 |
+
sample_from_checkpoint_ddp.py \
|
| 9 |
+
--ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.25/checkpoints/0230000.pt \
|
| 10 |
+
--out-dir ./my_samples_25 \
|
| 11 |
+
--num-images 5000 \
|
| 12 |
+
--batch-size 64 \
|
| 13 |
+
--seed 0 \
|
| 14 |
+
--t-c 0.25 \
|
| 15 |
+
--steps-before-tc 100 \
|
| 16 |
+
--steps-after-tc 2 \
|
| 17 |
+
--sampler em_image_noise_before_tc \
|
| 18 |
+
--cfg-scale 1.0 \
|
| 19 |
+
> samples_0.25_new.log 2>&1 &
|
| 20 |
+
|
| 21 |
+
# nohup python sample_from_checkpoint_ddp.py \
|
| 22 |
+
# --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.5/checkpoints/0250000.pt \
|
| 23 |
+
# --out-dir ./my_samples_5 \
|
| 24 |
+
# --num-images 20000 \
|
| 25 |
+
# --batch-size 16 \
|
| 26 |
+
# --seed 0 \
|
| 27 |
+
# --t-c 0.5 \
|
| 28 |
+
# --steps-before-tc 100 \
|
| 29 |
+
# --steps-after-tc 50 \
|
| 30 |
+
# --sampler em_image_noise_before_tc \
|
| 31 |
+
# --cfg-scale 1.0 \
|
| 32 |
+
# > samples_0.5.log 2>&1 &
|
REG/train.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
|
| 17 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 20 |
+
|
| 21 |
+
from models.sit import SiT_models
|
| 22 |
+
from loss import SILoss
|
| 23 |
+
from utils import load_encoders
|
| 24 |
+
|
| 25 |
+
from dataset import CustomDataset
|
| 26 |
+
from diffusers.models import AutoencoderKL
|
| 27 |
+
# import wandb_utils
|
| 28 |
+
import wandb
|
| 29 |
+
import math
|
| 30 |
+
from torchvision.utils import make_grid
|
| 31 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 32 |
+
from torchvision.transforms import Normalize
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 39 |
+
"""DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。"""
|
| 40 |
+
if enc_type is None:
|
| 41 |
+
return 768
|
| 42 |
+
s = str(enc_type).lower()
|
| 43 |
+
if "vit-g" in s or "vitg" in s:
|
| 44 |
+
return 1536
|
| 45 |
+
if "vit-l" in s or "vitl" in s:
|
| 46 |
+
return 1024
|
| 47 |
+
if "vit-s" in s or "vits" in s:
|
| 48 |
+
return 384
|
| 49 |
+
return 768
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 53 |
+
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def preprocess_raw_image(x, enc_type):
|
| 58 |
+
resolution = x.shape[-1]
|
| 59 |
+
if 'clip' in enc_type:
|
| 60 |
+
x = x / 255.
|
| 61 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 62 |
+
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| 63 |
+
elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| 64 |
+
x = x / 255.
|
| 65 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 66 |
+
elif 'dinov2' in enc_type:
|
| 67 |
+
x = x / 255.
|
| 68 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 69 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 70 |
+
elif 'dinov1' in enc_type:
|
| 71 |
+
x = x / 255.
|
| 72 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 73 |
+
elif 'jepa' in enc_type:
|
| 74 |
+
x = x / 255.
|
| 75 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 76 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 77 |
+
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def array2grid(x):
|
| 82 |
+
nrow = round(math.sqrt(x.size(0)))
|
| 83 |
+
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
|
| 84 |
+
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
|
| 90 |
+
device = moments.device
|
| 91 |
+
|
| 92 |
+
mean, std = torch.chunk(moments, 2, dim=1)
|
| 93 |
+
z = mean + std * torch.randn_like(mean)
|
| 94 |
+
z = (z * latents_scale + latents_bias)
|
| 95 |
+
return z
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 100 |
+
"""
|
| 101 |
+
Step the EMA model towards the current model.
|
| 102 |
+
"""
|
| 103 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 104 |
+
model_params = OrderedDict(model.named_parameters())
|
| 105 |
+
|
| 106 |
+
for name, param in model_params.items():
|
| 107 |
+
name = name.replace("module.", "")
|
| 108 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
| 109 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_logger(logging_dir):
|
| 113 |
+
"""
|
| 114 |
+
Create a logger that writes to a log file and stdout.
|
| 115 |
+
"""
|
| 116 |
+
logging.basicConfig(
|
| 117 |
+
level=logging.INFO,
|
| 118 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 119 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 120 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 121 |
+
)
|
| 122 |
+
logger = logging.getLogger(__name__)
|
| 123 |
+
return logger
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def requires_grad(model, flag=True):
|
| 127 |
+
"""
|
| 128 |
+
Set requires_grad flag for all parameters in a model.
|
| 129 |
+
"""
|
| 130 |
+
for p in model.parameters():
|
| 131 |
+
p.requires_grad = flag
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
#################################################################################
|
| 135 |
+
# Training Loop #
|
| 136 |
+
#################################################################################
|
| 137 |
+
|
| 138 |
+
def main(args):
|
| 139 |
+
# set accelerator
|
| 140 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 141 |
+
accelerator_project_config = ProjectConfiguration(
|
| 142 |
+
project_dir=args.output_dir, logging_dir=logging_dir
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
accelerator = Accelerator(
|
| 146 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 147 |
+
mixed_precision=args.mixed_precision,
|
| 148 |
+
log_with=args.report_to,
|
| 149 |
+
project_config=accelerator_project_config,
|
| 150 |
+
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if accelerator.is_main_process:
|
| 154 |
+
os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 155 |
+
save_dir = os.path.join(args.output_dir, args.exp_name)
|
| 156 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 157 |
+
args_dict = vars(args)
|
| 158 |
+
# Save to a JSON file
|
| 159 |
+
json_dir = os.path.join(save_dir, "args.json")
|
| 160 |
+
with open(json_dir, 'w') as f:
|
| 161 |
+
json.dump(args_dict, f, indent=4)
|
| 162 |
+
checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
|
| 163 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 164 |
+
logger = create_logger(save_dir)
|
| 165 |
+
logger.info(f"Experiment directory created at {save_dir}")
|
| 166 |
+
device = accelerator.device
|
| 167 |
+
if torch.backends.mps.is_available():
|
| 168 |
+
accelerator.native_amp = False
|
| 169 |
+
if args.seed is not None:
|
| 170 |
+
set_seed(args.seed + accelerator.process_index)
|
| 171 |
+
|
| 172 |
+
# Create model:
|
| 173 |
+
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 174 |
+
latent_size = args.resolution // 8
|
| 175 |
+
|
| 176 |
+
train_dataset = CustomDataset(
|
| 177 |
+
args.data_dir, semantic_features_dir=args.semantic_features_dir
|
| 178 |
+
)
|
| 179 |
+
use_preprocessed_semantic = train_dataset.use_preprocessed_semantic
|
| 180 |
+
|
| 181 |
+
if use_preprocessed_semantic:
|
| 182 |
+
encoders, encoder_types, architectures = [], [], []
|
| 183 |
+
z_dims = [semantic_dim_from_enc_type(args.enc_type)]
|
| 184 |
+
if accelerator.is_main_process:
|
| 185 |
+
logger.info(
|
| 186 |
+
f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}"
|
| 187 |
+
)
|
| 188 |
+
elif args.enc_type is not None:
|
| 189 |
+
encoders, encoder_types, architectures = load_encoders(
|
| 190 |
+
args.enc_type, device, args.resolution
|
| 191 |
+
)
|
| 192 |
+
z_dims = [encoder.embed_dim for encoder in encoders]
|
| 193 |
+
else:
|
| 194 |
+
raise NotImplementedError()
|
| 195 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 196 |
+
model = SiT_models[args.model](
|
| 197 |
+
input_size=latent_size,
|
| 198 |
+
num_classes=args.num_classes,
|
| 199 |
+
use_cfg = (args.cfg_prob > 0),
|
| 200 |
+
z_dims = z_dims,
|
| 201 |
+
encoder_depth=args.encoder_depth,
|
| 202 |
+
**block_kwargs
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
model = model.to(device)
|
| 206 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 207 |
+
requires_grad(ema, False)
|
| 208 |
+
|
| 209 |
+
latents_scale = torch.tensor(
|
| 210 |
+
[0.18215, 0.18215, 0.18215, 0.18215]
|
| 211 |
+
).view(1, 4, 1, 1).to(device)
|
| 212 |
+
latents_bias = torch.tensor(
|
| 213 |
+
[0., 0., 0., 0.]
|
| 214 |
+
).view(1, 4, 1, 1).to(device)
|
| 215 |
+
|
| 216 |
+
# VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse)
|
| 217 |
+
try:
|
| 218 |
+
from preprocessing import dnnlib
|
| 219 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 220 |
+
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
| 221 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 222 |
+
os.environ["HF_HOME"] = cache_dir
|
| 223 |
+
try:
|
| 224 |
+
vae = AutoencoderKL.from_pretrained(
|
| 225 |
+
"stabilityai/sd-vae-ft-mse",
|
| 226 |
+
cache_dir=cache_dir,
|
| 227 |
+
local_files_only=True,
|
| 228 |
+
).to(device)
|
| 229 |
+
vae.eval()
|
| 230 |
+
if accelerator.is_main_process:
|
| 231 |
+
logger.info(
|
| 232 |
+
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache "
|
| 233 |
+
f"at '{cache_dir}' for intermediate sampling."
|
| 234 |
+
)
|
| 235 |
+
except Exception as e_main:
|
| 236 |
+
vae = None
|
| 237 |
+
candidate_dir = None
|
| 238 |
+
possible_roots = [
|
| 239 |
+
cache_dir,
|
| 240 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 241 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 242 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 243 |
+
]
|
| 244 |
+
checked_roots = []
|
| 245 |
+
for root_dir in possible_roots:
|
| 246 |
+
if not os.path.isdir(root_dir):
|
| 247 |
+
continue
|
| 248 |
+
checked_roots.append(root_dir)
|
| 249 |
+
for root, dirs, files in os.walk(root_dir):
|
| 250 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 251 |
+
candidate_dir = root
|
| 252 |
+
break
|
| 253 |
+
if candidate_dir is not None:
|
| 254 |
+
break
|
| 255 |
+
if candidate_dir is not None:
|
| 256 |
+
try:
|
| 257 |
+
vae = AutoencoderKL.from_pretrained(
|
| 258 |
+
candidate_dir,
|
| 259 |
+
local_files_only=True,
|
| 260 |
+
).to(device)
|
| 261 |
+
vae.eval()
|
| 262 |
+
if accelerator.is_main_process:
|
| 263 |
+
logger.info(
|
| 264 |
+
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path "
|
| 265 |
+
f"'{candidate_dir}'. Searched roots: {checked_roots}"
|
| 266 |
+
)
|
| 267 |
+
except Exception as e_fallback:
|
| 268 |
+
if accelerator.is_main_process:
|
| 269 |
+
logger.warning(
|
| 270 |
+
"Tried to load VAE from discovered local path "
|
| 271 |
+
f"'{candidate_dir}' but failed: {e_fallback}"
|
| 272 |
+
)
|
| 273 |
+
if vae is None and accelerator.is_main_process:
|
| 274 |
+
logger.warning(
|
| 275 |
+
"Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. "
|
| 276 |
+
f"Last repo-level error: {e_main}"
|
| 277 |
+
)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
vae = None
|
| 280 |
+
if accelerator.is_main_process:
|
| 281 |
+
logger.warning(
|
| 282 |
+
f"Failed to initialize VAE loading logic (will skip image decoding): {e}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# create loss function
|
| 286 |
+
loss_fn = SILoss(
|
| 287 |
+
prediction=args.prediction,
|
| 288 |
+
path_type=args.path_type,
|
| 289 |
+
encoders=encoders,
|
| 290 |
+
accelerator=accelerator,
|
| 291 |
+
latents_scale=latents_scale,
|
| 292 |
+
latents_bias=latents_bias,
|
| 293 |
+
weighting=args.weighting,
|
| 294 |
+
t_c=args.t_c,
|
| 295 |
+
ot_cls=args.ot_cls,
|
| 296 |
+
)
|
| 297 |
+
if accelerator.is_main_process:
|
| 298 |
+
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 299 |
+
|
| 300 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 301 |
+
if args.allow_tf32:
|
| 302 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 303 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 304 |
+
|
| 305 |
+
optimizer = torch.optim.AdamW(
|
| 306 |
+
model.parameters(),
|
| 307 |
+
lr=args.learning_rate,
|
| 308 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 309 |
+
weight_decay=args.adam_weight_decay,
|
| 310 |
+
eps=args.adam_epsilon,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Setup data(train_dataset 已在上方创建)
|
| 314 |
+
local_batch_size = int(args.batch_size // accelerator.num_processes)
|
| 315 |
+
train_dataloader = DataLoader(
|
| 316 |
+
train_dataset,
|
| 317 |
+
batch_size=local_batch_size,
|
| 318 |
+
shuffle=True,
|
| 319 |
+
num_workers=args.num_workers,
|
| 320 |
+
pin_memory=True,
|
| 321 |
+
drop_last=True
|
| 322 |
+
)
|
| 323 |
+
if accelerator.is_main_process:
|
| 324 |
+
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
|
| 325 |
+
|
| 326 |
+
# Prepare models for training:
|
| 327 |
+
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
|
| 328 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 329 |
+
ema.eval() # EMA model should always be in eval mode
|
| 330 |
+
|
| 331 |
+
# resume:
|
| 332 |
+
global_step = 0
|
| 333 |
+
if args.resume_from_ckpt is not None:
|
| 334 |
+
ckpt = torch.load(args.resume_from_ckpt, map_location="cpu")
|
| 335 |
+
model.load_state_dict(ckpt["model"])
|
| 336 |
+
ema.load_state_dict(ckpt["ema"])
|
| 337 |
+
if "opt" in ckpt:
|
| 338 |
+
optimizer.load_state_dict(ckpt["opt"])
|
| 339 |
+
global_step = int(ckpt.get("steps", 0))
|
| 340 |
+
if accelerator.is_main_process:
|
| 341 |
+
logger.info(
|
| 342 |
+
f"Resumed from ckpt: {args.resume_from_ckpt} (global_step={global_step})"
|
| 343 |
+
)
|
| 344 |
+
elif args.resume_step > 0:
|
| 345 |
+
ckpt_name = str(args.resume_step).zfill(7) +'.pt'
|
| 346 |
+
ckpt = torch.load(
|
| 347 |
+
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
|
| 348 |
+
map_location='cpu',
|
| 349 |
+
)
|
| 350 |
+
model.load_state_dict(ckpt['model'])
|
| 351 |
+
ema.load_state_dict(ckpt['ema'])
|
| 352 |
+
optimizer.load_state_dict(ckpt['opt'])
|
| 353 |
+
global_step = ckpt['steps']
|
| 354 |
+
|
| 355 |
+
model, optimizer, train_dataloader = accelerator.prepare(
|
| 356 |
+
model, optimizer, train_dataloader
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if accelerator.is_main_process:
|
| 360 |
+
tracker_config = vars(copy.deepcopy(args))
|
| 361 |
+
accelerator.init_trackers(
|
| 362 |
+
project_name="REG",
|
| 363 |
+
config=tracker_config,
|
| 364 |
+
init_kwargs={
|
| 365 |
+
"wandb": {"name": f"{args.exp_name}"}
|
| 366 |
+
},
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
progress_bar = tqdm(
|
| 371 |
+
range(0, args.max_train_steps),
|
| 372 |
+
initial=global_step,
|
| 373 |
+
desc="Steps",
|
| 374 |
+
# Only show the progress bar once on each machine.
|
| 375 |
+
disable=not accelerator.is_local_main_process,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Labels to condition the model with (feel free to change):
|
| 379 |
+
sample_batch_size = 64 // accelerator.num_processes
|
| 380 |
+
first_batch = next(iter(train_dataloader))
|
| 381 |
+
if len(first_batch) == 4:
|
| 382 |
+
gt_raw_images, gt_xs, _, _ = first_batch
|
| 383 |
+
else:
|
| 384 |
+
gt_raw_images, gt_xs, _ = first_batch
|
| 385 |
+
assert gt_raw_images.shape[-1] == args.resolution
|
| 386 |
+
gt_xs = gt_xs[:sample_batch_size]
|
| 387 |
+
gt_xs = sample_posterior(
|
| 388 |
+
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
|
| 389 |
+
)
|
| 390 |
+
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
|
| 391 |
+
ys = ys.to(device)
|
| 392 |
+
# Create sampling noise:
|
| 393 |
+
n = ys.size(0)
|
| 394 |
+
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
|
| 395 |
+
|
| 396 |
+
for epoch in range(args.epochs):
|
| 397 |
+
model.train()
|
| 398 |
+
for batch in train_dataloader:
|
| 399 |
+
if len(batch) == 4:
|
| 400 |
+
raw_image, x, r_preprocessed, y = batch
|
| 401 |
+
use_sem_file = True
|
| 402 |
+
else:
|
| 403 |
+
raw_image, x, y = batch
|
| 404 |
+
r_preprocessed = None
|
| 405 |
+
use_sem_file = False
|
| 406 |
+
|
| 407 |
+
raw_image = raw_image.to(device)
|
| 408 |
+
x = x.squeeze(dim=1).to(device).float()
|
| 409 |
+
y = y.to(device)
|
| 410 |
+
if args.legacy:
|
| 411 |
+
# In our early experiments, we accidentally apply label dropping twice:
|
| 412 |
+
# once in train.py and once in sit.py.
|
| 413 |
+
# We keep this option for exact reproducibility with previous runs.
|
| 414 |
+
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
|
| 415 |
+
labels = torch.where(drop_ids, args.num_classes, y)
|
| 416 |
+
else:
|
| 417 |
+
labels = y
|
| 418 |
+
with torch.no_grad():
|
| 419 |
+
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
|
| 420 |
+
zs = []
|
| 421 |
+
if use_sem_file and r_preprocessed is not None:
|
| 422 |
+
cls_token = r_preprocessed.to(device).float()
|
| 423 |
+
if cls_token.dim() == 1:
|
| 424 |
+
cls_token = cls_token.unsqueeze(0)
|
| 425 |
+
while cls_token.dim() > 2:
|
| 426 |
+
cls_token = cls_token.squeeze(1)
|
| 427 |
+
base_m = model.module if hasattr(model, "module") else model
|
| 428 |
+
n_pad = base_m.x_embedder.num_patches
|
| 429 |
+
zs = [
|
| 430 |
+
torch.cat(
|
| 431 |
+
[
|
| 432 |
+
cls_token.unsqueeze(1),
|
| 433 |
+
cls_token.unsqueeze(1).expand(-1, n_pad, -1),
|
| 434 |
+
],
|
| 435 |
+
dim=1,
|
| 436 |
+
)
|
| 437 |
+
]
|
| 438 |
+
else:
|
| 439 |
+
with accelerator.autocast():
|
| 440 |
+
for encoder, encoder_type, arch in zip(
|
| 441 |
+
encoders, encoder_types, architectures
|
| 442 |
+
):
|
| 443 |
+
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
|
| 444 |
+
z = encoder.forward_features(raw_image_)
|
| 445 |
+
if 'dinov2' in encoder_type:
|
| 446 |
+
dense_z = z['x_norm_patchtokens']
|
| 447 |
+
cls_token = z['x_norm_clstoken']
|
| 448 |
+
dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1)
|
| 449 |
+
else:
|
| 450 |
+
exit()
|
| 451 |
+
zs.append(dense_z)
|
| 452 |
+
|
| 453 |
+
with accelerator.accumulate(model):
|
| 454 |
+
model_kwargs = dict(y=labels)
|
| 455 |
+
loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
|
| 456 |
+
cls_token=cls_token,
|
| 457 |
+
time_input=None, noises=None)
|
| 458 |
+
loss_mean = loss1.mean()
|
| 459 |
+
loss_mean_cls = loss2.mean() * args.cls
|
| 460 |
+
proj_loss_mean = proj_loss1.mean() * args.proj_coeff
|
| 461 |
+
tc_vel_loss = torch.tensor(0.0, device=device)
|
| 462 |
+
if args.tc_velocity_loss_coeff > 0:
|
| 463 |
+
tc_vel_loss = loss_fn.tc_velocity_loss(
|
| 464 |
+
model,
|
| 465 |
+
x,
|
| 466 |
+
model_kwargs=model_kwargs,
|
| 467 |
+
cls_token=cls_token,
|
| 468 |
+
noises=noises,
|
| 469 |
+
).mean()
|
| 470 |
+
loss = (
|
| 471 |
+
loss_mean
|
| 472 |
+
+ proj_loss_mean
|
| 473 |
+
+ loss_mean_cls
|
| 474 |
+
+ args.tc_velocity_loss_coeff * tc_vel_loss
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
## optimization
|
| 479 |
+
accelerator.backward(loss)
|
| 480 |
+
if accelerator.sync_gradients:
|
| 481 |
+
params_to_clip = model.parameters()
|
| 482 |
+
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 483 |
+
optimizer.step()
|
| 484 |
+
optimizer.zero_grad(set_to_none=True)
|
| 485 |
+
|
| 486 |
+
if accelerator.sync_gradients:
|
| 487 |
+
update_ema(ema, model) # change ema function
|
| 488 |
+
|
| 489 |
+
### enter
|
| 490 |
+
if accelerator.sync_gradients:
|
| 491 |
+
progress_bar.update(1)
|
| 492 |
+
global_step += 1
|
| 493 |
+
if global_step % args.checkpointing_steps == 0 and global_step > 0:
|
| 494 |
+
if accelerator.is_main_process:
|
| 495 |
+
checkpoint = {
|
| 496 |
+
"model": model.module.state_dict(),
|
| 497 |
+
"ema": ema.state_dict(),
|
| 498 |
+
"opt": optimizer.state_dict(),
|
| 499 |
+
"args": args,
|
| 500 |
+
"steps": global_step,
|
| 501 |
+
}
|
| 502 |
+
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
|
| 503 |
+
torch.save(checkpoint, checkpoint_path)
|
| 504 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 505 |
+
|
| 506 |
+
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
|
| 507 |
+
t_mid_vis = float(args.t_c)
|
| 508 |
+
tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
|
| 509 |
+
logging.info(
|
| 510 |
+
f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..."
|
| 511 |
+
)
|
| 512 |
+
ema.eval()
|
| 513 |
+
with torch.no_grad():
|
| 514 |
+
latent_size = args.resolution // 8
|
| 515 |
+
n_samples = min(16, args.batch_size)
|
| 516 |
+
base_model = model.module if hasattr(model, "module") else model
|
| 517 |
+
cls_dim = base_model.z_dims[0]
|
| 518 |
+
shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
|
| 519 |
+
torch.manual_seed(shared_seed)
|
| 520 |
+
z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
|
| 521 |
+
torch.manual_seed(shared_seed)
|
| 522 |
+
cls_init = torch.randn(n_samples, cls_dim, device=device)
|
| 523 |
+
y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
|
| 524 |
+
|
| 525 |
+
from samplers import euler_maruyama_sampler
|
| 526 |
+
z_0, z_mid, _ = euler_maruyama_sampler(
|
| 527 |
+
ema,
|
| 528 |
+
z_init,
|
| 529 |
+
y_samples,
|
| 530 |
+
num_steps=50,
|
| 531 |
+
cfg_scale=1.0,
|
| 532 |
+
guidance_low=0.0,
|
| 533 |
+
guidance_high=1.0,
|
| 534 |
+
path_type=args.path_type,
|
| 535 |
+
cls_latents=cls_init,
|
| 536 |
+
args=args,
|
| 537 |
+
return_mid_state=True,
|
| 538 |
+
t_mid=t_mid_vis,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
|
| 542 |
+
t0_dir = os.path.join(samples_root, "t0")
|
| 543 |
+
t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
|
| 544 |
+
os.makedirs(t0_dir, exist_ok=True)
|
| 545 |
+
os.makedirs(t_mid_dir, exist_ok=True)
|
| 546 |
+
|
| 547 |
+
if vae is not None:
|
| 548 |
+
z_f = z_0.to(dtype=torch.float32)
|
| 549 |
+
samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
|
| 550 |
+
samples_final = (samples_final + 1) / 2.0
|
| 551 |
+
samples_final = samples_final.clamp(0, 1)
|
| 552 |
+
grid_final = array2grid(samples_final)
|
| 553 |
+
Image.fromarray(grid_final).save(
|
| 554 |
+
os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
if z_mid is not None:
|
| 558 |
+
z_m = z_mid.to(dtype=torch.float32)
|
| 559 |
+
samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
|
| 560 |
+
samples_mid = (samples_mid + 1) / 2.0
|
| 561 |
+
samples_mid = samples_mid.clamp(0, 1)
|
| 562 |
+
grid_mid = array2grid(samples_mid)
|
| 563 |
+
Image.fromarray(grid_mid).save(
|
| 564 |
+
os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
logging.warning(
|
| 568 |
+
f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
|
| 569 |
+
f"skip t0_{tc_tag} image this step."
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
del z_init, cls_init, y_samples, z_0
|
| 573 |
+
if z_mid is not None:
|
| 574 |
+
del z_mid
|
| 575 |
+
if vae is not None:
|
| 576 |
+
del samples_final, grid_final
|
| 577 |
+
if "samples_mid" in locals():
|
| 578 |
+
del samples_mid, grid_mid
|
| 579 |
+
torch.cuda.empty_cache()
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
logs = {
|
| 583 |
+
"loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 584 |
+
"loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
|
| 585 |
+
"proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
|
| 586 |
+
"loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
|
| 587 |
+
"loss_tc_vel": accelerator.gather(tc_vel_loss).mean().detach().item(),
|
| 588 |
+
"grad_norm": accelerator.gather(grad_norm).mean().detach().item()
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
|
| 592 |
+
logging.info(f"Step: {global_step}, Training Logs: {log_message}")
|
| 593 |
+
|
| 594 |
+
progress_bar.set_postfix(**logs)
|
| 595 |
+
accelerator.log(logs, step=global_step)
|
| 596 |
+
|
| 597 |
+
if global_step >= args.max_train_steps:
|
| 598 |
+
break
|
| 599 |
+
if global_step >= args.max_train_steps:
|
| 600 |
+
break
|
| 601 |
+
|
| 602 |
+
model.eval() # important! This disables randomized embedding dropout
|
| 603 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 604 |
+
|
| 605 |
+
accelerator.wait_for_everyone()
|
| 606 |
+
if accelerator.is_main_process:
|
| 607 |
+
logger.info("Done!")
|
| 608 |
+
accelerator.end_training()
|
| 609 |
+
|
| 610 |
+
def parse_args(input_args=None):
|
| 611 |
+
parser = argparse.ArgumentParser(description="Training")
|
| 612 |
+
|
| 613 |
+
# logging:
|
| 614 |
+
parser.add_argument("--output-dir", type=str, default="exps")
|
| 615 |
+
parser.add_argument("--exp-name", type=str, required=True)
|
| 616 |
+
parser.add_argument("--logging-dir", type=str, default="logs")
|
| 617 |
+
parser.add_argument("--report-to", type=str, default="wandb")
|
| 618 |
+
parser.add_argument("--sampling-steps", type=int, default=2000)
|
| 619 |
+
parser.add_argument("--resume-step", type=int, default=0)
|
| 620 |
+
parser.add_argument(
|
| 621 |
+
"--resume-from-ckpt",
|
| 622 |
+
type=str,
|
| 623 |
+
default=None,
|
| 624 |
+
help="直接从指定 checkpoint 路径续训(优先于 --resume-step)。",
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# model
|
| 628 |
+
parser.add_argument("--model", type=str)
|
| 629 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 630 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 631 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
|
| 632 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 633 |
+
parser.add_argument("--ops-head", type=int, default=16)
|
| 634 |
+
|
| 635 |
+
# dataset
|
| 636 |
+
parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
|
| 637 |
+
parser.add_argument(
|
| 638 |
+
"--semantic-features-dir",
|
| 639 |
+
type=str,
|
| 640 |
+
default=None,
|
| 641 |
+
help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。"
|
| 642 |
+
"默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。",
|
| 643 |
+
)
|
| 644 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 645 |
+
parser.add_argument("--batch-size", type=int, default=256)#256
|
| 646 |
+
|
| 647 |
+
# precision
|
| 648 |
+
parser.add_argument("--allow-tf32", action="store_true")
|
| 649 |
+
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 650 |
+
|
| 651 |
+
# optimization
|
| 652 |
+
parser.add_argument("--epochs", type=int, default=14000)
|
| 653 |
+
parser.add_argument("--max-train-steps", type=int, default=10000000)
|
| 654 |
+
parser.add_argument("--checkpointing-steps", type=int, default=10000)
|
| 655 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
| 656 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
| 657 |
+
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 658 |
+
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 659 |
+
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
|
| 660 |
+
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 661 |
+
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
|
| 662 |
+
|
| 663 |
+
# seed
|
| 664 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 665 |
+
|
| 666 |
+
# cpu
|
| 667 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
| 668 |
+
|
| 669 |
+
# loss
|
| 670 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 671 |
+
parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
|
| 672 |
+
parser.add_argument("--cfg-prob", type=float, default=0.1)
|
| 673 |
+
parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
|
| 674 |
+
parser.add_argument("--proj-coeff", type=float, default=0.5)
|
| 675 |
+
parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
|
| 676 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 677 |
+
parser.add_argument("--cls", type=float, default=0.03)
|
| 678 |
+
parser.add_argument(
|
| 679 |
+
"--t-c",
|
| 680 |
+
type=float,
|
| 681 |
+
default=0.5,
|
| 682 |
+
help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。"
|
| 683 |
+
"t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);"
|
| 684 |
+
"t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。",
|
| 685 |
+
)
|
| 686 |
+
parser.add_argument(
|
| 687 |
+
"--ot-cls",
|
| 688 |
+
action=argparse.BooleanOptionalAction,
|
| 689 |
+
default=True,
|
| 690 |
+
help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。",
|
| 691 |
+
)
|
| 692 |
+
parser.add_argument(
|
| 693 |
+
"--tc-velocity-loss-coeff",
|
| 694 |
+
type=float,
|
| 695 |
+
default=0.0,
|
| 696 |
+
help="额外 t=t_c 图像速度场监督项权重(>0 启用,用于增强单步性)。",
|
| 697 |
+
)
|
| 698 |
+
if input_args is not None:
|
| 699 |
+
args = parser.parse_args(input_args)
|
| 700 |
+
else:
|
| 701 |
+
args = parser.parse_args()
|
| 702 |
+
|
| 703 |
+
return args
|
| 704 |
+
|
| 705 |
+
if __name__ == "__main__":
|
| 706 |
+
args = parse_args()
|
| 707 |
+
|
| 708 |
+
main(args)
|
REG/train.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# REG/train.py:与主仓库类似,可单独指定数据根目录与预处理 cls 特征目录。
|
| 3 |
+
# 数据布局:${DATA_DIR}/imagenet_256_vae/ 下 VAE latent;
|
| 4 |
+
# ${SEMANTIC_FEATURES_DIR}/ 下 img-feature-*.npy + dataset.json(与 parallel_encode 一致)。
|
| 5 |
+
|
| 6 |
+
NUM_GPUS=4
|
| 7 |
+
|
| 8 |
+
# ------------ 按本机路径修改 ------------
|
| 9 |
+
DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
|
| 10 |
+
SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
|
| 11 |
+
|
| 12 |
+
# 后台示例(与主实验脚本风格一致):
|
| 13 |
+
# nohup bash train.sh > jsflow-experiment.log 2>&1 &
|
| 14 |
+
|
| 15 |
+
nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
|
| 16 |
+
--report-to wandb \
|
| 17 |
+
--allow-tf32 \
|
| 18 |
+
--mixed-precision bf16 \
|
| 19 |
+
--seed 0 \
|
| 20 |
+
--path-type linear \
|
| 21 |
+
--prediction v \
|
| 22 |
+
--weighting uniform \
|
| 23 |
+
--model SiT-XL/2 \
|
| 24 |
+
--enc-type dinov2-vit-b \
|
| 25 |
+
--encoder-depth 8 \
|
| 26 |
+
--proj-coeff 0.5 \
|
| 27 |
+
--output-dir exps \
|
| 28 |
+
--exp-name jsflow-experiment-0.75-0.01 \
|
| 29 |
+
--batch-size 256 \
|
| 30 |
+
--data-dir "${DATA_DIR}" \
|
| 31 |
+
--semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
|
| 32 |
+
--learning-rate 0.00005 \
|
| 33 |
+
--t-c 0.75 \
|
| 34 |
+
--cls 0.01 \
|
| 35 |
+
--ot-cls \
|
| 36 |
+
> jsflow-experiment.log 2>&1 &
|
| 37 |
+
|
| 38 |
+
# 说明:
|
| 39 |
+
# - 不使用预处理特征、改在线抽 DINO 时:去掉 --semantic-features-dir,并保证 data-dir 为 REG 原布局
|
| 40 |
+
# (imagenet_256_vae + vae-sd)。
|
| 41 |
+
# - 关闭 minibatch OT:追加 --no-ot-cls。
|
| 42 |
+
# - 主仓库 train.py 中的 --weight-ratio / --semantic-reg-coeff / --repa-* 等为本 REG 脚本未实现项;
|
| 43 |
+
# 投影强度请用 --proj-coeff,cls 流损失权重用 --cls。
|
REG/train_resume_tc_velocity.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 基于指定 checkpoint 续训,并额外加入 t_c 处速度场监督项(用于增强单步性)
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
NUM_GPUS=4
|
| 6 |
+
|
| 7 |
+
DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
|
| 8 |
+
SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
|
| 9 |
+
|
| 10 |
+
# 用户指定的续训起点 checkpoint
|
| 11 |
+
RESUME_CKPT="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75-0.01-one-step/checkpoints/1920000.pt"
|
| 12 |
+
|
| 13 |
+
# 新增的 t_c 速度场损失权重(可按需调大/调小)
|
| 14 |
+
TC_VEL_COEFF=2
|
| 15 |
+
|
| 16 |
+
nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
|
| 17 |
+
--report-to wandb \
|
| 18 |
+
--allow-tf32 \
|
| 19 |
+
--mixed-precision bf16 \
|
| 20 |
+
--seed 0 \
|
| 21 |
+
--path-type linear \
|
| 22 |
+
--prediction v \
|
| 23 |
+
--weighting uniform \
|
| 24 |
+
--model SiT-XL/2 \
|
| 25 |
+
--enc-type dinov2-vit-b \
|
| 26 |
+
--encoder-depth 8 \
|
| 27 |
+
--proj-coeff 0.5 \
|
| 28 |
+
--output-dir exps \
|
| 29 |
+
--exp-name jsflow-experiment-0.75-0.01-one-step \
|
| 30 |
+
--batch-size 256 \
|
| 31 |
+
--data-dir "${DATA_DIR}" \
|
| 32 |
+
--semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
|
| 33 |
+
--learning-rate 0.00005 \
|
| 34 |
+
--t-c 0.75 \
|
| 35 |
+
--cls 0.005 \
|
| 36 |
+
--ot-cls \
|
| 37 |
+
--resume-from-ckpt "${RESUME_CKPT}" \
|
| 38 |
+
--tc-velocity-loss-coeff "${TC_VEL_COEFF}" \
|
| 39 |
+
> jsflow-experiment-0.75-0.01-tcvel.log 2>&1 &
|
| 40 |
+
|
| 41 |
+
echo "Launched resume training with tc velocity loss."
|
REG/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torchvision.datasets.utils import download_url
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.models as torchvision_models
|
| 5 |
+
import timm
|
| 6 |
+
from models import mocov3_vit
|
| 7 |
+
import math
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# code from SiT repository
|
| 12 |
+
pretrained_models = {'last.pt'}
|
| 13 |
+
|
| 14 |
+
def download_model(model_name):
|
| 15 |
+
"""
|
| 16 |
+
Downloads a pre-trained SiT model from the web.
|
| 17 |
+
"""
|
| 18 |
+
assert model_name in pretrained_models
|
| 19 |
+
local_path = f'pretrained_models/{model_name}'
|
| 20 |
+
if not os.path.isfile(local_path):
|
| 21 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 22 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
|
| 23 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 24 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def fix_mocov3_state_dict(state_dict):
|
| 28 |
+
for k in list(state_dict.keys()):
|
| 29 |
+
# retain only base_encoder up to before the embedding layer
|
| 30 |
+
if k.startswith('module.base_encoder'):
|
| 31 |
+
# fix naming bug in checkpoint
|
| 32 |
+
new_k = k[len("module.base_encoder."):]
|
| 33 |
+
if "blocks.13.norm13" in new_k:
|
| 34 |
+
new_k = new_k.replace("norm13", "norm1")
|
| 35 |
+
if "blocks.13.mlp.fc13" in k:
|
| 36 |
+
new_k = new_k.replace("fc13", "fc1")
|
| 37 |
+
if "blocks.14.norm14" in k:
|
| 38 |
+
new_k = new_k.replace("norm14", "norm2")
|
| 39 |
+
if "blocks.14.mlp.fc14" in k:
|
| 40 |
+
new_k = new_k.replace("fc14", "fc2")
|
| 41 |
+
# remove prefix
|
| 42 |
+
if 'head' not in new_k and new_k.split('.')[0] != 'fc':
|
| 43 |
+
state_dict[new_k] = state_dict[k]
|
| 44 |
+
# delete renamed or unused k
|
| 45 |
+
del state_dict[k]
|
| 46 |
+
if 'pos_embed' in state_dict.keys():
|
| 47 |
+
state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 48 |
+
state_dict['pos_embed'], [16, 16],
|
| 49 |
+
)
|
| 50 |
+
return state_dict
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def load_encoders(enc_type, device, resolution=256):
|
| 54 |
+
assert (resolution == 256) or (resolution == 512)
|
| 55 |
+
|
| 56 |
+
enc_names = enc_type.split(',')
|
| 57 |
+
encoders, architectures, encoder_types = [], [], []
|
| 58 |
+
for enc_name in enc_names:
|
| 59 |
+
encoder_type, architecture, model_config = enc_name.split('-')
|
| 60 |
+
# Currently, we only support 512x512 experiments with DINOv2 encoders.
|
| 61 |
+
if resolution == 512:
|
| 62 |
+
if encoder_type != 'dinov2':
|
| 63 |
+
raise NotImplementedError(
|
| 64 |
+
"Currently, we only support 512x512 experiments with DINOv2 encoders."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
architectures.append(architecture)
|
| 68 |
+
encoder_types.append(encoder_type)
|
| 69 |
+
if encoder_type == 'mocov3':
|
| 70 |
+
if architecture == 'vit':
|
| 71 |
+
if model_config == 's':
|
| 72 |
+
encoder = mocov3_vit.vit_small()
|
| 73 |
+
elif model_config == 'b':
|
| 74 |
+
encoder = mocov3_vit.vit_base()
|
| 75 |
+
elif model_config == 'l':
|
| 76 |
+
encoder = mocov3_vit.vit_large()
|
| 77 |
+
ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
|
| 78 |
+
state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
|
| 79 |
+
del encoder.head
|
| 80 |
+
encoder.load_state_dict(state_dict, strict=True)
|
| 81 |
+
encoder.head = torch.nn.Identity()
|
| 82 |
+
elif architecture == 'resnet':
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
encoder = encoder.to(device)
|
| 86 |
+
encoder.eval()
|
| 87 |
+
|
| 88 |
+
elif 'dinov2' in encoder_type:
|
| 89 |
+
import timm
|
| 90 |
+
if 'reg' in encoder_type:
|
| 91 |
+
try:
|
| 92 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 93 |
+
f'dinov2_vit{model_config}14_reg', source='local')
|
| 94 |
+
except:
|
| 95 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
|
| 96 |
+
else:
|
| 97 |
+
try:
|
| 98 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 99 |
+
f'dinov2_vit{model_config}14', source='local')
|
| 100 |
+
except:
|
| 101 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
|
| 102 |
+
|
| 103 |
+
print(f"Now you are using the {enc_name} as the aligning model")
|
| 104 |
+
del encoder.head
|
| 105 |
+
patch_resolution = 16 * (resolution // 256)
|
| 106 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 107 |
+
encoder.pos_embed.data, [patch_resolution, patch_resolution],
|
| 108 |
+
)
|
| 109 |
+
encoder.head = torch.nn.Identity()
|
| 110 |
+
encoder = encoder.to(device)
|
| 111 |
+
encoder.eval()
|
| 112 |
+
|
| 113 |
+
elif 'dinov1' == encoder_type:
|
| 114 |
+
import timm
|
| 115 |
+
from models import dinov1
|
| 116 |
+
encoder = dinov1.vit_base()
|
| 117 |
+
ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
|
| 118 |
+
if 'pos_embed' in ckpt.keys():
|
| 119 |
+
ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 120 |
+
ckpt['pos_embed'], [16, 16],
|
| 121 |
+
)
|
| 122 |
+
del encoder.head
|
| 123 |
+
encoder.head = torch.nn.Identity()
|
| 124 |
+
encoder.load_state_dict(ckpt, strict=True)
|
| 125 |
+
encoder = encoder.to(device)
|
| 126 |
+
encoder.forward_features = encoder.forward
|
| 127 |
+
encoder.eval()
|
| 128 |
+
|
| 129 |
+
elif encoder_type == 'clip':
|
| 130 |
+
import clip
|
| 131 |
+
from models.clip_vit import UpdatedVisionTransformer
|
| 132 |
+
encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
|
| 133 |
+
encoder = UpdatedVisionTransformer(encoder_).to(device)
|
| 134 |
+
#.to(device)
|
| 135 |
+
encoder.embed_dim = encoder.model.transformer.width
|
| 136 |
+
encoder.forward_features = encoder.forward
|
| 137 |
+
encoder.eval()
|
| 138 |
+
|
| 139 |
+
elif encoder_type == 'mae':
|
| 140 |
+
from models.mae_vit import vit_large_patch16
|
| 141 |
+
import timm
|
| 142 |
+
kwargs = dict(img_size=256)
|
| 143 |
+
encoder = vit_large_patch16(**kwargs).to(device)
|
| 144 |
+
with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
|
| 145 |
+
state_dict = torch.load(f)
|
| 146 |
+
if 'pos_embed' in state_dict["model"].keys():
|
| 147 |
+
state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 148 |
+
state_dict["model"]['pos_embed'], [16, 16],
|
| 149 |
+
)
|
| 150 |
+
encoder.load_state_dict(state_dict["model"])
|
| 151 |
+
|
| 152 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 153 |
+
encoder.pos_embed.data, [16, 16],
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
elif encoder_type == 'jepa':
|
| 157 |
+
from models.jepa import vit_huge
|
| 158 |
+
kwargs = dict(img_size=[224, 224], patch_size=14)
|
| 159 |
+
encoder = vit_huge(**kwargs).to(device)
|
| 160 |
+
with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
|
| 161 |
+
state_dict = torch.load(f, map_location=device)
|
| 162 |
+
new_state_dict = dict()
|
| 163 |
+
for key, value in state_dict['encoder'].items():
|
| 164 |
+
new_state_dict[key[7:]] = value
|
| 165 |
+
encoder.load_state_dict(new_state_dict)
|
| 166 |
+
encoder.forward_features = encoder.forward
|
| 167 |
+
|
| 168 |
+
encoders.append(encoder)
|
| 169 |
+
|
| 170 |
+
return encoders, encoder_types, architectures
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 174 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 175 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 176 |
+
def norm_cdf(x):
|
| 177 |
+
# Computes standard normal cumulative distribution function
|
| 178 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 179 |
+
|
| 180 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 181 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 182 |
+
"The distribution of values may be incorrect.",
|
| 183 |
+
stacklevel=2)
|
| 184 |
+
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
# Values are generated by using a truncated uniform distribution and
|
| 187 |
+
# then using the inverse CDF for the normal distribution.
|
| 188 |
+
# Get upper and lower cdf values
|
| 189 |
+
l = norm_cdf((a - mean) / std)
|
| 190 |
+
u = norm_cdf((b - mean) / std)
|
| 191 |
+
|
| 192 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 193 |
+
# [2l-1, 2u-1].
|
| 194 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 195 |
+
|
| 196 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 197 |
+
# standard normal
|
| 198 |
+
tensor.erfinv_()
|
| 199 |
+
|
| 200 |
+
# Transform to proper mean, std
|
| 201 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 202 |
+
tensor.add_(mean)
|
| 203 |
+
|
| 204 |
+
# Clamp to ensure it's in the proper range
|
| 205 |
+
tensor.clamp_(min=a, max=b)
|
| 206 |
+
return tensor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 210 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_legacy_checkpoints(state_dict, encoder_depth):
|
| 214 |
+
new_state_dict = dict()
|
| 215 |
+
for key, value in state_dict.items():
|
| 216 |
+
if 'decoder_blocks' in key:
|
| 217 |
+
parts =key.split('.')
|
| 218 |
+
new_idx = int(parts[1]) + encoder_depth
|
| 219 |
+
parts[0] = 'blocks'
|
| 220 |
+
parts[1] = str(new_idx)
|
| 221 |
+
new_key = '.'.join(parts)
|
| 222 |
+
new_state_dict[new_key] = value
|
| 223 |
+
else:
|
| 224 |
+
new_state_dict[key] = value
|
| 225 |
+
return new_state_dict
|
REG/wandb/run-20260322_150022-yhxc5cgu/files/config.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_wandb:
|
| 2 |
+
value:
|
| 3 |
+
cli_version: 0.25.0
|
| 4 |
+
e:
|
| 5 |
+
ucanic8s891x6sl28vnbha78lzoecw66:
|
| 6 |
+
args:
|
| 7 |
+
- --report-to
|
| 8 |
+
- wandb
|
| 9 |
+
- --allow-tf32
|
| 10 |
+
- --mixed-precision
|
| 11 |
+
- bf16
|
| 12 |
+
- --seed
|
| 13 |
+
- "0"
|
| 14 |
+
- --path-type
|
| 15 |
+
- linear
|
| 16 |
+
- --prediction
|
| 17 |
+
- v
|
| 18 |
+
- --weighting
|
| 19 |
+
- uniform
|
| 20 |
+
- --model
|
| 21 |
+
- SiT-XL/2
|
| 22 |
+
- --enc-type
|
| 23 |
+
- dinov2-vit-b
|
| 24 |
+
- --encoder-depth
|
| 25 |
+
- "8"
|
| 26 |
+
- --proj-coeff
|
| 27 |
+
- "0.5"
|
| 28 |
+
- --output-dir
|
| 29 |
+
- exps
|
| 30 |
+
- --exp-name
|
| 31 |
+
- jsflow-experiment
|
| 32 |
+
- --batch-size
|
| 33 |
+
- "256"
|
| 34 |
+
- --data-dir
|
| 35 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 36 |
+
- --semantic-features-dir
|
| 37 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 38 |
+
- --learning-rate
|
| 39 |
+
- "0.00005"
|
| 40 |
+
- --t-c
|
| 41 |
+
- "0.5"
|
| 42 |
+
- --cls
|
| 43 |
+
- "0.2"
|
| 44 |
+
- --ot-cls
|
| 45 |
+
codePath: train.py
|
| 46 |
+
codePathLocal: train.py
|
| 47 |
+
cpu_count: 96
|
| 48 |
+
cpu_count_logical: 192
|
| 49 |
+
cudaVersion: "13.0"
|
| 50 |
+
disk:
|
| 51 |
+
/:
|
| 52 |
+
total: "3838880616448"
|
| 53 |
+
used: "357557354496"
|
| 54 |
+
email: 2365972933@qq.com
|
| 55 |
+
executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 56 |
+
git:
|
| 57 |
+
commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
|
| 58 |
+
remote: https://github.com/Martinser/REG.git
|
| 59 |
+
gpu: NVIDIA H100 80GB HBM3
|
| 60 |
+
gpu_count: 4
|
| 61 |
+
gpu_nvidia:
|
| 62 |
+
- architecture: Hopper
|
| 63 |
+
cudaCores: 16896
|
| 64 |
+
memoryTotal: "85520809984"
|
| 65 |
+
name: NVIDIA H100 80GB HBM3
|
| 66 |
+
uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
|
| 67 |
+
- architecture: Hopper
|
| 68 |
+
cudaCores: 16896
|
| 69 |
+
memoryTotal: "85520809984"
|
| 70 |
+
name: NVIDIA H100 80GB HBM3
|
| 71 |
+
uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
|
| 72 |
+
- architecture: Hopper
|
| 73 |
+
cudaCores: 16896
|
| 74 |
+
memoryTotal: "85520809984"
|
| 75 |
+
name: NVIDIA H100 80GB HBM3
|
| 76 |
+
uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
|
| 77 |
+
- architecture: Hopper
|
| 78 |
+
cudaCores: 16896
|
| 79 |
+
memoryTotal: "85520809984"
|
| 80 |
+
name: NVIDIA H100 80GB HBM3
|
| 81 |
+
uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
|
| 82 |
+
host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 83 |
+
memory:
|
| 84 |
+
total: "2164115296256"
|
| 85 |
+
os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
|
| 86 |
+
program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
|
| 87 |
+
python: CPython 3.12.9
|
| 88 |
+
root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
|
| 89 |
+
startedAt: "2026-03-22T07:00:22.092510Z"
|
| 90 |
+
writerId: ucanic8s891x6sl28vnbha78lzoecw66
|
| 91 |
+
m: []
|
| 92 |
+
python_version: 3.12.9
|
| 93 |
+
t:
|
| 94 |
+
"1":
|
| 95 |
+
- 1
|
| 96 |
+
- 5
|
| 97 |
+
- 11
|
| 98 |
+
- 41
|
| 99 |
+
- 49
|
| 100 |
+
- 53
|
| 101 |
+
- 63
|
| 102 |
+
- 71
|
| 103 |
+
- 83
|
| 104 |
+
- 98
|
| 105 |
+
"2":
|
| 106 |
+
- 1
|
| 107 |
+
- 5
|
| 108 |
+
- 11
|
| 109 |
+
- 41
|
| 110 |
+
- 49
|
| 111 |
+
- 53
|
| 112 |
+
- 63
|
| 113 |
+
- 71
|
| 114 |
+
- 83
|
| 115 |
+
- 98
|
| 116 |
+
"3":
|
| 117 |
+
- 13
|
| 118 |
+
"4": 3.12.9
|
| 119 |
+
"5": 0.25.0
|
| 120 |
+
"6": 4.53.2
|
| 121 |
+
"12": 0.25.0
|
| 122 |
+
"13": linux-x86_64
|
| 123 |
+
adam_beta1:
|
| 124 |
+
value: 0.9
|
| 125 |
+
adam_beta2:
|
| 126 |
+
value: 0.999
|
| 127 |
+
adam_epsilon:
|
| 128 |
+
value: 1e-08
|
| 129 |
+
adam_weight_decay:
|
| 130 |
+
value: 0
|
| 131 |
+
allow_tf32:
|
| 132 |
+
value: true
|
| 133 |
+
batch_size:
|
| 134 |
+
value: 256
|
| 135 |
+
cfg_prob:
|
| 136 |
+
value: 0.1
|
| 137 |
+
checkpointing_steps:
|
| 138 |
+
value: 10000
|
| 139 |
+
cls:
|
| 140 |
+
value: 0.2
|
| 141 |
+
data_dir:
|
| 142 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 143 |
+
enc_type:
|
| 144 |
+
value: dinov2-vit-b
|
| 145 |
+
encoder_depth:
|
| 146 |
+
value: 8
|
| 147 |
+
epochs:
|
| 148 |
+
value: 1400
|
| 149 |
+
exp_name:
|
| 150 |
+
value: jsflow-experiment
|
| 151 |
+
fused_attn:
|
| 152 |
+
value: true
|
| 153 |
+
gradient_accumulation_steps:
|
| 154 |
+
value: 1
|
| 155 |
+
learning_rate:
|
| 156 |
+
value: 5e-05
|
| 157 |
+
legacy:
|
| 158 |
+
value: false
|
| 159 |
+
logging_dir:
|
| 160 |
+
value: logs
|
| 161 |
+
max_grad_norm:
|
| 162 |
+
value: 1
|
| 163 |
+
max_train_steps:
|
| 164 |
+
value: 1000000
|
| 165 |
+
mixed_precision:
|
| 166 |
+
value: bf16
|
| 167 |
+
model:
|
| 168 |
+
value: SiT-XL/2
|
| 169 |
+
num_classes:
|
| 170 |
+
value: 1000
|
| 171 |
+
num_workers:
|
| 172 |
+
value: 4
|
| 173 |
+
ops_head:
|
| 174 |
+
value: 16
|
| 175 |
+
ot_cls:
|
| 176 |
+
value: true
|
| 177 |
+
output_dir:
|
| 178 |
+
value: exps
|
| 179 |
+
path_type:
|
| 180 |
+
value: linear
|
| 181 |
+
prediction:
|
| 182 |
+
value: v
|
| 183 |
+
proj_coeff:
|
| 184 |
+
value: 0.5
|
| 185 |
+
qk_norm:
|
| 186 |
+
value: false
|
| 187 |
+
report_to:
|
| 188 |
+
value: wandb
|
| 189 |
+
resolution:
|
| 190 |
+
value: 256
|
| 191 |
+
resume_step:
|
| 192 |
+
value: 0
|
| 193 |
+
sampling_steps:
|
| 194 |
+
value: 2000
|
| 195 |
+
seed:
|
| 196 |
+
value: 0
|
| 197 |
+
semantic_features_dir:
|
| 198 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 199 |
+
t_c:
|
| 200 |
+
value: 0.5
|
| 201 |
+
weighting:
|
| 202 |
+
value: uniform
|
REG/wandb/run-20260322_150022-yhxc5cgu/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"_wandb":{"runtime":2},"_runtime":2}
|
REG/wandb/run-20260322_150443-e3yw9ii4/files/config.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_wandb:
|
| 2 |
+
value:
|
| 3 |
+
cli_version: 0.25.0
|
| 4 |
+
e:
|
| 5 |
+
q63x26q8nhayytv8q2rrmj9j9uy9kvub:
|
| 6 |
+
args:
|
| 7 |
+
- --report-to
|
| 8 |
+
- wandb
|
| 9 |
+
- --allow-tf32
|
| 10 |
+
- --mixed-precision
|
| 11 |
+
- bf16
|
| 12 |
+
- --seed
|
| 13 |
+
- "0"
|
| 14 |
+
- --path-type
|
| 15 |
+
- linear
|
| 16 |
+
- --prediction
|
| 17 |
+
- v
|
| 18 |
+
- --weighting
|
| 19 |
+
- uniform
|
| 20 |
+
- --model
|
| 21 |
+
- SiT-XL/2
|
| 22 |
+
- --enc-type
|
| 23 |
+
- dinov2-vit-b
|
| 24 |
+
- --encoder-depth
|
| 25 |
+
- "8"
|
| 26 |
+
- --proj-coeff
|
| 27 |
+
- "0.5"
|
| 28 |
+
- --output-dir
|
| 29 |
+
- exps
|
| 30 |
+
- --exp-name
|
| 31 |
+
- jsflow-experiment
|
| 32 |
+
- --batch-size
|
| 33 |
+
- "256"
|
| 34 |
+
- --data-dir
|
| 35 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 36 |
+
- --semantic-features-dir
|
| 37 |
+
- /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 38 |
+
- --learning-rate
|
| 39 |
+
- "0.00005"
|
| 40 |
+
- --t-c
|
| 41 |
+
- "0.5"
|
| 42 |
+
- --cls
|
| 43 |
+
- "0.2"
|
| 44 |
+
- --ot-cls
|
| 45 |
+
codePath: train.py
|
| 46 |
+
codePathLocal: train.py
|
| 47 |
+
cpu_count: 96
|
| 48 |
+
cpu_count_logical: 192
|
| 49 |
+
cudaVersion: "13.0"
|
| 50 |
+
disk:
|
| 51 |
+
/:
|
| 52 |
+
total: "3838880616448"
|
| 53 |
+
used: "357557714944"
|
| 54 |
+
email: 2365972933@qq.com
|
| 55 |
+
executable: /gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python
|
| 56 |
+
git:
|
| 57 |
+
commit: 021ea2e50c38c5803bd9afff16316958a01fbd1d
|
| 58 |
+
remote: https://github.com/Martinser/REG.git
|
| 59 |
+
gpu: NVIDIA H100 80GB HBM3
|
| 60 |
+
gpu_count: 4
|
| 61 |
+
gpu_nvidia:
|
| 62 |
+
- architecture: Hopper
|
| 63 |
+
cudaCores: 16896
|
| 64 |
+
memoryTotal: "85520809984"
|
| 65 |
+
name: NVIDIA H100 80GB HBM3
|
| 66 |
+
uuid: GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc
|
| 67 |
+
- architecture: Hopper
|
| 68 |
+
cudaCores: 16896
|
| 69 |
+
memoryTotal: "85520809984"
|
| 70 |
+
name: NVIDIA H100 80GB HBM3
|
| 71 |
+
uuid: GPU-a09f2421-99e6-a72e-63bd-fd7452510758
|
| 72 |
+
- architecture: Hopper
|
| 73 |
+
cudaCores: 16896
|
| 74 |
+
memoryTotal: "85520809984"
|
| 75 |
+
name: NVIDIA H100 80GB HBM3
|
| 76 |
+
uuid: GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d
|
| 77 |
+
- architecture: Hopper
|
| 78 |
+
cudaCores: 16896
|
| 79 |
+
memoryTotal: "85520809984"
|
| 80 |
+
name: NVIDIA H100 80GB HBM3
|
| 81 |
+
uuid: GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e
|
| 82 |
+
host: 24c964746905d416ce09d045f9a06f23-taskrole1-0
|
| 83 |
+
memory:
|
| 84 |
+
total: "2164115296256"
|
| 85 |
+
os: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
|
| 86 |
+
program: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py
|
| 87 |
+
python: CPython 3.12.9
|
| 88 |
+
root: /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG
|
| 89 |
+
startedAt: "2026-03-22T07:04:43.133739Z"
|
| 90 |
+
writerId: q63x26q8nhayytv8q2rrmj9j9uy9kvub
|
| 91 |
+
m: []
|
| 92 |
+
python_version: 3.12.9
|
| 93 |
+
t:
|
| 94 |
+
"1":
|
| 95 |
+
- 1
|
| 96 |
+
- 5
|
| 97 |
+
- 11
|
| 98 |
+
- 41
|
| 99 |
+
- 49
|
| 100 |
+
- 53
|
| 101 |
+
- 63
|
| 102 |
+
- 71
|
| 103 |
+
- 83
|
| 104 |
+
- 98
|
| 105 |
+
"2":
|
| 106 |
+
- 1
|
| 107 |
+
- 5
|
| 108 |
+
- 11
|
| 109 |
+
- 41
|
| 110 |
+
- 49
|
| 111 |
+
- 53
|
| 112 |
+
- 63
|
| 113 |
+
- 71
|
| 114 |
+
- 83
|
| 115 |
+
- 98
|
| 116 |
+
"3":
|
| 117 |
+
- 13
|
| 118 |
+
"4": 3.12.9
|
| 119 |
+
"5": 0.25.0
|
| 120 |
+
"6": 4.53.2
|
| 121 |
+
"12": 0.25.0
|
| 122 |
+
"13": linux-x86_64
|
| 123 |
+
adam_beta1:
|
| 124 |
+
value: 0.9
|
| 125 |
+
adam_beta2:
|
| 126 |
+
value: 0.999
|
| 127 |
+
adam_epsilon:
|
| 128 |
+
value: 1e-08
|
| 129 |
+
adam_weight_decay:
|
| 130 |
+
value: 0
|
| 131 |
+
allow_tf32:
|
| 132 |
+
value: true
|
| 133 |
+
batch_size:
|
| 134 |
+
value: 256
|
| 135 |
+
cfg_prob:
|
| 136 |
+
value: 0.1
|
| 137 |
+
checkpointing_steps:
|
| 138 |
+
value: 10000
|
| 139 |
+
cls:
|
| 140 |
+
value: 0.2
|
| 141 |
+
data_dir:
|
| 142 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256
|
| 143 |
+
enc_type:
|
| 144 |
+
value: dinov2-vit-b
|
| 145 |
+
encoder_depth:
|
| 146 |
+
value: 8
|
| 147 |
+
epochs:
|
| 148 |
+
value: 1400
|
| 149 |
+
exp_name:
|
| 150 |
+
value: jsflow-experiment
|
| 151 |
+
fused_attn:
|
| 152 |
+
value: true
|
| 153 |
+
gradient_accumulation_steps:
|
| 154 |
+
value: 1
|
| 155 |
+
learning_rate:
|
| 156 |
+
value: 5e-05
|
| 157 |
+
legacy:
|
| 158 |
+
value: false
|
| 159 |
+
logging_dir:
|
| 160 |
+
value: logs
|
| 161 |
+
max_grad_norm:
|
| 162 |
+
value: 1
|
| 163 |
+
max_train_steps:
|
| 164 |
+
value: 1000000
|
| 165 |
+
mixed_precision:
|
| 166 |
+
value: bf16
|
| 167 |
+
model:
|
| 168 |
+
value: SiT-XL/2
|
| 169 |
+
num_classes:
|
| 170 |
+
value: 1000
|
| 171 |
+
num_workers:
|
| 172 |
+
value: 4
|
| 173 |
+
ops_head:
|
| 174 |
+
value: 16
|
| 175 |
+
ot_cls:
|
| 176 |
+
value: true
|
| 177 |
+
output_dir:
|
| 178 |
+
value: exps
|
| 179 |
+
path_type:
|
| 180 |
+
value: linear
|
| 181 |
+
prediction:
|
| 182 |
+
value: v
|
| 183 |
+
proj_coeff:
|
| 184 |
+
value: 0.5
|
| 185 |
+
qk_norm:
|
| 186 |
+
value: false
|
| 187 |
+
report_to:
|
| 188 |
+
value: wandb
|
| 189 |
+
resolution:
|
| 190 |
+
value: 256
|
| 191 |
+
resume_step:
|
| 192 |
+
value: 0
|
| 193 |
+
sampling_steps:
|
| 194 |
+
value: 2000
|
| 195 |
+
seed:
|
| 196 |
+
value: 0
|
| 197 |
+
semantic_features_dir:
|
| 198 |
+
value: /gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0
|
| 199 |
+
t_c:
|
| 200 |
+
value: 0.5
|
| 201 |
+
weighting:
|
| 202 |
+
value: uniform
|
REG/wandb/run-20260322_150443-e3yw9ii4/files/output.log
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<588:29:38, 2.12s/it][[34m2026-03-22 15:04:48[0m] Generating EMA samples for evaluation (SDE → t=0)...
|
| 2 |
+
Traceback (most recent call last):
|
| 3 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 572, in <module>
|
| 4 |
+
main(args)
|
| 5 |
+
File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 444, in main
|
| 6 |
+
if vae is not None:
|
| 7 |
+
^^^
|
| 8 |
+
NameError: name 'vae' is not defined
|
| 9 |
+
[rank0]: Traceback (most recent call last):
|
| 10 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 572, in <module>
|
| 11 |
+
[rank0]: main(args)
|
| 12 |
+
[rank0]: File "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py", line 444, in main
|
| 13 |
+
[rank0]: if vae is not None:
|
| 14 |
+
[rank0]: ^^^
|
| 15 |
+
[rank0]: NameError: name 'vae' is not defined
|
REG/wandb/run-20260322_150443-e3yw9ii4/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-22T07:04:43.133739Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.5",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357557714944"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "q63x26q8nhayytv8q2rrmj9j9uy9kvub"
|
| 101 |
+
}
|
REG/wandb/run-20260322_150443-e3yw9ii4/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"_wandb":{"runtime":3},"_runtime":3}
|
REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T15:04:43.390486873+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T15:04:44.970687851+08:00","level":"INFO","msg":"stream: created new stream","id":"e3yw9ii4"}
|
| 3 |
+
{"time":"2026-03-22T15:04:44.970802178+08:00","level":"INFO","msg":"handler: started","stream_id":"e3yw9ii4"}
|
| 4 |
+
{"time":"2026-03-22T15:04:44.971744065+08:00","level":"INFO","msg":"stream: started","id":"e3yw9ii4"}
|
| 5 |
+
{"time":"2026-03-22T15:04:44.97174913+08:00","level":"INFO","msg":"writer: started","stream_id":"e3yw9ii4"}
|
| 6 |
+
{"time":"2026-03-22T15:04:44.971758857+08:00","level":"INFO","msg":"sender: started","stream_id":"e3yw9ii4"}
|
| 7 |
+
{"time":"2026-03-22T15:04:50.286711145+08:00","level":"INFO","msg":"stream: closing","id":"e3yw9ii4"}
|
REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 15:04:43,155 INFO MainThread:326012 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_setup.py:_flush():81] Configure stats pid to 326012
|
| 3 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug.log
|
| 5 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150443-e3yw9ii4/logs/debug-internal.log
|
| 6 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 15:04:43,156 INFO MainThread:326012 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 15:04:43,378 INFO MainThread:326012 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 15:04:43,388 INFO MainThread:326012 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 15:04:43,389 INFO MainThread:326012 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 15:04:43,402 INFO MainThread:326012 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 15:04:46,450 INFO MainThread:326012 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 15:04:46,541 INFO MainThread:326012 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 15:04:46,545 INFO MainThread:326012 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 15:04:46,545 INFO MainThread:326012 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
| 21 |
+
2026-03-22 15:04:50,286 INFO wandb-AsyncioManager-main:326012 [service_client.py:_forward_responses():134] Reached EOF.
|
| 22 |
+
2026-03-22 15:04:50,286 INFO wandb-AsyncioManager-main:326012 [mailbox.py:close():155] Closing mailbox, abandoning 1 handles.
|
REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-22T15:06:35.605542246+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-22T15:06:37.585674571+08:00","level":"INFO","msg":"stream: created new stream","id":"o2w3z8rq"}
|
| 3 |
+
{"time":"2026-03-22T15:06:37.585934805+08:00","level":"INFO","msg":"handler: started","stream_id":"o2w3z8rq"}
|
| 4 |
+
{"time":"2026-03-22T15:06:37.586954142+08:00","level":"INFO","msg":"stream: started","id":"o2w3z8rq"}
|
| 5 |
+
{"time":"2026-03-22T15:06:37.587002572+08:00","level":"INFO","msg":"sender: started","stream_id":"o2w3z8rq"}
|
| 6 |
+
{"time":"2026-03-22T15:06:37.58696296+08:00","level":"INFO","msg":"writer: started","stream_id":"o2w3z8rq"}
|
REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Configure stats pid to 328110
|
| 3 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug.log
|
| 5 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260322_150635-o2w3z8rq/logs/debug-internal.log
|
| 6 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-22 15:06:35,364 INFO MainThread:328110 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-22 15:06:35,590 INFO MainThread:328110 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-22 15:06:35,601 INFO MainThread:328110 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-22 15:06:35,604 INFO MainThread:328110 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-22 15:06:35,618 INFO MainThread:328110 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-22 15:06:38,166 INFO MainThread:328110 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-22 15:06:38,258 INFO MainThread:328110 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-22 15:06:38,258 INFO MainThread:328110 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-22 15:06:38,259 INFO MainThread:328110 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-22 15:06:38,259 INFO MainThread:328110 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-22 15:06:38,262 INFO MainThread:328110 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-22 15:06:38,263 INFO MainThread:328110 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.5, 'ot_cls': True}
|
REG/wandb/run-20260323_135607-zue1y2ba/files/output.log
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Steps: 0%| | 1/1000000 [00:02<603:07:37, 2.17s/it][[34m2026-03-23 13:56:13[0m] Generating EMA samples (Euler-Maruyama; t≈0.75 → t=0)...
|
| 2 |
+
[[34m2026-03-23 13:56:16[0m] Step: 1, Training Logs: loss_final: 4.886707, loss_mean: 1.706308, proj_loss: 0.001541, loss_mean_cls: 3.178859, grad_norm: 1.484174
|
| 3 |
+
Steps: 0%| | 2/1000000 [00:04<669:29:28, 2.41s/it, grad_norm=1.48, loss_final=4.89, loss_mean=1.71, loss_mean_cls=3.18, proj_loss=0.00154][[34m2026-03-23 13:56:16[0m] Step: 2, Training Logs: loss_final: 4.294325, loss_mean: 1.689698, proj_loss: -0.010266, loss_mean_cls: 2.614893, grad_norm: 1.062257
|
| 4 |
+
Steps: 0%| | 3/1000000 [00:04<394:40:41, 1.42s/it, grad_norm=1.06, loss_final=4.29, loss_mean=1.69, loss_mean_cls=2.61, proj_loss=-0.0103][[34m2026-03-23 13:56:16[0m] Step: 3, Training Logs: loss_final: 4.829489, loss_mean: 1.666703, proj_loss: -0.019215, loss_mean_cls: 3.182001, grad_norm: 1.115940
|
| 5 |
+
Steps: 0%| | 4/1000000 [00:05<265:36:26, 1.05it/s, grad_norm=1.12, loss_final=4.83, loss_mean=1.67, loss_mean_cls=3.18, proj_loss=-0.0192][[34m2026-03-23 13:56:16[0m] Step: 4, Training Logs: loss_final: 4.838729, loss_mean: 1.683388, proj_loss: -0.026348, loss_mean_cls: 3.181689, grad_norm: 0.753645
|
| 6 |
+
Steps: 0%| | 5/1000000 [00:05<194:13:56, 1.43it/s, grad_norm=0.754, loss_final=4.84, loss_mean=1.68, loss_mean_cls=3.18, proj_loss=-0.0263][[34m2026-03-23 13:56:17[0m] Step: 5, Training Logs: loss_final: 4.434019, loss_mean: 1.678989, proj_loss: -0.034618, loss_mean_cls: 2.789647, grad_norm: 0.828005
|
| 7 |
+
Steps: 0%| | 6/1000000 [00:05<151:13:33, 1.84it/s, grad_norm=0.828, loss_final=4.43, loss_mean=1.68, loss_mean_cls=2.79, proj_loss=-0.0346][[34m2026-03-23 13:56:17[0m] Step: 6, Training Logs: loss_final: 4.717834, loss_mean: 1.685299, proj_loss: -0.039478, loss_mean_cls: 3.072013, grad_norm: 0.944329
|
| 8 |
+
Steps: 0%| | 7/1000000 [00:05<123:56:37, 2.24it/s, grad_norm=0.944, loss_final=4.72, loss_mean=1.69, loss_mean_cls=3.07, proj_loss=-0.0395][[34m2026-03-23 13:56:17[0m] Step: 7, Training Logs: loss_final: 4.768410, loss_mean: 1.687585, proj_loss: -0.042706, loss_mean_cls: 3.123530, grad_norm: 0.827338
|
| 9 |
+
Steps: 0%| | 8/1000000 [00:06<106:03:40, 2.62it/s, grad_norm=0.827, loss_final=4.77, loss_mean=1.69, loss_mean_cls=3.12, proj_loss=-0.0427][[34m2026-03-23 13:56:17[0m] Step: 8, Training Logs: loss_final: 4.815378, loss_mean: 1.655840, proj_loss: -0.044987, loss_mean_cls: 3.204525, grad_norm: 0.854108
|
| 10 |
+
Steps: 0%| | 9/1000000 [00:06<94:05:27, 2.95it/s, grad_norm=0.854, loss_final=4.82, loss_mean=1.66, loss_mean_cls=3.2, proj_loss=-0.045] [[34m2026-03-23 13:56:18[0m] Step: 9, Training Logs: loss_final: 4.777181, loss_mean: 1.657609, proj_loss: -0.046041, loss_mean_cls: 3.165613, grad_norm: 0.939216
|
| 11 |
+
Steps: 0%| | 10/1000000 [00:06<86:16:44, 3.22it/s, grad_norm=0.939, loss_final=4.78, loss_mean=1.66, loss_mean_cls=3.17, proj_loss=-0.046][[34m2026-03-23 13:56:18[0m] Step: 10, Training Logs: loss_final: 4.966327, loss_mean: 1.662955, proj_loss: -0.047787, loss_mean_cls: 3.351158, grad_norm: 0.991139
|
| 12 |
+
Steps: 0%| | 11/1000000 [00:06<80:53:45, 3.43it/s, grad_norm=0.991, loss_final=4.97, loss_mean=1.66, loss_mean_cls=3.35, proj_loss=-0.0478][[34m2026-03-23 13:56:18[0m] Step: 11, Training Logs: loss_final: 5.292466, loss_mean: 1.650309, proj_loss: -0.049361, loss_mean_cls: 3.691518, grad_norm: 0.968609
|
| 13 |
+
Steps: 0%| | 12/1000000 [00:07<76:53:52, 3.61it/s, grad_norm=0.969, loss_final=5.29, loss_mean=1.65, loss_mean_cls=3.69, proj_loss=-0.0494][[34m2026-03-23 13:56:18[0m] Step: 12, Training Logs: loss_final: 5.088713, loss_mean: 1.626424, proj_loss: -0.049838, loss_mean_cls: 3.512127, grad_norm: 1.162980
|
| 14 |
+
Steps: 0%| | 13/1000000 [00:07<74:06:27, 3.75it/s, grad_norm=1.16, loss_final=5.09, loss_mean=1.63, loss_mean_cls=3.51, proj_loss=-0.0498][[34m2026-03-23 13:56:19[0m] Step: 13, Training Logs: loss_final: 4.928039, loss_mean: 1.623318, proj_loss: -0.051428, loss_mean_cls: 3.356148, grad_norm: 1.370081
|
| 15 |
+
Steps: 0%| | 14/1000000 [00:07<72:19:32, 3.84it/s, grad_norm=1.37, loss_final=4.93, loss_mean=1.62, loss_mean_cls=3.36, proj_loss=-0.0514][[34m2026-03-23 13:56:19[0m] Step: 14, Training Logs: loss_final: 4.342262, loss_mean: 1.599151, proj_loss: -0.051557, loss_mean_cls: 2.794668, grad_norm: 1.247963
|
| 16 |
+
Steps: 0%| | 15/1000000 [00:07<70:56:55, 3.92it/s, grad_norm=1.25, loss_final=4.34, loss_mean=1.6, loss_mean_cls=2.79, proj_loss=-0.0516][[34m2026-03-23 13:56:19[0m] Step: 15, Training Logs: loss_final: 5.220107, loss_mean: 1.589133, proj_loss: -0.054820, loss_mean_cls: 3.685793, grad_norm: 1.207489
|
| 17 |
+
Steps: 0%| | 16/1000000 [00:08<69:57:27, 3.97it/s, grad_norm=1.21, loss_final=5.22, loss_mean=1.59, loss_mean_cls=3.69, proj_loss=-0.0548][[34m2026-03-23 13:56:19[0m] Step: 16, Training Logs: loss_final: 4.653113, loss_mean: 1.599458, proj_loss: -0.052329, loss_mean_cls: 3.105984, grad_norm: 1.168053
|
| 18 |
+
Steps: 0%| | 17/1000000 [00:08<69:18:40, 4.01it/s, grad_norm=1.17, loss_final=4.65, loss_mean=1.6, loss_mean_cls=3.11, proj_loss=-0.0523][[34m2026-03-23 13:56:20[0m] Step: 17, Training Logs: loss_final: 4.984623, loss_mean: 1.639365, proj_loss: -0.051690, loss_mean_cls: 3.396948, grad_norm: 2.396022
|
| 19 |
+
Steps: 0%| | 18/1000000 [00:08<69:00:10, 4.03it/s, grad_norm=2.4, loss_final=4.98, loss_mean=1.64, loss_mean_cls=3.4, proj_loss=-0.0517][[34m2026-03-23 13:56:20[0m] Step: 18, Training Logs: loss_final: 4.356441, loss_mean: 1.560717, proj_loss: -0.054883, loss_mean_cls: 2.850607, grad_norm: 0.931938
|
| 20 |
+
Steps: 0%| | 19/1000000 [00:08<68:37:45, 4.05it/s, grad_norm=0.932, loss_final=4.36, loss_mean=1.56, loss_mean_cls=2.85, proj_loss=-0.0549][[34m2026-03-23 13:56:20[0m] Step: 19, Training Logs: loss_final: 4.922860, loss_mean: 1.625848, proj_loss: -0.055209, loss_mean_cls: 3.352221, grad_norm: 2.649956
|
| 21 |
+
Steps: 0%| | 20/1000000 [00:09<68:31:08, 4.05it/s, grad_norm=2.65, loss_final=4.92, loss_mean=1.63, loss_mean_cls=3.35, proj_loss=-0.0552][[34m2026-03-23 13:56:20[0m] Step: 20, Training Logs: loss_final: 4.693020, loss_mean: 1.635111, proj_loss: -0.053463, loss_mean_cls: 3.111372, grad_norm: 2.886806
|
| 22 |
+
Steps: 0%| | 21/1000000 [00:09<68:17:30, 4.07it/s, grad_norm=2.89, loss_final=4.69, loss_mean=1.64, loss_mean_cls=3.11, proj_loss=-0.0535][[34m2026-03-23 13:56:21[0m] Step: 21, Training Logs: loss_final: 4.847960, loss_mean: 1.608342, proj_loss: -0.053926, loss_mean_cls: 3.293545, grad_norm: 1.804076
|
| 23 |
+
Steps: 0%| | 22/1000000 [00:09<68:07:59, 4.08it/s, grad_norm=1.8, loss_final=4.85, loss_mean=1.61, loss_mean_cls=3.29, proj_loss=-0.0539][[34m2026-03-23 13:56:21[0m] Step: 22, Training Logs: loss_final: 4.531857, loss_mean: 1.556872, proj_loss: -0.053500, loss_mean_cls: 3.028484, grad_norm: 1.136153
|
| 24 |
+
Steps: 0%| | 23/1000000 [00:09<68:01:00, 4.08it/s, grad_norm=1.14, loss_final=4.53, loss_mean=1.56, loss_mean_cls=3.03, proj_loss=-0.0535][[34m2026-03-23 13:56:21[0m] Step: 23, Training Logs: loss_final: 4.347858, loss_mean: 1.571953, proj_loss: -0.051053, loss_mean_cls: 2.826958, grad_norm: 1.379415
|
| 25 |
+
Steps: 0%| | 24/1000000 [00:10<67:55:02, 4.09it/s, grad_norm=1.38, loss_final=4.35, loss_mean=1.57, loss_mean_cls=2.83, proj_loss=-0.0511][[34m2026-03-23 13:56:21[0m] Step: 24, Training Logs: loss_final: 4.812301, loss_mean: 1.573597, proj_loss: -0.054251, loss_mean_cls: 3.292954, grad_norm: 1.536997
|
| 26 |
+
Steps: 0%| | 25/1000000 [00:10<67:50:56, 4.09it/s, grad_norm=1.54, loss_final=4.81, loss_mean=1.57, loss_mean_cls=3.29, proj_loss=-0.0543][[34m2026-03-23 13:56:22[0m] Step: 25, Training Logs: loss_final: 5.140490, loss_mean: 1.572134, proj_loss: -0.055976, loss_mean_cls: 3.624332, grad_norm: 1.501310
|
| 27 |
+
Steps: 0%| | 26/1000000 [00:10<67:49:32, 4.10it/s, grad_norm=1.5, loss_final=5.14, loss_mean=1.57, loss_mean_cls=3.62, proj_loss=-0.056][[34m2026-03-23 13:56:22[0m] Step: 26, Training Logs: loss_final: 4.614564, loss_mean: 1.575298, proj_loss: -0.055522, loss_mean_cls: 3.094787, grad_norm: 1.374379
|
| 28 |
+
Steps: 0%| | 27/1000000 [00:10<67:48:51, 4.10it/s, grad_norm=1.37, loss_final=4.61, loss_mean=1.58, loss_mean_cls=3.09, proj_loss=-0.0555][[34m2026-03-23 13:56:22[0m] Step: 27, Training Logs: loss_final: 4.428196, loss_mean: 1.559373, proj_loss: -0.054907, loss_mean_cls: 2.923730, grad_norm: 1.247505
|
| 29 |
+
Steps: 0%| | 28/1000000 [00:11<67:51:05, 4.09it/s, grad_norm=1.25, loss_final=4.43, loss_mean=1.56, loss_mean_cls=2.92, proj_loss=-0.0549][[34m2026-03-23 13:56:22[0m] Step: 28, Training Logs: loss_final: 4.514633, loss_mean: 1.561718, proj_loss: -0.054182, loss_mean_cls: 3.007097, grad_norm: 1.272959
|
| 30 |
+
Steps: 0%| | 29/1000000 [00:11<67:50:57, 4.09it/s, grad_norm=1.27, loss_final=4.51, loss_mean=1.56, loss_mean_cls=3.01, proj_loss=-0.0542][[34m2026-03-23 13:56:23[0m] Step: 29, Training Logs: loss_final: 4.064789, loss_mean: 1.520746, proj_loss: -0.055004, loss_mean_cls: 2.599047, grad_norm: 1.213601
|
| 31 |
+
Steps: 0%| | 30/1000000 [00:11<67:49:21, 4.10it/s, grad_norm=1.21, loss_final=4.06, loss_mean=1.52, loss_mean_cls=2.6, proj_loss=-0.055][[34m2026-03-23 13:56:23[0m] Step: 30, Training Logs: loss_final: 4.343926, loss_mean: 1.523063, proj_loss: -0.055674, loss_mean_cls: 2.876538, grad_norm: 1.148791
|
| 32 |
+
Steps: 0%| | 31/1000000 [00:11<67:50:28, 4.09it/s, grad_norm=1.15, loss_final=4.34, loss_mean=1.52, loss_mean_cls=2.88, proj_loss=-0.0557][[34m2026-03-23 13:56:23[0m] Step: 31, Training Logs: loss_final: 4.943371, loss_mean: 1.500860, proj_loss: -0.056089, loss_mean_cls: 3.498600, grad_norm: 1.126629
|
| 33 |
+
Steps: 0%| | 32/1000000 [00:12<67:51:49, 4.09it/s, grad_norm=1.13, loss_final=4.94, loss_mean=1.5, loss_mean_cls=3.5, proj_loss=-0.0561][[34m2026-03-23 13:56:23[0m] Step: 32, Training Logs: loss_final: 4.913333, loss_mean: 1.487980, proj_loss: -0.053908, loss_mean_cls: 3.479261, grad_norm: 1.111562
|
| 34 |
+
Steps: 0%| | 33/1000000 [00:12<67:49:26, 4.10it/s, grad_norm=1.11, loss_final=4.91, loss_mean=1.49, loss_mean_cls=3.48, proj_loss=-0.0539][[34m2026-03-23 13:56:24[0m] Step: 33, Training Logs: loss_final: 5.135122, loss_mean: 1.485391, proj_loss: -0.058625, loss_mean_cls: 3.708356, grad_norm: 1.166992
|
| 35 |
+
Steps: 0%| | 34/1000000 [00:12<67:49:31, 4.10it/s, grad_norm=1.17, loss_final=5.14, loss_mean=1.49, loss_mean_cls=3.71, proj_loss=-0.0586][[34m2026-03-23 13:56:24[0m] Step: 34, Training Logs: loss_final: 4.094241, loss_mean: 1.475285, proj_loss: -0.056888, loss_mean_cls: 2.675844, grad_norm: 1.037110
|
| 36 |
+
Steps: 0%| | 35/1000000 [00:12<67:47:41, 4.10it/s, grad_norm=1.04, loss_final=4.09, loss_mean=1.48, loss_mean_cls=2.68, proj_loss=-0.0569][[34m2026-03-23 13:56:24[0m] Step: 35, Training Logs: loss_final: 4.305937, loss_mean: 1.460218, proj_loss: -0.052705, loss_mean_cls: 2.898424, grad_norm: 1.091828
|
| 37 |
+
Steps: 0%| | 36/1000000 [00:13<67:50:00, 4.09it/s, grad_norm=1.09, loss_final=4.31, loss_mean=1.46, loss_mean_cls=2.9, proj_loss=-0.0527][[34m2026-03-23 13:56:24[0m] Step: 36, Training Logs: loss_final: 4.382219, loss_mean: 1.466251, proj_loss: -0.054344, loss_mean_cls: 2.970313, grad_norm: 1.093016
|
| 38 |
+
Steps: 0%| | 37/1000000 [00:13<68:14:38, 4.07it/s, grad_norm=1.09, loss_final=4.38, loss_mean=1.47, loss_mean_cls=2.97, proj_loss=-0.0543][[34m2026-03-23 13:56:25[0m] Step: 37, Training Logs: loss_final: 4.428336, loss_mean: 1.469206, proj_loss: -0.057364, loss_mean_cls: 3.016495, grad_norm: 1.468433
|
| 39 |
+
Steps: 0%| | 38/1000000 [00:13<68:07:38, 4.08it/s, grad_norm=1.47, loss_final=4.43, loss_mean=1.47, loss_mean_cls=3.02, proj_loss=-0.0574][[34m2026-03-23 13:56:25[0m] Step: 38, Training Logs: loss_final: 4.412411, loss_mean: 1.453988, proj_loss: -0.056726, loss_mean_cls: 3.015149, grad_norm: 0.825492
|
| 40 |
+
Steps: 0%| | 39/1000000 [00:13<68:01:05, 4.08it/s, grad_norm=0.825, loss_final=4.41, loss_mean=1.45, loss_mean_cls=3.02, proj_loss=-0.0567][[34m2026-03-23 13:56:25[0m] Step: 39, Training Logs: loss_final: 4.554020, loss_mean: 1.450588, proj_loss: -0.055717, loss_mean_cls: 3.159149, grad_norm: 1.281297
|
| 41 |
+
Steps: 0%| | 40/1000000 [00:14<68:02:46, 4.08it/s, grad_norm=1.28, loss_final=4.55, loss_mean=1.45, loss_mean_cls=3.16, proj_loss=-0.0557][[34m2026-03-23 13:56:25[0m] Step: 40, Training Logs: loss_final: 4.900630, loss_mean: 1.421538, proj_loss: -0.056336, loss_mean_cls: 3.535427, grad_norm: 1.096009
|
| 42 |
+
Steps: 0%| | 41/1000000 [00:14<67:56:56, 4.09it/s, grad_norm=1.1, loss_final=4.9, loss_mean=1.42, loss_mean_cls=3.54, proj_loss=-0.0563][[34m2026-03-23 13:56:26[0m] Step: 41, Training Logs: loss_final: 5.000130, loss_mean: 1.418576, proj_loss: -0.056068, loss_mean_cls: 3.637622, grad_norm: 1.111790
|
| 43 |
+
Steps: 0%| | 42/1000000 [00:14<67:53:46, 4.09it/s, grad_norm=1.11, loss_final=5, loss_mean=1.42, loss_mean_cls=3.64, proj_loss=-0.0561][[34m2026-03-23 13:56:26[0m] Step: 42, Training Logs: loss_final: 4.223974, loss_mean: 1.442425, proj_loss: -0.054436, loss_mean_cls: 2.835985, grad_norm: 1.223004
|
| 44 |
+
Steps: 0%| | 43/1000000 [00:14<67:51:37, 4.09it/s, grad_norm=1.22, loss_final=4.22, loss_mean=1.44, loss_mean_cls=2.84, proj_loss=-0.0544][[34m2026-03-23 13:56:26[0m] Step: 43, Training Logs: loss_final: 5.023710, loss_mean: 1.441718, proj_loss: -0.055548, loss_mean_cls: 3.637539, grad_norm: 1.391159
|
| 45 |
+
Steps: 0%| | 44/1000000 [00:15<67:53:09, 4.09it/s, grad_norm=1.39, loss_final=5.02, loss_mean=1.44, loss_mean_cls=3.64, proj_loss=-0.0555][[34m2026-03-23 13:56:26[0m] Step: 44, Training Logs: loss_final: 5.000957, loss_mean: 1.414230, proj_loss: -0.055362, loss_mean_cls: 3.642089, grad_norm: 1.162115
|
| 46 |
+
Steps: 0%| | 45/1000000 [00:15<67:51:51, 4.09it/s, grad_norm=1.16, loss_final=5, loss_mean=1.41, loss_mean_cls=3.64, proj_loss=-0.0554][[34m2026-03-23 13:56:27[0m] Step: 45, Training Logs: loss_final: 4.689414, loss_mean: 1.389372, proj_loss: -0.054793, loss_mean_cls: 3.354835, grad_norm: 0.814496
|
| 47 |
+
Steps: 0%| | 46/1000000 [00:15<67:51:06, 4.09it/s, grad_norm=0.814, loss_final=4.69, loss_mean=1.39, loss_mean_cls=3.35, proj_loss=-0.0548][[34m2026-03-23 13:56:27[0m] Step: 46, Training Logs: loss_final: 4.452005, loss_mean: 1.403370, proj_loss: -0.056346, loss_mean_cls: 3.104981, grad_norm: 1.062373
|
| 48 |
+
Steps: 0%| | 47/1000000 [00:15<67:50:59, 4.09it/s, grad_norm=1.06, loss_final=4.45, loss_mean=1.4, loss_mean_cls=3.1, proj_loss=-0.0563][[34m2026-03-23 13:56:27[0m] Step: 47, Training Logs: loss_final: 4.638161, loss_mean: 1.413495, proj_loss: -0.055944, loss_mean_cls: 3.280609, grad_norm: 0.873316
|
| 49 |
+
Steps: 0%| | 48/1000000 [00:15<67:51:08, 4.09it/s, grad_norm=0.873, loss_final=4.64, loss_mean=1.41, loss_mean_cls=3.28, proj_loss=-0.0559][[34m2026-03-23 13:56:27[0m] Step: 48, Training Logs: loss_final: 4.429680, loss_mean: 1.411911, proj_loss: -0.054931, loss_mean_cls: 3.072701, grad_norm: 0.759066
|
| 50 |
+
Steps: 0%| | 49/1000000 [00:16<67:51:48, 4.09it/s, grad_norm=0.759, loss_final=4.43, loss_mean=1.41, loss_mean_cls=3.07, proj_loss=-0.0549][[34m2026-03-23 13:56:27[0m] Step: 49, Training Logs: loss_final: 4.901843, loss_mean: 1.388557, proj_loss: -0.055650, loss_mean_cls: 3.568936, grad_norm: 0.750354
|
| 51 |
+
Steps: 0%| | 50/1000000 [00:16<67:49:39, 4.10it/s, grad_norm=0.75, loss_final=4.9, loss_mean=1.39, loss_mean_cls=3.57, proj_loss=-0.0557][[34m2026-03-23 13:56:28[0m] Step: 50, Training Logs: loss_final: 4.100002, loss_mean: 1.412004, proj_loss: -0.052795, loss_mean_cls: 2.740793, grad_norm: 0.767011
|
| 52 |
+
Steps: 0%| | 51/1000000 [00:16<67:51:25, 4.09it/s, grad_norm=0.767, loss_final=4.1, loss_mean=1.41, loss_mean_cls=2.74, proj_loss=-0.0528][[34m2026-03-23 13:56:28[0m] Step: 51, Training Logs: loss_final: 4.759223, loss_mean: 1.391587, proj_loss: -0.054960, loss_mean_cls: 3.422596, grad_norm: 0.719157
|
| 53 |
+
Steps: 0%| | 52/1000000 [00:16<67:51:11, 4.09it/s, grad_norm=0.719, loss_final=4.76, loss_mean=1.39, loss_mean_cls=3.42, proj_loss=-0.055][[34m2026-03-23 13:56:28[0m] Step: 52, Training Logs: loss_final: 4.482382, loss_mean: 1.395736, proj_loss: -0.056075, loss_mean_cls: 3.142721, grad_norm: 0.726095
|
| 54 |
+
Steps: 0%| | 53/1000000 [00:17<67:50:55, 4.09it/s, grad_norm=0.726, loss_final=4.48, loss_mean=1.4, loss_mean_cls=3.14, proj_loss=-0.0561][[34m2026-03-23 13:56:28[0m] Step: 53, Training Logs: loss_final: 4.466825, loss_mean: 1.373211, proj_loss: -0.055033, loss_mean_cls: 3.148648, grad_norm: 0.663936
|
| 55 |
+
Steps: 0%| | 54/1000000 [00:17<67:49:43, 4.10it/s, grad_norm=0.664, loss_final=4.47, loss_mean=1.37, loss_mean_cls=3.15, proj_loss=-0.055][[34m2026-03-23 13:56:29[0m] Step: 54, Training Logs: loss_final: 4.687413, loss_mean: 1.367031, proj_loss: -0.056856, loss_mean_cls: 3.377239, grad_norm: 0.706146
|
| 56 |
+
Steps: 0%| | 55/1000000 [00:17<67:47:50, 4.10it/s, grad_norm=0.706, loss_final=4.69, loss_mean=1.37, loss_mean_cls=3.38, proj_loss=-0.0569][[34m2026-03-23 13:56:29[0m] Step: 55, Training Logs: loss_final: 4.724741, loss_mean: 1.355043, proj_loss: -0.056556, loss_mean_cls: 3.426254, grad_norm: 0.891681
|
| 57 |
+
Steps: 0%| | 56/1000000 [00:17<67:49:45, 4.10it/s, grad_norm=0.892, loss_final=4.72, loss_mean=1.36, loss_mean_cls=3.43, proj_loss=-0.0566][[34m2026-03-23 13:56:29[0m] Step: 56, Training Logs: loss_final: 4.598913, loss_mean: 1.357820, proj_loss: -0.053707, loss_mean_cls: 3.294799, grad_norm: 0.688579
|
| 58 |
+
Steps: 0%| | 57/1000000 [00:18<67:48:24, 4.10it/s, grad_norm=0.689, loss_final=4.6, loss_mean=1.36, loss_mean_cls=3.29, proj_loss=-0.0537][[34m2026-03-23 13:56:29[0m] Step: 57, Training Logs: loss_final: 4.647234, loss_mean: 1.346388, proj_loss: -0.055346, loss_mean_cls: 3.356192, grad_norm: 0.643795
|
| 59 |
+
Steps: 0%| | 58/1000000 [00:18<67:48:01, 4.10it/s, grad_norm=0.644, loss_final=4.65, loss_mean=1.35, loss_mean_cls=3.36, proj_loss=-0.0553][[34m2026-03-23 13:56:30[0m] Step: 58, Training Logs: loss_final: 4.629024, loss_mean: 1.361648, proj_loss: -0.056354, loss_mean_cls: 3.323730, grad_norm: 0.640788
|
| 60 |
+
Steps: 0%| | 59/1000000 [00:18<67:48:29, 4.10it/s, grad_norm=0.641, loss_final=4.63, loss_mean=1.36, loss_mean_cls=3.32, proj_loss=-0.0564][[34m2026-03-23 13:56:30[0m] Step: 59, Training Logs: loss_final: 4.318052, loss_mean: 1.352944, proj_loss: -0.054974, loss_mean_cls: 3.020082, grad_norm: 1.103365
|
| 61 |
+
Steps: 0%| | 60/1000000 [00:18<67:52:13, 4.09it/s, grad_norm=1.1, loss_final=4.32, loss_mean=1.35, loss_mean_cls=3.02, proj_loss=-0.055][[34m2026-03-23 13:56:30[0m] Step: 60, Training Logs: loss_final: 4.558999, loss_mean: 1.356349, proj_loss: -0.056304, loss_mean_cls: 3.258954, grad_norm: 2.178408
|
| 62 |
+
Steps: 0%| | 61/1000000 [00:19<67:50:06, 4.09it/s, grad_norm=2.18, loss_final=4.56, loss_mean=1.36, loss_mean_cls=3.26, proj_loss=-0.0563][[34m2026-03-23 13:56:30[0m] Step: 61, Training Logs: loss_final: 4.405947, loss_mean: 1.345089, proj_loss: -0.054088, loss_mean_cls: 3.114946, grad_norm: 0.833812
|
| 63 |
+
Steps: 0%| | 62/1000000 [00:19<67:49:48, 4.09it/s, grad_norm=0.834, loss_final=4.41, loss_mean=1.35, loss_mean_cls=3.11, proj_loss=-0.0541][[34m2026-03-23 13:56:31[0m] Step: 62, Training Logs: loss_final: 4.248630, loss_mean: 1.369897, proj_loss: -0.055107, loss_mean_cls: 2.933840, grad_norm: 2.604588
|
| 64 |
+
Steps: 0%| | 63/1000000 [00:19<67:49:04, 4.10it/s, grad_norm=2.6, loss_final=4.25, loss_mean=1.37, loss_mean_cls=2.93, proj_loss=-0.0551][[34m2026-03-23 13:56:31[0m] Step: 63, Training Logs: loss_final: 4.370800, loss_mean: 1.356826, proj_loss: -0.054216, loss_mean_cls: 3.068191, grad_norm: 1.995899
|
| 65 |
+
Steps: 0%| | 64/1000000 [00:19<67:52:12, 4.09it/s, grad_norm=2, loss_final=4.37, loss_mean=1.36, loss_mean_cls=3.07, proj_loss=-0.0542][[34m2026-03-23 13:56:31[0m] Step: 64, Training Logs: loss_final: 3.620857, loss_mean: 1.371637, proj_loss: -0.056188, loss_mean_cls: 2.305408, grad_norm: 1.270011
|
| 66 |
+
Steps: 0%| | 65/1000000 [00:20<67:52:26, 4.09it/s, grad_norm=1.27, loss_final=3.62, loss_mean=1.37, loss_mean_cls=2.31, proj_loss=-0.0562][[34m2026-03-23 13:56:31[0m] Step: 65, Training Logs: loss_final: 3.841930, loss_mean: 1.345907, proj_loss: -0.057290, loss_mean_cls: 2.553313, grad_norm: 1.103519
|
| 67 |
+
Steps: 0%| | 66/1000000 [00:20<67:50:32, 4.09it/s, grad_norm=1.1, loss_final=3.84, loss_mean=1.35, loss_mean_cls=2.55, proj_loss=-0.0573][[34m2026-03-23 13:56:32[0m] Step: 66, Training Logs: loss_final: 4.823415, loss_mean: 1.309009, proj_loss: -0.057261, loss_mean_cls: 3.571667, grad_norm: 1.146726
|
| 68 |
+
Steps: 0%| | 67/1000000 [00:20<67:51:11, 4.09it/s, grad_norm=1.15, loss_final=4.82, loss_mean=1.31, loss_mean_cls=3.57, proj_loss=-0.0573][[34m2026-03-23 13:56:32[0m] Step: 67, Training Logs: loss_final: 4.654735, loss_mean: 1.315030, proj_loss: -0.055812, loss_mean_cls: 3.395517, grad_norm: 0.900543
|
| 69 |
+
Steps: 0%| | 68/1000000 [00:20<67:53:43, 4.09it/s, grad_norm=0.901, loss_final=4.65, loss_mean=1.32, loss_mean_cls=3.4, proj_loss=-0.0558][[34m2026-03-23 13:56:32[0m] Step: 68, Training Logs: loss_final: 3.865313, loss_mean: 1.311920, proj_loss: -0.055171, loss_mean_cls: 2.608564, grad_norm: 0.968363
|
| 70 |
+
Steps: 0%| | 69/1000000 [00:21<67:51:01, 4.09it/s, grad_norm=0.968, loss_final=3.87, loss_mean=1.31, loss_mean_cls=2.61, proj_loss=-0.0552][[34m2026-03-23 13:56:32[0m] Step: 69, Training Logs: loss_final: 4.332185, loss_mean: 1.287292, proj_loss: -0.055524, loss_mean_cls: 3.100417, grad_norm: 0.992981
|
| 71 |
+
Steps: 0%| | 70/1000000 [00:21<67:50:34, 4.09it/s, grad_norm=0.993, loss_final=4.33, loss_mean=1.29, loss_mean_cls=3.1, proj_loss=-0.0555][[34m2026-03-23 13:56:33[0m] Step: 70, Training Logs: loss_final: 5.214293, loss_mean: 1.271327, proj_loss: -0.056000, loss_mean_cls: 3.998966, grad_norm: 0.950980
|
| 72 |
+
Steps: 0%| | 71/1000000 [00:21<67:51:17, 4.09it/s, grad_norm=0.951, loss_final=5.21, loss_mean=1.27, loss_mean_cls=4, proj_loss=-0.056][[34m2026-03-23 13:56:33[0m] Step: 71, Training Logs: loss_final: 4.271416, loss_mean: 1.298077, proj_loss: -0.058767, loss_mean_cls: 3.032106, grad_norm: 0.897326
|
| 73 |
+
Steps: 0%| | 72/1000000 [00:21<67:53:03, 4.09it/s, grad_norm=0.897, loss_final=4.27, loss_mean=1.3, loss_mean_cls=3.03, proj_loss=-0.0588][[34m2026-03-23 13:56:33[0m] Step: 72, Training Logs: loss_final: 4.944882, loss_mean: 1.275258, proj_loss: -0.054986, loss_mean_cls: 3.724611, grad_norm: 1.406698
|
| 74 |
+
Steps: 0%| | 73/1000000 [00:22<67:51:07, 4.09it/s, grad_norm=1.41, loss_final=4.94, loss_mean=1.28, loss_mean_cls=3.72, proj_loss=-0.055][[34m2026-03-23 13:56:33[0m] Step: 73, Training Logs: loss_final: 4.262454, loss_mean: 1.391114, proj_loss: -0.054645, loss_mean_cls: 2.925985, grad_norm: 4.082608
|
| 75 |
+
Steps: 0%| | 74/1000000 [00:22<67:49:45, 4.09it/s, grad_norm=4.08, loss_final=4.26, loss_mean=1.39, loss_mean_cls=2.93, proj_loss=-0.0546][[34m2026-03-23 13:56:34[0m] Step: 74, Training Logs: loss_final: 4.441927, loss_mean: 1.368322, proj_loss: -0.056059, loss_mean_cls: 3.129665, grad_norm: 3.740271
|
| 76 |
+
Steps: 0%| | 75/1000000 [00:22<67:48:41, 4.10it/s, grad_norm=3.74, loss_final=4.44, loss_mean=1.37, loss_mean_cls=3.13, proj_loss=-0.0561][[34m2026-03-23 13:56:34[0m] Step: 75, Training Logs: loss_final: 4.786633, loss_mean: 1.323001, proj_loss: -0.054923, loss_mean_cls: 3.518555, grad_norm: 1.360202
|
| 77 |
+
Steps: 0%| | 76/1000000 [00:22<67:47:44, 4.10it/s, grad_norm=1.36, loss_final=4.79, loss_mean=1.32, loss_mean_cls=3.52, proj_loss=-0.0549][[34m2026-03-23 13:56:34[0m] Step: 76, Training Logs: loss_final: 4.880804, loss_mean: 1.312154, proj_loss: -0.058812, loss_mean_cls: 3.627462, grad_norm: 2.462873
|
| 78 |
+
Steps: 0%| | 77/1000000 [00:23<67:48:38, 4.10it/s, grad_norm=2.46, loss_final=4.88, loss_mean=1.31, loss_mean_cls=3.63, proj_loss=-0.0588][[34m2026-03-23 13:56:34[0m] Step: 77, Training Logs: loss_final: 4.347323, loss_mean: 1.277282, proj_loss: -0.056339, loss_mean_cls: 3.126381, grad_norm: 1.089468
|
| 79 |
+
Steps: 0%| | 78/1000000 [00:23<67:47:33, 4.10it/s, grad_norm=1.09, loss_final=4.35, loss_mean=1.28, loss_mean_cls=3.13, proj_loss=-0.0563][[34m2026-03-23 13:56:35[0m] Step: 78, Training Logs: loss_final: 4.452549, loss_mean: 1.308045, proj_loss: -0.057751, loss_mean_cls: 3.202255, grad_norm: 1.409237
|
| 80 |
+
Steps: 0%| | 79/1000000 [00:23<67:48:01, 4.10it/s, grad_norm=1.41, loss_final=4.45, loss_mean=1.31, loss_mean_cls=3.2, proj_loss=-0.0578][[34m2026-03-23 13:56:35[0m] Step: 79, Training Logs: loss_final: 4.300065, loss_mean: 1.256680, proj_loss: -0.057729, loss_mean_cls: 3.101114, grad_norm: 1.063736
|
| 81 |
+
Steps: 0%| | 80/1000000 [00:23<67:46:24, 4.10it/s, grad_norm=1.06, loss_final=4.3, loss_mean=1.26, loss_mean_cls=3.1, proj_loss=-0.0577][[34m2026-03-23 13:56:35[0m] Step: 80, Training Logs: loss_final: 5.255699, loss_mean: 1.230699, proj_loss: -0.055231, loss_mean_cls: 4.080230, grad_norm: 1.159766
|
| 82 |
+
Steps: 0%| | 81/1000000 [00:24<67:45:22, 4.10it/s, grad_norm=1.16, loss_final=5.26, loss_mean=1.23, loss_mean_cls=4.08, proj_loss=-0.0552][[34m2026-03-23 13:56:35[0m] Step: 81, Training Logs: loss_final: 4.184104, loss_mean: 1.256977, proj_loss: -0.057558, loss_mean_cls: 2.984685, grad_norm: 1.074197
|
| 83 |
+
Steps: 0%| | 82/1000000 [00:24<67:46:48, 4.10it/s, grad_norm=1.07, loss_final=4.18, loss_mean=1.26, loss_mean_cls=2.98, proj_loss=-0.0576][[34m2026-03-23 13:56:36[0m] Step: 82, Training Logs: loss_final: 4.368578, loss_mean: 1.246449, proj_loss: -0.055964, loss_mean_cls: 3.178093, grad_norm: 1.212082
|
| 84 |
+
Steps: 0%| | 83/1000000 [00:24<67:46:42, 4.10it/s, grad_norm=1.21, loss_final=4.37, loss_mean=1.25, loss_mean_cls=3.18, proj_loss=-0.056][[34m2026-03-23 13:56:36[0m] Step: 83, Training Logs: loss_final: 4.066692, loss_mean: 1.279943, proj_loss: -0.057100, loss_mean_cls: 2.843850, grad_norm: 1.081535
|
| 85 |
+
Steps: 0%| | 84/1000000 [00:24<67:45:25, 4.10it/s, grad_norm=1.08, loss_final=4.07, loss_mean=1.28, loss_mean_cls=2.84, proj_loss=-0.0571][[34m2026-03-23 13:56:36[0m] Step: 84, Training Logs: loss_final: 4.419612, loss_mean: 1.222443, proj_loss: -0.056123, loss_mean_cls: 3.253291, grad_norm: 1.017672
|
| 86 |
+
Steps: 0%| | 85/1000000 [00:25<67:46:48, 4.10it/s, grad_norm=1.02, loss_final=4.42, loss_mean=1.22, loss_mean_cls=3.25, proj_loss=-0.0561][[34m2026-03-23 13:56:36[0m] Step: 85, Training Logs: loss_final: 4.330603, loss_mean: 1.217940, proj_loss: -0.055811, loss_mean_cls: 3.168474, grad_norm: 1.048292
|
| 87 |
+
Steps: 0%| | 86/1000000 [00:25<67:46:18, 4.10it/s, grad_norm=1.05, loss_final=4.33, loss_mean=1.22, loss_mean_cls=3.17, proj_loss=-0.0558][[34m2026-03-23 13:56:37[0m] Step: 86, Training Logs: loss_final: 4.073160, loss_mean: 1.261520, proj_loss: -0.057447, loss_mean_cls: 2.869087, grad_norm: 2.760909
|
| 88 |
+
Steps: 0%| | 87/1000000 [00:25<67:46:56, 4.10it/s, grad_norm=2.76, loss_final=4.07, loss_mean=1.26, loss_mean_cls=2.87, proj_loss=-0.0574][[34m2026-03-23 13:56:37[0m] Step: 87, Training Logs: loss_final: 4.349278, loss_mean: 1.225316, proj_loss: -0.055978, loss_mean_cls: 3.179941, grad_norm: 1.415633
|
| 89 |
+
Steps: 0%| | 88/1000000 [00:25<67:48:19, 4.10it/s, grad_norm=1.42, loss_final=4.35, loss_mean=1.23, loss_mean_cls=3.18, proj_loss=-0.056][[34m2026-03-23 13:56:37[0m] Step: 88, Training Logs: loss_final: 4.191838, loss_mean: 1.244159, proj_loss: -0.053158, loss_mean_cls: 3.000837, grad_norm: 2.142123
|
| 90 |
+
Steps: 0%| | 89/1000000 [00:26<67:48:47, 4.10it/s, grad_norm=2.14, loss_final=4.19, loss_mean=1.24, loss_mean_cls=3, proj_loss=-0.0532][[34m2026-03-23 13:56:37[0m] Step: 89, Training Logs: loss_final: 3.810838, loss_mean: 1.241034, proj_loss: -0.056567, loss_mean_cls: 2.626369, grad_norm: 1.977536
|
| 91 |
+
Steps: 0%| | 90/1000000 [00:26<67:49:26, 4.10it/s, grad_norm=1.98, loss_final=3.81, loss_mean=1.24, loss_mean_cls=2.63, proj_loss=-0.0566][[34m2026-03-23 13:56:37[0m] Step: 90, Training Logs: loss_final: 4.302361, loss_mean: 1.235736, proj_loss: -0.055757, loss_mean_cls: 3.122383, grad_norm: 1.465364
|
| 92 |
+
Steps: 0%| | 91/1000000 [00:26<67:49:17, 4.10it/s, grad_norm=1.47, loss_final=4.3, loss_mean=1.24, loss_mean_cls=3.12, proj_loss=-0.0558][[34m2026-03-23 13:56:38[0m] Step: 91, Training Logs: loss_final: 4.793566, loss_mean: 1.295077, proj_loss: -0.054678, loss_mean_cls: 3.553167, grad_norm: 4.777657
|
| 93 |
+
Steps: 0%| | 92/1000000 [00:26<67:50:23, 4.09it/s, grad_norm=4.78, loss_final=4.79, loss_mean=1.3, loss_mean_cls=3.55, proj_loss=-0.0547][[34m2026-03-23 13:56:38[0m] Step: 92, Training Logs: loss_final: 4.318158, loss_mean: 1.268673, proj_loss: -0.056761, loss_mean_cls: 3.106246, grad_norm: 3.593853
|
| 94 |
+
Steps: 0%| | 93/1000000 [00:26<67:49:16, 4.10it/s, grad_norm=3.59, loss_final=4.32, loss_mean=1.27, loss_mean_cls=3.11, proj_loss=-0.0568][[34m2026-03-23 13:56:38[0m] Step: 93, Training Logs: loss_final: 4.038731, loss_mean: 1.237636, proj_loss: -0.058346, loss_mean_cls: 2.859441, grad_norm: 2.444315
|
| 95 |
+
Steps: 0%| | 94/1000000 [00:27<67:49:30, 4.10it/s, grad_norm=2.44, loss_final=4.04, loss_mean=1.24, loss_mean_cls=2.86, proj_loss=-0.0583][[34m2026-03-23 13:56:38[0m] Step: 94, Training Logs: loss_final: 5.108729, loss_mean: 1.212993, proj_loss: -0.058394, loss_mean_cls: 3.954130, grad_norm: 1.895388
|
| 96 |
+
Steps: 0%| | 95/1000000 [00:27<67:48:34, 4.10it/s, grad_norm=1.9, loss_final=5.11, loss_mean=1.21, loss_mean_cls=3.95, proj_loss=-0.0584][[34m2026-03-23 13:56:39[0m] Step: 95, Training Logs: loss_final: 4.209163, loss_mean: 1.207963, proj_loss: -0.055074, loss_mean_cls: 3.056275, grad_norm: 1.442985
|
| 97 |
+
Steps: 0%| | 96/1000000 [00:27<67:50:56, 4.09it/s, grad_norm=1.44, loss_final=4.21, loss_mean=1.21, loss_mean_cls=3.06, proj_loss=-0.0551][[34m2026-03-23 13:56:39[0m] Step: 96, Training Logs: loss_final: 4.601864, loss_mean: 1.191701, proj_loss: -0.057576, loss_mean_cls: 3.467739, grad_norm: 1.299760
|
| 98 |
+
Steps: 0%| | 97/1000000 [00:27<67:51:29, 4.09it/s, grad_norm=1.3, loss_final=4.6, loss_mean=1.19, loss_mean_cls=3.47, proj_loss=-0.0576][[34m2026-03-23 13:56:39[0m] Step: 97, Training Logs: loss_final: 4.614489, loss_mean: 1.185281, proj_loss: -0.056098, loss_mean_cls: 3.485306, grad_norm: 1.128512
|
| 99 |
+
Steps: 0%| | 98/1000000 [00:28<67:50:38, 4.09it/s, grad_norm=1.13, loss_final=4.61, loss_mean=1.19, loss_mean_cls=3.49, proj_loss=-0.0561][[34m2026-03-23 13:56:39[0m] Step: 98, Training Logs: loss_final: 4.294365, loss_mean: 1.196555, proj_loss: -0.058301, loss_mean_cls: 3.156111, grad_norm: 1.399289
|
| 100 |
+
Steps: 0%| | 99/1000000 [00:28<67:49:41, 4.09it/s, grad_norm=1.4, loss_final=4.29, loss_mean=1.2, loss_mean_cls=3.16, proj_loss=-0.0583][[34m2026-03-23 13:56:40[0m] Step: 99, Training Logs: loss_final: 4.215539, loss_mean: 1.199367, proj_loss: -0.057493, loss_mean_cls: 3.073665, grad_norm: 1.343760
|
| 101 |
+
Steps: 0%| | 100/1000000 [00:28<67:51:40, 4.09it/s, grad_norm=1.34, loss_final=4.22, loss_mean=1.2, loss_mean_cls=3.07, proj_loss=-0.0575][[34m2026-03-23 13:56:40[0m] Step: 100, Training Logs: loss_final: 4.488984, loss_mean: 1.149765, proj_loss: -0.056815, loss_mean_cls: 3.396034, grad_norm: 1.045582
|
| 102 |
+
Steps: 0%| | 101/1000000 [00:28<67:50:50, 4.09it/s, grad_norm=1.05, loss_final=4.49, loss_mean=1.15, loss_mean_cls=3.4, proj_loss=-0.0568][[34m2026-03-23 13:56:40[0m] Step: 101, Training Logs: loss_final: 4.368436, loss_mean: 1.178788, proj_loss: -0.058116, loss_mean_cls: 3.247765, grad_norm: 1.342256
|
| 103 |
+
Steps: 0%| | 102/1000000 [00:29<67:50:14, 4.09it/s, grad_norm=1.34, loss_final=4.37, loss_mean=1.18, loss_mean_cls=3.25, proj_loss=-0.0581][[34m2026-03-23 13:56:40[0m] Step: 102, Training Logs: loss_final: 4.096940, loss_mean: 1.195665, proj_loss: -0.053437, loss_mean_cls: 2.954712, grad_norm: 1.167338
|
| 104 |
+
Steps: 0%| | 103/1000000 [00:29<67:52:23, 4.09it/s, grad_norm=1.17, loss_final=4.1, loss_mean=1.2, loss_mean_cls=2.95, proj_loss=-0.0534][[34m2026-03-23 13:56:41[0m] Step: 103, Training Logs: loss_final: 4.406911, loss_mean: 1.123473, proj_loss: -0.055880, loss_mean_cls: 3.339318, grad_norm: 0.934016
|
| 105 |
+
Steps: 0%| | 104/1000000 [00:29<67:52:00, 4.09it/s, grad_norm=0.934, loss_final=4.41, loss_mean=1.12, loss_mean_cls=3.34, proj_loss=-0.0559][[34m2026-03-23 13:56:41[0m] Step: 104, Training Logs: loss_final: 3.831389, loss_mean: 1.127624, proj_loss: -0.056442, loss_mean_cls: 2.760207, grad_norm: 1.079511
|
| 106 |
+
Steps: 0%| | 105/1000000 [00:29<67:50:29, 4.09it/s, grad_norm=1.08, loss_final=3.83, loss_mean=1.13, loss_mean_cls=2.76, proj_loss=-0.0564][[34m2026-03-23 13:56:41[0m] Step: 105, Training Logs: loss_final: 3.949926, loss_mean: 1.169529, proj_loss: -0.056276, loss_mean_cls: 2.836674, grad_norm: 0.823243
|
| 107 |
+
Steps: 0%| | 106/1000000 [00:30<67:49:55, 4.09it/s, grad_norm=0.823, loss_final=3.95, loss_mean=1.17, loss_mean_cls=2.84, proj_loss=-0.0563][[34m2026-03-23 13:56:41[0m] Step: 106, Training Logs: loss_final: 4.447761, loss_mean: 1.131930, proj_loss: -0.057379, loss_mean_cls: 3.373209, grad_norm: 0.780265
|
| 108 |
+
Steps: 0%| | 107/1000000 [00:30<67:50:26, 4.09it/s, grad_norm=0.78, loss_final=4.45, loss_mean=1.13, loss_mean_cls=3.37, proj_loss=-0.0574][[34m2026-03-23 13:56:42[0m] Step: 107, Training Logs: loss_final: 4.185682, loss_mean: 1.137062, proj_loss: -0.056792, loss_mean_cls: 3.105412, grad_norm: 0.896073
|
| 109 |
+
Steps: 0%| | 108/1000000 [00:30<67:51:47, 4.09it/s, grad_norm=0.896, loss_final=4.19, loss_mean=1.14, loss_mean_cls=3.11, proj_loss=-0.0568][[34m2026-03-23 13:56:42[0m] Step: 108, Training Logs: loss_final: 4.381330, loss_mean: 1.127375, proj_loss: -0.057231, loss_mean_cls: 3.311185, grad_norm: 0.791624
|
| 110 |
+
Steps: 0%| | 109/1000000 [00:30<67:50:27, 4.09it/s, grad_norm=0.792, loss_final=4.38, loss_mean=1.13, loss_mean_cls=3.31, proj_loss=-0.0572][[34m2026-03-23 13:56:42[0m] Step: 109, Training Logs: loss_final: 4.101658, loss_mean: 1.133106, proj_loss: -0.058366, loss_mean_cls: 3.026918, grad_norm: 1.389539
|
| 111 |
+
Steps: 0%| | 110/1000000 [00:31<67:50:56, 4.09it/s, grad_norm=1.39, loss_final=4.1, loss_mean=1.13, loss_mean_cls=3.03, proj_loss=-0.0584][[34m2026-03-23 13:56:42[0m] Step: 110, Training Logs: loss_final: 4.222694, loss_mean: 1.125139, proj_loss: -0.056927, loss_mean_cls: 3.154482, grad_norm: 1.616820
|
| 112 |
+
Steps: 0%| | 111/1000000 [00:31<67:52:18, 4.09it/s, grad_norm=1.62, loss_final=4.22, loss_mean=1.13, loss_mean_cls=3.15, proj_loss=-0.0569][[34m2026-03-23 13:56:43[0m] Step: 111, Training Logs: loss_final: 3.836849, loss_mean: 1.123032, proj_loss: -0.057576, loss_mean_cls: 2.771394, grad_norm: 1.516640
|
| 113 |
+
Steps: 0%| | 112/1000000 [00:31<67:50:27, 4.09it/s, grad_norm=1.52, loss_final=3.84, loss_mean=1.12, loss_mean_cls=2.77, proj_loss=-0.0576][[34m2026-03-23 13:56:43[0m] Step: 112, Training Logs: loss_final: 4.385750, loss_mean: 1.170239, proj_loss: -0.057827, loss_mean_cls: 3.273338, grad_norm: 3.234420
|
| 114 |
+
Steps: 0%| | 113/1000000 [00:31<67:50:21, 4.09it/s, grad_norm=3.23, loss_final=4.39, loss_mean=1.17, loss_mean_cls=3.27, proj_loss=-0.0578][[34m2026-03-23 13:56:43[0m] Step: 113, Training Logs: loss_final: 4.791656, loss_mean: 1.139730, proj_loss: -0.057919, loss_mean_cls: 3.709845, grad_norm: 2.920728
|
| 115 |
+
Steps: 0%| | 114/1000000 [00:32<67:52:04, 4.09it/s, grad_norm=2.92, loss_final=4.79, loss_mean=1.14, loss_mean_cls=3.71, proj_loss=-0.0579][[34m2026-03-23 13:56:43[0m] Step: 114, Training Logs: loss_final: 4.198164, loss_mean: 1.112274, proj_loss: -0.054186, loss_mean_cls: 3.140076, grad_norm: 1.197350
|
| 116 |
+
Steps: 0%| | 115/1000000 [00:32<67:52:07, 4.09it/s, grad_norm=1.2, loss_final=4.2, loss_mean=1.11, loss_mean_cls=3.14, proj_loss=-0.0542][[34m2026-03-23 13:56:44[0m] Step: 115, Training Logs: loss_final: 4.495673, loss_mean: 1.149152, proj_loss: -0.056316, loss_mean_cls: 3.402837, grad_norm: 1.762717
|
| 117 |
+
Steps: 0%| | 116/1000000 [00:32<67:52:00, 4.09it/s, grad_norm=1.76, loss_final=4.5, loss_mean=1.15, loss_mean_cls=3.4, proj_loss=-0.0563][[34m2026-03-23 13:56:44[0m] Step: 116, Training Logs: loss_final: 3.863493, loss_mean: 1.173006, proj_loss: -0.057397, loss_mean_cls: 2.747884, grad_norm: 2.976036
|
| 118 |
+
Steps: 0%| | 117/1000000 [00:32<67:52:53, 4.09it/s, grad_norm=2.98, loss_final=3.86, loss_mean=1.17, loss_mean_cls=2.75, proj_loss=-0.0574][[34m2026-03-23 13:56:44[0m] Step: 117, Training Logs: loss_final: 4.637027, loss_mean: 1.119735, proj_loss: -0.056042, loss_mean_cls: 3.573334, grad_norm: 2.094532
|
| 119 |
+
Steps: 0%| | 118/1000000 [00:33<67:52:15, 4.09it/s, grad_norm=2.09, loss_final=4.64, loss_mean=1.12, loss_mean_cls=3.57, proj_loss=-0.056][[34m2026-03-23 13:56:44[0m] Step: 118, Training Logs: loss_final: 3.896062, loss_mean: 1.133374, proj_loss: -0.057744, loss_mean_cls: 2.820432, grad_norm: 2.091462
|
| 120 |
+
Steps: 0%| | 119/1000000 [00:33<67:52:19, 4.09it/s, grad_norm=2.09, loss_final=3.9, loss_mean=1.13, loss_mean_cls=2.82, proj_loss=-0.0577][[34m2026-03-23 13:56:45[0m] Step: 119, Training Logs: loss_final: 4.494811, loss_mean: 1.127229, proj_loss: -0.056165, loss_mean_cls: 3.423747, grad_norm: 2.573430
|
| 121 |
+
Steps: 0%| | 120/1000000 [00:33<67:52:07, 4.09it/s, grad_norm=2.57, loss_final=4.49, loss_mean=1.13, loss_mean_cls=3.42, proj_loss=-0.0562][[34m2026-03-23 13:56:45[0m] Step: 120, Training Logs: loss_final: 4.395417, loss_mean: 1.094057, proj_loss: -0.057758, loss_mean_cls: 3.359118, grad_norm: 1.652974
|
| 122 |
+
Steps: 0%| | 121/1000000 [00:33<67:49:02, 4.10it/s, grad_norm=1.65, loss_final=4.4, loss_mean=1.09, loss_mean_cls=3.36, proj_loss=-0.0578][[34m2026-03-23 13:56:45[0m] Step: 121, Training Logs: loss_final: 3.780295, loss_mean: 1.143866, proj_loss: -0.057413, loss_mean_cls: 2.693842, grad_norm: 2.330424
|
| 123 |
+
Steps: 0%| | 122/1000000 [00:34<67:51:29, 4.09it/s, grad_norm=2.33, loss_final=3.78, loss_mean=1.14, loss_mean_cls=2.69, proj_loss=-0.0574][[34m2026-03-23 13:56:45[0m] Step: 122, Training Logs: loss_final: 5.011144, loss_mean: 1.122681, proj_loss: -0.058837, loss_mean_cls: 3.947300, grad_norm: 1.509129
|
| 124 |
+
Steps: 0%| | 123/1000000 [00:34<67:51:03, 4.09it/s, grad_norm=1.51, loss_final=5.01, loss_mean=1.12, loss_mean_cls=3.95, proj_loss=-0.0588][[34m2026-03-23 13:56:46[0m] Step: 123, Training Logs: loss_final: 4.172338, loss_mean: 1.115492, proj_loss: -0.056144, loss_mean_cls: 3.112991, grad_norm: 1.528705
|
| 125 |
+
Steps: 0%| | 124/1000000 [00:34<67:50:03, 4.09it/s, grad_norm=1.53, loss_final=4.17, loss_mean=1.12, loss_mean_cls=3.11, proj_loss=-0.0561][[34m2026-03-23 13:56:46[0m] Step: 124, Training Logs: loss_final: 4.010628, loss_mean: 1.117229, proj_loss: -0.054621, loss_mean_cls: 2.948020, grad_norm: 1.329769
|
| 126 |
+
Steps: 0%| | 125/1000000 [00:34<67:48:41, 4.10it/s, grad_norm=1.33, loss_final=4.01, loss_mean=1.12, loss_mean_cls=2.95, proj_loss=-0.0546][[34m2026-03-23 13:56:46[0m] Step: 125, Training Logs: loss_final: 3.764182, loss_mean: 1.127925, proj_loss: -0.057976, loss_mean_cls: 2.694233, grad_norm: 1.674507
|
| 127 |
+
Steps: 0%| | 126/1000000 [00:35<67:47:01, 4.10it/s, grad_norm=1.67, loss_final=3.76, loss_mean=1.13, loss_mean_cls=2.69, proj_loss=-0.058][[34m2026-03-23 13:56:46[0m] Step: 126, Training Logs: loss_final: 4.371668, loss_mean: 1.097472, proj_loss: -0.055551, loss_mean_cls: 3.329747, grad_norm: 1.962917
|
| 128 |
+
Steps: 0%| | 127/1000000 [00:35<67:47:01, 4.10it/s, grad_norm=1.96, loss_final=4.37, loss_mean=1.1, loss_mean_cls=3.33, proj_loss=-0.0556][[34m2026-03-23 13:56:47[0m] Step: 127, Training Logs: loss_final: 4.417885, loss_mean: 1.097842, proj_loss: -0.057305, loss_mean_cls: 3.377348, grad_norm: 1.860904
|
| 129 |
+
Steps: 0%| | 128/1000000 [00:35<67:46:39, 4.10it/s, grad_norm=1.86, loss_final=4.42, loss_mean=1.1, loss_mean_cls=3.38, proj_loss=-0.0573][[34m2026-03-23 13:56:47[0m] Step: 128, Training Logs: loss_final: 4.017467, loss_mean: 1.103394, proj_loss: -0.057005, loss_mean_cls: 2.971077, grad_norm: 1.122871
|
| 130 |
+
Steps: 0%| | 129/1000000 [00:35<67:55:50, 4.09it/s, grad_norm=1.12, loss_final=4.02, loss_mean=1.1, loss_mean_cls=2.97, proj_loss=-0.057][[34m2026-03-23 13:56:47[0m] Step: 129, Training Logs: loss_final: 4.208836, loss_mean: 1.102054, proj_loss: -0.056320, loss_mean_cls: 3.163102, grad_norm: 1.851275
|
| 131 |
+
Steps: 0%| | 130/1000000 [00:36<67:52:57, 4.09it/s, grad_norm=1.85, loss_final=4.21, loss_mean=1.1, loss_mean_cls=3.16, proj_loss=-0.0563][[34m2026-03-23 13:56:47[0m] Step: 130, Training Logs: loss_final: 4.207572, loss_mean: 1.127353, proj_loss: -0.057998, loss_mean_cls: 3.138218, grad_norm: 1.432804
|
| 132 |
+
Steps: 0%| | 131/1000000 [00:36<67:51:54, 4.09it/s, grad_norm=1.43, loss_final=4.21, loss_mean=1.13, loss_mean_cls=3.14, proj_loss=-0.058][[34m2026-03-23 13:56:48[0m] Step: 131, Training Logs: loss_final: 4.999227, loss_mean: 1.065102, proj_loss: -0.057811, loss_mean_cls: 3.991935, grad_norm: 1.260865
|
| 133 |
+
Steps: 0%| | 132/1000000 [00:36<67:50:15, 4.09it/s, grad_norm=1.26, loss_final=5, loss_mean=1.07, loss_mean_cls=3.99, proj_loss=-0.0578][[34m2026-03-23 13:56:48[0m] Step: 132, Training Logs: loss_final: 4.958624, loss_mean: 1.080354, proj_loss: -0.055375, loss_mean_cls: 3.933645, grad_norm: 1.087305
|
| 134 |
+
Steps: 0%| | 133/1000000 [00:36<67:49:24, 4.10it/s, grad_norm=1.09, loss_final=4.96, loss_mean=1.08, loss_mean_cls=3.93, proj_loss=-0.0554][[34m2026-03-23 13:56:48[0m] Step: 133, Training Logs: loss_final: 4.336567, loss_mean: 1.074137, proj_loss: -0.057906, loss_mean_cls: 3.320336, grad_norm: 0.845532
|
| 135 |
+
Steps: 0%| | 134/1000000 [00:36<67:49:27, 4.09it/s, grad_norm=0.846, loss_final=4.34, loss_mean=1.07, loss_mean_cls=3.32, proj_loss=-0.0579][[34m2026-03-23 13:56:48[0m] Step: 134, Training Logs: loss_final: 4.113417, loss_mean: 1.101781, proj_loss: -0.055828, loss_mean_cls: 3.067464, grad_norm: 0.853130
|
| 136 |
+
Steps: 0%| | 135/1000000 [00:37<67:48:26, 4.10it/s, grad_norm=0.853, loss_final=4.11, loss_mean=1.1, loss_mean_cls=3.07, proj_loss=-0.0558][[34m2026-03-23 13:56:48[0m] Step: 135, Training Logs: loss_final: 4.451187, loss_mean: 1.071875, proj_loss: -0.057522, loss_mean_cls: 3.436834, grad_norm: 1.211643
|
| 137 |
+
Steps: 0%| | 136/1000000 [00:37<67:48:00, 4.10it/s, grad_norm=1.21, loss_final=4.45, loss_mean=1.07, loss_mean_cls=3.44, proj_loss=-0.0575][[34m2026-03-23 13:56:49[0m] Step: 136, Training Logs: loss_final: 3.470972, loss_mean: 1.105708, proj_loss: -0.057636, loss_mean_cls: 2.422900, grad_norm: 1.028398
|
| 138 |
+
Steps: 0%| | 137/1000000 [00:37<67:48:46, 4.10it/s, grad_norm=1.03, loss_final=3.47, loss_mean=1.11, loss_mean_cls=2.42, proj_loss=-0.0576][[34m2026-03-23 13:56:49[0m] Step: 137, Training Logs: loss_final: 4.167720, loss_mean: 1.116012, proj_loss: -0.056297, loss_mean_cls: 3.108005, grad_norm: 1.331365
|
| 139 |
+
Steps: 0%| | 138/1000000 [00:37<67:49:54, 4.09it/s, grad_norm=1.33, loss_final=4.17, loss_mean=1.12, loss_mean_cls=3.11, proj_loss=-0.0563][[34m2026-03-23 13:56:49[0m] Step: 138, Training Logs: loss_final: 3.969088, loss_mean: 1.080545, proj_loss: -0.058032, loss_mean_cls: 2.946575, grad_norm: 1.159432
|
| 140 |
+
Steps: 0%| | 139/1000000 [00:38<67:49:15, 4.10it/s, grad_norm=1.16, loss_final=3.97, loss_mean=1.08, loss_mean_cls=2.95, proj_loss=-0.058][[34m2026-03-23 13:56:49[0m] Step: 139, Training Logs: loss_final: 4.553603, loss_mean: 1.100756, proj_loss: -0.057699, loss_mean_cls: 3.510547, grad_norm: 1.379237
|
| 141 |
+
Steps: 0%| | 140/1000000 [00:38<67:48:29, 4.10it/s, grad_norm=1.38, loss_final=4.55, loss_mean=1.1, loss_mean_cls=3.51, proj_loss=-0.0577][[34m2026-03-23 13:56:50[0m] Step: 140, Training Logs: loss_final: 3.823996, loss_mean: 1.091382, proj_loss: -0.055826, loss_mean_cls: 2.788440, grad_norm: 1.269147
|
| 142 |
+
Steps: 0%| | 141/1000000 [00:38<67:49:20, 4.10it/s, grad_norm=1.27, loss_final=3.82, loss_mean=1.09, loss_mean_cls=2.79, proj_loss=-0.0558][[34m2026-03-23 13:56:50[0m] Step: 141, Training Logs: loss_final: 4.244388, loss_mean: 1.086618, proj_loss: -0.060178, loss_mean_cls: 3.217947, grad_norm: 1.224737
|
| 143 |
+
Steps: 0%| | 142/1000000 [00:38<67:48:47, 4.10it/s, grad_norm=1.22, loss_final=4.24, loss_mean=1.09, loss_mean_cls=3.22, proj_loss=-0.0602][[34m2026-03-23 13:56:50[0m] Step: 142, Training Logs: loss_final: 4.555122, loss_mean: 1.074355, proj_loss: -0.056447, loss_mean_cls: 3.537215, grad_norm: 1.288044
|
| 144 |
+
Steps: 0%| | 143/1000000 [00:39<67:49:15, 4.10it/s, grad_norm=1.29, loss_final=4.56, loss_mean=1.07, loss_mean_cls=3.54, proj_loss=-0.0564][[34m2026-03-23 13:56:50[0m] Step: 143, Training Logs: loss_final: 4.631051, loss_mean: 1.031162, proj_loss: -0.056307, loss_mean_cls: 3.656196, grad_norm: 0.875154
|
| 145 |
+
Steps: 0%| | 144/1000000 [00:39<67:47:56, 4.10it/s, grad_norm=0.875, loss_final=4.63, loss_mean=1.03, loss_mean_cls=3.66, proj_loss=-0.0563][[34m2026-03-23 13:56:51[0m] Step: 144, Training Logs: loss_final: 4.039268, loss_mean: 1.089356, proj_loss: -0.056316, loss_mean_cls: 3.006228, grad_norm: 1.155283
|
| 146 |
+
Steps: 0%| | 145/1000000 [00:39<67:55:59, 4.09it/s, grad_norm=1.16, loss_final=4.04, loss_mean=1.09, loss_mean_cls=3.01, proj_loss=-0.0563][[34m2026-03-23 13:56:51[0m] Step: 145, Training Logs: loss_final: 3.760608, loss_mean: 1.070431, proj_loss: -0.059299, loss_mean_cls: 2.749476, grad_norm: 1.032837
|
| 147 |
+
Steps: 0%| | 146/1000000 [00:39<67:53:40, 4.09it/s, grad_norm=1.03, loss_final=3.76, loss_mean=1.07, loss_mean_cls=2.75, proj_loss=-0.0593][[34m2026-03-23 13:56:51[0m] Step: 146, Training Logs: loss_final: 4.181468, loss_mean: 1.069318, proj_loss: -0.057231, loss_mean_cls: 3.169381, grad_norm: 1.739599
|
| 148 |
+
Steps: 0%| | 147/1000000 [00:40<87:04:35, 3.19it/s, grad_norm=1.74, loss_final=4.18, loss_mean=1.07, loss_mean_cls=3.17, proj_loss=-0.0572][[34m2026-03-23 13:56:52[0m] Step: 147, Training Logs: loss_final: 4.397898, loss_mean: 1.077711, proj_loss: -0.057317, loss_mean_cls: 3.377505, grad_norm: 1.324051
|
| 149 |
+
Steps: 0%| | 148/1000000 [00:40<81:18:29, 3.42it/s, grad_norm=1.32, loss_final=4.4, loss_mean=1.08, loss_mean_cls=3.38, proj_loss=-0.0573][[34m2026-03-23 13:56:52[0m] Step: 148, Training Logs: loss_final: 3.678877, loss_mean: 1.095331, proj_loss: -0.056643, loss_mean_cls: 2.640189, grad_norm: 0.957576
|
| 150 |
+
Steps: 0%| | 149/1000000 [00:40<77:21:07, 3.59it/s, grad_norm=0.958, loss_final=3.68, loss_mean=1.1, loss_mean_cls=2.64, proj_loss=-0.0566][[34m2026-03-23 13:56:52[0m] Step: 149, Training Logs: loss_final: 3.970488, loss_mean: 1.096237, proj_loss: -0.057982, loss_mean_cls: 2.932232, grad_norm: 0.751983
|
| 151 |
+
Steps: 0%| | 150/1000000 [00:41<74:28:13, 3.73it/s, grad_norm=0.752, loss_final=3.97, loss_mean=1.1, loss_mean_cls=2.93, proj_loss=-0.058][[34m2026-03-23 13:56:52[0m] Step: 150, Training Logs: loss_final: 3.589296, loss_mean: 1.085876, proj_loss: -0.059221, loss_mean_cls: 2.562641, grad_norm: 1.001571
|
| 152 |
+
Steps: 0%| | 151/1000000 [00:41<72:29:17, 3.83it/s, grad_norm=1, loss_final=3.59, loss_mean=1.09, loss_mean_cls=2.56, proj_loss=-0.0592][[34m2026-03-23 13:56:53[0m] Step: 151, Training Logs: loss_final: 3.809782, loss_mean: 1.041376, proj_loss: -0.057593, loss_mean_cls: 2.825998, grad_norm: 0.772958
|
| 153 |
+
Steps: 0%| | 152/1000000 [00:41<71:06:45, 3.91it/s, grad_norm=0.773, loss_final=3.81, loss_mean=1.04, loss_mean_cls=2.83, proj_loss=-0.0576][[34m2026-03-23 13:56:53[0m] Step: 152, Training Logs: loss_final: 3.756772, loss_mean: 1.077754, proj_loss: -0.055612, loss_mean_cls: 2.734629, grad_norm: 0.942414
|
| 154 |
+
Steps: 0%| | 153/1000000 [00:41<70:15:13, 3.95it/s, grad_norm=0.942, loss_final=3.76, loss_mean=1.08, loss_mean_cls=2.73, proj_loss=-0.0556][[34m2026-03-23 13:56:53[0m] Step: 153, Training Logs: loss_final: 3.815321, loss_mean: 1.091242, proj_loss: -0.055186, loss_mean_cls: 2.779264, grad_norm: 1.427716
|
| 155 |
+
Steps: 0%| | 154/1000000 [00:42<69:30:56, 4.00it/s, grad_norm=1.43, loss_final=3.82, loss_mean=1.09, loss_mean_cls=2.78, proj_loss=-0.0552][[34m2026-03-23 13:56:53[0m] Step: 154, Training Logs: loss_final: 4.484177, loss_mean: 1.070384, proj_loss: -0.056569, loss_mean_cls: 3.470362, grad_norm: 0.890748
|
| 156 |
+
Steps: 0%| | 155/1000000 [00:42<68:58:59, 4.03it/s, grad_norm=0.891, loss_final=4.48, loss_mean=1.07, loss_mean_cls=3.47, proj_loss=-0.0566][[34m2026-03-23 13:56:54[0m] Step: 155, Training Logs: loss_final: 3.992099, loss_mean: 1.094668, proj_loss: -0.056406, loss_mean_cls: 2.953837, grad_norm: 1.103738
|
| 157 |
+
Steps: 0%| | 156/1000000 [00:42<68:38:04, 4.05it/s, grad_norm=1.1, loss_final=3.99, loss_mean=1.09, loss_mean_cls=2.95, proj_loss=-0.0564][[34m2026-03-23 13:56:54[0m] Step: 156, Training Logs: loss_final: 4.835541, loss_mean: 1.066255, proj_loss: -0.057052, loss_mean_cls: 3.826337, grad_norm: 1.572935
|
| 158 |
+
Steps: 0%| | 157/1000000 [00:42<68:32:39, 4.05it/s, grad_norm=1.57, loss_final=4.84, loss_mean=1.07, loss_mean_cls=3.83, proj_loss=-0.0571][[34m2026-03-23 13:56:54[0m] Step: 157, Training Logs: loss_final: 4.422526, loss_mean: 1.037669, proj_loss: -0.057450, loss_mean_cls: 3.442307, grad_norm: 0.881763
|
| 159 |
+
Steps: 0%| | 158/1000000 [00:43<68:18:41, 4.07it/s, grad_norm=0.882, loss_final=4.42, loss_mean=1.04, loss_mean_cls=3.44, proj_loss=-0.0574][[34m2026-03-23 13:56:54[0m] Step: 158, Training Logs: loss_final: 3.933815, loss_mean: 1.094053, proj_loss: -0.058011, loss_mean_cls: 2.897773, grad_norm: 1.524001
|
| 160 |
+
Steps: 0%| | 159/1000000 [00:43<68:08:39, 4.08it/s, grad_norm=1.52, loss_final=3.93, loss_mean=1.09, loss_mean_cls=2.9, proj_loss=-0.058][[34m2026-03-23 13:56:55[0m] Step: 159, Training Logs: loss_final: 3.294493, loss_mean: 1.101539, proj_loss: -0.057312, loss_mean_cls: 2.250265, grad_norm: 0.862417
|
| 161 |
+
Steps: 0%| | 160/1000000 [00:43<68:00:56, 4.08it/s, grad_norm=0.862, loss_final=3.29, loss_mean=1.1, loss_mean_cls=2.25, proj_loss=-0.0573][[34m2026-03-23 13:56:55[0m] Step: 160, Training Logs: loss_final: 4.003653, loss_mean: 1.049939, proj_loss: -0.056768, loss_mean_cls: 3.010482, grad_norm: 1.406133
|
| 162 |
+
Steps: 0%| | 161/1000000 [00:43<68:05:23, 4.08it/s, grad_norm=1.41, loss_final=4, loss_mean=1.05, loss_mean_cls=3.01, proj_loss=-0.0568][[34m2026-03-23 13:56:55[0m] Step: 161, Training Logs: loss_final: 3.735500, loss_mean: 1.059312, proj_loss: -0.055975, loss_mean_cls: 2.732163, grad_norm: 1.257030
|
| 163 |
+
Steps: 0%| | 162/1000000 [00:44<67:59:39, 4.08it/s, grad_norm=1.26, loss_final=3.74, loss_mean=1.06, loss_mean_cls=2.73, proj_loss=-0.056][[34m2026-03-23 13:56:55[0m] Step: 162, Training Logs: loss_final: 4.134730, loss_mean: 1.051557, proj_loss: -0.055925, loss_mean_cls: 3.139099, grad_norm: 1.383831
|
| 164 |
+
Steps: 0%| | 163/1000000 [00:44<67:54:57, 4.09it/s, grad_norm=1.38, loss_final=4.13, loss_mean=1.05, loss_mean_cls=3.14, proj_loss=-0.0559][[34m2026-03-23 13:56:56[0m] Step: 163, Training Logs: loss_final: 3.976990, loss_mean: 1.080228, proj_loss: -0.055570, loss_mean_cls: 2.952332, grad_norm: 2.321836
|
| 165 |
+
Steps: 0%| | 164/1000000 [00:44<67:52:28, 4.09it/s, grad_norm=2.32, loss_final=3.98, loss_mean=1.08, loss_mean_cls=2.95, proj_loss=-0.0556][[34m2026-03-23 13:56:56[0m] Step: 164, Training Logs: loss_final: 4.920438, loss_mean: 1.053734, proj_loss: -0.057831, loss_mean_cls: 3.924535, grad_norm: 1.427572
|
| 166 |
+
Steps: 0%| | 165/1000000 [00:44<67:51:20, 4.09it/s, grad_norm=1.43, loss_final=4.92, loss_mean=1.05, loss_mean_cls=3.92, proj_loss=-0.0578][[34m2026-03-23 13:56:56[0m] Step: 165, Training Logs: loss_final: 4.075354, loss_mean: 1.102409, proj_loss: -0.055738, loss_mean_cls: 3.028682, grad_norm: 1.554521
|
| 167 |
+
Steps: 0%| | 166/1000000 [00:45<67:50:22, 4.09it/s, grad_norm=1.55, loss_final=4.08, loss_mean=1.1, loss_mean_cls=3.03, proj_loss=-0.0557][[34m2026-03-23 13:56:56[0m] Step: 166, Training Logs: loss_final: 3.970260, loss_mean: 1.052696, proj_loss: -0.054281, loss_mean_cls: 2.971844, grad_norm: 1.398542
|
| 168 |
+
Steps: 0%| | 167/1000000 [00:45<67:49:03, 4.10it/s, grad_norm=1.4, loss_final=3.97, loss_mean=1.05, loss_mean_cls=2.97, proj_loss=-0.0543][[34m2026-03-23 13:56:57[0m] Step: 167, Training Logs: loss_final: 4.506516, loss_mean: 1.058694, proj_loss: -0.055691, loss_mean_cls: 3.503513, grad_norm: 1.465548
|
| 169 |
+
Steps: 0%| | 168/1000000 [00:45<67:48:15, 4.10it/s, grad_norm=1.47, loss_final=4.51, loss_mean=1.06, loss_mean_cls=3.5, proj_loss=-0.0557][[34m2026-03-23 13:56:57[0m] Step: 168, Training Logs: loss_final: 3.774779, loss_mean: 1.070169, proj_loss: -0.056798, loss_mean_cls: 2.761408, grad_norm: 1.063631
|
| 170 |
+
Steps: 0%| | 169/1000000 [00:45<67:55:41, 4.09it/s, grad_norm=1.06, loss_final=3.77, loss_mean=1.07, loss_mean_cls=2.76, proj_loss=-0.0568][[34m2026-03-23 13:56:57[0m] Step: 169, Training Logs: loss_final: 3.799341, loss_mean: 1.072764, proj_loss: -0.057054, loss_mean_cls: 2.783631, grad_norm: 1.432508
|
| 171 |
+
Steps: 0%| | 170/1000000 [00:46<67:53:00, 4.09it/s, grad_norm=1.43, loss_final=3.8, loss_mean=1.07, loss_mean_cls=2.78, proj_loss=-0.0571][[34m2026-03-23 13:56:57[0m] Step: 170, Training Logs: loss_final: 3.950215, loss_mean: 1.050979, proj_loss: -0.058352, loss_mean_cls: 2.957588, grad_norm: 1.403183
|
| 172 |
+
Steps: 0%| | 171/1000000 [00:46<67:51:12, 4.09it/s, grad_norm=1.4, loss_final=3.95, loss_mean=1.05, loss_mean_cls=2.96, proj_loss=-0.0584][[34m2026-03-23 13:56:58[0m] Step: 171, Training Logs: loss_final: 4.198855, loss_mean: 1.049812, proj_loss: -0.056677, loss_mean_cls: 3.205721, grad_norm: 0.921758
|
| 173 |
+
Steps: 0%| | 172/1000000 [00:46<67:49:22, 4.09it/s, grad_norm=0.922, loss_final=4.2, loss_mean=1.05, loss_mean_cls=3.21, proj_loss=-0.0567][[34m2026-03-23 13:56:58[0m] Step: 172, Training Logs: loss_final: 3.441348, loss_mean: 1.076280, proj_loss: -0.057867, loss_mean_cls: 2.422936, grad_norm: 0.854744
|
| 174 |
+
Steps: 0%| | 173/1000000 [00:46<67:50:30, 4.09it/s, grad_norm=0.855, loss_final=3.44, loss_mean=1.08, loss_mean_cls=2.42, proj_loss=-0.0579][[34m2026-03-23 13:56:58[0m] Step: 173, Training Logs: loss_final: 3.798677, loss_mean: 1.085783, proj_loss: -0.057882, loss_mean_cls: 2.770777, grad_norm: 1.453923
|
| 175 |
+
Steps: 0%| | 174/1000000 [00:46<67:50:14, 4.09it/s, grad_norm=1.45, loss_final=3.8, loss_mean=1.09, loss_mean_cls=2.77, proj_loss=-0.0579][[34m2026-03-23 13:56:58[0m] Step: 174, Training Logs: loss_final: 3.112827, loss_mean: 1.064432, proj_loss: -0.058472, loss_mean_cls: 2.106868, grad_norm: 1.250231
|
| 176 |
+
Steps: 0%| | 175/1000000 [00:47<67:49:59, 4.09it/s, grad_norm=1.25, loss_final=3.11, loss_mean=1.06, loss_mean_cls=2.11, proj_loss=-0.0585][[34m2026-03-23 13:56:58[0m] Step: 175, Training Logs: loss_final: 3.708531, loss_mean: 1.080565, proj_loss: -0.055104, loss_mean_cls: 2.683070, grad_norm: 1.818153
|
| 177 |
+
Steps: 0%| | 176/1000000 [00:47<67:49:26, 4.09it/s, grad_norm=1.82, loss_final=3.71, loss_mean=1.08, loss_mean_cls=2.68, proj_loss=-0.0551][[34m2026-03-23 13:56:59[0m] Step: 176, Training Logs: loss_final: 4.057882, loss_mean: 1.041325, proj_loss: -0.057044, loss_mean_cls: 3.073601, grad_norm: 1.056386
|
| 178 |
+
Steps: 0%| | 177/1000000 [00:47<67:50:14, 4.09it/s, grad_norm=1.06, loss_final=4.06, loss_mean=1.04, loss_mean_cls=3.07, proj_loss=-0.057][[34m2026-03-23 13:56:59[0m] Step: 177, Training Logs: loss_final: 3.843677, loss_mean: 1.082743, proj_loss: -0.056343, loss_mean_cls: 2.817278, grad_norm: 1.454900
|
| 179 |
+
Steps: 0%| | 178/1000000 [00:47<67:50:45, 4.09it/s, grad_norm=1.45, loss_final=3.84, loss_mean=1.08, loss_mean_cls=2.82, proj_loss=-0.0563][[34m2026-03-23 13:56:59[0m] Step: 178, Training Logs: loss_final: 4.592722, loss_mean: 1.051988, proj_loss: -0.056379, loss_mean_cls: 3.597114, grad_norm: 1.553295
|
| 180 |
+
Steps: 0%| | 179/1000000 [00:48<67:51:26, 4.09it/s, grad_norm=1.55, loss_final=4.59, loss_mean=1.05, loss_mean_cls=3.6, proj_loss=-0.0564][[34m2026-03-23 13:56:59[0m] Step: 179, Training Logs: loss_final: 4.083874, loss_mean: 1.044562, proj_loss: -0.057760, loss_mean_cls: 3.097073, grad_norm: 1.207282
|
| 181 |
+
Steps: 0%| | 180/1000000 [00:48<67:50:22, 4.09it/s, grad_norm=1.21, loss_final=4.08, loss_mean=1.04, loss_mean_cls=3.1, proj_loss=-0.0578][[34m2026-03-23 13:57:00[0m] Step: 180, Training Logs: loss_final: 3.533881, loss_mean: 1.084617, proj_loss: -0.055471, loss_mean_cls: 2.504735, grad_norm: 1.384633
|
| 182 |
+
Steps: 0%| | 181/1000000 [00:48<67:49:31, 4.09it/s, grad_norm=1.38, loss_final=3.53, loss_mean=1.08, loss_mean_cls=2.5, proj_loss=-0.0555][[34m2026-03-23 13:57:00[0m] Step: 181, Training Logs: loss_final: 4.409495, loss_mean: 1.036472, proj_loss: -0.057444, loss_mean_cls: 3.430467, grad_norm: 1.297928
|
| 183 |
+
Steps: 0%| | 182/1000000 [00:48<67:48:57, 4.10it/s, grad_norm=1.3, loss_final=4.41, loss_mean=1.04, loss_mean_cls=3.43, proj_loss=-0.0574][[34m2026-03-23 13:57:00[0m] Step: 182, Training Logs: loss_final: 3.176256, loss_mean: 1.091783, proj_loss: -0.057242, loss_mean_cls: 2.141715, grad_norm: 1.965931
|
| 184 |
+
Steps: 0%| | 183/1000000 [00:49<67:49:28, 4.09it/s, grad_norm=1.97, loss_final=3.18, loss_mean=1.09, loss_mean_cls=2.14, proj_loss=-0.0572][[34m2026-03-23 13:57:00[0m] Step: 183, Training Logs: loss_final: 3.651900, loss_mean: 1.063701, proj_loss: -0.057272, loss_mean_cls: 2.645471, grad_norm: 1.260781
|
| 185 |
+
Steps: 0%| | 184/1000000 [00:49<67:49:34, 4.09it/s, grad_norm=1.26, loss_final=3.65, loss_mean=1.06, loss_mean_cls=2.65, proj_loss=-0.0573][[34m2026-03-23 13:57:01[0m] Step: 184, Training Logs: loss_final: 4.134811, loss_mean: 1.052191, proj_loss: -0.056823, loss_mean_cls: 3.139443, grad_norm: 1.083165
|
| 186 |
+
Steps: 0%| | 185/1000000 [00:49<67:54:27, 4.09it/s, grad_norm=1.08, loss_final=4.13, loss_mean=1.05, loss_mean_cls=3.14, proj_loss=-0.0568][[34m2026-03-23 13:57:01[0m] Step: 185, Training Logs: loss_final: 3.952762, loss_mean: 1.034234, proj_loss: -0.056918, loss_mean_cls: 2.975446, grad_norm: 1.370537
|
| 187 |
+
Steps: 0%| | 186/1000000 [00:49<67:53:04, 4.09it/s, grad_norm=1.37, loss_final=3.95, loss_mean=1.03, loss_mean_cls=2.98, proj_loss=-0.0569][[34m2026-03-23 13:57:01[0m] Step: 186, Training Logs: loss_final: 4.377052, loss_mean: 1.068180, proj_loss: -0.055129, loss_mean_cls: 3.364000, grad_norm: 2.047776
|
| 188 |
+
Steps: 0%| | 187/1000000 [00:50<67:51:19, 4.09it/s, grad_norm=2.05, loss_final=4.38, loss_mean=1.07, loss_mean_cls=3.36, proj_loss=-0.0551][[34m2026-03-23 13:57:01[0m] Step: 187, Training Logs: loss_final: 3.465180, loss_mean: 1.067420, proj_loss: -0.055022, loss_mean_cls: 2.452782, grad_norm: 1.094218
|
| 189 |
+
Steps: 0%| | 188/1000000 [00:50<67:50:38, 4.09it/s, grad_norm=1.09, loss_final=3.47, loss_mean=1.07, loss_mean_cls=2.45, proj_loss=-0.055][[34m2026-03-23 13:57:02[0m] Step: 188, Training Logs: loss_final: 4.428637, loss_mean: 1.075879, proj_loss: -0.056132, loss_mean_cls: 3.408890, grad_norm: 2.243081
|
| 190 |
+
Steps: 0%| | 189/1000000 [00:50<67:50:39, 4.09it/s, grad_norm=2.24, loss_final=4.43, loss_mean=1.08, loss_mean_cls=3.41, proj_loss=-0.0561][[34m2026-03-23 13:57:02[0m] Step: 189, Training Logs: loss_final: 3.874612, loss_mean: 1.059214, proj_loss: -0.055212, loss_mean_cls: 2.870610, grad_norm: 1.509933
|
| 191 |
+
Steps: 0%| | 190/1000000 [00:50<67:53:49, 4.09it/s, grad_norm=1.51, loss_final=3.87, loss_mean=1.06, loss_mean_cls=2.87, proj_loss=-0.0552][[34m2026-03-23 13:57:02[0m] Step: 190, Training Logs: loss_final: 3.696738, loss_mean: 1.066510, proj_loss: -0.058307, loss_mean_cls: 2.688535, grad_norm: 1.849224
|
| 192 |
+
Steps: 0%| | 191/1000000 [00:51<67:54:49, 4.09it/s, grad_norm=1.85, loss_final=3.7, loss_mean=1.07, loss_mean_cls=2.69, proj_loss=-0.0583][[34m2026-03-23 13:57:02[0m] Step: 191, Training Logs: loss_final: 4.354342, loss_mean: 1.020227, proj_loss: -0.055745, loss_mean_cls: 3.389860, grad_norm: 1.847939
|
| 193 |
+
Steps: 0%| | 192/1000000 [00:51<67:54:25, 4.09it/s, grad_norm=1.85, loss_final=4.35, loss_mean=1.02, loss_mean_cls=3.39, proj_loss=-0.0557][[34m2026-03-23 13:57:03[0m] Step: 192, Training Logs: loss_final: 4.409212, loss_mean: 1.048026, proj_loss: -0.056594, loss_mean_cls: 3.417780, grad_norm: 1.524102
|
| 194 |
+
Steps: 0%| | 193/1000000 [00:51<67:53:07, 4.09it/s, grad_norm=1.52, loss_final=4.41, loss_mean=1.05, loss_mean_cls=3.42, proj_loss=-0.0566][[34m2026-03-23 13:57:03[0m] Step: 193, Training Logs: loss_final: 3.909914, loss_mean: 1.034270, proj_loss: -0.058143, loss_mean_cls: 2.933788, grad_norm: 1.197792
|
| 195 |
+
Steps: 0%| | 194/1000000 [00:51<67:53:28, 4.09it/s, grad_norm=1.2, loss_final=3.91, loss_mean=1.03, loss_mean_cls=2.93, proj_loss=-0.0581][[34m2026-03-23 13:57:03[0m] Step: 194, Training Logs: loss_final: 3.518339, loss_mean: 1.061185, proj_loss: -0.053464, loss_mean_cls: 2.510617, grad_norm: 1.685875
|
| 196 |
+
Steps: 0%| | 195/1000000 [00:52<67:52:30, 4.09it/s, grad_norm=1.69, loss_final=3.52, loss_mean=1.06, loss_mean_cls=2.51, proj_loss=-0.0535][[34m2026-03-23 13:57:03[0m] Step: 195, Training Logs: loss_final: 4.144682, loss_mean: 1.069985, proj_loss: -0.056024, loss_mean_cls: 3.130722, grad_norm: 1.617648
|
| 197 |
+
Steps: 0%| | 196/1000000 [00:52<67:53:25, 4.09it/s, grad_norm=1.62, loss_final=4.14, loss_mean=1.07, loss_mean_cls=3.13, proj_loss=-0.056][[34m2026-03-23 13:57:04[0m] Step: 196, Training Logs: loss_final: 4.209203, loss_mean: 1.043093, proj_loss: -0.056968, loss_mean_cls: 3.223077, grad_norm: 1.561610
|
| 198 |
+
Steps: 0%| | 197/1000000 [00:52<67:51:41, 4.09it/s, grad_norm=1.56, loss_final=4.21, loss_mean=1.04, loss_mean_cls=3.22, proj_loss=-0.057][[34m2026-03-23 13:57:04[0m] Step: 197, Training Logs: loss_final: 3.986165, loss_mean: 1.048607, proj_loss: -0.058094, loss_mean_cls: 2.995652, grad_norm: 1.465502
|
| 199 |
+
Steps: 0%| | 198/1000000 [00:52<67:51:10, 4.09it/s, grad_norm=1.47, loss_final=3.99, loss_mean=1.05, loss_mean_cls=3, proj_loss=-0.0581][[34m2026-03-23 13:57:04[0m] Step: 198, Training Logs: loss_final: 4.109353, loss_mean: 1.057425, proj_loss: -0.057203, loss_mean_cls: 3.109131, grad_norm: 1.602210
|
| 200 |
+
Steps: 0%| | 199/1000000 [00:53<67:51:00, 4.09it/s, grad_norm=1.6, loss_final=4.11, loss_mean=1.06, loss_mean_cls=3.11, proj_loss=-0.0572][[34m2026-03-23 13:57:04[0m] Step: 199, Training Logs: loss_final: 3.827809, loss_mean: 1.060561, proj_loss: -0.056529, loss_mean_cls: 2.823776, grad_norm: 1.281930
|
| 201 |
+
Steps: 0%| | 200/1000000 [00:53<67:52:45, 4.09it/s, grad_norm=1.28, loss_final=3.83, loss_mean=1.06, loss_mean_cls=2.82, proj_loss=-0.0565][[34m2026-03-23 13:57:05[0m] Step: 200, Training Logs: loss_final: 4.859810, loss_mean: 1.018786, proj_loss: -0.056772, loss_mean_cls: 3.897796, grad_norm: 1.346983
|
| 202 |
+
Steps: 0%| | 201/1000000 [00:53<67:50:08, 4.09it/s, grad_norm=1.35, loss_final=4.86, loss_mean=1.02, loss_mean_cls=3.9, proj_loss=-0.0568][[34m2026-03-23 13:57:05[0m] Step: 201, Training Logs: loss_final: 3.945022, loss_mean: 1.039917, proj_loss: -0.058203, loss_mean_cls: 2.963308, grad_norm: 1.138654
|
| 203 |
+
Steps: 0%| | 202/1000000 [00:53<67:49:17, 4.09it/s, grad_norm=1.14, loss_final=3.95, loss_mean=1.04, loss_mean_cls=2.96, proj_loss=-0.0582][[34m2026-03-23 13:57:05[0m] Step: 202, Training Logs: loss_final: 3.795771, loss_mean: 1.045783, proj_loss: -0.058287, loss_mean_cls: 2.808275, grad_norm: 1.378415
|
| 204 |
+
Steps: 0%| | 203/1000000 [00:54<67:47:45, 4.10it/s, grad_norm=1.38, loss_final=3.8, loss_mean=1.05, loss_mean_cls=2.81, proj_loss=-0.0583][[34m2026-03-23 13:57:05[0m] Step: 203, Training Logs: loss_final: 3.873578, loss_mean: 1.057585, proj_loss: -0.057143, loss_mean_cls: 2.873136, grad_norm: 1.039413
|
| 205 |
+
Steps: 0%| | 204/1000000 [00:54<67:47:23, 4.10it/s, grad_norm=1.04, loss_final=3.87, loss_mean=1.06, loss_mean_cls=2.87, proj_loss=-0.0571][[34m2026-03-23 13:57:06[0m] Step: 204, Training Logs: loss_final: 3.869634, loss_mean: 1.025371, proj_loss: -0.059635, loss_mean_cls: 2.903898, grad_norm: 1.113446
|
| 206 |
+
Steps: 0%| | 205/1000000 [00:54<67:46:28, 4.10it/s, grad_norm=1.11, loss_final=3.87, loss_mean=1.03, loss_mean_cls=2.9, proj_loss=-0.0596][[34m2026-03-23 13:57:06[0m] Step: 205, Training Logs: loss_final: 4.516129, loss_mean: 1.026917, proj_loss: -0.059113, loss_mean_cls: 3.548324, grad_norm: 1.246919
|
| 207 |
+
Steps: 0%| | 206/1000000 [00:54<67:46:58, 4.10it/s, grad_norm=1.25, loss_final=4.52, loss_mean=1.03, loss_mean_cls=3.55, proj_loss=-0.0591][[34m2026-03-23 13:57:06[0m] Step: 206, Training Logs: loss_final: 3.870474, loss_mean: 1.033348, proj_loss: -0.057558, loss_mean_cls: 2.894683, grad_norm: 1.104687
|
| 208 |
+
Steps: 0%| | 207/1000000 [00:55<67:49:28, 4.09it/s, grad_norm=1.1, loss_final=3.87, loss_mean=1.03, loss_mean_cls=2.89, proj_loss=-0.0576][[34m2026-03-23 13:57:06[0m] Step: 207, Training Logs: loss_final: 3.657545, loss_mean: 1.036661, proj_loss: -0.054751, loss_mean_cls: 2.675634, grad_norm: 1.384344
|
| 209 |
+
Steps: 0%| | 208/1000000 [00:55<67:48:34, 4.10it/s, grad_norm=1.38, loss_final=3.66, loss_mean=1.04, loss_mean_cls=2.68, proj_loss=-0.0548][[34m2026-03-23 13:57:07[0m] Step: 208, Training Logs: loss_final: 3.776420, loss_mean: 1.025812, proj_loss: -0.057821, loss_mean_cls: 2.808429, grad_norm: 1.096370
|
| 210 |
+
Steps: 0%| | 209/1000000 [00:55<67:48:19, 4.10it/s, grad_norm=1.1, loss_final=3.78, loss_mean=1.03, loss_mean_cls=2.81, proj_loss=-0.0578][[34m2026-03-23 13:57:07[0m] Step: 209, Training Logs: loss_final: 3.761090, loss_mean: 1.062359, proj_loss: -0.056922, loss_mean_cls: 2.755653, grad_norm: 1.903223
|
| 211 |
+
Steps: 0%| | 210/1000000 [00:55<67:48:31, 4.10it/s, grad_norm=1.9, loss_final=3.76, loss_mean=1.06, loss_mean_cls=2.76, proj_loss=-0.0569][[34m2026-03-23 13:57:07[0m] Step: 210, Training Logs: loss_final: 3.906418, loss_mean: 1.041829, proj_loss: -0.057181, loss_mean_cls: 2.921770, grad_norm: 1.735096
|
| 212 |
+
Steps: 0%| | 211/1000000 [00:56<67:48:32, 4.10it/s, grad_norm=1.74, loss_final=3.91, loss_mean=1.04, loss_mean_cls=2.92, proj_loss=-0.0572][[34m2026-03-23 13:57:07[0m] Step: 211, Training Logs: loss_final: 3.858996, loss_mean: 1.023147, proj_loss: -0.059874, loss_mean_cls: 2.895722, grad_norm: 1.423884
|
| 213 |
+
Steps: 0%| | 212/1000000 [00:56<67:46:51, 4.10it/s, grad_norm=1.42, loss_final=3.86, loss_mean=1.02, loss_mean_cls=2.9, proj_loss=-0.0599][[34m2026-03-23 13:57:08[0m] Step: 212, Training Logs: loss_final: 4.049764, loss_mean: 1.020845, proj_loss: -0.057857, loss_mean_cls: 3.086776, grad_norm: 1.539841
|
| 214 |
+
Steps: 0%| | 213/1000000 [00:56<67:48:03, 4.10it/s, grad_norm=1.54, loss_final=4.05, loss_mean=1.02, loss_mean_cls=3.09, proj_loss=-0.0579][[34m2026-03-23 13:57:08[0m] Step: 213, Training Logs: loss_final: 3.666535, loss_mean: 1.064638, proj_loss: -0.058414, loss_mean_cls: 2.660310, grad_norm: 1.364981
|
| 215 |
+
Steps: 0%| | 214/1000000 [00:56<67:48:11, 4.10it/s, grad_norm=1.36, loss_final=3.67, loss_mean=1.06, loss_mean_cls=2.66, proj_loss=-0.0584][[34m2026-03-23 13:57:08[0m] Step: 214, Training Logs: loss_final: 4.115675, loss_mean: 1.022230, proj_loss: -0.055882, loss_mean_cls: 3.149326, grad_norm: 1.587428
|
| 216 |
+
Steps: 0%| | 215/1000000 [00:57<67:46:32, 4.10it/s, grad_norm=1.59, loss_final=4.12, loss_mean=1.02, loss_mean_cls=3.15, proj_loss=-0.0559][[34m2026-03-23 13:57:08[0m] Step: 215, Training Logs: loss_final: 4.290797, loss_mean: 1.006664, proj_loss: -0.057540, loss_mean_cls: 3.341673, grad_norm: 1.364294
|
| 217 |
+
Steps: 0%| | 216/1000000 [00:57<67:48:39, 4.10it/s, grad_norm=1.36, loss_final=4.29, loss_mean=1.01, loss_mean_cls=3.34, proj_loss=-0.0575][[34m2026-03-23 13:57:09[0m] Step: 216, Training Logs: loss_final: 4.299802, loss_mean: 1.033919, proj_loss: -0.057627, loss_mean_cls: 3.323509, grad_norm: 2.076446
|
| 218 |
+
Steps: 0%| | 217/1000000 [00:57<67:48:50, 4.10it/s, grad_norm=2.08, loss_final=4.3, loss_mean=1.03, loss_mean_cls=3.32, proj_loss=-0.0576][[34m2026-03-23 13:57:09[0m] Step: 217, Training Logs: loss_final: 4.032390, loss_mean: 1.038509, proj_loss: -0.059780, loss_mean_cls: 3.053661, grad_norm: 2.095842
|
| 219 |
+
Steps: 0%| | 218/1000000 [00:57<67:48:13, 4.10it/s, grad_norm=2.1, loss_final=4.03, loss_mean=1.04, loss_mean_cls=3.05, proj_loss=-0.0598][[34m2026-03-23 13:57:09[0m] Step: 218, Training Logs: loss_final: 4.852066, loss_mean: 0.991051, proj_loss: -0.055897, loss_mean_cls: 3.916911, grad_norm: 1.589446
|
| 220 |
+
Steps: 0%| | 219/1000000 [00:57<67:54:34, 4.09it/s, grad_norm=1.59, loss_final=4.85, loss_mean=0.991, loss_mean_cls=3.92, proj_loss=-0.0559][[34m2026-03-23 13:57:09[0m] Step: 219, Training Logs: loss_final: 3.452114, loss_mean: 1.032779, proj_loss: -0.057448, loss_mean_cls: 2.476782, grad_norm: 2.759307
|
| 221 |
+
Steps: 0%| | 220/1000000 [00:58<67:53:33, 4.09it/s, grad_norm=2.76, loss_final=3.45, loss_mean=1.03, loss_mean_cls=2.48, proj_loss=-0.0574][[34m2026-03-23 13:57:09[0m] Step: 220, Training Logs: loss_final: 4.474224, loss_mean: 1.012484, proj_loss: -0.056469, loss_mean_cls: 3.518209, grad_norm: 2.037073
|
| 222 |
+
Steps: 0%| | 221/1000000 [00:58<67:58:08, 4.09it/s, grad_norm=2.04, loss_final=4.47, loss_mean=1.01, loss_mean_cls=3.52, proj_loss=-0.0565][[34m2026-03-23 13:57:10[0m] Step: 221, Training Logs: loss_final: 3.619599, loss_mean: 1.027089, proj_loss: -0.055216, loss_mean_cls: 2.647726, grad_norm: 1.876838
|
| 223 |
+
Steps: 0%| | 222/1000000 [00:58<67:55:14, 4.09it/s, grad_norm=1.88, loss_final=3.62, loss_mean=1.03, loss_mean_cls=2.65, proj_loss=-0.0552][[34m2026-03-23 13:57:10[0m] Step: 222, Training Logs: loss_final: 3.956278, loss_mean: 0.994196, proj_loss: -0.058793, loss_mean_cls: 3.020875, grad_norm: 1.756858
|
| 224 |
+
Steps: 0%| | 223/1000000 [00:58<67:57:51, 4.09it/s, grad_norm=1.76, loss_final=3.96, loss_mean=0.994, loss_mean_cls=3.02, proj_loss=-0.0588][[34m2026-03-23 13:57:10[0m] Step: 223, Training Logs: loss_final: 3.396563, loss_mean: 1.048192, proj_loss: -0.056162, loss_mean_cls: 2.404532, grad_norm: 1.566004
|
| 225 |
+
Steps: 0%| | 224/1000000 [00:59<67:54:18, 4.09it/s, grad_norm=1.57, loss_final=3.4, loss_mean=1.05, loss_mean_cls=2.4, proj_loss=-0.0562][[34m2026-03-23 13:57:10[0m] Step: 224, Training Logs: loss_final: 4.041858, loss_mean: 1.026601, proj_loss: -0.059143, loss_mean_cls: 3.074400, grad_norm: 1.464198
|
| 226 |
+
Steps: 0%| | 225/1000000 [00:59<67:52:38, 4.09it/s, grad_norm=1.46, loss_final=4.04, loss_mean=1.03, loss_mean_cls=3.07, proj_loss=-0.0591][[34m2026-03-23 13:57:11[0m] Step: 225, Training Logs: loss_final: 3.787225, loss_mean: 1.008417, proj_loss: -0.057760, loss_mean_cls: 2.836567, grad_norm: 1.825020
|
| 227 |
+
Steps: 0%| | 226/1000000 [00:59<67:52:45, 4.09it/s, grad_norm=1.83, loss_final=3.79, loss_mean=1.01, loss_mean_cls=2.84, proj_loss=-0.0578][[34m2026-03-23 13:57:11[0m] Step: 226, Training Logs: loss_final: 3.892933, loss_mean: 1.035562, proj_loss: -0.058638, loss_mean_cls: 2.916008, grad_norm: 1.450337
|
| 228 |
+
Steps: 0%| | 227/1000000 [00:59<67:53:30, 4.09it/s, grad_norm=1.45, loss_final=3.89, loss_mean=1.04, loss_mean_cls=2.92, proj_loss=-0.0586][[34m2026-03-23 13:57:11[0m] Step: 227, Training Logs: loss_final: 4.605734, loss_mean: 1.015767, proj_loss: -0.058328, loss_mean_cls: 3.648294, grad_norm: 1.576961
|
| 229 |
+
Steps: 0%| | 228/1000000 [01:00<67:52:30, 4.09it/s, grad_norm=1.58, loss_final=4.61, loss_mean=1.02, loss_mean_cls=3.65, proj_loss=-0.0583][[34m2026-03-23 13:57:11[0m] Step: 228, Training Logs: loss_final: 3.988146, loss_mean: 1.020183, proj_loss: -0.057191, loss_mean_cls: 3.025154, grad_norm: 1.494521
|
| 230 |
+
Steps: 0%| | 229/1000000 [01:00<67:49:43, 4.09it/s, grad_norm=1.49, loss_final=3.99, loss_mean=1.02, loss_mean_cls=3.03, proj_loss=-0.0572][[34m2026-03-23 13:57:12[0m] Step: 229, Training Logs: loss_final: 4.138639, loss_mean: 1.001511, proj_loss: -0.058552, loss_mean_cls: 3.195680, grad_norm: 1.869578
|
| 231 |
+
Steps: 0%| | 230/1000000 [01:00<67:48:23, 4.10it/s, grad_norm=1.87, loss_final=4.14, loss_mean=1, loss_mean_cls=3.2, proj_loss=-0.0586][[34m2026-03-23 13:57:12[0m] Step: 230, Training Logs: loss_final: 3.587549, loss_mean: 1.026086, proj_loss: -0.057275, loss_mean_cls: 2.618739, grad_norm: 1.361713
|
| 232 |
+
Steps: 0%| | 231/1000000 [01:00<67:50:24, 4.09it/s, grad_norm=1.36, loss_final=3.59, loss_mean=1.03, loss_mean_cls=2.62, proj_loss=-0.0573][[34m2026-03-23 13:57:12[0m] Step: 231, Training Logs: loss_final: 4.299013, loss_mean: 1.019667, proj_loss: -0.057679, loss_mean_cls: 3.337025, grad_norm: 2.507548
|
| 233 |
+
Steps: 0%| | 232/1000000 [01:01<67:49:48, 4.09it/s, grad_norm=2.51, loss_final=4.3, loss_mean=1.02, loss_mean_cls=3.34, proj_loss=-0.0577][[34m2026-03-23 13:57:12[0m] Step: 232, Training Logs: loss_final: 3.918818, loss_mean: 1.041992, proj_loss: -0.057748, loss_mean_cls: 2.934574, grad_norm: 1.768478
|
| 234 |
+
Steps: 0%| | 233/1000000 [01:01<67:48:26, 4.10it/s, grad_norm=1.77, loss_final=3.92, loss_mean=1.04, loss_mean_cls=2.93, proj_loss=-0.0577][[34m2026-03-23 13:57:13[0m] Step: 233, Training Logs: loss_final: 4.135336, loss_mean: 1.012130, proj_loss: -0.055806, loss_mean_cls: 3.179013, grad_norm: 2.010638
|
| 235 |
+
Steps: 0%| | 234/1000000 [01:01<67:48:16, 4.10it/s, grad_norm=2.01, loss_final=4.14, loss_mean=1.01, loss_mean_cls=3.18, proj_loss=-0.0558][[34m2026-03-23 13:57:13[0m] Step: 234, Training Logs: loss_final: 4.622794, loss_mean: 1.008856, proj_loss: -0.059102, loss_mean_cls: 3.673040, grad_norm: 1.824655
|
| 236 |
+
Steps: 0%| | 235/1000000 [01:01<67:48:42, 4.10it/s, grad_norm=1.82, loss_final=4.62, loss_mean=1.01, loss_mean_cls=3.67, proj_loss=-0.0591][[34m2026-03-23 13:57:13[0m] Step: 235, Training Logs: loss_final: 4.330180, loss_mean: 1.019101, proj_loss: -0.054455, loss_mean_cls: 3.365534, grad_norm: 1.865192
|
| 237 |
+
Steps: 0%| | 236/1000000 [01:02<67:47:59, 4.10it/s, grad_norm=1.87, loss_final=4.33, loss_mean=1.02, loss_mean_cls=3.37, proj_loss=-0.0545][[34m2026-03-23 13:57:13[0m] Step: 236, Training Logs: loss_final: 4.471676, loss_mean: 1.005463, proj_loss: -0.057791, loss_mean_cls: 3.524004, grad_norm: 1.992857
|
| 238 |
+
Steps: 0%| | 237/1000000 [01:02<67:48:33, 4.10it/s, grad_norm=1.99, loss_final=4.47, loss_mean=1.01, loss_mean_cls=3.52, proj_loss=-0.0578][[34m2026-03-23 13:57:14[0m] Step: 237, Training Logs: loss_final: 4.708742, loss_mean: 1.007163, proj_loss: -0.057294, loss_mean_cls: 3.758873, grad_norm: 1.696959
|
| 239 |
+
Steps: 0%| | 238/1000000 [01:02<67:48:46, 4.10it/s, grad_norm=1.7, loss_final=4.71, loss_mean=1.01, loss_mean_cls=3.76, proj_loss=-0.0573][[34m2026-03-23 13:57:14[0m] Step: 238, Training Logs: loss_final: 4.700453, loss_mean: 1.011082, proj_loss: -0.057899, loss_mean_cls: 3.747270, grad_norm: 2.577658
|
| 240 |
+
Steps: 0%| | 239/1000000 [01:02<67:48:18, 4.10it/s, grad_norm=2.58, loss_final=4.7, loss_mean=1.01, loss_mean_cls=3.75, proj_loss=-0.0579][[34m2026-03-23 13:57:14[0m] Step: 239, Training Logs: loss_final: 4.519000, loss_mean: 0.999633, proj_loss: -0.058458, loss_mean_cls: 3.577825, grad_norm: 2.147466
|
| 241 |
+
Steps: 0%| | 240/1000000 [01:03<67:48:55, 4.10it/s, grad_norm=2.15, loss_final=4.52, loss_mean=1, loss_mean_cls=3.58, proj_loss=-0.0585][[34m2026-03-23 13:57:14[0m] Step: 240, Training Logs: loss_final: 3.874477, loss_mean: 1.005100, proj_loss: -0.059289, loss_mean_cls: 2.928666, grad_norm: 2.779819
|
| 242 |
+
Steps: 0%| | 241/1000000 [01:03<67:50:12, 4.09it/s, grad_norm=2.78, loss_final=3.87, loss_mean=1.01, loss_mean_cls=2.93, proj_loss=-0.0593][[34m2026-03-23 13:57:15[0m] Step: 241, Training Logs: loss_final: 4.244301, loss_mean: 0.978814, proj_loss: -0.057319, loss_mean_cls: 3.322806, grad_norm: 1.934856
|
| 243 |
+
Steps: 0%| | 242/1000000 [01:03<67:50:20, 4.09it/s, grad_norm=1.93, loss_final=4.24, loss_mean=0.979, loss_mean_cls=3.32, proj_loss=-0.0573][[34m2026-03-23 13:57:15[0m] Step: 242, Training Logs: loss_final: 4.099132, loss_mean: 1.017777, proj_loss: -0.056814, loss_mean_cls: 3.138168, grad_norm: 2.223382
|
| 244 |
+
Steps: 0%| | 243/1000000 [01:03<67:49:22, 4.09it/s, grad_norm=2.22, loss_final=4.1, loss_mean=1.02, loss_mean_cls=3.14, proj_loss=-0.0568][[34m2026-03-23 13:57:15[0m] Step: 243, Training Logs: loss_final: 3.643645, loss_mean: 1.043657, proj_loss: -0.057582, loss_mean_cls: 2.657569, grad_norm: 1.876094
|
| 245 |
+
Steps: 0%| | 244/1000000 [01:04<67:48:22, 4.10it/s, grad_norm=1.88, loss_final=3.64, loss_mean=1.04, loss_mean_cls=2.66, proj_loss=-0.0576][[34m2026-03-23 13:57:15[0m] Step: 244, Training Logs: loss_final: 3.791383, loss_mean: 1.039013, proj_loss: -0.057821, loss_mean_cls: 2.810191, grad_norm: 2.101933
|
| 246 |
+
Steps: 0%| | 245/1000000 [01:04<67:48:03, 4.10it/s, grad_norm=2.1, loss_final=3.79, loss_mean=1.04, loss_mean_cls=2.81, proj_loss=-0.0578][[34m2026-03-23 13:57:16[0m] Step: 245, Training Logs: loss_final: 4.122998, loss_mean: 1.003735, proj_loss: -0.056978, loss_mean_cls: 3.176242, grad_norm: 1.748520
|
| 247 |
+
Steps: 0%| | 246/1000000 [01:04<67:47:44, 4.10it/s, grad_norm=1.75, loss_final=4.12, loss_mean=1, loss_mean_cls=3.18, proj_loss=-0.057][[34m2026-03-23 13:57:16[0m] Step: 246, Training Logs: loss_final: 3.975110, loss_mean: 1.031366, proj_loss: -0.058318, loss_mean_cls: 3.002061, grad_norm: 2.088799
|
| 248 |
+
Steps: 0%| | 247/1000000 [01:04<67:49:45, 4.09it/s, grad_norm=2.09, loss_final=3.98, loss_mean=1.03, loss_mean_cls=3, proj_loss=-0.0583][[34m2026-03-23 13:57:16[0m] Step: 247, Training Logs: loss_final: 3.795434, loss_mean: 1.016577, proj_loss: -0.059947, loss_mean_cls: 2.838804, grad_norm: 1.548398
|
| 249 |
+
Steps: 0%| | 248/1000000 [01:05<67:48:52, 4.10it/s, grad_norm=1.55, loss_final=3.8, loss_mean=1.02, loss_mean_cls=2.84, proj_loss=-0.0599][[34m2026-03-23 13:57:16[0m] Step: 248, Training Logs: loss_final: 3.547227, loss_mean: 1.007156, proj_loss: -0.057843, loss_mean_cls: 2.597914, grad_norm: 2.097709
|
| 250 |
+
Steps: 0%| | 249/1000000 [01:05<67:47:51, 4.10it/s, grad_norm=2.1, loss_final=3.55, loss_mean=1.01, loss_mean_cls=2.6, proj_loss=-0.0578][[34m2026-03-23 13:57:17[0m] Step: 249, Training Logs: loss_final: 3.670953, loss_mean: 1.005734, proj_loss: -0.056730, loss_mean_cls: 2.721948, grad_norm: 1.993682
|
| 251 |
+
Steps: 0%| | 250/1000000 [01:05<67:48:32, 4.10it/s, grad_norm=1.99, loss_final=3.67, loss_mean=1.01, loss_mean_cls=2.72, proj_loss=-0.0567][[34m2026-03-23 13:57:17[0m] Step: 250, Training Logs: loss_final: 3.783008, loss_mean: 1.009454, proj_loss: -0.057429, loss_mean_cls: 2.830983, grad_norm: 2.015732
|
| 252 |
+
Steps: 0%| | 251/1000000 [01:05<67:49:52, 4.09it/s, grad_norm=2.02, loss_final=3.78, loss_mean=1.01, loss_mean_cls=2.83, proj_loss=-0.0574][[34m2026-03-23 13:57:17[0m] Step: 251, Training Logs: loss_final: 3.917508, loss_mean: 1.015374, proj_loss: -0.057006, loss_mean_cls: 2.959140, grad_norm: 1.911892
|
| 253 |
+
Steps: 0%| | 252/1000000 [01:06<67:48:52, 4.10it/s, grad_norm=1.91, loss_final=3.92, loss_mean=1.02, loss_mean_cls=2.96, proj_loss=-0.057][[34m2026-03-23 13:57:17[0m] Step: 252, Training Logs: loss_final: 4.345321, loss_mean: 0.999146, proj_loss: -0.059442, loss_mean_cls: 3.405616, grad_norm: 2.544859
|
| 254 |
+
Steps: 0%| | 253/1000000 [01:06<67:48:39, 4.10it/s, grad_norm=2.54, loss_final=4.35, loss_mean=0.999, loss_mean_cls=3.41, proj_loss=-0.0594][[34m2026-03-23 13:57:18[0m] Step: 253, Training Logs: loss_final: 3.264407, loss_mean: 1.012033, proj_loss: -0.060338, loss_mean_cls: 2.312712, grad_norm: 2.002759
|
| 255 |
+
Steps: 0%| | 254/1000000 [01:06<67:47:32, 4.10it/s, grad_norm=2, loss_final=3.26, loss_mean=1.01, loss_mean_cls=2.31, proj_loss=-0.0603][[34m2026-03-23 13:57:18[0m] Step: 254, Training Logs: loss_final: 3.680009, loss_mean: 1.009429, proj_loss: -0.057669, loss_mean_cls: 2.728250, grad_norm: 2.470939
|
| 256 |
+
Steps: 0%| | 255/1000000 [01:06<67:46:55, 4.10it/s, grad_norm=2.47, loss_final=3.68, loss_mean=1.01, loss_mean_cls=2.73, proj_loss=-0.0577][[34m2026-03-23 13:57:18[0m] Step: 255, Training Logs: loss_final: 3.617963, loss_mean: 1.040573, proj_loss: -0.058244, loss_mean_cls: 2.635633, grad_norm: 2.220286
|
| 257 |
+
Steps: 0%| | 256/1000000 [01:07<67:48:18, 4.10it/s, grad_norm=2.22, loss_final=3.62, loss_mean=1.04, loss_mean_cls=2.64, proj_loss=-0.0582][[34m2026-03-23 13:57:18[0m] Step: 256, Training Logs: loss_final: 4.080931, loss_mean: 0.997066, proj_loss: -0.058532, loss_mean_cls: 3.142397, grad_norm: 2.140220
|
| 258 |
+
Steps: 0%| | 257/1000000 [01:07<67:47:49, 4.10it/s, grad_norm=2.14, loss_final=4.08, loss_mean=0.997, loss_mean_cls=3.14, proj_loss=-0.0585][[34m2026-03-23 13:57:19[0m] Step: 257, Training Logs: loss_final: 3.699068, loss_mean: 1.000059, proj_loss: -0.058102, loss_mean_cls: 2.757111, grad_norm: 1.981827
|
| 259 |
+
Steps: 0%| | 258/1000000 [01:07<67:46:24, 4.10it/s, grad_norm=1.98, loss_final=3.7, loss_mean=1, loss_mean_cls=2.76, proj_loss=-0.0581][[34m2026-03-23 13:57:19[0m] Step: 258, Training Logs: loss_final: 3.992317, loss_mean: 1.029925, proj_loss: -0.058207, loss_mean_cls: 3.020598, grad_norm: 2.269195
|
| 260 |
+
Steps: 0%| | 259/1000000 [01:07<67:46:06, 4.10it/s, grad_norm=2.27, loss_final=3.99, loss_mean=1.03, loss_mean_cls=3.02, proj_loss=-0.0582][[34m2026-03-23 13:57:19[0m] Step: 259, Training Logs: loss_final: 3.969361, loss_mean: 1.005176, proj_loss: -0.055455, loss_mean_cls: 3.019640, grad_norm: 2.430135
|
| 261 |
+
Steps: 0%| | 260/1000000 [01:08<67:47:12, 4.10it/s, grad_norm=2.43, loss_final=3.97, loss_mean=1.01, loss_mean_cls=3.02, proj_loss=-0.0555][[34m2026-03-23 13:57:19[0m] Step: 260, Training Logs: loss_final: 3.946629, loss_mean: 0.986893, proj_loss: -0.055241, loss_mean_cls: 3.014977, grad_norm: 2.162997
|
| 262 |
+
Steps: 0%| | 261/1000000 [01:08<79:04:37, 3.51it/s, grad_norm=2.16, loss_final=3.95, loss_mean=0.987, loss_mean_cls=3.01, proj_loss=-0.0552][[34m2026-03-23 13:57:20[0m] Step: 261, Training Logs: loss_final: 3.920058, loss_mean: 0.984585, proj_loss: -0.057787, loss_mean_cls: 2.993260, grad_norm: 2.114045
|
| 263 |
+
Steps: 0%| | 262/1000000 [01:08<78:26:04, 3.54it/s, grad_norm=2.11, loss_final=3.92, loss_mean=0.985, loss_mean_cls=2.99, proj_loss=-0.0578][[34m2026-03-23 13:57:20[0m] Step: 262, Training Logs: loss_final: 4.358081, loss_mean: 0.977391, proj_loss: -0.057046, loss_mean_cls: 3.437736, grad_norm: 2.065785
|
| 264 |
+
Steps: 0%| | 263/1000000 [01:08<75:13:43, 3.69it/s, grad_norm=2.07, loss_final=4.36, loss_mean=0.977, loss_mean_cls=3.44, proj_loss=-0.057][[34m2026-03-23 13:57:20[0m] Step: 263, Training Logs: loss_final: 3.840959, loss_mean: 0.999257, proj_loss: -0.056460, loss_mean_cls: 2.898163, grad_norm: 2.413785
|
| 265 |
+
Steps: 0%| | 264/1000000 [01:09<72:59:25, 3.80it/s, grad_norm=2.41, loss_final=3.84, loss_mean=0.999, loss_mean_cls=2.9, proj_loss=-0.0565][[34m2026-03-23 13:57:20[0m] Step: 264, Training Logs: loss_final: 3.940903, loss_mean: 1.023376, proj_loss: -0.057920, loss_mean_cls: 2.975446, grad_norm: 1.844397
|
| 266 |
+
Steps: 0%| | 265/1000000 [01:09<71:26:14, 3.89it/s, grad_norm=1.84, loss_final=3.94, loss_mean=1.02, loss_mean_cls=2.98, proj_loss=-0.0579][[34m2026-03-23 13:57:21[0m] Step: 265, Training Logs: loss_final: 4.206155, loss_mean: 0.981618, proj_loss: -0.059432, loss_mean_cls: 3.283970, grad_norm: 2.069561
|
| 267 |
+
Steps: 0%| | 266/1000000 [01:09<70:20:03, 3.95it/s, grad_norm=2.07, loss_final=4.21, loss_mean=0.982, loss_mean_cls=3.28, proj_loss=-0.0594][[34m2026-03-23 13:57:21[0m] Step: 266, Training Logs: loss_final: 3.804409, loss_mean: 0.990602, proj_loss: -0.060707, loss_mean_cls: 2.874513, grad_norm: 1.578623
|
| 268 |
+
Steps: 0%| | 267/1000000 [01:09<69:36:48, 3.99it/s, grad_norm=1.58, loss_final=3.8, loss_mean=0.991, loss_mean_cls=2.87, proj_loss=-0.0607][[34m2026-03-23 13:57:21[0m] Step: 267, Training Logs: loss_final: 4.086649, loss_mean: 1.005726, proj_loss: -0.058538, loss_mean_cls: 3.139461, grad_norm: 2.692237
|
| 269 |
+
Steps: 0%| | 268/1000000 [01:10<69:04:32, 4.02it/s, grad_norm=2.69, loss_final=4.09, loss_mean=1.01, loss_mean_cls=3.14, proj_loss=-0.0585][[34m2026-03-23 13:57:21[0m] Step: 268, Training Logs: loss_final: 3.484965, loss_mean: 1.013161, proj_loss: -0.056986, loss_mean_cls: 2.528790, grad_norm: 2.118618
|
| 270 |
+
Steps: 0%| | 269/1000000 [01:10<69:49:28, 3.98it/s, grad_norm=2.12, loss_final=3.48, loss_mean=1.01, loss_mean_cls=2.53, proj_loss=-0.057][[34m2026-03-23 13:57:22[0m] Step: 269, Training Logs: loss_final: 3.312028, loss_mean: 1.008822, proj_loss: -0.058315, loss_mean_cls: 2.361521, grad_norm: 1.997808
|
| 271 |
+
Steps: 0%| | 270/1000000 [01:10<68:55:10, 4.03it/s, grad_norm=2, loss_final=3.31, loss_mean=1.01, loss_mean_cls=2.36, proj_loss=-0.0583][[34m2026-03-23 13:57:22[0m] Step: 270, Training Logs: loss_final: 3.996657, loss_mean: 1.009214, proj_loss: -0.058962, loss_mean_cls: 3.046405, grad_norm: 2.755926
|
| 272 |
+
Steps: 0%| | 271/1000000 [01:10<68:39:18, 4.04it/s, grad_norm=2.76, loss_final=4, loss_mean=1.01, loss_mean_cls=3.05, proj_loss=-0.059][[34m2026-03-23 13:57:22[0m] Step: 271, Training Logs: loss_final: 4.391388, loss_mean: 1.024230, proj_loss: -0.058988, loss_mean_cls: 3.426146, grad_norm: 2.096098
|
| 273 |
+
Steps: 0%| | 272/1000000 [01:11<68:23:53, 4.06it/s, grad_norm=2.1, loss_final=4.39, loss_mean=1.02, loss_mean_cls=3.43, proj_loss=-0.059][[34m2026-03-23 13:57:22[0m] Step: 272, Training Logs: loss_final: 3.920087, loss_mean: 1.024717, proj_loss: -0.059831, loss_mean_cls: 2.955201, grad_norm: 2.465345
|
| 274 |
+
Steps: 0%| | 273/1000000 [01:11<68:12:32, 4.07it/s, grad_norm=2.47, loss_final=3.92, loss_mean=1.02, loss_mean_cls=2.96, proj_loss=-0.0598][[34m2026-03-23 13:57:23[0m] Step: 273, Training Logs: loss_final: 3.677792, loss_mean: 1.019457, proj_loss: -0.056399, loss_mean_cls: 2.714733, grad_norm: 2.047685
|
| 275 |
+
Steps: 0%| | 274/1000000 [01:11<68:06:48, 4.08it/s, grad_norm=2.05, loss_final=3.68, loss_mean=1.02, loss_mean_cls=2.71, proj_loss=-0.0564][[34m2026-03-23 13:57:23[0m] Step: 274, Training Logs: loss_final: 3.611300, loss_mean: 1.009368, proj_loss: -0.057628, loss_mean_cls: 2.659561, grad_norm: 1.780220
|
| 276 |
+
Steps: 0%| | 275/1000000 [01:11<68:01:32, 4.08it/s, grad_norm=1.78, loss_final=3.61, loss_mean=1.01, loss_mean_cls=2.66, proj_loss=-0.0576][[34m2026-03-23 13:57:23[0m] Step: 275, Training Logs: loss_final: 3.762297, loss_mean: 1.000715, proj_loss: -0.057146, loss_mean_cls: 2.818728, grad_norm: 2.212818
|
| 277 |
+
Steps: 0%| | 276/1000000 [01:12<67:56:22, 4.09it/s, grad_norm=2.21, loss_final=3.76, loss_mean=1, loss_mean_cls=2.82, proj_loss=-0.0571][[34m2026-03-23 13:57:23[0m] Step: 276, Training Logs: loss_final: 4.110180, loss_mean: 1.001855, proj_loss: -0.058339, loss_mean_cls: 3.166663, grad_norm: 3.446807
|
| 278 |
+
Steps: 0%| | 277/1000000 [01:12<67:54:20, 4.09it/s, grad_norm=3.45, loss_final=4.11, loss_mean=1, loss_mean_cls=3.17, proj_loss=-0.0583][[34m2026-03-23 13:57:24[0m] Step: 277, Training Logs: loss_final: 3.589436, loss_mean: 1.014081, proj_loss: -0.058753, loss_mean_cls: 2.634108, grad_norm: 1.866073
|
| 279 |
+
Steps: 0%| | 278/1000000 [01:12<67:51:55, 4.09it/s, grad_norm=1.87, loss_final=3.59, loss_mean=1.01, loss_mean_cls=2.63, proj_loss=-0.0588][[34m2026-03-23 13:57:24[0m] Step: 278, Training Logs: loss_final: 3.909306, loss_mean: 1.036322, proj_loss: -0.057491, loss_mean_cls: 2.930474, grad_norm: 3.416674
|
| 280 |
+
Steps: 0%| | 279/1000000 [01:12<67:51:48, 4.09it/s, grad_norm=3.42, loss_final=3.91, loss_mean=1.04, loss_mean_cls=2.93, proj_loss=-0.0575][[34m2026-03-23 13:57:24[0m] Step: 279, Training Logs: loss_final: 3.635005, loss_mean: 1.021109, proj_loss: -0.058334, loss_mean_cls: 2.672229, grad_norm: 2.773069
|
| 281 |
+
Steps: 0%| | 280/1000000 [01:13<67:50:10, 4.09it/s, grad_norm=2.77, loss_final=3.64, loss_mean=1.02, loss_mean_cls=2.67, proj_loss=-0.0583][[34m2026-03-23 13:57:24[0m] Step: 280, Training Logs: loss_final: 4.731202, loss_mean: 0.976675, proj_loss: -0.056143, loss_mean_cls: 3.810670, grad_norm: 2.045803
|
| 282 |
+
Steps: 0%| | 281/1000000 [01:13<67:48:49, 4.10it/s, grad_norm=2.05, loss_final=4.73, loss_mean=0.977, loss_mean_cls=3.81, proj_loss=-0.0561][[34m2026-03-23 13:57:25[0m] Step: 281, Training Logs: loss_final: 4.352121, loss_mean: 1.020213, proj_loss: -0.057039, loss_mean_cls: 3.388947, grad_norm: 3.085879
|
| 283 |
+
Steps: 0%| | 282/1000000 [01:13<67:48:21, 4.10it/s, grad_norm=3.09, loss_final=4.35, loss_mean=1.02, loss_mean_cls=3.39, proj_loss=-0.057][[34m2026-03-23 13:57:25[0m] Step: 282, Training Logs: loss_final: 3.645599, loss_mean: 1.009737, proj_loss: -0.056890, loss_mean_cls: 2.692752, grad_norm: 2.743040
|
| 284 |
+
Steps: 0%| | 283/1000000 [01:13<67:47:40, 4.10it/s, grad_norm=2.74, loss_final=3.65, loss_mean=1.01, loss_mean_cls=2.69, proj_loss=-0.0569][[34m2026-03-23 13:57:25[0m] Step: 283, Training Logs: loss_final: 3.621672, loss_mean: 1.014742, proj_loss: -0.058281, loss_mean_cls: 2.665210, grad_norm: 2.332768
|
| 285 |
+
Steps: 0%| | 284/1000000 [01:14<67:48:17, 4.10it/s, grad_norm=2.33, loss_final=3.62, loss_mean=1.01, loss_mean_cls=2.67, proj_loss=-0.0583][[34m2026-03-23 13:57:25[0m] Step: 284, Training Logs: loss_final: 3.634438, loss_mean: 1.014846, proj_loss: -0.058438, loss_mean_cls: 2.678030, grad_norm: 2.844196
|
| 286 |
+
Steps: 0%| | 285/1000000 [01:14<67:48:15, 4.10it/s, grad_norm=2.84, loss_final=3.63, loss_mean=1.01, loss_mean_cls=2.68, proj_loss=-0.0584][[34m2026-03-23 13:57:26[0m] Step: 285, Training Logs: loss_final: 3.942388, loss_mean: 0.986733, proj_loss: -0.056891, loss_mean_cls: 3.012546, grad_norm: 2.134713
|
| 287 |
+
Steps: 0%| | 286/1000000 [01:14<67:47:40, 4.10it/s, grad_norm=2.13, loss_final=3.94, loss_mean=0.987, loss_mean_cls=3.01, proj_loss=-0.0569][[34m2026-03-23 13:57:26[0m] Step: 286, Training Logs: loss_final: 3.926682, loss_mean: 1.011870, proj_loss: -0.057176, loss_mean_cls: 2.971987, grad_norm: 2.768237
|
| 288 |
+
Steps: 0%| | 287/1000000 [01:14<67:46:57, 4.10it/s, grad_norm=2.77, loss_final=3.93, loss_mean=1.01, loss_mean_cls=2.97, proj_loss=-0.0572][[34m2026-03-23 13:57:26[0m] Step: 287, Training Logs: loss_final: 4.047902, loss_mean: 1.002864, proj_loss: -0.059520, loss_mean_cls: 3.104558, grad_norm: 2.354553
|
| 289 |
+
Steps: 0%| | 287/1000000 [01:14<67:46:57, 4.10it/s, grad_norm=2.35, loss_final=4.05, loss_mean=1, loss_mean_cls=3.1, proj_loss=-0.0595]
|
REG/wandb/run-20260323_135607-zue1y2ba/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260323_135607-zue1y2ba/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-23T05:56:07.858187Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment-0.75",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.75",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.2",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357568126976"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "l4vui4vfnl881ctol25fj9y70t6im9l9"
|
| 101 |
+
}
|
REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-23T13:56:08.183465712+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-23T13:56:10.661719633+08:00","level":"INFO","msg":"stream: created new stream","id":"zue1y2ba"}
|
| 3 |
+
{"time":"2026-03-23T13:56:10.661931952+08:00","level":"INFO","msg":"handler: started","stream_id":"zue1y2ba"}
|
| 4 |
+
{"time":"2026-03-23T13:56:10.662874633+08:00","level":"INFO","msg":"stream: started","id":"zue1y2ba"}
|
| 5 |
+
{"time":"2026-03-23T13:56:10.662895027+08:00","level":"INFO","msg":"writer: started","stream_id":"zue1y2ba"}
|
| 6 |
+
{"time":"2026-03-23T13:56:10.662918583+08:00","level":"INFO","msg":"sender: started","stream_id":"zue1y2ba"}
|
REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Configure stats pid to 397944
|
| 3 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135607-zue1y2ba/logs/debug.log
|
| 5 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135607-zue1y2ba/logs/debug-internal.log
|
| 6 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-23 13:56:07,881 INFO MainThread:397944 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-23 13:56:08,167 INFO MainThread:397944 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-23 13:56:08,180 INFO MainThread:397944 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-23 13:56:08,181 INFO MainThread:397944 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-23 13:56:08,194 INFO MainThread:397944 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-23 13:56:11,614 INFO MainThread:397944 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-23 13:56:11,706 INFO MainThread:397944 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-23 13:56:11,707 INFO MainThread:397944 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-23 13:56:11,712 INFO MainThread:397944 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-23 13:56:11,713 INFO MainThread:397944 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.2, 't_c': 0.75, 'ot_cls': True}
|
REG/wandb/run-20260323_135841-w9holkos/files/requirements.txt
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dill==0.3.8
|
| 2 |
+
mkl-service==2.4.0
|
| 3 |
+
mpmath==1.3.0
|
| 4 |
+
typing_extensions==4.12.2
|
| 5 |
+
urllib3==2.3.0
|
| 6 |
+
torch==2.5.1
|
| 7 |
+
ptyprocess==0.7.0
|
| 8 |
+
traitlets==5.14.3
|
| 9 |
+
pyasn1==0.6.1
|
| 10 |
+
opencv-python-headless==4.12.0.88
|
| 11 |
+
nest-asyncio==1.6.0
|
| 12 |
+
kiwisolver==1.4.8
|
| 13 |
+
click==8.2.1
|
| 14 |
+
fire==0.7.1
|
| 15 |
+
diffusers==0.35.1
|
| 16 |
+
accelerate==1.7.0
|
| 17 |
+
ipykernel==6.29.5
|
| 18 |
+
peft==0.17.1
|
| 19 |
+
attrs==24.3.0
|
| 20 |
+
six==1.17.0
|
| 21 |
+
numpy==2.0.1
|
| 22 |
+
yarl==1.18.0
|
| 23 |
+
huggingface_hub==0.34.4
|
| 24 |
+
Bottleneck==1.4.2
|
| 25 |
+
numexpr==2.11.0
|
| 26 |
+
dataclasses==0.6
|
| 27 |
+
typing-inspection==0.4.1
|
| 28 |
+
safetensors==0.5.3
|
| 29 |
+
pyparsing==3.2.3
|
| 30 |
+
psutil==7.0.0
|
| 31 |
+
imageio==2.37.0
|
| 32 |
+
debugpy==1.8.14
|
| 33 |
+
cycler==0.12.1
|
| 34 |
+
pyasn1_modules==0.4.2
|
| 35 |
+
matplotlib-inline==0.1.7
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
jedi==0.19.2
|
| 38 |
+
tokenizers==0.21.2
|
| 39 |
+
seaborn==0.13.2
|
| 40 |
+
timm==1.0.15
|
| 41 |
+
aiohappyeyeballs==2.6.1
|
| 42 |
+
hf-xet==1.1.8
|
| 43 |
+
multidict==6.1.0
|
| 44 |
+
tqdm==4.67.1
|
| 45 |
+
wheel==0.45.1
|
| 46 |
+
simsimd==6.5.1
|
| 47 |
+
sentencepiece==0.2.1
|
| 48 |
+
grpcio==1.74.0
|
| 49 |
+
asttokens==3.0.0
|
| 50 |
+
absl-py==2.3.1
|
| 51 |
+
stack-data==0.6.3
|
| 52 |
+
pandas==2.3.0
|
| 53 |
+
importlib_metadata==8.7.0
|
| 54 |
+
pytorch-image-generation-metrics==0.6.1
|
| 55 |
+
frozenlist==1.5.0
|
| 56 |
+
MarkupSafe==3.0.2
|
| 57 |
+
setuptools==78.1.1
|
| 58 |
+
multiprocess==0.70.15
|
| 59 |
+
pip==25.1
|
| 60 |
+
requests==2.32.3
|
| 61 |
+
mkl_random==1.2.8
|
| 62 |
+
tensorboard-plugin-wit==1.8.1
|
| 63 |
+
ExifRead-nocycle==3.0.1
|
| 64 |
+
webdataset==0.2.111
|
| 65 |
+
threadpoolctl==3.6.0
|
| 66 |
+
pyarrow==21.0.0
|
| 67 |
+
executing==2.2.0
|
| 68 |
+
decorator==5.2.1
|
| 69 |
+
contourpy==1.3.2
|
| 70 |
+
annotated-types==0.7.0
|
| 71 |
+
scikit-learn==1.7.1
|
| 72 |
+
jupyter_client==8.6.3
|
| 73 |
+
albumentations==1.4.24
|
| 74 |
+
wandb==0.25.0
|
| 75 |
+
certifi==2025.8.3
|
| 76 |
+
idna==3.7
|
| 77 |
+
xxhash==3.5.0
|
| 78 |
+
Jinja2==3.1.6
|
| 79 |
+
python-dateutil==2.9.0.post0
|
| 80 |
+
aiosignal==1.4.0
|
| 81 |
+
triton==3.1.0
|
| 82 |
+
torchvision==0.20.1
|
| 83 |
+
stringzilla==3.12.6
|
| 84 |
+
pure_eval==0.2.3
|
| 85 |
+
braceexpand==0.1.7
|
| 86 |
+
zipp==3.22.0
|
| 87 |
+
oauthlib==3.3.1
|
| 88 |
+
Markdown==3.8.2
|
| 89 |
+
fsspec==2025.3.0
|
| 90 |
+
fonttools==4.58.2
|
| 91 |
+
comm==0.2.2
|
| 92 |
+
ipython==9.3.0
|
| 93 |
+
img2dataset==1.47.0
|
| 94 |
+
networkx==3.4.2
|
| 95 |
+
PySocks==1.7.1
|
| 96 |
+
tzdata==2025.2
|
| 97 |
+
smmap==5.0.2
|
| 98 |
+
mkl_fft==1.3.11
|
| 99 |
+
sentry-sdk==2.29.1
|
| 100 |
+
Pygments==2.19.1
|
| 101 |
+
pexpect==4.9.0
|
| 102 |
+
ftfy==6.3.1
|
| 103 |
+
einops==0.8.1
|
| 104 |
+
requests-oauthlib==2.0.0
|
| 105 |
+
gitdb==4.0.12
|
| 106 |
+
albucore==0.0.23
|
| 107 |
+
torchdiffeq==0.2.5
|
| 108 |
+
GitPython==3.1.44
|
| 109 |
+
bitsandbytes==0.47.0
|
| 110 |
+
pytorch-fid==0.3.0
|
| 111 |
+
clean-fid==0.1.35
|
| 112 |
+
pytorch-gan-metrics==0.5.4
|
| 113 |
+
Brotli==1.0.9
|
| 114 |
+
charset-normalizer==3.3.2
|
| 115 |
+
gmpy2==2.2.1
|
| 116 |
+
pillow==11.1.0
|
| 117 |
+
PyYAML==6.0.2
|
| 118 |
+
tornado==6.5.1
|
| 119 |
+
termcolor==3.1.0
|
| 120 |
+
setproctitle==1.3.6
|
| 121 |
+
scipy==1.15.3
|
| 122 |
+
regex==2024.11.6
|
| 123 |
+
protobuf==6.31.1
|
| 124 |
+
platformdirs==4.3.8
|
| 125 |
+
joblib==1.5.1
|
| 126 |
+
cachetools==4.2.4
|
| 127 |
+
ipython_pygments_lexers==1.1.1
|
| 128 |
+
google-auth==1.35.0
|
| 129 |
+
transformers==4.53.2
|
| 130 |
+
torch-fidelity==0.3.0
|
| 131 |
+
tensorboard==2.4.0
|
| 132 |
+
filelock==3.17.0
|
| 133 |
+
packaging==25.0
|
| 134 |
+
propcache==0.3.1
|
| 135 |
+
pytz==2025.2
|
| 136 |
+
aiohttp==3.11.10
|
| 137 |
+
wcwidth==0.2.13
|
| 138 |
+
clip==0.2.0
|
| 139 |
+
Werkzeug==3.1.3
|
| 140 |
+
tensorboard-data-server==0.6.1
|
| 141 |
+
sympy==1.13.1
|
| 142 |
+
pyzmq==26.4.0
|
| 143 |
+
pydantic_core==2.33.2
|
| 144 |
+
prompt_toolkit==3.0.51
|
| 145 |
+
parso==0.8.4
|
| 146 |
+
docker-pycreds==0.4.0
|
| 147 |
+
rsa==4.9.1
|
| 148 |
+
pydantic==2.11.5
|
| 149 |
+
jupyter_core==5.8.1
|
| 150 |
+
google-auth-oauthlib==0.4.6
|
| 151 |
+
datasets==4.0.0
|
| 152 |
+
torch-tb-profiler==0.4.3
|
| 153 |
+
autocommand==2.2.2
|
| 154 |
+
backports.tarfile==1.2.0
|
| 155 |
+
importlib_metadata==8.0.0
|
| 156 |
+
jaraco.collections==5.1.0
|
| 157 |
+
jaraco.context==5.3.0
|
| 158 |
+
jaraco.functools==4.0.1
|
| 159 |
+
more-itertools==10.3.0
|
| 160 |
+
packaging==24.2
|
| 161 |
+
platformdirs==4.2.2
|
| 162 |
+
typeguard==4.3.0
|
| 163 |
+
inflect==7.3.1
|
| 164 |
+
jaraco.text==3.12.1
|
| 165 |
+
tomli==2.0.1
|
| 166 |
+
typing_extensions==4.12.2
|
| 167 |
+
wheel==0.45.1
|
| 168 |
+
zipp==3.19.2
|
REG/wandb/run-20260323_135841-w9holkos/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
|
| 3 |
+
"python": "CPython 3.12.9",
|
| 4 |
+
"startedAt": "2026-03-23T05:58:41.322248Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--report-to",
|
| 7 |
+
"wandb",
|
| 8 |
+
"--allow-tf32",
|
| 9 |
+
"--mixed-precision",
|
| 10 |
+
"bf16",
|
| 11 |
+
"--seed",
|
| 12 |
+
"0",
|
| 13 |
+
"--path-type",
|
| 14 |
+
"linear",
|
| 15 |
+
"--prediction",
|
| 16 |
+
"v",
|
| 17 |
+
"--weighting",
|
| 18 |
+
"uniform",
|
| 19 |
+
"--model",
|
| 20 |
+
"SiT-XL/2",
|
| 21 |
+
"--enc-type",
|
| 22 |
+
"dinov2-vit-b",
|
| 23 |
+
"--encoder-depth",
|
| 24 |
+
"8",
|
| 25 |
+
"--proj-coeff",
|
| 26 |
+
"0.5",
|
| 27 |
+
"--output-dir",
|
| 28 |
+
"exps",
|
| 29 |
+
"--exp-name",
|
| 30 |
+
"jsflow-experiment-0.75",
|
| 31 |
+
"--batch-size",
|
| 32 |
+
"256",
|
| 33 |
+
"--data-dir",
|
| 34 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
|
| 35 |
+
"--semantic-features-dir",
|
| 36 |
+
"/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0",
|
| 37 |
+
"--learning-rate",
|
| 38 |
+
"0.00005",
|
| 39 |
+
"--t-c",
|
| 40 |
+
"0.75",
|
| 41 |
+
"--cls",
|
| 42 |
+
"0.05",
|
| 43 |
+
"--ot-cls"
|
| 44 |
+
],
|
| 45 |
+
"program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/train.py",
|
| 46 |
+
"codePath": "train.py",
|
| 47 |
+
"codePathLocal": "train.py",
|
| 48 |
+
"git": {
|
| 49 |
+
"remote": "https://github.com/Martinser/REG.git",
|
| 50 |
+
"commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
|
| 51 |
+
},
|
| 52 |
+
"email": "2365972933@qq.com",
|
| 53 |
+
"root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG",
|
| 54 |
+
"host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
|
| 55 |
+
"executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
|
| 56 |
+
"cpu_count": 96,
|
| 57 |
+
"cpu_count_logical": 192,
|
| 58 |
+
"gpu": "NVIDIA H100 80GB HBM3",
|
| 59 |
+
"gpu_count": 4,
|
| 60 |
+
"disk": {
|
| 61 |
+
"/": {
|
| 62 |
+
"total": "3838880616448",
|
| 63 |
+
"used": "357568360448"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"memory": {
|
| 67 |
+
"total": "2164115296256"
|
| 68 |
+
},
|
| 69 |
+
"gpu_nvidia": [
|
| 70 |
+
{
|
| 71 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 72 |
+
"memoryTotal": "85520809984",
|
| 73 |
+
"cudaCores": 16896,
|
| 74 |
+
"architecture": "Hopper",
|
| 75 |
+
"uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 79 |
+
"memoryTotal": "85520809984",
|
| 80 |
+
"cudaCores": 16896,
|
| 81 |
+
"architecture": "Hopper",
|
| 82 |
+
"uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 86 |
+
"memoryTotal": "85520809984",
|
| 87 |
+
"cudaCores": 16896,
|
| 88 |
+
"architecture": "Hopper",
|
| 89 |
+
"uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"name": "NVIDIA H100 80GB HBM3",
|
| 93 |
+
"memoryTotal": "85520809984",
|
| 94 |
+
"cudaCores": 16896,
|
| 95 |
+
"architecture": "Hopper",
|
| 96 |
+
"uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"cudaVersion": "13.0",
|
| 100 |
+
"writerId": "nlbia82zbry6kpqgagoidmc6x8szwd5d"
|
| 101 |
+
}
|
REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2026-03-23T13:58:41.647788404+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
|
| 2 |
+
{"time":"2026-03-23T13:58:42.578470875+08:00","level":"INFO","msg":"stream: created new stream","id":"w9holkos"}
|
| 3 |
+
{"time":"2026-03-23T13:58:42.578676113+08:00","level":"INFO","msg":"handler: started","stream_id":"w9holkos"}
|
| 4 |
+
{"time":"2026-03-23T13:58:42.579473589+08:00","level":"INFO","msg":"stream: started","id":"w9holkos"}
|
| 5 |
+
{"time":"2026-03-23T13:58:42.57951741+08:00","level":"INFO","msg":"sender: started","stream_id":"w9holkos"}
|
| 6 |
+
{"time":"2026-03-23T13:58:42.579478227+08:00","level":"INFO","msg":"writer: started","stream_id":"w9holkos"}
|
| 7 |
+
{"time":"2026-03-23T14:49:13.568442881+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 8 |
+
{"time":"2026-03-23T14:52:15.597652411+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 9 |
+
{"time":"2026-03-23T14:52:26.072213509+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": write tcp 172.20.98.27:52324->35.186.228.49:443: write: broken pipe"}
|
| 10 |
+
{"time":"2026-03-23T17:02:52.905542765+08:00","level":"INFO","msg":"api: retrying HTTP error","status":408,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>408 Request Timeout</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Request Timeout</h1>\n<h2>Your client has taken too long to issue its request.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 11 |
+
{"time":"2026-03-23T17:05:55.176103762+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": net/http: request canceled (Client.Timeout exceeded while awaiting headers)"}
|
| 12 |
+
{"time":"2026-03-23T17:06:10.164453104+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
|
| 13 |
+
{"time":"2026-03-23T22:05:06.25355716+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:44154->35.186.228.49:443: read: connection reset by peer"}
|
| 14 |
+
{"time":"2026-03-23T22:05:20.791067182+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:40392->35.186.228.49:443: read: connection reset by peer"}
|
| 15 |
+
{"time":"2026-03-24T02:18:38.770696332+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 16 |
+
{"time":"2026-03-24T06:25:41.879737278+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 17 |
+
{"time":"2026-03-24T06:30:14.989373032+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
|
| 18 |
+
{"time":"2026-03-24T09:05:02.85908394+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": read tcp 172.20.98.27:46722->35.186.228.49:443: read: connection reset by peer"}
|
| 19 |
+
{"time":"2026-03-25T04:41:04.741907157+08:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/2365972933-teleai/REG/w9holkos/file_stream\": unexpected EOF"}
|
REG/wandb/run-20260323_135841-w9holkos/logs/debug.log
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
|
| 2 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Configure stats pid to 400275
|
| 3 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 4 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug.log
|
| 5 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/wandb/run-20260323_135841-w9holkos/logs/debug-internal.log
|
| 6 |
+
2026-03-23 13:58:41,343 INFO MainThread:400275 [wandb_init.py:init():844] calling init triggers
|
| 7 |
+
2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
|
| 8 |
+
config: {'_wandb': {}}
|
| 9 |
+
2026-03-23 13:58:41,344 INFO MainThread:400275 [wandb_init.py:init():892] starting backend
|
| 10 |
+
2026-03-23 13:58:41,630 INFO MainThread:400275 [wandb_init.py:init():895] sending inform_init request
|
| 11 |
+
2026-03-23 13:58:41,643 INFO MainThread:400275 [wandb_init.py:init():903] backend started and connected
|
| 12 |
+
2026-03-23 13:58:41,646 INFO MainThread:400275 [wandb_init.py:init():973] updated telemetry
|
| 13 |
+
2026-03-23 13:58:41,659 INFO MainThread:400275 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
|
| 14 |
+
2026-03-23 13:58:43,108 INFO MainThread:400275 [wandb_init.py:init():1042] starting run threads in backend
|
| 15 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_console_start():2524] atexit reg
|
| 16 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2373] redirect: wrap_raw
|
| 17 |
+
2026-03-23 13:58:43,201 INFO MainThread:400275 [wandb_run.py:_redirect():2442] Wrapping output streams.
|
| 18 |
+
2026-03-23 13:58:43,202 INFO MainThread:400275 [wandb_run.py:_redirect():2465] Redirects installed.
|
| 19 |
+
2026-03-23 13:58:43,209 INFO MainThread:400275 [wandb_init.py:init():1082] run started, returning control to user process
|
| 20 |
+
2026-03-23 13:58:43,210 INFO MainThread:400275 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': 'exps', 'exp_name': 'jsflow-experiment-0.75', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 2000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'bf16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 5e-05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.05, 't_c': 0.75, 'ot_cls': True}
|