lolcats / configs /model /distill_llama3_8b_lk_smd_fd64.yaml
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
748 Bytes
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
device_map: auto
low_cpu_mem_usage: true
torch_dtype: bfloat16
attn_implementation: flash_attention_2
rope_theta: 500000.0
attention:
attention_type: lolcats_llama
feature_map: softmax_dim
feature_map_kwargs:
eps: 1e-12
# mlp: null # to set
fullspace: true
layer_idx: null # to set
learned_kernel: untied_head_einsum
learned_kernel_kwargs:
feature_dim: 64
skip_connection: false
bias: false
zero_init: false
tie_qk_kernels: false
train_qk: false