imthanhlv commited on
Commit
d7a3fe0
1 Parent(s): b544d5a

fixed prompt tokens

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -123,7 +123,7 @@ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
123
  tokens = torch.tensor(tokenizer.encode(prompt))
124
  tokens = tokens.unsqueeze(0).to(device)
125
  prompt_tokens = model.gpt.transformer.wte(tokens)
126
- generated = torch.cat((generated, prompt_tokens), dim=1)
127
 
128
  for i in range(entry_length):
129
  outputs = model.gpt(inputs_embeds=generated)
 
123
  tokens = torch.tensor(tokenizer.encode(prompt))
124
  tokens = tokens.unsqueeze(0).to(device)
125
  prompt_tokens = model.gpt.transformer.wte(tokens)
126
+ generated = torch.cat((generated, prompt_tokens), dim=1)
127
 
128
  for i in range(entry_length):
129
  outputs = model.gpt(inputs_embeds=generated)