Spaces:
Sleeping
Sleeping
# Required to make the "experiments" dir the default one for the output of the models | |
hydra: | |
run: | |
dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} | |
model_name: ${model.language_model} # used to name the model in wandb | |
project_name: relik-retriever # used to name the project in wandb | |
defaults: | |
- _self_ | |
- model: golden_retriever | |
- index: inmemory | |
- loss: nce_loss | |
- optimizer: radamw | |
- scheduler: linear_scheduler | |
- data: dataset_v2 # iterable_in_batch_negatives #dataset_v2 | |
- logging: wandb_logging | |
- override hydra/job_logging: colorlog | |
- override hydra/hydra_logging: colorlog | |
train: | |
# reproducibility | |
seed: 42 | |
set_determinism_the_old_way: False | |
# torch parameters | |
float32_matmul_precision: "medium" | |
# if true, only test the model | |
only_test: False | |
# if provided, initialize the model with the weights from the checkpoint | |
pretrain_ckpt_path: null | |
# if provided, start training from the checkpoint | |
checkpoint_path: null | |
# task specific parameter | |
top_k: 100 | |
# pl_trainer | |
pl_trainer: | |
_target_: lightning.Trainer | |
accelerator: gpu | |
devices: 1 | |
num_nodes: 1 | |
strategy: auto | |
accumulate_grad_batches: 1 | |
gradient_clip_val: 1.0 | |
val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps | |
check_val_every_n_epoch: 1 | |
max_epochs: 0 | |
max_steps: 25_000 | |
deterministic: True | |
fast_dev_run: False | |
precision: 16 | |
reload_dataloaders_every_n_epochs: 1 | |
early_stopping_callback: | |
# null | |
_target_: lightning.callbacks.EarlyStopping | |
monitor: validate_recall@${train.top_k} | |
mode: max | |
patience: 3 | |
model_checkpoint_callback: | |
_target_: lightning.callbacks.ModelCheckpoint | |
monitor: validate_recall@${train.top_k} | |
mode: max | |
verbose: True | |
save_top_k: 1 | |
save_last: False | |
filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" | |
auto_insert_metric_name: False | |
callbacks: | |
prediction_callback: | |
_target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback | |
k: ${train.top_k} | |
batch_size: 64 | |
precision: 16 | |
index_precision: 16 | |
other_callbacks: | |
- _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback | |
k: ${train.top_k} | |
verbose: True | |
- _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback | |
k: 50 | |
verbose: True | |
prog_bar: False | |
- _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback | |
k: ${train.top_k} | |
verbose: True | |
- _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback | |
hard_negatives_callback: | |
_target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback | |
k: ${train.top_k} | |
batch_size: 64 | |
precision: 16 | |
index_precision: 16 | |
stages: [validate] #[validate, sanity_check] | |
metrics_to_monitor: | |
validate_recall@${train.top_k} | |
# - sanity_check_recall@${train.top_k} | |
threshold: 0.0 | |
max_negatives: 20 | |
add_with_probability: 1.0 | |
refresh_every_n_epochs: 1 | |
other_callbacks: | |
- _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback | |
k: ${train.top_k} | |
verbose: True | |
prefix: "train" | |
utils_callbacks: | |
- _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback | |
- _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback | |
# - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback | |
# question_encoder: ${model.pl_module.model.question_encoder} | |
# passage_encoder: ${model.pl_module.model.passage_encoder} | |