aryan3212/clae-bengali-encoder
Continuous latent autoencoder for Bengali speech.
Architecture
- Encoder: Conformer
- Self-supervised objective: LeJEPA
- Reconstruction loss: multi-resolution STFT
Training
- Step:
67101
Config
resolved_config_path: configs/kaggle_3m_gan.yaml
run:
run_id: final_aryan
out_dir: /kaggle/working/runs
seed: 0
amp: true
gpu_mem_fraction: 0.95
wandb:
enabled: true
project: continuous-latent-ae
name: exp_3m_gan
data:
sample_rate: 16000
segment_seconds: 3.0
train_manifest: /kaggle/working/manifests/train.jsonl
val_manifest: /kaggle/working/manifests/val.jsonl
num_workers: 3
pin_memory: true
persistent_workers: true
prefetch_factor: 6
aug:
wave_aug:
enabled: true
noise_prob: 1.0
noise_snr_min: 0.0
noise_snr_max: 25.0
lowpass_prob: 1.0
lowpass_min_freq: 2700.0
lowpass_max_freq: 8000.0
gain_prob: 1.0
gain_min: 0.6
gain_max: 1.5
clip_prob: 0.3
clip_min: 0.5
wave_chunk_mask:
enabled: true
target_ratio: 0.25
min_span_frames: 2
max_span_frames: 8
model:
frontend:
channels:
- 32
- 64
- 96
- 128
- 128
kernels:
- 10
- 8
- 8
- 4
- 4
strides:
- 5
- 4
- 4
- 4
- 4
groups: 8
encoder:
d_model: 128
n_layers: 5
num_heads: 4
feedforward_dim: 320
dropout: 0.1
cnn_module_kernel: 31
mhc:
enabled: true
num_streams: 2
start_layer: 2
period: 3
sinkhorn_iters: 10
tau: 0.05
dropout: 0.0
identity_mix: true
alpha_init: 0.01
decoder:
channels: 256
up_strides:
- 4
- 4
- 4
- 4
- 5
up_kernels:
- 8
- 8
- 8
- 8
- 10
res_blocks_per_up: 1
res_dilations:
- 1
- 3
- 9
film_hidden: 128
projector:
hidden_dim: 512
output_dim: 48
n_hidden_layers: 2
loss:
stft_weight: 0.1
stft:
fft_sizes:
- 256
- 512
- 1024
- 2048
hop_ratio: 0.25
win_ratio: 1.0
center: true
window: hann
logmag_eps: 0.001
sc_weight: 1.0
mag_weight: 1.0
logmag_weight: 1.0
wav_l1_weight: 0.0
jepa:
weight: 6.0
num_globals: 2
num_locals: 4
context_weight: 1.0
sigreg:
weight: 0.05
num_slices: 1024
t_max: 5.0
n_points: 17
adv:
enabled: true
adv_weight: 1.0
fm_weight: 2.0
adv_start_step: 20000
fm_start_step: 20000
lr: 0.0002
betas:
- 0.8
- 0.99
periods:
- 2
- 3
- 5
- 7
- 11
disc_channels:
- 24
- 48
- 64
- 96
loss_type: lsgan
adaptive: true
adaptive_max: 1.0
optim:
lr: 0.0005
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 1.0e-05
scheduler:
warmup_steps: 2000
total_steps: 70000
min_lr_ratio: 0.0
grad_clip: 1.0
train:
batch_size: 32
grad_accum_steps: 4
max_steps: 70000
log_interval_steps: 10
eval_interval_steps: 5000
save_interval_steps: 2000
probe_interval_steps: 1000
val_batches: null
eval:
enabled: true
emotion:
enabled: false
train_manifest: null
dev_manifest: null
label_key: emotion
steps: 2000
batch_size: 64
segment_seconds: null
gender:
enabled: false
train_manifest: null
dev_manifest: null
label_key: gender
steps: 1500
batch_size: 64
segment_seconds: null
asr:
enabled: true
train_manifest: staging/manifests/asr_probe_train.jsonl
dev_manifest: staging/manifests/asr_probe_val.jsonl
text_key: text
steps: 8000
batch_size: 16
segment_seconds: 15.0
max_samples: 500
How to load
import torch
ckpt = torch.load('last.pt', map_location='cpu')
state_dict = ckpt['model']
cfg = ckpt['cfg']
- Downloads last month
- 70
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support