Update interpret.py
Browse files- interpret.py +1 -1
interpret.py
CHANGED
@@ -92,7 +92,7 @@ class InterpretationPrompt:
|
|
92 |
|
93 |
def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
|
94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
95 |
-
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)])
|
96 |
module = model.get_submodule(layer_format.format(k=k))
|
97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|
|
|
92 |
|
93 |
def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
|
94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
95 |
+
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
|
96 |
module = model.get_submodule(layer_format.format(k=k))
|
97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|