Spaces:
Running
Running
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Credits | |
This code is modified from https://github.com/GitYCC/g2pW | |
""" | |
from typing import Dict | |
from typing import List | |
from typing import Tuple | |
import numpy as np | |
from .utils import tokenize_and_map | |
ANCHOR_CHAR = "β" | |
def prepare_onnx_input( | |
tokenizer, | |
labels: List[str], | |
char2phonemes: Dict[str, List[int]], | |
chars: List[str], | |
texts: List[str], | |
query_ids: List[int], | |
use_mask: bool = False, | |
window_size: int = None, | |
max_len: int = 512, | |
) -> Dict[str, np.array]: | |
if window_size is not None: | |
truncated_texts, truncated_query_ids = _truncate_texts( | |
window_size=window_size, texts=texts, query_ids=query_ids | |
) | |
input_ids = [] | |
token_type_ids = [] | |
attention_masks = [] | |
phoneme_masks = [] | |
char_ids = [] | |
position_ids = [] | |
for idx in range(len(texts)): | |
text = (truncated_texts if window_size else texts)[idx].lower() | |
query_id = (truncated_query_ids if window_size else query_ids)[idx] | |
try: | |
tokens, text2token, token2text = tokenize_and_map( | |
tokenizer=tokenizer, text=text | |
) | |
except Exception: | |
print(f'warning: text "{text}" is invalid') | |
return {} | |
text, query_id, tokens, text2token, token2text = _truncate( | |
max_len=max_len, | |
text=text, | |
query_id=query_id, | |
tokens=tokens, | |
text2token=text2token, | |
token2text=token2text, | |
) | |
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] | |
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) | |
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) | |
attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) | |
query_char = text[query_id] | |
phoneme_mask = ( | |
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] | |
if use_mask | |
else [1] * len(labels) | |
) | |
char_id = chars.index(query_char) | |
position_id = text2token[query_id] + 1 # [CLS] token locate at first place | |
input_ids.append(input_id) | |
token_type_ids.append(token_type_id) | |
attention_masks.append(attention_mask) | |
phoneme_masks.append(phoneme_mask) | |
char_ids.append(char_id) | |
position_ids.append(position_id) | |
outputs = { | |
"input_ids": np.array(input_ids).astype(np.int64), | |
"token_type_ids": np.array(token_type_ids).astype(np.int64), | |
"attention_masks": np.array(attention_masks).astype(np.int64), | |
"phoneme_masks": np.array(phoneme_masks).astype(np.float32), | |
"char_ids": np.array(char_ids).astype(np.int64), | |
"position_ids": np.array(position_ids).astype(np.int64), | |
} | |
return outputs | |
def _truncate_texts( | |
window_size: int, texts: List[str], query_ids: List[int] | |
) -> Tuple[List[str], List[int]]: | |
truncated_texts = [] | |
truncated_query_ids = [] | |
for text, query_id in zip(texts, query_ids): | |
start = max(0, query_id - window_size // 2) | |
end = min(len(text), query_id + window_size // 2) | |
truncated_text = text[start:end] | |
truncated_texts.append(truncated_text) | |
truncated_query_id = query_id - start | |
truncated_query_ids.append(truncated_query_id) | |
return truncated_texts, truncated_query_ids | |
def _truncate( | |
max_len: int, | |
text: str, | |
query_id: int, | |
tokens: List[str], | |
text2token: List[int], | |
token2text: List[Tuple[int]], | |
): | |
truncate_len = max_len - 2 | |
if len(tokens) <= truncate_len: | |
return (text, query_id, tokens, text2token, token2text) | |
token_position = text2token[query_id] | |
token_start = token_position - truncate_len // 2 | |
token_end = token_start + truncate_len | |
font_exceed_dist = -token_start | |
back_exceed_dist = token_end - len(tokens) | |
if font_exceed_dist > 0: | |
token_start += font_exceed_dist | |
token_end += font_exceed_dist | |
elif back_exceed_dist > 0: | |
token_start -= back_exceed_dist | |
token_end -= back_exceed_dist | |
start = token2text[token_start][0] | |
end = token2text[token_end - 1][1] | |
return ( | |
text[start:end], | |
query_id - start, | |
tokens[token_start:token_end], | |
[i - token_start if i is not None else None for i in text2token[start:end]], | |
[(s - start, e - start) for s, e in token2text[token_start:token_end]], | |
) | |
def get_phoneme_labels( | |
polyphonic_chars: List[List[str]], | |
) -> Tuple[List[str], Dict[str, List[int]]]: | |
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) | |
char2phonemes = {} | |
for char, phoneme in polyphonic_chars: | |
if char not in char2phonemes: | |
char2phonemes[char] = [] | |
char2phonemes[char].append(labels.index(phoneme)) | |
return labels, char2phonemes | |
def get_char_phoneme_labels( | |
polyphonic_chars: List[List[str]], | |
) -> Tuple[List[str], Dict[str, List[int]]]: | |
labels = sorted( | |
list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])) | |
) | |
char2phonemes = {} | |
for char, phoneme in polyphonic_chars: | |
if char not in char2phonemes: | |
char2phonemes[char] = [] | |
char2phonemes[char].append(labels.index(f"{char} {phoneme}")) | |
return labels, char2phonemes | |