phoebeklett commited on
Commit
35c3478
1 Parent(s): 4ea5d17

Upload 2 files

Browse files
Files changed (1) hide show
  1. modeling.py +4 -4
modeling.py CHANGED
@@ -47,7 +47,7 @@ from transformers.utils import (
47
  replace_return_docstrings,
48
  )
49
 
50
- from .configuration import ExtendedLlamaConfig
51
 
52
  logger = logging.get_logger(__name__)
53
 
@@ -1144,7 +1144,7 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
1144
 
1145
  _tied_weights_keys = ["lm_head.weight"]
1146
 
1147
- def __init__(self, config, external_memories=None):
1148
  super().__init__(config)
1149
  self.model = ExtendedLlamaModel(config)
1150
  self.vocab_size = config.vocab_size
@@ -1242,9 +1242,9 @@ class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
1242
  if (
1243
  self.memory_ids is not None and self.memories is None
1244
  ):
 
1245
  self.memories = self.generate_cache(
1246
- torch.tensor(self.memory_ids, device=self.device),
1247
- cache_type=self.memory_type,
1248
  )
1249
  # EM: Remove special tokens from memory cache
1250
  if self.remove_special_ids:
 
47
  replace_return_docstrings,
48
  )
49
 
50
+ from emts_clean.src.llama.configuration import ExtendedLlamaConfig
51
 
52
  logger = logging.get_logger(__name__)
53
 
 
1144
 
1145
  _tied_weights_keys = ["lm_head.weight"]
1146
 
1147
+ def __init__(self, config, external_memories:list=None):
1148
  super().__init__(config)
1149
  self.model = ExtendedLlamaModel(config)
1150
  self.vocab_size = config.vocab_size
 
1242
  if (
1243
  self.memory_ids is not None and self.memories is None
1244
  ):
1245
+ self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
1246
  self.memories = self.generate_cache(
1247
+ self.memory_ids, cache_type=self.memory_type,
 
1248
  )
1249
  # EM: Remove special tokens from memory cache
1250
  if self.remove_special_ids: