dar-tau commited on
Commit
e651bb1
1 Parent(s): 20c0832

Update interpret.py

Browse files
Files changed (1) hide show
  1. interpret.py +1 -1
interpret.py CHANGED
@@ -92,7 +92,7 @@ class InterpretationPrompt:
92
 
93
  def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
94
  num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
95
- tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)])
96
  module = model.get_submodule(layer_format.format(k=k))
97
  with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
98
  generated = model.generate(tokens_batch, **generation_kwargs)
 
92
 
93
  def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
94
  num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
95
+ tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
96
  module = model.get_submodule(layer_format.format(k=k))
97
  with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
98
  generated = model.generate(tokens_batch, **generation_kwargs)