import torch from transformers import PreTrainedTokenizerBase def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor): # Workaround to make SentencePiece .decode() keep leading spaces in a token fake_token = tokenizer("^")["input_ids"][0] result = tokenizer.decode([fake_token] + outputs.tolist()) # We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "" return result.lstrip()[1:]