aryan3212/clae-bengali-encoder

Continuous latent autoencoder for Bengali speech.

Architecture

  • Encoder: Conformer
  • Self-supervised objective: LeJEPA
  • Reconstruction loss: multi-resolution STFT

Training

  • Step: 67101

Config

resolved_config_path: configs/kaggle_3m_gan.yaml
run:
  run_id: final_aryan
  out_dir: /kaggle/working/runs
  seed: 0
  amp: true
  gpu_mem_fraction: 0.95
  wandb:
    enabled: true
    project: continuous-latent-ae
    name: exp_3m_gan
data:
  sample_rate: 16000
  segment_seconds: 3.0
  train_manifest: /kaggle/working/manifests/train.jsonl
  val_manifest: /kaggle/working/manifests/val.jsonl
  num_workers: 3
  pin_memory: true
  persistent_workers: true
  prefetch_factor: 6
aug:
  wave_aug:
    enabled: true
    noise_prob: 1.0
    noise_snr_min: 0.0
    noise_snr_max: 25.0
    lowpass_prob: 1.0
    lowpass_min_freq: 2700.0
    lowpass_max_freq: 8000.0
    gain_prob: 1.0
    gain_min: 0.6
    gain_max: 1.5
    clip_prob: 0.3
    clip_min: 0.5
  wave_chunk_mask:
    enabled: true
    target_ratio: 0.25
    min_span_frames: 2
    max_span_frames: 8
model:
  frontend:
    channels:
    - 32
    - 64
    - 96
    - 128
    - 128
    kernels:
    - 10
    - 8
    - 8
    - 4
    - 4
    strides:
    - 5
    - 4
    - 4
    - 4
    - 4
    groups: 8
  encoder:
    d_model: 128
    n_layers: 5
    num_heads: 4
    feedforward_dim: 320
    dropout: 0.1
    cnn_module_kernel: 31
    mhc:
      enabled: true
      num_streams: 2
      start_layer: 2
      period: 3
      sinkhorn_iters: 10
      tau: 0.05
      dropout: 0.0
      identity_mix: true
      alpha_init: 0.01
  decoder:
    channels: 256
    up_strides:
    - 4
    - 4
    - 4
    - 4
    - 5
    up_kernels:
    - 8
    - 8
    - 8
    - 8
    - 10
    res_blocks_per_up: 1
    res_dilations:
    - 1
    - 3
    - 9
    film_hidden: 128
  projector:
    hidden_dim: 512
    output_dim: 48
    n_hidden_layers: 2
loss:
  stft_weight: 0.1
  stft:
    fft_sizes:
    - 256
    - 512
    - 1024
    - 2048
    hop_ratio: 0.25
    win_ratio: 1.0
    center: true
    window: hann
    logmag_eps: 0.001
    sc_weight: 1.0
    mag_weight: 1.0
    logmag_weight: 1.0
  wav_l1_weight: 0.0
  jepa:
    weight: 6.0
    num_globals: 2
    num_locals: 4
    context_weight: 1.0
  sigreg:
    weight: 0.05
    num_slices: 1024
    t_max: 5.0
    n_points: 17
  adv:
    enabled: true
    adv_weight: 1.0
    fm_weight: 2.0
    adv_start_step: 20000
    fm_start_step: 20000
    lr: 0.0002
    betas:
    - 0.8
    - 0.99
    periods:
    - 2
    - 3
    - 5
    - 7
    - 11
    disc_channels:
    - 24
    - 48
    - 64
    - 96
    loss_type: lsgan
    adaptive: true
    adaptive_max: 1.0
optim:
  lr: 0.0005
  betas:
  - 0.9
  - 0.999
  eps: 1.0e-08
  weight_decay: 1.0e-05
  scheduler:
    warmup_steps: 2000
    total_steps: 70000
    min_lr_ratio: 0.0
  grad_clip: 1.0
train:
  batch_size: 32
  grad_accum_steps: 4
  max_steps: 70000
  log_interval_steps: 10
  eval_interval_steps: 5000
  save_interval_steps: 2000
  probe_interval_steps: 1000
  val_batches: null
eval:
  enabled: true
  emotion:
    enabled: false
    train_manifest: null
    dev_manifest: null
    label_key: emotion
    steps: 2000
    batch_size: 64
    segment_seconds: null
  gender:
    enabled: false
    train_manifest: null
    dev_manifest: null
    label_key: gender
    steps: 1500
    batch_size: 64
    segment_seconds: null
  asr:
    enabled: true
    train_manifest: staging/manifests/asr_probe_train.jsonl
    dev_manifest: staging/manifests/asr_probe_val.jsonl
    text_key: text
    steps: 8000
    batch_size: 16
    segment_seconds: 15.0
    max_samples: 500

How to load

import torch
ckpt = torch.load('last.pt', map_location='cpu')
state_dict = ckpt['model']
cfg = ckpt['cfg']
Downloads last month
70
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using aryan3212/clae-bengali-encoder 1