ethanlshen commited on
Commit
70c86a7
1 Parent(s): cfcc2e7

Update superposed/llama/superposed_generation.py

Browse files
superposed/llama/superposed_generation.py CHANGED
@@ -150,7 +150,7 @@ class SuperposedLlama:
150
  ngrams=ngrams,
151
  get_time=get_time,
152
  penalty=penalty)
153
- unseen_first = torch.ones(bsz)
154
  # Superposition matrix
155
  token_weights = torch.zeros(bsz, self.model.vocab_size)
156
  if verbose:
 
150
  ngrams=ngrams,
151
  get_time=get_time,
152
  penalty=penalty)
153
+ unseen_first = torch.ones(bsz, device=self.device)
154
  # Superposition matrix
155
  token_weights = torch.zeros(bsz, self.model.vocab_size)
156
  if verbose: