lolcats / configs /model /distill_llama3_1_8b_lk_smd_fd64.yaml
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
893 Bytes
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
device_map: auto
low_cpu_mem_usage: true
torch_dtype: bfloat16
attn_implementation: eager
rope_theta: 500000.0
rope_scaling:
factor: 8.0
low_freq_factor: 1.0
high_freq_factor: 4.0
original_max_position_embeddings: 8192
rope_type: llama3
attention:
attention_type: lolcats_llama
feature_map: softmax_dim
feature_map_kwargs:
eps: 1e-12
# mlp: null # to set
fullspace: true
layer_idx: null # to set
learned_kernel: untied_head_einsum
learned_kernel_kwargs:
feature_dim: 64
skip_connection: false
bias: false
zero_init: false
tie_qk_kernels: false
train_qk: false