Балаганский Никита Николаевич commited on
Commit
e2f0b3b
1 Parent(s): 116ed83

add logging

Browse files
Files changed (1) hide show
  1. generator.py +2 -2
generator.py CHANGED
@@ -199,12 +199,12 @@ class Generator:
199
  return input_ids, past, ended_sequences
200
 
201
  def get_input_ids(self, input_prompt, num_samples):
202
- input_ids = torch.tensor([[self.lm.config.bos_token_id]])
203
  if input_prompt is not None:
204
  input_prompt = self.tokenizer(
205
  input_prompt, return_tensors="pt"
206
  ).input_ids
207
- input_ids = torch.cat([input_ids, input_prompt], 1)
208
  input_ids = input_ids.repeat(num_samples, 1).to(self.device)
209
  past = None
210
  ended_sequences = torch.zeros(
 
199
  return input_ids, past, ended_sequences
200
 
201
  def get_input_ids(self, input_prompt, num_samples):
202
+ #input_ids = torch.tensor([[self.lm.config.bos_token_id]])
203
  if input_prompt is not None:
204
  input_prompt = self.tokenizer(
205
  input_prompt, return_tensors="pt"
206
  ).input_ids
207
+ input_ids = input_prompt
208
  input_ids = input_ids.repeat(num_samples, 1).to(self.device)
209
  past = None
210
  ended_sequences = torch.zeros(