shriniket73 commited on
Commit
c6bc3fd
1 Parent(s): 196cb0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -57
app.py CHANGED
@@ -21,24 +21,26 @@ from TTS.api import TTS
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
 
24
  class TTSRequest(BaseModel):
25
  text: str
26
 
 
27
  class OptimizedTTSService:
28
  def __init__(self):
29
  logger.info("Initializing Optimized TTS Service...")
30
-
31
  try:
32
  # Set TTS home directory and accept license
33
- os.environ['HOME'] = '/tmp/home'
34
- os.environ['TTS_HOME'] = '/tmp/tts_home'
35
  os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license
36
-
37
  # Set number of threads for PyTorch
38
  n_threads = max(2, multiprocessing.cpu_count() - 1)
39
  torch.set_num_threads(n_threads)
40
  logger.info(f"Using {n_threads} CPU threads")
41
-
42
  # Initialize TTS with error handling
43
  try:
44
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
@@ -48,40 +50,40 @@ class OptimizedTTSService:
48
  except Exception as e:
49
  logger.error(f"Failed to load TTS model: {str(e)}")
50
  raise
51
-
52
  # Load latents
53
  try:
54
  logger.info("Loading voice latents...")
55
  latents_path = "models/goggins_latents.pt"
56
  if not os.path.exists(latents_path):
57
  raise FileNotFoundError(f"Latents file not found at {latents_path}")
58
-
59
- self.latents = torch.load(latents_path, map_location='cpu')
60
  logger.info("Latents loaded successfully")
61
  except Exception as e:
62
  logger.error(f"Failed to load latents: {str(e)}")
63
  raise
64
-
65
  # Initialize thread pool for parallel processing
66
  self.executor = ThreadPoolExecutor(max_workers=n_threads)
67
-
68
  # Configure model for inference
69
  self.model = self.tts.synthesizer.tts_model
70
  self.model.eval()
71
-
72
  # Initialize device
73
  self.device = torch.device("cpu")
74
  logger.info(f"Using device: {self.device}")
75
-
76
  # Initialize cache
77
  self._setup_cache()
78
-
79
  logger.info("Service initialization complete!")
80
-
81
  except Exception as e:
82
  logger.error(f"Failed to initialize TTS service: {str(e)}")
83
  raise
84
-
85
  def _setup_cache(self):
86
  """Setup caching mechanisms with error handling"""
87
  try:
@@ -93,25 +95,23 @@ class OptimizedTTSService:
93
  except Exception as e:
94
  logger.error(f"Failed to setup cache: {str(e)}")
95
  raise
96
-
97
  def _process_chunk(self, chunk: str) -> np.ndarray:
98
  """Process a single chunk of text with improved error handling"""
99
  try:
100
  # Convert latents to tensors
101
  speaker_embedding = torch.tensor(
102
- self.latents['speaker_embedding'],
103
  dtype=torch.float32,
104
- device=self.device
105
  )
106
  gpt_cond_latent = torch.tensor(
107
- self.latents['gpt_cond_latent'],
108
- dtype=torch.float32,
109
- device=self.device
110
  )
111
-
112
  # Get optimized parameters based on chunk length
113
  params = self._get_params_for_length(len(chunk))
114
-
115
  # Generate speech
116
  with torch.no_grad():
117
  wav = self.model.inference(
@@ -119,76 +119,77 @@ class OptimizedTTSService:
119
  language="en",
120
  gpt_cond_latent=gpt_cond_latent,
121
  speaker_embedding=speaker_embedding,
122
- **params
123
  )
124
-
125
  return wav["wav"]
126
-
127
  except Exception as e:
128
  logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}")
129
  raise
130
-
131
  def _get_params_for_length(self, chunk_length: int) -> Dict:
132
  """Get optimized parameters based on text length"""
133
  if chunk_length <= 80:
134
  return {
135
- 'temperature': 0.75,
136
- 'length_penalty': 0.8,
137
- 'repetition_penalty': 1.8,
138
- 'top_k': 40,
139
- 'top_p': 0.80
140
  }
141
  elif chunk_length <= 150:
142
  return {
143
- 'temperature': 0.85,
144
- 'length_penalty': 1.0,
145
- 'repetition_penalty': 2.0,
146
- 'top_k': 50,
147
- 'top_p': 0.85
148
  }
149
  else:
150
  return {
151
- 'temperature': 0.95,
152
- 'length_penalty': 1.2,
153
- 'repetition_penalty': 2.2,
154
- 'top_k': 60,
155
- 'top_p': 0.90
156
  }
157
-
158
  def generate_speech(self, text: str) -> np.ndarray:
159
  """Generate speech with improved error handling"""
160
  try:
161
  # Clean and validate input
162
  if not text or not text.strip():
163
  raise ValueError("Empty text input")
164
-
165
  text = text.strip()
166
  if len(text) > 1000: # Add reasonable limit
167
  raise ValueError("Text too long (max 1000 characters)")
168
-
169
  # Process single chunk for short text
170
  if len(text) <= 150:
171
  return self._process_chunk(text)
172
-
173
  # Split longer text into chunks
174
- chunks = text.split('. ')
175
- chunks = [chunk.strip() + '.' for chunk in chunks if chunk.strip()]
176
-
177
  # Process chunks
178
  wavs = []
179
  for i, chunk in enumerate(chunks, 1):
180
  logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...")
181
  wav = self._process_chunk(chunk)
182
  wavs.append(wav)
183
-
184
  # Concatenate results
185
  final_wav = np.concatenate(wavs)
186
  return final_wav
187
-
188
  except Exception as e:
189
  logger.error(f"Error in generate_speech: {str(e)}")
190
  raise
191
 
 
192
  # Initialize FastAPI app
193
  app = FastAPI(title="Goggins TTS API")
194
 
@@ -204,6 +205,7 @@ app.add_middleware(
204
  # Initialize service
205
  service = None
206
 
 
207
  @app.on_event("startup")
208
  async def startup_event():
209
  global service
@@ -213,36 +215,45 @@ async def startup_event():
213
  logger.error(f"Failed to initialize service: {str(e)}")
214
  raise
215
 
 
216
  @app.post("/generate")
217
  async def generate_speech(request: TTSRequest):
218
  """Generate speech from text with detailed timing"""
219
  try:
220
  total_start = time.time()
221
  logger.info(f"\nReceived request for text: {request.text[:50]}...")
222
-
223
  # Model processing time
224
  model_start = time.time()
225
  wav = service.generate_speech(request.text)
226
  model_time = time.time() - model_start
227
-
228
  # Audio conversion time
229
  conversion_start = time.time()
230
  buffer = io.BytesIO()
231
  np.save(buffer, wav.astype(np.float32))
232
  audio_base64 = base64.b64encode(buffer.getvalue()).decode()
233
  conversion_time = time.time() - conversion_start
234
-
235
  # Total processing time
236
  total_time = time.time() - total_start
237
-
238
  timing_info = {
239
  "total_processing_time": round(total_time, 2),
240
  "model_processing_time": round(model_time, 2),
241
- "audio_conversion_time": round(conversion_time, 2)
242
  }
243
-
244
  logger.info(f"Timing breakdown: {timing_info}")
245
 
 
 
 
 
 
 
 
 
246
  @app.get("/health")
247
  async def health_check():
248
  """Health check endpoint"""
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+
25
  class TTSRequest(BaseModel):
26
  text: str
27
 
28
+
29
  class OptimizedTTSService:
30
  def __init__(self):
31
  logger.info("Initializing Optimized TTS Service...")
32
+
33
  try:
34
  # Set TTS home directory and accept license
35
+ os.environ["HOME"] = "/tmp/home"
36
+ os.environ["TTS_HOME"] = "/tmp/tts_home"
37
  os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license
38
+
39
  # Set number of threads for PyTorch
40
  n_threads = max(2, multiprocessing.cpu_count() - 1)
41
  torch.set_num_threads(n_threads)
42
  logger.info(f"Using {n_threads} CPU threads")
43
+
44
  # Initialize TTS with error handling
45
  try:
46
  model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
 
50
  except Exception as e:
51
  logger.error(f"Failed to load TTS model: {str(e)}")
52
  raise
53
+
54
  # Load latents
55
  try:
56
  logger.info("Loading voice latents...")
57
  latents_path = "models/goggins_latents.pt"
58
  if not os.path.exists(latents_path):
59
  raise FileNotFoundError(f"Latents file not found at {latents_path}")
60
+
61
+ self.latents = torch.load(latents_path, map_location="cpu")
62
  logger.info("Latents loaded successfully")
63
  except Exception as e:
64
  logger.error(f"Failed to load latents: {str(e)}")
65
  raise
66
+
67
  # Initialize thread pool for parallel processing
68
  self.executor = ThreadPoolExecutor(max_workers=n_threads)
69
+
70
  # Configure model for inference
71
  self.model = self.tts.synthesizer.tts_model
72
  self.model.eval()
73
+
74
  # Initialize device
75
  self.device = torch.device("cpu")
76
  logger.info(f"Using device: {self.device}")
77
+
78
  # Initialize cache
79
  self._setup_cache()
80
+
81
  logger.info("Service initialization complete!")
82
+
83
  except Exception as e:
84
  logger.error(f"Failed to initialize TTS service: {str(e)}")
85
  raise
86
+
87
  def _setup_cache(self):
88
  """Setup caching mechanisms with error handling"""
89
  try:
 
95
  except Exception as e:
96
  logger.error(f"Failed to setup cache: {str(e)}")
97
  raise
98
+
99
  def _process_chunk(self, chunk: str) -> np.ndarray:
100
  """Process a single chunk of text with improved error handling"""
101
  try:
102
  # Convert latents to tensors
103
  speaker_embedding = torch.tensor(
104
+ self.latents["speaker_embedding"],
105
  dtype=torch.float32,
106
+ device=self.device,
107
  )
108
  gpt_cond_latent = torch.tensor(
109
+ self.latents["gpt_cond_latent"], dtype=torch.float32, device=self.device
 
 
110
  )
111
+
112
  # Get optimized parameters based on chunk length
113
  params = self._get_params_for_length(len(chunk))
114
+
115
  # Generate speech
116
  with torch.no_grad():
117
  wav = self.model.inference(
 
119
  language="en",
120
  gpt_cond_latent=gpt_cond_latent,
121
  speaker_embedding=speaker_embedding,
122
+ **params,
123
  )
124
+
125
  return wav["wav"]
126
+
127
  except Exception as e:
128
  logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}")
129
  raise
130
+
131
  def _get_params_for_length(self, chunk_length: int) -> Dict:
132
  """Get optimized parameters based on text length"""
133
  if chunk_length <= 80:
134
  return {
135
+ "temperature": 0.75,
136
+ "length_penalty": 0.8,
137
+ "repetition_penalty": 1.8,
138
+ "top_k": 40,
139
+ "top_p": 0.80,
140
  }
141
  elif chunk_length <= 150:
142
  return {
143
+ "temperature": 0.85,
144
+ "length_penalty": 1.0,
145
+ "repetition_penalty": 2.0,
146
+ "top_k": 50,
147
+ "top_p": 0.85,
148
  }
149
  else:
150
  return {
151
+ "temperature": 0.95,
152
+ "length_penalty": 1.2,
153
+ "repetition_penalty": 2.2,
154
+ "top_k": 60,
155
+ "top_p": 0.90,
156
  }
157
+
158
  def generate_speech(self, text: str) -> np.ndarray:
159
  """Generate speech with improved error handling"""
160
  try:
161
  # Clean and validate input
162
  if not text or not text.strip():
163
  raise ValueError("Empty text input")
164
+
165
  text = text.strip()
166
  if len(text) > 1000: # Add reasonable limit
167
  raise ValueError("Text too long (max 1000 characters)")
168
+
169
  # Process single chunk for short text
170
  if len(text) <= 150:
171
  return self._process_chunk(text)
172
+
173
  # Split longer text into chunks
174
+ chunks = text.split(". ")
175
+ chunks = [chunk.strip() + "." for chunk in chunks if chunk.strip()]
176
+
177
  # Process chunks
178
  wavs = []
179
  for i, chunk in enumerate(chunks, 1):
180
  logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...")
181
  wav = self._process_chunk(chunk)
182
  wavs.append(wav)
183
+
184
  # Concatenate results
185
  final_wav = np.concatenate(wavs)
186
  return final_wav
187
+
188
  except Exception as e:
189
  logger.error(f"Error in generate_speech: {str(e)}")
190
  raise
191
 
192
+
193
  # Initialize FastAPI app
194
  app = FastAPI(title="Goggins TTS API")
195
 
 
205
  # Initialize service
206
  service = None
207
 
208
+
209
  @app.on_event("startup")
210
  async def startup_event():
211
  global service
 
215
  logger.error(f"Failed to initialize service: {str(e)}")
216
  raise
217
 
218
+
219
  @app.post("/generate")
220
  async def generate_speech(request: TTSRequest):
221
  """Generate speech from text with detailed timing"""
222
  try:
223
  total_start = time.time()
224
  logger.info(f"\nReceived request for text: {request.text[:50]}...")
225
+
226
  # Model processing time
227
  model_start = time.time()
228
  wav = service.generate_speech(request.text)
229
  model_time = time.time() - model_start
230
+
231
  # Audio conversion time
232
  conversion_start = time.time()
233
  buffer = io.BytesIO()
234
  np.save(buffer, wav.astype(np.float32))
235
  audio_base64 = base64.b64encode(buffer.getvalue()).decode()
236
  conversion_time = time.time() - conversion_start
237
+
238
  # Total processing time
239
  total_time = time.time() - total_start
240
+
241
  timing_info = {
242
  "total_processing_time": round(total_time, 2),
243
  "model_processing_time": round(model_time, 2),
244
+ "audio_conversion_time": round(conversion_time, 2),
245
  }
246
+
247
  logger.info(f"Timing breakdown: {timing_info}")
248
 
249
+ # Add the missing return statement
250
+ return {"status": "success", "audio": audio_base64, "timing": timing_info}
251
+
252
+ except Exception as e:
253
+ logger.error(f"Error in generate_speech endpoint: {str(e)}")
254
+ raise HTTPException(status_code=500, detail=str(e))
255
+
256
+
257
  @app.get("/health")
258
  async def health_check():
259
  """Health check endpoint"""