dpo: data: splits: "random" # {random, novel_premises} train_size: 0.8 include_next_state: false # paths are relative to the project root raw_data: "data/time_filtered_v3.json" formatted_dataset_dir: "data/straight_shot_proof_sample/filtered_negative_tactics_dataset.json" expand_records: false # processed_data: "data/straight_shot_proof_sample/dpo_expanded_dataset" # processed_data: "data/straight_shot_proof_sample/dpo_single_entry_dataset" processed_data: "data/straight_shot_proof_sample/dpo_flattened_dataset" # prompt formatting {llemma, deepseek} # model_prompt_template: "deepseek" model_prompt_template: "llemma" use_sts_format: false model: # TODO: figure this out base_model_id: "deepseek-ai/DeepSeek-Prover-V1" max_seq_length: 1024 packing: true # pack examples together for better efficiency training_args: # output_dir: "dspv1_dpo_dspfmt_medium" # directory to save and repository id (relative to project root) output_dir: "dspv1_dpo_llemmafmt_medium" # directory to save and repository id (relative to project root) num_train_epochs: 3 # number of training epochs per_device_train_batch_size: 3 # batch size per device during training gradient_accumulation_steps: 2 # number of steps before performing a backward/update pass gradient_checkpointing: true # use gradient checkpointing to save memory optim: "adamw_torch_fused" # use fused adamw optimizer logging_steps: 10 # log every 10 steps save_strategy: "epoch" # save checkpoint every epoch learning_rate: 0.0002 # learning rate, based on QLoRA paper bf16: true # use bfloat16 precision tf32: true # use tf32 precision max_grad_norm: 0.3 # max gradient norm based on QLoRA paper warmup_ratio: 0.03 # warmup ratio based on QLoRA paper lr_scheduler_type: "constant" # use constant learning rate scheduler push_to_hub: true # push model to hub report_to: "tensorboard" # report metrics to tensorboard beta: 0.01 # # TODO: tune this (beta for the loss function) bnb: _target_: transformers.BitsAndBytesConfig load_in_4bit: true bnb_4bit_use_double_quant: true bnb_4bit_quant_type: nf4 bnb_4bit_compute_dtype: bfloat16 lora: _target_: peft.LoraConfig r: 16 lora_alpha: 32 lora_dropout: 0.05 bias: "none" target_modules: "all-linear" task_type: "CAUSAL_LM"