Spaces:
Runtime error
Runtime error
squeeze em bits
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ def generate_response(text):
|
|
19 |
mask_token_id = tokenizer.mask_token_id
|
20 |
mask_token_index = torch.where(input_ids == mask_token_id)[1]
|
21 |
token_logits = logits[0, mask_token_index, :]
|
22 |
-
top_5_tokens = torch.topk(token_logits, k=5).indices # get top 5 tokens
|
23 |
predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens.tolist()) # convert ids to tokens
|
24 |
|
25 |
# Choose one of the predicted tokens randomly and replace the mask with it
|
|
|
19 |
mask_token_id = tokenizer.mask_token_id
|
20 |
mask_token_index = torch.where(input_ids == mask_token_id)[1]
|
21 |
token_logits = logits[0, mask_token_index, :]
|
22 |
+
top_5_tokens = torch.topk(token_logits.squeeze(), k=5).indices # get top 5 tokens
|
23 |
predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens.tolist()) # convert ids to tokens
|
24 |
|
25 |
# Choose one of the predicted tokens randomly and replace the mask with it
|