syberWolf commited on
Commit
6bfeebe
1 Parent(s): 2b0d49f

squeeze em bits

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -19,7 +19,7 @@ def generate_response(text):
19
  mask_token_id = tokenizer.mask_token_id
20
  mask_token_index = torch.where(input_ids == mask_token_id)[1]
21
  token_logits = logits[0, mask_token_index, :]
22
- top_5_tokens = torch.topk(token_logits, k=5).indices # get top 5 tokens
23
  predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens.tolist()) # convert ids to tokens
24
 
25
  # Choose one of the predicted tokens randomly and replace the mask with it
 
19
  mask_token_id = tokenizer.mask_token_id
20
  mask_token_index = torch.where(input_ids == mask_token_id)[1]
21
  token_logits = logits[0, mask_token_index, :]
22
+ top_5_tokens = torch.topk(token_logits.squeeze(), k=5).indices # get top 5 tokens
23
  predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens.tolist()) # convert ids to tokens
24
 
25
  # Choose one of the predicted tokens randomly and replace the mask with it