zhzluke96 commited on
Commit
c5458aa
1 Parent(s): f34bda5
modules/api/impl/google_api.py CHANGED
@@ -11,6 +11,7 @@ from modules.utils.audio import apply_prosody_to_audio_data
11
  from modules.normalization import text_normalize
12
 
13
  from modules import generate_audio as generate
 
14
 
15
 
16
  from modules.ssml import parse_ssml
@@ -74,6 +75,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
74
  volume_gain_db = audioConfig.get("volumeGainDb", 0)
75
 
76
  batch_size = audioConfig.get("batchSize", 1)
 
 
77
  spliter_threshold = audioConfig.get("spliterThreshold", 100)
78
 
79
  # TODO sample_rate
@@ -84,6 +87,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
84
  # TODO maybe need to change the sample rate
85
  sample_rate = 24000
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
  if input.text:
89
  # 处理文本合成逻辑
 
11
  from modules.normalization import text_normalize
12
 
13
  from modules import generate_audio as generate
14
+ from modules.speaker import speaker_mgr
15
 
16
 
17
  from modules.ssml import parse_ssml
 
75
  volume_gain_db = audioConfig.get("volumeGainDb", 0)
76
 
77
  batch_size = audioConfig.get("batchSize", 1)
78
+
79
+ # TODO spliter_threshold
80
  spliter_threshold = audioConfig.get("spliterThreshold", 100)
81
 
82
  # TODO sample_rate
 
87
  # TODO maybe need to change the sample rate
88
  sample_rate = 24000
89
 
90
+ # TODO 使用 speaker
91
+ spk = speaker_mgr.get_speaker(voice_name)
92
+ if spk is None:
93
+ raise HTTPException(
94
+ status_code=400, detail="The specified voice name is not supported."
95
+ )
96
+
97
+ if audio_format != "mp3" and audio_format != "wav":
98
+ raise HTTPException(
99
+ status_code=400, detail="Invalid audio encoding format specified."
100
+ )
101
+
102
  try:
103
  if input.text:
104
  # 处理文本合成逻辑
modules/api/impl/openai_api.py CHANGED
@@ -20,6 +20,9 @@ import pyrubberband as pyrb
20
  from modules.api import utils as api_utils
21
  from modules.api.Api import APIManager
22
 
 
 
 
23
  import numpy as np
24
 
25
 
@@ -29,6 +32,8 @@ class AudioSpeechRequest(BaseModel):
29
  voice: str = "female2"
30
  response_format: Literal["mp3", "wav"] = "mp3"
31
  speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
 
 
32
  style: str = ""
33
  # 是否开启batch合成,小于等于1表示不适用batch
34
  # 开启batch合成会自动分割句子
@@ -43,20 +48,27 @@ async def openai_speech_api(
43
  ..., description="JSON body with model, input text, and voice"
44
  )
45
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
- model = request.model
48
- input_text = request.input
49
- voice = request.voice
50
- style = request.style
51
- response_format = request.response_format
52
- batch_size = request.batch_size
53
- spliter_threshold = request.spliter_threshold
54
- speed = request.speed
55
- speed = clip(speed, 0.1, 10)
56
-
57
- if not input_text:
58
- raise HTTPException(status_code=400, detail="Input text is required.")
59
 
 
60
  # Normalize the text
61
  text = text_normalize(input_text, is_end=True)
62
 
@@ -112,7 +124,7 @@ class TranscribeSegment(BaseModel):
112
  start: float
113
  end: float
114
  text: str
115
- tokens: List[int]
116
  temperature: float
117
  avg_logprob: float
118
  compression_ratio: float
@@ -124,7 +136,7 @@ class TranscriptionsVerboseResponse(BaseModel):
124
  language: str
125
  duration: float
126
  text: str
127
- segments: List[TranscribeSegment]
128
 
129
 
130
  def setup(app: APIManager):
@@ -146,8 +158,8 @@ openai api document:
146
 
147
  @app.post(
148
  "/v1/audio/transcriptions",
149
- response_class=TranscriptionsVerboseResponse,
150
- description="WIP",
151
  )
152
  async def transcribe(
153
  file: UploadFile = File(...),
@@ -159,12 +171,4 @@ openai api document:
159
  timestamp_granularities: List[str] = Form(["segment"]),
160
  ):
161
  # TODO: Implement transcribe
162
- return {
163
- "file": file.filename,
164
- "model": model,
165
- "language": language,
166
- "prompt": prompt,
167
- "response_format": response_format,
168
- "temperature": temperature,
169
- "timestamp_granularities": timestamp_granularities,
170
- }
 
20
  from modules.api import utils as api_utils
21
  from modules.api.Api import APIManager
22
 
23
+ from modules.speaker import speaker_mgr
24
+ from modules.data import styles_mgr
25
+
26
  import numpy as np
27
 
28
 
 
32
  voice: str = "female2"
33
  response_format: Literal["mp3", "wav"] = "mp3"
34
  speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
35
+ seed: int = 42
36
+ temperature: float = 0.3
37
  style: str = ""
38
  # 是否开启batch合成,小于等于1表示不适用batch
39
  # 开启batch合成会自动分割句子
 
48
  ..., description="JSON body with model, input text, and voice"
49
  )
50
  ):
51
+ model = request.model
52
+ input_text = request.input
53
+ voice = request.voice
54
+ style = request.style
55
+ response_format = request.response_format
56
+ batch_size = request.batch_size
57
+ spliter_threshold = request.spliter_threshold
58
+ speed = request.speed
59
+ speed = clip(speed, 0.1, 10)
60
+
61
+ if not input_text:
62
+ raise HTTPException(status_code=400, detail="Input text is required.")
63
+ if speaker_mgr.get_speaker(voice) is None:
64
+ raise HTTPException(status_code=400, detail="Invalid voice.")
65
  try:
66
+ if style:
67
+ styles_mgr.find_item_by_name(style)
68
+ except:
69
+ raise HTTPException(status_code=400, detail="Invalid style.")
 
 
 
 
 
 
 
 
70
 
71
+ try:
72
  # Normalize the text
73
  text = text_normalize(input_text, is_end=True)
74
 
 
124
  start: float
125
  end: float
126
  text: str
127
+ tokens: list[int]
128
  temperature: float
129
  avg_logprob: float
130
  compression_ratio: float
 
136
  language: str
137
  duration: float
138
  text: str
139
+ segments: list[TranscribeSegment]
140
 
141
 
142
  def setup(app: APIManager):
 
158
 
159
  @app.post(
160
  "/v1/audio/transcriptions",
161
+ response_model=TranscriptionsVerboseResponse,
162
+ description="Transcribes audio into the input language.",
163
  )
164
  async def transcribe(
165
  file: UploadFile = File(...),
 
171
  timestamp_granularities: List[str] = Form(["segment"]),
172
  ):
173
  # TODO: Implement transcribe
174
+ return api_utils.success_response("not implemented yet")
 
 
 
 
 
 
 
 
modules/api/utils.py CHANGED
@@ -29,12 +29,6 @@ class BaseResponse(BaseModel):
29
  message: str
30
  data: Any
31
 
32
- class Config:
33
- json_encoders = {
34
- torch.Tensor: lambda v: v.tolist(),
35
- Speaker: lambda v: v.to_json(),
36
- }
37
-
38
 
39
  def success_response(data: Any, message: str = "ok") -> BaseResponse:
40
  return BaseResponse(message=message, data=data)
 
29
  message: str
30
  data: Any
31
 
 
 
 
 
 
 
32
 
33
  def success_response(data: Any, message: str = "ok") -> BaseResponse:
34
  return BaseResponse(message=message, data=data)
modules/devices/devices.py CHANGED
@@ -74,7 +74,7 @@ def get_target_device_id_or_memory_available_gpu():
74
 
75
 
76
  def get_optimal_device_name():
77
- if config.runtime_env_vars.use_cpu:
78
  return "cpu"
79
 
80
  if torch.cuda.is_available():
 
74
 
75
 
76
  def get_optimal_device_name():
77
+ if config.runtime_env_vars.use_cpu == "all":
78
  return "cpu"
79
 
80
  if torch.cuda.is_available():
modules/utils/CsvMgr.py CHANGED
@@ -15,6 +15,7 @@ class DataNotFoundError(Exception):
15
  pass
16
 
17
 
 
18
  class BaseManager:
19
  def __init__(self, csv_file):
20
  self.csv_file = csv_file
 
15
  pass
16
 
17
 
18
+ # FIXME: 😓这个东西写的比较拉跨,最好找个什么csv库替代掉...
19
  class BaseManager:
20
  def __init__(self, csv_file):
21
  self.csv_file = csv_file
webui.py CHANGED
@@ -40,7 +40,7 @@ from modules.api.utils import calc_spk_style
40
  import modules.generate_audio as generate
41
 
42
  from modules.normalization import text_normalize
43
- from modules import refiner, config
44
 
45
  from modules.utils import env, audio
46
  from modules.SentenceSplitter import SentenceSplitter
@@ -101,6 +101,7 @@ def synthesize_ssml(ssml: str, batch_size=4):
101
  if len(segments) == 0:
102
  return None
103
 
 
104
  synthesize = SynthesizeSegments(batch_size=batch_size)
105
  audio_segments = synthesize.synthesize_segments(segments)
106
  combined_audio = combine_audio_segments(audio_segments)
@@ -157,6 +158,7 @@ def tts_generate(
157
  if not disable_normalize:
158
  text = text_normalize(text)
159
 
 
160
  sample_rate, audio_data = synthesize_audio(
161
  text=text,
162
  temperature=temperature,
 
40
  import modules.generate_audio as generate
41
 
42
  from modules.normalization import text_normalize
43
+ from modules import refiner, config, models
44
 
45
  from modules.utils import env, audio
46
  from modules.SentenceSplitter import SentenceSplitter
 
101
  if len(segments) == 0:
102
  return None
103
 
104
+ models.load_chat_tts()
105
  synthesize = SynthesizeSegments(batch_size=batch_size)
106
  audio_segments = synthesize.synthesize_segments(segments)
107
  combined_audio = combine_audio_segments(audio_segments)
 
158
  if not disable_normalize:
159
  text = text_normalize(text)
160
 
161
+ models.load_chat_tts()
162
  sample_rate, audio_data = synthesize_audio(
163
  text=text,
164
  temperature=temperature,