File size: 3,817 Bytes
87cf786 ab5108b 87cf786 ab5108b 87cf786 ab5108b d62201b 87cf786 ab5108b 87cf786 ab5108b 87cf786 d62201b 87cf786 a904049 87cf786 d62201b a904049 87cf786 ab5108b 87cf786 a904049 a27e354 87cf786 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
import logging
import os
from pathlib import Path
import re
from transformers import SpeechT5Tokenizer
from transformers.models.speecht5.tokenization_speecht5 import (
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES,
)
from itertools import chain
from typing import List, Optional, Tuple
logger = logging.getLogger(__name__)
NP_CHARCTERS = " !\"#$%&'()=~|`{+*}<>?_-^\\@[;:],./ !”#$%&’()=~|`{+*}<>?_ー^¥@「;:」、。・`"
def _g2p_with_np(text: str, np_lsit: str) -> List[str]:
from pyopenjtalk import g2p
np_pattern = re.compile(f"([{re.escape(np_lsit)}])")
return list(
chain.from_iterable(
[
(text,) if text in np_lsit else g2p(text, kana=False, join=False)
for text in np_pattern.split(text)
if len(text) > 0
]
)
)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"esnya/japanese_speecht5_tts": "https://huggingface.co/esnya/japanese_speecht5_tts/resolve/main/vocab.json",
},
}
class SpeechT5OpenjtalkTokenizer(SpeechT5Tokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
bos_token: str = "<s>",
eos_token: str = "</s>",
unk_token: str = "<unk>",
pad_token: str = "<pad>",
non_phenome_characters: str = NP_CHARCTERS,
**kwargs,
):
try:
super().__init__(
vocab_file=None,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
except TypeError:
pass
self.non_phenome_characters = non_phenome_characters
self.vocab_file = vocab_file
self._load_vocab()
def _load_vocab(self):
if isinstance(self.vocab_file, str) and self.vocab_file.endswith(".json"):
with open(self.vocab_file, encoding="utf-8") as f:
self.label2id = json.load(f)
self.id2label = {v: k for k, v in self.label2id.items()}
@property
def bos_token_id(self) -> int | None:
return super().bos_token_id
@property
def vocab_size(self):
return len(self.label2id)
def get_vocab(self):
return self.label2id
def __getstate__(self):
state = super().__getstate__()
del state["sp_model"]
return state
def __setstate__(self, d):
self.__dict__ = d
self._load_vocab()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
):
if filename_prefix is None:
filename_prefix = ".json"
save_path = Path(save_directory)
if not save_path.is_dir():
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_path = Path(save_directory) / Path(f"vocab{filename_prefix}")
vocab_path.parent.mkdir(parents=True, exist_ok=True)
with open(vocab_path, "w", encoding="utf-8") as f:
json.dump(self.label2id, f, ensure_ascii=False, indent=2)
return (str(vocab_path),)
def _tokenize(self, text: str) -> List[str]:
return _g2p_with_np(text, self.non_phenome_characters)
def _convert_token_to_id(self, token):
return self.label2id.get(token, self.label2id.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.id2label.get(index, self.unk_token)
|