|
name: megatron_virtual_prompt_gpt |
|
|
|
trainer: |
|
devices: 1 |
|
accelerator: gpu |
|
num_nodes: 1 |
|
precision: 16 |
|
logger: False |
|
enable_checkpointing: False |
|
replace_sampler_ddp: False |
|
max_epochs: 3 |
|
max_steps: -1 |
|
log_every_n_steps: 10 |
|
val_check_interval: 1.0 |
|
gradient_clip_val: 1.0 |
|
resume_from_checkpoint: null |
|
benchmark: False |
|
|
|
|
|
|
|
exp_manager: |
|
explicit_log_dir: null |
|
exp_dir: null |
|
name: ${name} |
|
create_wandb_logger: False |
|
wandb_logger_kwargs: |
|
project: null |
|
name: null |
|
resume_if_exists: True |
|
resume_ignore_no_checkpoint: True |
|
create_checkpoint_callback: True |
|
checkpoint_callback_params: |
|
monitor: val_loss |
|
save_top_k: 2 |
|
mode: min |
|
save_nemo_on_train_end: False |
|
filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}' |
|
model_parallel_size: ${model.tensor_model_parallel_size} |
|
save_best_model: True |
|
create_early_stopping_callback: True |
|
early_stopping_callback_params: |
|
monitor: "val_loss" |
|
mode: "min" |
|
min_delta: 0.001 |
|
patience: 10 |
|
verbose: True |
|
|
|
|
|
model: |
|
seed: 1234 |
|
nemo_path: ${name}.nemo |
|
virtual_prompt_style: 'p-tuning' |
|
tensor_model_parallel_size: 1 |
|
pipeline_model_parallel_size: 1 |
|
global_batch_size: 8 |
|
micro_batch_size: 4 |
|
validation_global_batch_size: ${model.global_batch_size} |
|
validation_micro_batch_size: ${model.micro_batch_size} |
|
validation_drop_last: False |
|
|
|
restore_path: null |
|
language_model_path: ??? |
|
save_nemo_on_validation_end: True |
|
existing_tasks: ['boolq', 'intent_and_slot'] |
|
new_tasks: ['rte'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
sequence_parallel: False |
|
|
|
|
|
activations_checkpoint_granularity: null |
|
activations_checkpoint_method: null |
|
|
|
|
|
|
|
activations_checkpoint_num_layers: null |
|
|
|
task_templates: |
|
- taskname: "boolq" |
|
prompt_template: "<|VIRTUAL_PROMPT_0|> Passage: {passage} <|VIRTUAL_PROMPT_1|> \nQuestion: {question} \nAnswer: {answer}" |
|
total_virtual_tokens: 30 |
|
virtual_token_splits: [20, 10] |
|
truncate_field: "passage" |
|
answer_only_loss: True |
|
answer_field: "answer" |
|
|
|
- taskname: "intent_and_slot" |
|
prompt_template: "<|VIRTUAL_PROMPT_0|> intent options: {intent_options} <|VIRTUAL_PROMPT_1|> slot options: {slot_options} <|VIRTUAL_PROMPT_2|> {utterance} \nintent: {intent} \nslot: {slot}" |
|
total_virtual_tokens: 30 |
|
answer_only_loss: False |
|
virtual_token_splits: [15, 10, 5] |
|
truncate_field: null |
|
|
|
- taskname: "rte" |
|
prompt_template: "<|VIRTUAL_PROMPT_0|>{premise}\n{hypothesis}\nAnswer: {answer}" |
|
total_virtual_tokens: 9 |
|
virtual_token_splits: [9] |
|
truncate_field: null |
|
answer_only_loss: True |
|
answer_field: "answer" |
|
|
|
- taskname: "squad" |
|
prompt_template: "<|VIRTUAL_PROMPT_0|> context: {context} question: {question} answer: {answer}" |
|
total_virtual_tokens: 10 |
|
virtual_token_splits: [10] |
|
truncate_field: null |
|
answer_only_loss: True |
|
answer_field: "answer" |
|
|
|
- taskname: "taskname" |
|
prompt_template: "<|VIRTUAL_PROMPT_0|> {prompt} {completion}" |
|
total_virtual_tokens: 100 |
|
virtual_token_splits: [100] |
|
truncate_field: null |
|
answer_only_loss: True |
|
answer_field: "completion" |
|
|
|
prompt_tuning: |
|
new_prompt_init_methods: ['text'] |
|
new_prompt_init_text: ['some init text goes here'] |
|
|
|
p_tuning: |
|
encoder_type: "tpmlp" |
|
dropout: 0.0 |
|
num_layers: 2 |
|
encoder_hidden: 2048 |
|
init_std: 0.023 |
|
|
|
data: |
|
train_ds: [data/rte_train.jsonl,] |
|
validation_ds: [data/rte_val.jsonl,] |
|
add_eos: True |
|
shuffle: True |
|
num_workers: 8 |
|
pin_memory: True |
|
train_cache_data_path: null |
|
validation_cache_data_path: null |
|
test_cache_data_path: null |
|
load_cache: False |
|
max_seq_length: 1024 |
|
min_seq_length: 1 |
|
|
|
|
|
optim: |
|
name: fused_adam |
|
lr: 1e-4 |
|
weight_decay: 0.01 |
|
betas: |
|
- 0.9 |
|
- 0.98 |
|
sched: |
|
name: CosineAnnealing |
|
warmup_steps: 50 |
|
min_lr: 0.0 |
|
constant_steps: 0 |
|
monitor: val_loss |
|
reduce_on_plateau: false |
|
|