Feat: Add rope scaling (#343)
Browse files* Feat: Add rope scaling
* fix: move rope config
- README.md +4 -0
- src/axolotl/utils/models.py +3 -1
README.md
CHANGED
@@ -474,6 +474,10 @@ landmark_attention:
|
|
474 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
475 |
# llama only
|
476 |
xpos_rope:
|
|
|
|
|
|
|
|
|
477 |
|
478 |
# resume from a specific checkpoint dir
|
479 |
resume_from_checkpoint:
|
|
|
474 |
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
475 |
# llama only
|
476 |
xpos_rope:
|
477 |
+
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
478 |
+
rope_scaling:
|
479 |
+
type: # linear | dynamic
|
480 |
+
factor: # float
|
481 |
|
482 |
# resume from a specific checkpoint dir
|
483 |
resume_from_checkpoint:
|
src/axolotl/utils/models.py
CHANGED
@@ -219,7 +219,9 @@ def load_model(
|
|
219 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
220 |
from transformers import LlamaForCausalLM
|
221 |
|
222 |
-
config = LlamaConfig.from_pretrained(
|
|
|
|
|
223 |
model = LlamaForCausalLM.from_pretrained(
|
224 |
base_model,
|
225 |
config=config,
|
|
|
219 |
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
220 |
from transformers import LlamaForCausalLM
|
221 |
|
222 |
+
config = LlamaConfig.from_pretrained(
|
223 |
+
base_model_config, rope_scaling=cfg.rope_scaling
|
224 |
+
)
|
225 |
model = LlamaForCausalLM.from_pretrained(
|
226 |
base_model,
|
227 |
config=config,
|