Nanobit commited on
Commit
b521206
1 Parent(s): 289d5c4

Feat: Add rope scaling (#343)

Browse files

* Feat: Add rope scaling

* fix: move rope config

Files changed (2) hide show
  1. README.md +4 -0
  2. src/axolotl/utils/models.py +3 -1
README.md CHANGED
@@ -474,6 +474,10 @@ landmark_attention:
474
  # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
475
  # llama only
476
  xpos_rope:
 
 
 
 
477
 
478
  # resume from a specific checkpoint dir
479
  resume_from_checkpoint:
 
474
  # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
475
  # llama only
476
  xpos_rope:
477
+ # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
478
+ rope_scaling:
479
+ type: # linear | dynamic
480
+ factor: # float
481
 
482
  # resume from a specific checkpoint dir
483
  resume_from_checkpoint:
src/axolotl/utils/models.py CHANGED
@@ -219,7 +219,9 @@ def load_model(
219
  elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
220
  from transformers import LlamaForCausalLM
221
 
222
- config = LlamaConfig.from_pretrained(base_model_config)
 
 
223
  model = LlamaForCausalLM.from_pretrained(
224
  base_model,
225
  config=config,
 
219
  elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
220
  from transformers import LlamaForCausalLM
221
 
222
+ config = LlamaConfig.from_pretrained(
223
+ base_model_config, rope_scaling=cfg.rope_scaling
224
+ )
225
  model = LlamaForCausalLM.from_pretrained(
226
  base_model,
227
  config=config,