File size: 9,177 Bytes
ab03d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
Facebook AI Similarity Search (Faiss) ํ
==================
# Faiss์ ๋ํ์ฌ
Faiss ๋ Facebook Research๊ฐ ๊ฐ๋ฐํ๋, ๊ณ ๋ฐ๋ ๋ฒกํฐ ์ด์ ๊ฒ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์
๋๋ค. ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ (Approximate Neigbor Search)์ ์ฝ๊ฐ์ ์ ํ์ฑ์ ํฌ์ํ์ฌ ์ ์ฌ ๋ฒกํฐ๋ฅผ ๊ณ ์์ผ๋ก ์ฐพ์ต๋๋ค.
## RVC์ ์์ด์ Faiss
RVC์์๋ HuBERT๋ก ๋ณํํ feature์ embedding์ ์ํด ํ๋ จ ๋ฐ์ดํฐ์์ ์์ฑ๋ embedding๊ณผ ์ ์ฌํ embadding์ ๊ฒ์ํ๊ณ ํผํฉํ์ฌ ์๋์ ์์ฑ์ ๋์ฑ ๊ฐ๊น์ด ๋ณํ์ ๋ฌ์ฑํฉ๋๋ค. ๊ทธ๋ฌ๋, ์ด ํ์๋ฒ์ ๋จ์ํ ์ํํ๋ฉด ์๊ฐ์ด ๋ค์ ์๋ชจ๋๋ฏ๋ก, ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ํตํด ๊ณ ์ ๋ณํ์ ๊ฐ๋ฅ์ผ ํ๊ณ ์์ต๋๋ค.
# ๊ตฌํ ๊ฐ์
๋ชจ๋ธ์ด ์์นํ `/logs/your-experiment/3_feature256`์๋ ๊ฐ ์์ฑ ๋ฐ์ดํฐ์์ HuBERT๊ฐ ์ถ์ถํ feature๋ค์ด ์์ต๋๋ค. ์ฌ๊ธฐ์์ ํ์ผ ์ด๋ฆ๋ณ๋ก ์ ๋ ฌ๋ npy ํ์ผ์ ์ฝ๊ณ , ๋ฒกํฐ๋ฅผ ์ฐ๊ฒฐํ์ฌ big_npy ([N, 256] ๋ชจ์์ ๋ฒกํฐ) ๋ฅผ ๋ง๋ญ๋๋ค. big_npy๋ฅผ `/logs/your-experiment/total_fea.npy`๋ก ์ ์ฅํ ํ, Faiss๋ก ํ์ต์ํต๋๋ค.
2023/04/18 ๊ธฐ์ค์ผ๋ก, Faiss์ Index Factory ๊ธฐ๋ฅ์ ์ด์ฉํด, L2 ๊ฑฐ๋ฆฌ์ ๊ทผ๊ฑฐํ๋ IVF๋ฅผ ์ด์ฉํ๊ณ ์์ต๋๋ค. IVF์ ๋ถํ ์(n_ivf)๋ N//39๋ก, n_probe๋ int(np.power(n_ivf, 0.3))๊ฐ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. (infer-web.py์ train_index ์ฃผ์๋ฅผ ์ฐพ์ผ์ญ์์ค.)
์ด ํ์์๋ ๋จผ์ ์ด๋ฌํ ๋งค๊ฐ ๋ณ์์ ์๋ฏธ๋ฅผ ์ค๋ช
ํ๊ณ , ๊ฐ๋ฐ์๊ฐ ์ถํ ๋ ๋์ index๋ฅผ ์์ฑํ ์ ์๋๋ก ํ๋ ์กฐ์ธ์ ์์ฑํฉ๋๋ค.
# ๋ฐฉ๋ฒ์ ์ค๋ช
## Index factory
index factory๋ ์ฌ๋ฌ ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ๋ฌธ์์ด๋ก ์ฐ๊ฒฐํ๋ pipeline์ ๋ฌธ์์ด๋ก ํ๊ธฐํ๋ Faiss๋ง์ ๋
์์ ์ธ ๊ธฐ๋ฒ์
๋๋ค. ์ด๋ฅผ ํตํด index factory์ ๋ฌธ์์ด์ ๋ณ๊ฒฝํ๋ ๊ฒ๋ง์ผ๋ก ๋ค์ํ ๊ทผ์ฌ ๊ทผ์ ํ์์ ์๋ํด ๋ณผ ์ ์์ต๋๋ค. RVC์์๋ ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉ๋ฉ๋๋ค:
```python
index = Faiss.index_factory(256, "IVF%s,Flat" % n_ivf)
```
`index_factory`์ ์ธ์๋ค ์ค ์ฒซ ๋ฒ์งธ๋ ๋ฒกํฐ์ ์ฐจ์ ์์ด๊ณ , ๋๋ฒ์งธ๋ index factory ๋ฌธ์์ด์ด๋ฉฐ, ์ธ๋ฒ์งธ์๋ ์ฌ์ฉํ ๊ฑฐ๋ฆฌ๋ฅผ ์ง์ ํ ์ ์์ต๋๋ค.
๊ธฐ๋ฒ์ ๋ณด๋ค ์์ธํ ์ค๋ช
์ https://github.com/facebookresearch/Faiss/wiki/The-index-factory ๋ฅผ ํ์ธํด ์ฃผ์ญ์์ค.
## ๊ฑฐ๋ฆฌ์ ๋ํ index
embedding์ ์ ์ฌ๋๋ก์ ์ฌ์ฉ๋๋ ๋ํ์ ์ธ ์งํ๋ก์ ์ดํ์ 2๊ฐ๊ฐ ์์ต๋๋ค.
- ์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ (METRIC_L2)
- ๋ด์ (ๅ
็ฉ) (METRIC_INNER_PRODUCT)
์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ์์๋ ๊ฐ ์ฐจ์์์ ์ ๊ณฑ์ ์ฐจ๋ฅผ ๊ตฌํ๊ณ , ๊ฐ ์ฐจ์์์ ๊ตฌํ ์ฐจ๋ฅผ ๋ชจ๋ ๋ํ ํ ์ ๊ณฑ๊ทผ์ ์ทจํฉ๋๋ค. ์ด๊ฒ์ ์ผ์์ ์ผ๋ก ์ฌ์ฉ๋๋ 2์ฐจ์, 3์ฐจ์์์์ ๊ฑฐ๋ฆฌ์ ์ฐ์ฐ๋ฒ๊ณผ ๊ฐ์ต๋๋ค. ๋ด์ ์ ๊ทธ ๊ฐ์ ๊ทธ๋๋ก ์ ์ฌ๋ ์งํ๋ก ์ฌ์ฉํ์ง ์๊ณ , L2 ์ ๊ทํ๋ฅผ ํ ์ดํ ๋ด์ ์ ์ทจํ๋ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์ด๋ ์ชฝ์ด ๋ ์ข์์ง๋ ๊ฒฝ์ฐ์ ๋ฐ๋ผ ๋ค๋ฅด์ง๋ง, word2vec์์ ์ป์ embedding ๋ฐ ArcFace๋ฅผ ํ์ฉํ ์ด๋ฏธ์ง ๊ฒ์ ๋ชจ๋ธ์ ์ฝ์ฌ์ธ ์ ์ฌ์ฑ์ด ์ด์ฉ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. numpy๋ฅผ ์ฌ์ฉํ์ฌ ๋ฒกํฐ X์ ๋ํด L2 ์ ๊ทํ๋ฅผ ํ๊ณ ์ ํ๋ ๊ฒฝ์ฐ, 0 division์ ํผํ๊ธฐ ์ํด ์ถฉ๋ถํ ์์ ๊ฐ์ eps๋ก ํ ๋ค ์ดํ์ ์ฝ๋๋ฅผ ํ์ฉํ๋ฉด ๋ฉ๋๋ค.
```python
X_normed = X / np.maximum(eps, np.linalg.norm(X, ord=2, axis=-1, keepdims=True))
```
๋ํ, `index factory`์ 3๋ฒ์งธ ์ธ์์ ๊ฑด๋ค์ฃผ๋ ๊ฐ์ ์ ํํ๋ ๊ฒ์ ํตํด ๊ณ์ฐ์ ์ฌ์ฉํ๋ ๊ฑฐ๋ฆฌ index๋ฅผ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
```python
index = Faiss.index_factory(dimention, text, Faiss.METRIC_INNER_PRODUCT)
```
## IVF
IVF (Inverted file indexes)๋ ์ญ์์ธ ํ์๋ฒ๊ณผ ์ ์ฌํ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค. ํ์ต์์๋ ๊ฒ์ ๋์์ ๋ํด k-ํ๊ท ๊ตฐ์ง๋ฒ์ ์ค์ํ๊ณ ํด๋ฌ์คํฐ ์ค์ฌ์ ์ด์ฉํด ๋ณด๋ก๋
ธ์ด ๋ถํ ์ ์ค์ํฉ๋๋ค. ๊ฐ ๋ฐ์ดํฐ ํฌ์ธํธ์๋ ํด๋ฌ์คํฐ๊ฐ ํ ๋น๋๋ฏ๋ก, ํด๋ฌ์คํฐ์์ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ์กฐํํ๋ dictionary๋ฅผ ๋ง๋ญ๋๋ค.
์๋ฅผ ๋ค์ด, ํด๋ฌ์คํฐ๊ฐ ๋ค์๊ณผ ๊ฐ์ด ํ ๋น๋ ๊ฒฝ์ฐ
|index|Cluster|
|-----|-------|
|1|A|
|2|B|
|3|A|
|4|C|
|5|B|
IVF ์ดํ์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
|cluster|index|
|-------|-----|
|A|1, 3|
|B|2, 5|
|C|4|
ํ์ ์, ์ฐ์ ํด๋ฌ์คํฐ์์ `n_probe`๊ฐ์ ํด๋ฌ์คํฐ๋ฅผ ํ์ํ ๋ค์, ๊ฐ ํด๋ฌ์คํฐ์ ์ํ ๋ฐ์ดํฐ ํฌ์ธํธ์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
# ๊ถ์ฅ ๋งค๊ฐ๋ณ์
index์ ์ ํ ๋ฐฉ๋ฒ์ ๋ํด์๋ ๊ณต์์ ์ผ๋ก ๊ฐ์ด๋ ๋ผ์ธ์ด ์์ผ๋ฏ๋ก, ๊ฑฐ๊ธฐ์ ์คํด ์ค๋ช
ํฉ๋๋ค.
https://github.com/facebookresearch/Faiss/wiki/Guidelines-to-choose-an-index
1M ์ดํ์ ๋ฐ์ดํฐ ์ธํธ์ ์์ด์๋ 4bit-PQ๊ฐ 2023๋
4์ ์์ ์์๋ Faiss๋ก ์ด์ฉํ ์ ์๋ ๊ฐ์ฅ ํจ์จ์ ์ธ ์๋ฒ์
๋๋ค. ์ด๊ฒ์ IVF์ ์กฐํฉํด, 4bit-PQ๋ก ํ๋ณด๋ฅผ ์ถ๋ ค๋ด๊ณ , ๋ง์ง๋ง์ผ๋ก ์ดํ์ index factory๋ฅผ ์ด์ฉํ์ฌ ์ ํํ ์งํ๋ก ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ๊ณ์ฐํ๋ฉด ๋ฉ๋๋ค.
```python
index = Faiss.index_factory(256, "IVF1024,PQ128x4fs,RFlat")
```
## IVF ๊ถ์ฅ ๋งค๊ฐ๋ณ์
IVF์ ์๊ฐ ๋๋ฌด ๋ง์ผ๋ฉด, ๊ฐ๋ น ๋ฐ์ดํฐ ์์ ์๋งํผ IVF๋ก ์์ํ(Quantization)๋ฅผ ์ํํ๋ฉด, ์ด๊ฒ์ ์์ ํ์๊ณผ ๊ฐ์์ ธ ํจ์จ์ด ๋๋น ์ง๊ฒ ๋ฉ๋๋ค. 1M ์ดํ์ ๊ฒฝ์ฐ IVF ๊ฐ์ ๋ฐ์ดํฐ ํฌ์ธํธ ์ N์ ๋ํด 4sqrt(N) ~ 16sqrt(N)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
n_probe๋ n_probe์ ์์ ๋น๋กํ์ฌ ๊ณ์ฐ ์๊ฐ์ด ๋์ด๋๋ฏ๋ก ์ ํ๋์ ์๊ฐ์ ์ ์ ํ ๊ท ํ์ ๋ง์ถ์ด ์ฃผ์ญ์์ค. ๊ฐ์ธ์ ์ผ๋ก RVC์ ์์ด์ ๊ทธ๋ ๊ฒ๊น์ง ์ ํ๋๋ ํ์ ์๋ค๊ณ ์๊ฐํ๊ธฐ ๋๋ฌธ์ n_probe = 1์ด๋ฉด ๋๋ค๊ณ ์๊ฐํฉ๋๋ค.
## FastScan
FastScan์ ์ง์ ์์ํ๋ฅผ ๋ ์ง์คํฐ์์ ์ํํจ์ผ๋ก์จ ๊ฑฐ๋ฆฌ์ ๊ณ ์ ๊ทผ์ฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์
๋๋ค.์ง์ ์์ํ๋ ํ์ต์์ d์ฐจ์๋ง๋ค(๋ณดํต d=2)์ ๋
๋ฆฝ์ ์ผ๋ก ํด๋ฌ์คํฐ๋ง์ ์ค์ํด, ํด๋ฌ์คํฐ๋ผ๋ฆฌ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ์ ๊ณ์ฐํด lookup table๋ฅผ ์์ฑํฉ๋๋ค. ์์ธก์๋ lookup table์ ๋ณด๋ฉด ๊ฐ ์ฐจ์์ ๊ฑฐ๋ฆฌ๋ฅผ O(1)๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ PQ ๋ค์์ ์ง์ ํ๋ ์ซ์๋ ์ผ๋ฐ์ ์ผ๋ก ๋ฒกํฐ์ ์ ๋ฐ ์ฐจ์์ ์ง์ ํฉ๋๋ค.
FastScan์ ๋ํ ์์ธํ ์ค๋ช
์ ๊ณต์ ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
https://github.com/facebookresearch/Faiss/wiki/Fast-accumulation-of-PQ-and-AQ-codes-(FastScan)
## RFlat
RFlat์ FastScan์ด ๊ณ์ฐํ ๋๋ต์ ์ธ ๊ฑฐ๋ฆฌ๋ฅผ index factory์ 3๋ฒ์งธ ์ธ์๋ก ์ง์ ํ ์ ํํ ๊ฑฐ๋ฆฌ๋ก ๋ค์ ๊ณ์ฐํ๋ผ๋ ์ธ์คํธ๋ญ์
์
๋๋ค. k๊ฐ์ ๊ทผ์ ๋ณ์๋ฅผ ๊ฐ์ ธ์ฌ ๋ k*k_factor๊ฐ์ ์ ์ ๋ํด ์ฌ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค.
# Embedding ํ
ํฌ๋
## Alpha ์ฟผ๋ฆฌ ํ์ฅ
ํด๋ฆฌ ํ์ฅ์ด๋ ํ์์์ ์ฌ์ฉ๋๋ ๊ธฐ์ ๋ก, ์๋ฅผ ๋ค์ด ์ ๋ฌธ ํ์ ์, ์
๋ ฅ๋ ๊ฒ์๋ฌธ์ ๋จ์ด๋ฅผ ๋ช ๊ฐ๋ฅผ ์ถ๊ฐํจ์ผ๋ก์จ ๊ฒ์ ์ ํ๋๋ฅผ ์ฌ๋ฆฌ๋ ๋ฐฉ๋ฒ์
๋๋ค. ๋ฐฑํฐ ํ์์ ์ํด์๋ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ์ ์๋์๋๋ฐ, ๊ทธ ์ค ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ์ถ๊ฐ ํ์ต์ด ํ์ ์๋ ๋งค์ฐ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. [Attention-Based Query Expansion Learning](https://arxiv.org/abs/2007.08019)์ [2nd place solution of kaggle shopee competition](https://www.kaggle.com/code/lyakaap/2nd-place-solution/notebook) ๋
ผ๋ฌธ์์ ์๊ฐ๋ ๋ฐ ์์ต๋๋ค..
ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ํ ๋ฒกํฐ์ ์ธ์ ํ ๋ฒกํฐ๋ฅผ ์ ์ฌ๋์ ฮฑ๊ณฑํ ๊ฐ์ค์น๋ก ๋ํด์ฃผ๋ฉด ๋ฉ๋๋ค. ์ฝ๋๋ก ์์๋ฅผ ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค. big_npy๋ฅผ ฮฑ query expansion๋ก ๋์ฒดํฉ๋๋ค.
```python
alpha = 3.
index = Faiss.index_factory(256, "IVF512,PQ128x4fs,RFlat")
original_norm = np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)
big_npy /= original_norm
index.train(big_npy)
index.add(big_npy)
dist, neighbor = index.search(big_npy, num_expand)
expand_arrays = []
ixs = np.arange(big_npy.shape[0])
for i in range(-(-big_npy.shape[0]//batch_size)):
ix = ixs[i*batch_size:(i+1)*batch_size]
weight = np.power(np.einsum("nd,nmd->nm", big_npy[ix], big_npy[neighbor[ix]]), alpha)
expand_arrays.append(np.sum(big_npy[neighbor[ix]] * np.expand_dims(weight, axis=2),axis=1))
big_npy = np.concatenate(expand_arrays, axis=0)
# index version ์ ๊ทํ
big_npy = big_npy / np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)
```
์ ํ
ํฌ๋์ ํ์์ ์ํํ๋ ์ฟผ๋ฆฌ์๋, ํ์ ๋์ DB์๋ ์ ์ ๊ฐ๋ฅํ ํ
ํฌ๋์
๋๋ค.
## MiniBatch KMeans์ ์ํ embedding ์์ถ
total_fea.npy๊ฐ ๋๋ฌด ํด ๊ฒฝ์ฐ K-means๋ฅผ ์ด์ฉํ์ฌ ๋ฒกํฐ๋ฅผ ์๊ฒ ๋ง๋๋ ๊ฒ์ด ๊ฐ๋ฅํฉ๋๋ค. ์ดํ ์ฝ๋๋ก embedding์ ์์ถ์ด ๊ฐ๋ฅํฉ๋๋ค. n_clusters์ ์์ถํ๊ณ ์ ํ๋ ํฌ๊ธฐ๋ฅผ ์ง์ ํ๊ณ batch_size์ 256 * CPU์ ์ฝ์ด ์๋ฅผ ์ง์ ํจ์ผ๋ก์จ CPU ๋ณ๋ ฌํ์ ํํ์ ์ถฉ๋ถํ ์ป์ ์ ์์ต๋๋ค.
```python
import multiprocessing
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=10000, batch_size=256 * multiprocessing.cpu_count(), init="random")
kmeans.fit(big_npy)
sample_npy = kmeans.cluster_centers_
``` |