imthanhlv commited on
Commit
b544d5a
1 Parent(s): 3800c65

fixed prompt

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -123,7 +123,8 @@ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
123
  tokens = torch.tensor(tokenizer.encode(prompt))
124
  tokens = tokens.unsqueeze(0).to(device)
125
  prompt_tokens = model.gpt.transformer.wte(tokens)
126
- print(">>>>", generated.shape, prompt_tokens.shape)
 
127
  for i in range(entry_length):
128
  outputs = model.gpt(inputs_embeds=generated)
129
  logits = outputs.logits
 
123
  tokens = torch.tensor(tokenizer.encode(prompt))
124
  tokens = tokens.unsqueeze(0).to(device)
125
  prompt_tokens = model.gpt.transformer.wte(tokens)
126
+ generated = torch.cat((generated, prompt_tokens), dim=1)
127
+
128
  for i in range(entry_length):
129
  outputs = model.gpt(inputs_embeds=generated)
130
  logits = outputs.logits