PIQ-ESC50 / hyperparams.yaml
cemsubakan's picture
adding the .ckpt files
c858225
raw history blame
No virus
4.31 kB
# Generated 2023-07-14 from:
# /data2/cloned_repos/speechbrain-clone/recipes/ESC50/interpret/hparams/piq.yaml
# yamllint disable
# #################################
# The recipe for training PIQ on the ESC50 dataset.
#
# Author:
# * Cem Subakan 2022, 2023
# * Francesco Paissan 2022, 2023
# (based on the SpeechBrain UrbanSound8k recipe)
# #################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !!python/object/apply:torch.manual_seed [1234]
# Set up folders for reading from and writing to
# Dataset must already exist at `audio_data_folder`
data_folder: /data2/ESC-50-master
# e.g., /localscratch/UrbanSound8K
audio_data_folder: /data2/ESC-50-master/audio
experiment_name: piq
output_folder: ./results/piq/1234
save_folder: ./results/piq/1234/save
train_log: ./results/piq/1234/train_log.txt
test_only: false
save_interpretations: true
interpret_period: 10
# Tensorboard logs
use_tensorboard: false
tensorboard_logs_folder: ./results/piq/1234/tb_logs/
# Path where data manifest files will be stored
train_annotation: /data2/ESC-50-master/manifest/train.json
valid_annotation: /data2/ESC-50-master/manifest/valid.json
test_annotation: /data2/ESC-50-master/manifest/test.json
# To standardize results, UrbanSound8k has pre-separated samples into
# 10 folds for multi-fold validation
train_fold_nums: [1, 2, 3]
valid_fold_nums: [4]
test_fold_nums: [5]
skip_manifest_creation: false
ckpt_interval_minutes: 15 # save checkpoint every N min
# Training parameters
number_of_epochs: 200
batch_size: 16
lr: 0.0002
sample_rate: 16000
use_vq: true
rec_loss_coef: 1
use_mask_output: true
mask_th: 0.35
device: cuda
# Feature parameters
n_mels: 80
# Number of classes
out_n_neurons: 50
shuffle: true
dataloader_options:
batch_size: 16
shuffle: true
num_workers: 0
epoch_counter: &id001 !new:speechbrain.utils.epoch_loop.EpochCounter
limit: 200
opt_class: !name:torch.optim.Adam
lr: 0.0002
weight_decay: 0.000002
lr_annealing: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
factor: 0.5
patience: 3
dont_halve_until_epoch: 100
# Logging + checkpoints
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: ./results/piq/1234/train_log.txt
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: ./results/piq/1234/save
recoverables:
psi_model: &id004 !new:speechbrain.lobes.models.PIQ.VectorQuantizedPSI_Audio
dim: 256
K: 1024
shared_keys: 0
activate_class_partitioning: true
use_adapter: true
adapter_reduce_dim: true
counter: *id001
use_pretrained: true
# embedding_model: !new:custom_models.Conv2dEncoder_v2
embedding_model: &id002 !new:speechbrain.lobes.models.PIQ.Conv2dEncoder_v2
dim: 256
classifier: &id003 !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
input_size: 256
out_neurons: 50
lin_blocks: 1
# Interpretation hyperparams
K: 1024
# pre-processing
n_fft: 1024
spec_mag_power: 0.5
hop_length: 11.6099
win_length: 23.2199
compute_stft: &id005 !new:speechbrain.processing.features.STFT
n_fft: 1024
hop_length: 11.6099
win_length: 23.2199
sample_rate: 16000
compute_fbank: &id006 !new:speechbrain.processing.features.Filterbank
n_mels: 80
n_fft: 1024
sample_rate: 16000
compute_istft: &id007 !new:speechbrain.processing.features.ISTFT
sample_rate: 16000
hop_length: 11.6099
win_length: 23.2199
label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
psi_model: *id004
modules:
compute_stft: *id005
compute_fbank: *id006
compute_istft: *id007
psi: *id004
embedding_model: !ref <embedding_model>
classifier: !ref <classifier>
embedding_model_path: fpaissan/conv2d_us8k/embedding_modelft.ckpt
classifier_model_path: fpaissan/conv2d_us8k/classifier.ckpt
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
loadables:
embedding_model: !ref <embedding_model>
classifier: !ref <classifier>
psi: !ref <psi_model>
label_encoder: !ref <label_encoder>
paths:
embedding_model: fpaissan/conv2d_us8k/embedding_modelft.ckpt
classifier: fpaissan/conv2d_us8k/classifier.ckpt
psi: /data2/PIQ-ESC50/psi_model.ckpt
label_encoder: speechbrain/cnn14-esc50/label_encoder.txt