petalschatlvn / utils.py
lavanjv's picture
Upload 7 files
2b58075
raw
history blame
461 Bytes
import torch
from transformers import PreTrainedTokenizerBase
def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor):
# Workaround to make SentencePiece .decode() keep leading spaces in a token
fake_token = tokenizer("^")["input_ids"][0]
result = tokenizer.decode([fake_token] + outputs.tolist())
# We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "</s>"
return result.lstrip()[1:]