File size: 5,588 Bytes
da8d589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from lxml import etree
from typing import Any, List, Dict, Union
import logging
from modules.data import styles_mgr
from modules.speaker import speaker_mgr
from box import Box
import copy
class SSMLContext(Box):
def __init__(self, parent=None):
self.parent: Union[SSMLContext, None] = parent
self.style = None
self.spk = None
self.volume = None
self.rate = None
self.pitch = None
# tempurature
self.temp = None
self.top_p = None
self.top_k = None
self.seed = None
self.noramalize = None
self.prompt1 = None
self.prompt2 = None
self.prefix = None
class SSMLSegment(Box):
def __init__(self, text: str, attrs=SSMLContext()):
self.attrs = attrs
self.text = text
self.params = None
class SSMLBreak:
def __init__(self, duration_ms: Union[str, int, float]):
# TODO 支持其他单位
duration_ms = int(str(duration_ms).replace("ms", ""))
self.attrs = Box(**{"duration": duration_ms})
class SSMLParser:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.logger.debug("SSMLParser.__init__()")
self.resolvers = []
def resolver(self, tag: str):
def decorator(func):
self.resolvers.append((tag, func))
return func
return decorator
def parse(self, ssml: str) -> List[Union[SSMLSegment, SSMLBreak]]:
root = etree.fromstring(ssml)
root_ctx = SSMLContext()
segments = []
self.resolve(root, root_ctx, segments)
return segments
def resolve(
self, element: etree.Element, context: SSMLContext, segments: List[SSMLSegment]
):
resolver = [resolver for tag, resolver in self.resolvers if tag == element.tag]
if len(resolver) == 0:
raise NotImplementedError(f"Tag {element.tag} not supported.")
else:
resolver = resolver[0]
resolver(element, context, segments, self)
def create_ssml_parser():
parser = SSMLParser()
@parser.resolver("speak")
def tag_speak(element, context, segments, parser):
ctx = copy.deepcopy(context)
version = element.get("version")
if version != "0.1":
raise ValueError(f"Unsupported SSML version {version}")
for child in element:
parser.resolve(child, ctx, segments)
@parser.resolver("voice")
def tag_voice(element, context, segments, parser):
ctx = copy.deepcopy(context)
ctx.spk = element.get("spk", ctx.spk)
ctx.style = element.get("style", ctx.style)
ctx.spk = element.get("spk", ctx.spk)
ctx.volume = element.get("volume", ctx.volume)
ctx.rate = element.get("rate", ctx.rate)
ctx.pitch = element.get("pitch", ctx.pitch)
# tempurature
ctx.temp = element.get("temp", ctx.temp)
ctx.top_p = element.get("top_p", ctx.top_p)
ctx.top_k = element.get("top_k", ctx.top_k)
ctx.seed = element.get("seed", ctx.seed)
ctx.noramalize = element.get("noramalize", ctx.noramalize)
ctx.prompt1 = element.get("prompt1", ctx.prompt1)
ctx.prompt2 = element.get("prompt2", ctx.prompt2)
ctx.prefix = element.get("prefix", ctx.prefix)
# 处理 voice 开头的文本
if element.text and element.text.strip():
segments.append(SSMLSegment(element.text.strip(), ctx))
for child in element:
parser.resolve(child, ctx, segments)
# 处理 voice 结尾的文本
if child.tail and child.tail.strip():
segments.append(SSMLSegment(child.tail.strip(), ctx))
@parser.resolver("break")
def tag_break(element, context, segments, parser):
time_ms = int(element.get("time", "0").replace("ms", ""))
segments.append(SSMLBreak(time_ms))
@parser.resolver("prosody")
def tag_prosody(element, context, segments, parser):
ctx = copy.deepcopy(context)
ctx.spk = element.get("spk", ctx.spk)
ctx.style = element.get("style", ctx.style)
ctx.spk = element.get("spk", ctx.spk)
ctx.volume = element.get("volume", ctx.volume)
ctx.rate = element.get("rate", ctx.rate)
ctx.pitch = element.get("pitch", ctx.pitch)
# tempurature
ctx.temp = element.get("temp", ctx.temp)
ctx.top_p = element.get("top_p", ctx.top_p)
ctx.top_k = element.get("top_k", ctx.top_k)
ctx.seed = element.get("seed", ctx.seed)
ctx.noramalize = element.get("noramalize", ctx.noramalize)
ctx.prompt1 = element.get("prompt1", ctx.prompt1)
ctx.prompt2 = element.get("prompt2", ctx.prompt2)
ctx.prefix = element.get("prefix", ctx.prefix)
if element.text and element.text.strip():
segments.append(SSMLSegment(element.text.strip(), ctx))
return parser
if __name__ == "__main__":
parser = create_ssml_parser()
ssml = """
<speak version="0.1">
<voice spk="xiaoyan" style="news">
<prosody rate="fast">你好</prosody>
<break time="500ms"/>
<prosody rate="slow">你好</prosody>
</voice>
</speak>
"""
segments = parser.parse(ssml)
for segment in segments:
if isinstance(segment, SSMLBreak):
print("<break>", segment.attrs)
elif isinstance(segment, SSMLSegment):
print(segment.text, segment.attrs)
else:
raise ValueError("Unknown segment type")
|