hubert-base-korean
Model Details
Hubert(Hidden-Unit BERT)๋ Facebook์์ ์ ์ํ Speech Representation Learning ๋ชจ๋ธ์ ๋๋ค. Hubert๋ ๊ธฐ์กด์ ์์ฑ ์ธ์ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ, ์์ฑ ์ ํธ๋ฅผ raw waveform์์ ๋ฐ๋ก ํ์ตํ๋ self-supervised learning ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค.
์ด ์ฐ๊ตฌ๋ ๊ตฌ๊ธ์ TPU Research Cloud(TRC)๋ฅผ ํตํด ์ง์๋ฐ์ Cloud TPU๋ก ํ์ต๋์์ต๋๋ค.
Model Description
Base | Large | ||
CNN Encoder | strides | 5, 2, 2, 2, 2, 2, 2 | |
kernel width | 10, 3, 3, 3, 3, 2, 2 | ||
channel | 512 | ||
Transformer Encoder | Layer | 12 | 24 |
embedding dim | 768 | 1024 | |
inner FFN dim | 3072 | 4096 | |
attention heads | 8 | 16 | |
Projection | dim | 256 | 768 |
Params | 95M | 317M |
How to Get Started with the Model
Pytorch
import torch
from transformers import HubertModel
model = HubertModel.from_pretrained("team-lucid/hubert-base-korean")
wav = torch.ones(1, 16000)
outputs = model(wav)
print(f"Input: {wav.shape}") # [1, 16000]
print(f"Output: {outputs.last_hidden_state.shape}") # [1, 49, 768]
JAX/Flax
import jax.numpy as jnp
from transformers import FlaxAutoModel
model = FlaxAutoModel.from_pretrained("team-lucid/hubert-base-korean", trust_remote_code=True)
wav = jnp.ones((1, 16000))
outputs = model(wav)
print(f"Input: {wav.shape}") # [1, 16000]
print(f"Output: {outputs.last_hidden_state.shape}") # [1, 49, 768]
Training Details
Training Data
ํด๋น ๋ชจ๋ธ์ ๊ณผํ๊ธฐ์ ์ ๋ณดํต์ ๋ถ์ ์ฌ์์ผ๋ก ํ๊ตญ์ง๋ฅ์ ๋ณด์ฌํ์งํฅ์์ ์ง์์ ๋ฐ์ ๊ตฌ์ถ๋ ์์ ๋ํ ์์ฑ(์ผ๋ฐ๋จ์ฌ), ๋คํ์ ์์ฑํฉ์ฑ ๋ฐ์ดํฐ, ๋ฐฉ์ก ์ฝํ ์ธ ๋ํ์ฒด ์์ฑ์ธ์ ๋ฐ์ดํฐ ์์ ์ฝ 4,000์๊ฐ์ ์ถ์ถํด ํ์ต๋์์ต๋๋ค.
Training Procedure
์ ๋ ผ๋ฌธ๊ณผ ๋์ผํ๊ฒ MFCC ๊ธฐ๋ฐ์ผ๋ก Base ๋ชจ๋ธ์ ํ์ตํ ๋ค์, 500 cluster๋ก k-means๋ฅผ ์ํํด ๋ค์ Base์ Large ๋ชจ๋ธ์ ํ์ตํ์ต๋๋ค.
Training Hyperparameters
Hyperparameter | Base | Large |
---|---|---|
Warmup Steps | 32,000 | 32,000 |
Learning Rates | 5e-4 | 1.5e-3 |
Batch Size | 128 | 128 |
Weight Decay | 0.01 | 0.01 |
Max Steps | 400,000 | 400,000 |
Learning Rate Decay | 0.1 | 0.1 |
0.9 | 0.9 | |
0.99 | 0.99 |
- Downloads last month
- 208