Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,221 Bytes
51ce47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# Randomized Autoregressive Visual Generation
RAR is a an autoregressive (AR) image generator with full compatibility to language modeling. It introduces a randomness annealing strategy with permuted objective at no additional cost, which enhances the model's ability to learn bidirectional contexts while leaving the autoregressive framework intact. RAR sets a FID score 1.48, demonstrating state-of-the-art performance on ImageNet-256 benchmark and significantly outperforming prior AR image generators.
<p>
<img src="assets/rar_overview.png" alt="teaser" width=90% height=90%>
</p>
<p>
<img src="assets/perf_comp.png" alt="teaser" width=90% height=90%>
</p>
## π Contributions
#### We introduce RAR, an improved training strategy enabling standard autoregressive image generator to achieve state-of-the-art performance.
#### The proposed RAR is extremly simple yet effective: During training, we randomly permute the input token sequence with probability r, where r will starts at 1.0 and linearly decays to 0.0 over the course of training. This simple strategy enbales better bidirectional representation learning which is missing in standard raster-order-based AR image generator training.
#### RAR keeps the AR framework intact, and thus it is totally compatible to the LLM optimization techniques, such as KV-cache, leading to a significantly faster sampling speed compared to MAR-H or MaskBit while maintaining a better performance.
## Model Zoo
| Model | Link | FID |
| ------------- | ------------- | ------------- |
| RAR-B | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_b.bin)| 1.95 (generation) |
| RAR-L | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_l.bin)| 1.70 (generation) |
| RAR-XL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xl.bin)| 1.50 (generation) |
| RAR-XXL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xxl.bin)| 1.48 (generation) |
Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
## Installation
```shell
pip3 install -r requirements.txt
```
## Get Started
```python
import torch
from PIL import Image
import numpy as np
import demo_util
from huggingface_hub import hf_hub_download
from utils.train_utils import create_pretrained_tokenizer
# Choose one from ["rar_b_imagenet", "rar_l_imagenet", "rar_xl_imagenet", "rar_xxl_imagenet"]
rar_model_name = ["rar_b", "rar_l", "rar_xl", "rar_xxl"][3]
# download the maskgit-vq tokenizer
hf_hub_download(repo_id="fun-research/TiTok", filename=f"maskgit-vqgan-imagenet-f16-256.bin", local_dir="./")
# download the rar generator weight
hf_hub_download(repo_id="yucornetto/RAR", filename=f"{rar_model_name}.bin", local_dir="./")
# load config
# config = demo_util.get_config("configs/infer/titok_l32.yaml")
# titok_tokenizer = demo_util.get_titok_tokenizer(config)
# titok_generator = demo_util.get_titok_generator(config)
device = "cuda"
# maskgit-vq as tokenizer
tokenizer = create_pretrained_tokenizer(config)
generator = demo_util.get_rar_generator(config)
tokenizer.to(device)
generator.to(device)
# generate an image
sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
generated_image = demo_util.sample_fn(
generator=generator,
tokenizer=tokenizer,
labels=sample_labels,
randomize_temperature=1.0,
guidance_scale=4.0,
guidance_scale_pow=0.0, # constant cfg
device=device
)
Image.fromarray(generated_image[0]).save(f"assets/rar_generated_{sample_labels[0]}.png")
```
## Testing on ImageNet-1K Benchmark
We provide a [sampling script](./sample_imagenet_rar.py) for reproducing the generation results on ImageNet-1K benchmark.
```bash
# Prepare ADM evaluation script
git clone https://github.com/openai/guided-diffusion.git
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
```
```python
# Reproducing RAR-B
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
experiment.output_dir="rar_b" \
experiment.generator_checkpoint="rar_b.bin" \
model.generator.hidden_size=768 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=3072 \
model.generator.randomize_temperature=1.0 \
model.generator.guidance_scale=16.0 \
model.generator.guidance_scale_pow=2.75
# Run eval script. The result FID should be ~1.95
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_b.npz
# Reproducing RAR-L
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
experiment.output_dir="rar_l" \
experiment.generator_checkpoint="rar_l.bin" \
model.generator.hidden_size=1024 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=4096 \
model.generator.randomize_temperature=1.02 \
model.generator.guidance_scale=15.5 \
model.generator.guidance_scale_pow=2.5
# Run eval script. The result FID should be ~1.70
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_l.npz
# Reproducing RAR-XL
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
experiment.output_dir="rar_xl" \
experiment.generator_checkpoint="rar_xl.bin" \
model.generator.hidden_size=1280 \
model.generator.num_hidden_layers=32 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=5120 \
model.generator.randomize_temperature=1.02 \
model.generator.guidance_scale=6.9 \
model.generator.guidance_scale_pow=1.5
# Run eval script. The result FID should be ~1.50
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xl.npz
# Reproducing RAR-XXL
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
experiment.output_dir="rar_xxl" \
experiment.generator_checkpoint="rar_xxl.bin" \
model.generator.hidden_size=1408 \
model.generator.num_hidden_layers=40 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=6144 \
model.generator.randomize_temperature=1.02 \
model.generator.guidance_scale=8.0 \
model.generator.guidance_scale_pow=1.2
# Run eval script. The result FID should be ~1.48
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xxl.npz
```
## Training Preparation
We pretokenize the whole dataset for speed-up the training process. We have uploaded [it](https://huggingface.co/yucornetto/RAR/blob/main/maskgitvq.jsonl) so you can train RAR directly. The training script will download the prerequisite checkpoints and dataset automatically.
## Training
We provide example commands to train RAR as follows:
```bash
# Training for RAR-B
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
experiment.project="rar" \
experiment.name="rar_b" \
experiment.output_dir="rar_b" \
model.generator.hidden_size=768 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=3072
# Training for RAR-L
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
experiment.project="rar" \
experiment.name="rar_l" \
experiment.output_dir="rar_l" \
model.generator.hidden_size=1024 \
model.generator.num_hidden_layers=24 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=4096
# Training for RAR-XL
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
experiment.project="rar" \
experiment.name="rar_xl" \
experiment.output_dir="rar_xl" \
model.generator.hidden_size=1280 \
model.generator.num_hidden_layers=32 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=5120
# Training for RAR-XXL
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
experiment.project="rar" \
experiment.name="rar_xxl" \
experiment.output_dir="rar_xxl" \
model.generator.hidden_size=1408 \
model.generator.num_hidden_layers=40 \
model.generator.num_attention_heads=16 \
model.generator.intermediate_size=6144
```
You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
Notably, you can enable grad checkpointing by adding the flag "model.generator.use_checkpoint=True" and adjust the machine number & GPU number based on your own need. All RAR checkpoints were trained with a global batchsize = 2048.
## Visualizations
<p>
<img src="assets/vis1.png" alt="teaser" width=90% height=90%>
</p>
<p>
<img src="assets/vis2.png" alt="teaser" width=90% height=90%>
</p>
<p>
<img src="assets/vis3.png" alt="teaser" width=90% height=90%>
</p>
## Citing
If you use our work in your research, please use the following BibTeX entry.
```BibTeX
@inproceedings{yu2024randomized,
author = {Qihang Yu and Ju He and Xueqing Deng and Xiaohui Shen and Liang-Chieh Chen},
title = {Randomized Autoregressive Visual Generation},
journal = {arXiv preprint arXiv:2411.00776},
year = {2024}
}
``` |