name: llama | |
model: | |
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" | |
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights | |
return_dict: true | |
load_in_8bit: false | |
load_in_4bit: false | |
device_map: auto | |
low_cpu_mem_usage: true | |
torch_dtype: bfloat16 | |
attn_implementation: flash_attention_2 | |
rope_theta: 500000.0 | |
attention: | |
attention_type: lolcats_llama | |
feature_map: softmax_dim | |
feature_map_kwargs: | |
eps: 1e-12 | |
# mlp: null # to set | |
fullspace: true | |
layer_idx: null # to set | |
learned_kernel: untied_head_einsum | |
learned_kernel_kwargs: | |
feature_dim: 64 | |
skip_connection: false | |
bias: false | |
zero_init: false | |
tie_qk_kernels: false | |
train_qk: false | |