phoebeklett
commited on
Commit
•
35c3478
1
Parent(s):
4ea5d17
Upload 2 files
Browse files- modeling.py +4 -4
modeling.py
CHANGED
@@ -47,7 +47,7 @@ from transformers.utils import (
|
|
47 |
replace_return_docstrings,
|
48 |
)
|
49 |
|
50 |
-
from .configuration import ExtendedLlamaConfig
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
@@ -1144,7 +1144,7 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
|
|
1144 |
|
1145 |
_tied_weights_keys = ["lm_head.weight"]
|
1146 |
|
1147 |
-
def __init__(self, config, external_memories=None):
|
1148 |
super().__init__(config)
|
1149 |
self.model = ExtendedLlamaModel(config)
|
1150 |
self.vocab_size = config.vocab_size
|
@@ -1242,9 +1242,9 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
|
|
1242 |
if (
|
1243 |
self.memory_ids is not None and self.memories is None
|
1244 |
):
|
|
|
1245 |
self.memories = self.generate_cache(
|
1246 |
-
|
1247 |
-
cache_type=self.memory_type,
|
1248 |
)
|
1249 |
# EM: Remove special tokens from memory cache
|
1250 |
if self.remove_special_ids:
|
|
|
47 |
replace_return_docstrings,
|
48 |
)
|
49 |
|
50 |
+
from emts_clean.src.llama.configuration import ExtendedLlamaConfig
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
|
|
1144 |
|
1145 |
_tied_weights_keys = ["lm_head.weight"]
|
1146 |
|
1147 |
+
def __init__(self, config, external_memories:list=None):
|
1148 |
super().__init__(config)
|
1149 |
self.model = ExtendedLlamaModel(config)
|
1150 |
self.vocab_size = config.vocab_size
|
|
|
1242 |
if (
|
1243 |
self.memory_ids is not None and self.memories is None
|
1244 |
):
|
1245 |
+
self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
|
1246 |
self.memories = self.generate_cache(
|
1247 |
+
self.memory_ids, cache_type=self.memory_type,
|
|
|
1248 |
)
|
1249 |
# EM: Remove special tokens from memory cache
|
1250 |
if self.remove_special_ids:
|