fix tokenization tensor
Browse files- tokenization_gptpangu.py +4 -0
tokenization_gptpangu.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
2 |
|
|
|
3 |
import sentencepiece
|
4 |
import jieba
|
5 |
|
@@ -37,6 +38,9 @@ class GPTPanguTokenizer(PreTrainedTokenizer):
|
|
37 |
return self.decode(ids)
|
38 |
|
39 |
def decode(self, tokens, **kwargs):
|
|
|
|
|
|
|
40 |
text = self.sp.decode(tokens)
|
41 |
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
42 |
return text
|
1 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
2 |
|
3 |
+
import torch
|
4 |
import sentencepiece
|
5 |
import jieba
|
6 |
|
38 |
return self.decode(ids)
|
39 |
|
40 |
def decode(self, tokens, **kwargs):
|
41 |
+
if isinstance(tokens, torch.Tensor):
|
42 |
+
tokens = tokens.tolist()
|
43 |
+
|
44 |
text = self.sp.decode(tokens)
|
45 |
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
|
46 |
return text
|