update README
Browse files- README.md +119 -1
- assets/base.gif +0 -0
- assets/head.png +0 -0
- assets/sink.gif +0 -0
README.md
CHANGED
@@ -11,4 +11,122 @@ license: mit
|
|
11 |
short_description: Preventing Local Pitfalls in Vector Quantization via Optimal
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
short_description: Preventing Local Pitfalls in Vector Quantization via Optimal
|
12 |
---
|
13 |
|
14 |
+
# Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
15 |
+
|
16 |
+
<p style="color: gray; font-size: 18px; font-weight: bold; text-align: center;">
|
17 |
+
Paper |
|
18 |
+
<a href="https://boruizhang.site/OptVQ/" style="text-decoration: none; color: white; background-color: #660874; padding: 4px 8px; border-radius: 8px; border-bottom: none;" target="_blank">Project Page</a> |
|
19 |
+
<a href="https://huggingface.co/spaces/BorelTHU/OptVQ" style="text-decoration: none; color: white; background-color: #660874; padding: 4px 8px; border-radius: 8px; border-bottom: none;" target="_blank">HF Demo</a>
|
20 |
+
</p>
|
21 |
+
|
22 |
+
![head](assets/head.png)
|
23 |
+
|
24 |
+
## News
|
25 |
+
|
26 |
+
| [2024-11-26] Release the pre-trained models of OptVQ.
|
27 |
+
|
28 |
+
## Introduction
|
29 |
+
|
30 |
+
We conduct image reconstruction experiments on the ImageNet dataset, and the quantitative comparison is shown below:
|
31 |
+
|
32 |
+
| Model | Latent Size | #Tokens | From Scratch | SSIM↑ | PSNR ↑ | LPIPS↓ | rFID↓ |
|
33 |
+
| - | - | - | - | - | - | - | - |
|
34 |
+
| taming-VQGAN | 16 × 16 | 1,024 | √ | 0.521 | 23.30 | 0.195 | 6.25 |
|
35 |
+
| MaskGIT-VQGAN | 16 × 16 | 1,024 | √ | - | - | - | 2.28 |
|
36 |
+
| Mo-VQGAN | 16 × 16 × 4 | 1,024 | √ | 0.673 | 22.42 | 0.113 | 1.12 |
|
37 |
+
| TiTok-S-128 | 128 | 4,096 | × | - | - | - | 1.71 |
|
38 |
+
| ViT-VQGAN | 32 × 32 | 8,192 | √ | - | - | - | 1.28 |
|
39 |
+
| taming-VQGAN | 16 × 16 | 16,384 | √ | 0.542 | 19.93 | 0.177 | 3.64 |
|
40 |
+
| RQ-VAE | 8 × 8 × 16 | 16,384 | √ | - | - | - | 1.83 |
|
41 |
+
| VQGAN-LC | 16 × 16 | 100,000 | × | 0.589 | 23.80 | 0.120 | 2.62 |
|
42 |
+
| OptVQ (ours) | 16 × 16 × 4 | 16,384 | √ | 0.717 | 26.59 | 0.076 | 1.00 |
|
43 |
+
| OptVQ (ours) | 16 × 16 × 8 | 16,384 | √ | 0.729 | 27.57 | 0.066 | 0.91 |
|
44 |
+
|
45 |
+
### Toy Example
|
46 |
+
|
47 |
+
We visualize the process of OptVQ and Vanilla VQ on a two-dimensional toy example.
|
48 |
+
The left figure with red points represents the baseline (Vanilla VQ), and the right figure with green points represents the proposed method (OptVQ).
|
49 |
+
<p float="left">
|
50 |
+
<img src="assets/base.gif" width="300" />
|
51 |
+
<img src="assets/sink.gif" width="300" />
|
52 |
+
</p>
|
53 |
+
|
54 |
+
## Usage
|
55 |
+
|
56 |
+
### Installation
|
57 |
+
|
58 |
+
Please install the dependencies by running the following command:
|
59 |
+
```bash
|
60 |
+
# install the dependencies
|
61 |
+
pip install -r requirements.txt
|
62 |
+
# install the faiss-gpu package via conda
|
63 |
+
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
|
64 |
+
# install the optvq package
|
65 |
+
pip install -e .
|
66 |
+
```
|
67 |
+
|
68 |
+
### Inference
|
69 |
+
|
70 |
+
Please download the pre-trained models from the following links:
|
71 |
+
|
72 |
+
| Model | Link (Tsinghua) | Link (Hugging Face) |
|
73 |
+
| - | - | - |
|
74 |
+
| OptVQ (16 x 16 x 4) | [Download](https://cloud.tsinghua.edu.cn/d/91befd96f06a4a83bb03/) | [Download](https://huggingface.co/BorelTHU/optvq-16x16x4) |
|
75 |
+
| OptVQ (16 x 16 x 8) | [Download](https://cloud.tsinghua.edu.cn/d/309a55529e1f4f42a8d2/) | [Download](https://huggingface.co/BorelTHU/optvq-16x16x8) |
|
76 |
+
|
77 |
+
#### Option 1: Load from Hugging Face
|
78 |
+
|
79 |
+
You can load from the Hugging Face model hub by running the following code:
|
80 |
+
```python
|
81 |
+
# Example: load the OptVQ with 16 x 16 x 4
|
82 |
+
from optvq.models.vqgan_hf import VQModelHF
|
83 |
+
model = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
|
84 |
+
```
|
85 |
+
|
86 |
+
#### Option 2: Load from the local checkpoint
|
87 |
+
|
88 |
+
You can also write the following code to load the pre-trained model locally:
|
89 |
+
```python
|
90 |
+
# Example: load the OptVQ with 16 x 16 x 4
|
91 |
+
from optvq.utils.init import initiate_from_config_recursively
|
92 |
+
from omegaconf import OmegaConf
|
93 |
+
import torch
|
94 |
+
config = OmegaConf.load("configs/optvq.yaml")
|
95 |
+
model = initiate_from_config_recursively(config.autoencoder)
|
96 |
+
params = torch.load(..., map_location="cpu")
|
97 |
+
model.load_state_dict(params["model"])
|
98 |
+
```
|
99 |
+
|
100 |
+
#### Perform inference
|
101 |
+
|
102 |
+
After loading the model, you can perform inference (reconstruction):
|
103 |
+
|
104 |
+
```python
|
105 |
+
# load the dataset
|
106 |
+
dataset = ... # the input should be normalized to [-1, 1]
|
107 |
+
data = dataset[...] # size: (BS, C, H, W)
|
108 |
+
|
109 |
+
# reconstruct the input
|
110 |
+
with torch.no_grad():
|
111 |
+
quant, *_ = model.encode(data)
|
112 |
+
recon = model.decode(quant)
|
113 |
+
```
|
114 |
+
|
115 |
+
### Evaluation
|
116 |
+
|
117 |
+
To evaluate the model, you can use the following code:
|
118 |
+
```bash
|
119 |
+
python eval.py --config $config_path --log_dir $log_dir --resume $resume --is_distributed
|
120 |
+
```
|
121 |
+
|
122 |
+
### Training
|
123 |
+
|
124 |
+
We will release the training scripts soon.
|
125 |
+
|
126 |
+
<!-- ## Citation
|
127 |
+
|
128 |
+
If you find this work useful, please consider citing it.
|
129 |
+
|
130 |
+
```bibtex
|
131 |
+
xxx
|
132 |
+
``` -->
|
assets/base.gif
ADDED
assets/head.png
ADDED
assets/sink.gif
ADDED