File size: 461 Bytes
b3cc940
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from transformers import PreTrainedTokenizerBase


def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor):
    # Workaround to make SentencePiece .decode() keep leading spaces in a token
    fake_token = tokenizer("^")["input_ids"][0]
    result = tokenizer.decode([fake_token] + outputs.tolist())

    # We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "</s>"
    return result.lstrip()[1:]