Fix set mem_id for inference and refactor
Browse files
scripts/finetune.py
CHANGED
@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
78 |
)
|
79 |
|
80 |
if cfg.landmark_attention:
|
|
|
|
|
|
|
81 |
model.set_mem_cache_args(
|
82 |
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
83 |
)
|
|
|
78 |
)
|
79 |
|
80 |
if cfg.landmark_attention:
|
81 |
+
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
82 |
+
|
83 |
+
set_model_mem_id(model, tokenizer)
|
84 |
model.set_mem_cache_args(
|
85 |
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
86 |
)
|
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
@@ -29,6 +29,7 @@ import torch
|
|
29 |
import torch.utils.checkpoint
|
30 |
from torch import nn
|
31 |
from torch.nn import CrossEntropyLoss
|
|
|
32 |
from transformers.modeling_outputs import (
|
33 |
BaseModelOutputWithPast,
|
34 |
CausalLMOutputWithPast,
|
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
|
|
1237 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
1238 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
1239 |
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
import torch.utils.checkpoint
|
30 |
from torch import nn
|
31 |
from torch.nn import CrossEntropyLoss
|
32 |
+
from transformers import LlamaTokenizer
|
33 |
from transformers.modeling_outputs import (
|
34 |
BaseModelOutputWithPast,
|
35 |
CausalLMOutputWithPast,
|
|
|
1238 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
1239 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
1240 |
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
1241 |
+
|
1242 |
+
|
1243 |
+
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
|
1244 |
+
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
1245 |
+
model.set_mem_id(mem_id)
|
1246 |
+
|
1247 |
+
|
1248 |
+
def get_mem_id(tokenizer: LlamaTokenizer):
|
1249 |
+
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
239 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
240 |
from functools import partial
|
241 |
|
242 |
-
from axolotl.monkeypatch.llama_landmark_attn import
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
model.set_mem_id(mem_id)
|
246 |
|
247 |
logging.info("Adding landmark attention tokens to dataset")
|
248 |
|
249 |
for dataset in [train_dataset, eval_dataset]:
|
250 |
dataset = dataset.map(
|
251 |
-
partial(add_mem_tokens, mem_freq=50, mem_id=
|
252 |
batched=False,
|
253 |
num_proc=32,
|
254 |
)
|
|
|
239 |
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
240 |
from functools import partial
|
241 |
|
242 |
+
from axolotl.monkeypatch.llama_landmark_attn import (
|
243 |
+
add_mem_tokens,
|
244 |
+
get_mem_id,
|
245 |
+
set_model_mem_id,
|
246 |
+
)
|
247 |
|
248 |
+
set_model_mem_id(model, tokenizer)
|
|
|
249 |
|
250 |
logging.info("Adding landmark attention tokens to dataset")
|
251 |
|
252 |
for dataset in [train_dataset, eval_dataset]:
|
253 |
dataset = dataset.map(
|
254 |
+
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
255 |
batched=False,
|
256 |
num_proc=32,
|
257 |
)
|