Spaces:
Sleeping
Sleeping
shriniket73
commited on
Commit
•
c6bc3fd
1
Parent(s):
196cb0c
Update app.py
Browse files
app.py
CHANGED
@@ -21,24 +21,26 @@ from TTS.api import TTS
|
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
|
|
24 |
class TTSRequest(BaseModel):
|
25 |
text: str
|
26 |
|
|
|
27 |
class OptimizedTTSService:
|
28 |
def __init__(self):
|
29 |
logger.info("Initializing Optimized TTS Service...")
|
30 |
-
|
31 |
try:
|
32 |
# Set TTS home directory and accept license
|
33 |
-
os.environ[
|
34 |
-
os.environ[
|
35 |
os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license
|
36 |
-
|
37 |
# Set number of threads for PyTorch
|
38 |
n_threads = max(2, multiprocessing.cpu_count() - 1)
|
39 |
torch.set_num_threads(n_threads)
|
40 |
logger.info(f"Using {n_threads} CPU threads")
|
41 |
-
|
42 |
# Initialize TTS with error handling
|
43 |
try:
|
44 |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
@@ -48,40 +50,40 @@ class OptimizedTTSService:
|
|
48 |
except Exception as e:
|
49 |
logger.error(f"Failed to load TTS model: {str(e)}")
|
50 |
raise
|
51 |
-
|
52 |
# Load latents
|
53 |
try:
|
54 |
logger.info("Loading voice latents...")
|
55 |
latents_path = "models/goggins_latents.pt"
|
56 |
if not os.path.exists(latents_path):
|
57 |
raise FileNotFoundError(f"Latents file not found at {latents_path}")
|
58 |
-
|
59 |
-
self.latents = torch.load(latents_path, map_location=
|
60 |
logger.info("Latents loaded successfully")
|
61 |
except Exception as e:
|
62 |
logger.error(f"Failed to load latents: {str(e)}")
|
63 |
raise
|
64 |
-
|
65 |
# Initialize thread pool for parallel processing
|
66 |
self.executor = ThreadPoolExecutor(max_workers=n_threads)
|
67 |
-
|
68 |
# Configure model for inference
|
69 |
self.model = self.tts.synthesizer.tts_model
|
70 |
self.model.eval()
|
71 |
-
|
72 |
# Initialize device
|
73 |
self.device = torch.device("cpu")
|
74 |
logger.info(f"Using device: {self.device}")
|
75 |
-
|
76 |
# Initialize cache
|
77 |
self._setup_cache()
|
78 |
-
|
79 |
logger.info("Service initialization complete!")
|
80 |
-
|
81 |
except Exception as e:
|
82 |
logger.error(f"Failed to initialize TTS service: {str(e)}")
|
83 |
raise
|
84 |
-
|
85 |
def _setup_cache(self):
|
86 |
"""Setup caching mechanisms with error handling"""
|
87 |
try:
|
@@ -93,25 +95,23 @@ class OptimizedTTSService:
|
|
93 |
except Exception as e:
|
94 |
logger.error(f"Failed to setup cache: {str(e)}")
|
95 |
raise
|
96 |
-
|
97 |
def _process_chunk(self, chunk: str) -> np.ndarray:
|
98 |
"""Process a single chunk of text with improved error handling"""
|
99 |
try:
|
100 |
# Convert latents to tensors
|
101 |
speaker_embedding = torch.tensor(
|
102 |
-
self.latents[
|
103 |
dtype=torch.float32,
|
104 |
-
device=self.device
|
105 |
)
|
106 |
gpt_cond_latent = torch.tensor(
|
107 |
-
self.latents[
|
108 |
-
dtype=torch.float32,
|
109 |
-
device=self.device
|
110 |
)
|
111 |
-
|
112 |
# Get optimized parameters based on chunk length
|
113 |
params = self._get_params_for_length(len(chunk))
|
114 |
-
|
115 |
# Generate speech
|
116 |
with torch.no_grad():
|
117 |
wav = self.model.inference(
|
@@ -119,76 +119,77 @@ class OptimizedTTSService:
|
|
119 |
language="en",
|
120 |
gpt_cond_latent=gpt_cond_latent,
|
121 |
speaker_embedding=speaker_embedding,
|
122 |
-
**params
|
123 |
)
|
124 |
-
|
125 |
return wav["wav"]
|
126 |
-
|
127 |
except Exception as e:
|
128 |
logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}")
|
129 |
raise
|
130 |
-
|
131 |
def _get_params_for_length(self, chunk_length: int) -> Dict:
|
132 |
"""Get optimized parameters based on text length"""
|
133 |
if chunk_length <= 80:
|
134 |
return {
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
}
|
141 |
elif chunk_length <= 150:
|
142 |
return {
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
}
|
149 |
else:
|
150 |
return {
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
}
|
157 |
-
|
158 |
def generate_speech(self, text: str) -> np.ndarray:
|
159 |
"""Generate speech with improved error handling"""
|
160 |
try:
|
161 |
# Clean and validate input
|
162 |
if not text or not text.strip():
|
163 |
raise ValueError("Empty text input")
|
164 |
-
|
165 |
text = text.strip()
|
166 |
if len(text) > 1000: # Add reasonable limit
|
167 |
raise ValueError("Text too long (max 1000 characters)")
|
168 |
-
|
169 |
# Process single chunk for short text
|
170 |
if len(text) <= 150:
|
171 |
return self._process_chunk(text)
|
172 |
-
|
173 |
# Split longer text into chunks
|
174 |
-
chunks = text.split(
|
175 |
-
chunks = [chunk.strip() +
|
176 |
-
|
177 |
# Process chunks
|
178 |
wavs = []
|
179 |
for i, chunk in enumerate(chunks, 1):
|
180 |
logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...")
|
181 |
wav = self._process_chunk(chunk)
|
182 |
wavs.append(wav)
|
183 |
-
|
184 |
# Concatenate results
|
185 |
final_wav = np.concatenate(wavs)
|
186 |
return final_wav
|
187 |
-
|
188 |
except Exception as e:
|
189 |
logger.error(f"Error in generate_speech: {str(e)}")
|
190 |
raise
|
191 |
|
|
|
192 |
# Initialize FastAPI app
|
193 |
app = FastAPI(title="Goggins TTS API")
|
194 |
|
@@ -204,6 +205,7 @@ app.add_middleware(
|
|
204 |
# Initialize service
|
205 |
service = None
|
206 |
|
|
|
207 |
@app.on_event("startup")
|
208 |
async def startup_event():
|
209 |
global service
|
@@ -213,36 +215,45 @@ async def startup_event():
|
|
213 |
logger.error(f"Failed to initialize service: {str(e)}")
|
214 |
raise
|
215 |
|
|
|
216 |
@app.post("/generate")
|
217 |
async def generate_speech(request: TTSRequest):
|
218 |
"""Generate speech from text with detailed timing"""
|
219 |
try:
|
220 |
total_start = time.time()
|
221 |
logger.info(f"\nReceived request for text: {request.text[:50]}...")
|
222 |
-
|
223 |
# Model processing time
|
224 |
model_start = time.time()
|
225 |
wav = service.generate_speech(request.text)
|
226 |
model_time = time.time() - model_start
|
227 |
-
|
228 |
# Audio conversion time
|
229 |
conversion_start = time.time()
|
230 |
buffer = io.BytesIO()
|
231 |
np.save(buffer, wav.astype(np.float32))
|
232 |
audio_base64 = base64.b64encode(buffer.getvalue()).decode()
|
233 |
conversion_time = time.time() - conversion_start
|
234 |
-
|
235 |
# Total processing time
|
236 |
total_time = time.time() - total_start
|
237 |
-
|
238 |
timing_info = {
|
239 |
"total_processing_time": round(total_time, 2),
|
240 |
"model_processing_time": round(model_time, 2),
|
241 |
-
"audio_conversion_time": round(conversion_time, 2)
|
242 |
}
|
243 |
-
|
244 |
logger.info(f"Timing breakdown: {timing_info}")
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
@app.get("/health")
|
247 |
async def health_check():
|
248 |
"""Health check endpoint"""
|
|
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
+
|
25 |
class TTSRequest(BaseModel):
|
26 |
text: str
|
27 |
|
28 |
+
|
29 |
class OptimizedTTSService:
|
30 |
def __init__(self):
|
31 |
logger.info("Initializing Optimized TTS Service...")
|
32 |
+
|
33 |
try:
|
34 |
# Set TTS home directory and accept license
|
35 |
+
os.environ["HOME"] = "/tmp/home"
|
36 |
+
os.environ["TTS_HOME"] = "/tmp/tts_home"
|
37 |
os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license
|
38 |
+
|
39 |
# Set number of threads for PyTorch
|
40 |
n_threads = max(2, multiprocessing.cpu_count() - 1)
|
41 |
torch.set_num_threads(n_threads)
|
42 |
logger.info(f"Using {n_threads} CPU threads")
|
43 |
+
|
44 |
# Initialize TTS with error handling
|
45 |
try:
|
46 |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
|
50 |
except Exception as e:
|
51 |
logger.error(f"Failed to load TTS model: {str(e)}")
|
52 |
raise
|
53 |
+
|
54 |
# Load latents
|
55 |
try:
|
56 |
logger.info("Loading voice latents...")
|
57 |
latents_path = "models/goggins_latents.pt"
|
58 |
if not os.path.exists(latents_path):
|
59 |
raise FileNotFoundError(f"Latents file not found at {latents_path}")
|
60 |
+
|
61 |
+
self.latents = torch.load(latents_path, map_location="cpu")
|
62 |
logger.info("Latents loaded successfully")
|
63 |
except Exception as e:
|
64 |
logger.error(f"Failed to load latents: {str(e)}")
|
65 |
raise
|
66 |
+
|
67 |
# Initialize thread pool for parallel processing
|
68 |
self.executor = ThreadPoolExecutor(max_workers=n_threads)
|
69 |
+
|
70 |
# Configure model for inference
|
71 |
self.model = self.tts.synthesizer.tts_model
|
72 |
self.model.eval()
|
73 |
+
|
74 |
# Initialize device
|
75 |
self.device = torch.device("cpu")
|
76 |
logger.info(f"Using device: {self.device}")
|
77 |
+
|
78 |
# Initialize cache
|
79 |
self._setup_cache()
|
80 |
+
|
81 |
logger.info("Service initialization complete!")
|
82 |
+
|
83 |
except Exception as e:
|
84 |
logger.error(f"Failed to initialize TTS service: {str(e)}")
|
85 |
raise
|
86 |
+
|
87 |
def _setup_cache(self):
|
88 |
"""Setup caching mechanisms with error handling"""
|
89 |
try:
|
|
|
95 |
except Exception as e:
|
96 |
logger.error(f"Failed to setup cache: {str(e)}")
|
97 |
raise
|
98 |
+
|
99 |
def _process_chunk(self, chunk: str) -> np.ndarray:
|
100 |
"""Process a single chunk of text with improved error handling"""
|
101 |
try:
|
102 |
# Convert latents to tensors
|
103 |
speaker_embedding = torch.tensor(
|
104 |
+
self.latents["speaker_embedding"],
|
105 |
dtype=torch.float32,
|
106 |
+
device=self.device,
|
107 |
)
|
108 |
gpt_cond_latent = torch.tensor(
|
109 |
+
self.latents["gpt_cond_latent"], dtype=torch.float32, device=self.device
|
|
|
|
|
110 |
)
|
111 |
+
|
112 |
# Get optimized parameters based on chunk length
|
113 |
params = self._get_params_for_length(len(chunk))
|
114 |
+
|
115 |
# Generate speech
|
116 |
with torch.no_grad():
|
117 |
wav = self.model.inference(
|
|
|
119 |
language="en",
|
120 |
gpt_cond_latent=gpt_cond_latent,
|
121 |
speaker_embedding=speaker_embedding,
|
122 |
+
**params,
|
123 |
)
|
124 |
+
|
125 |
return wav["wav"]
|
126 |
+
|
127 |
except Exception as e:
|
128 |
logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}")
|
129 |
raise
|
130 |
+
|
131 |
def _get_params_for_length(self, chunk_length: int) -> Dict:
|
132 |
"""Get optimized parameters based on text length"""
|
133 |
if chunk_length <= 80:
|
134 |
return {
|
135 |
+
"temperature": 0.75,
|
136 |
+
"length_penalty": 0.8,
|
137 |
+
"repetition_penalty": 1.8,
|
138 |
+
"top_k": 40,
|
139 |
+
"top_p": 0.80,
|
140 |
}
|
141 |
elif chunk_length <= 150:
|
142 |
return {
|
143 |
+
"temperature": 0.85,
|
144 |
+
"length_penalty": 1.0,
|
145 |
+
"repetition_penalty": 2.0,
|
146 |
+
"top_k": 50,
|
147 |
+
"top_p": 0.85,
|
148 |
}
|
149 |
else:
|
150 |
return {
|
151 |
+
"temperature": 0.95,
|
152 |
+
"length_penalty": 1.2,
|
153 |
+
"repetition_penalty": 2.2,
|
154 |
+
"top_k": 60,
|
155 |
+
"top_p": 0.90,
|
156 |
}
|
157 |
+
|
158 |
def generate_speech(self, text: str) -> np.ndarray:
|
159 |
"""Generate speech with improved error handling"""
|
160 |
try:
|
161 |
# Clean and validate input
|
162 |
if not text or not text.strip():
|
163 |
raise ValueError("Empty text input")
|
164 |
+
|
165 |
text = text.strip()
|
166 |
if len(text) > 1000: # Add reasonable limit
|
167 |
raise ValueError("Text too long (max 1000 characters)")
|
168 |
+
|
169 |
# Process single chunk for short text
|
170 |
if len(text) <= 150:
|
171 |
return self._process_chunk(text)
|
172 |
+
|
173 |
# Split longer text into chunks
|
174 |
+
chunks = text.split(". ")
|
175 |
+
chunks = [chunk.strip() + "." for chunk in chunks if chunk.strip()]
|
176 |
+
|
177 |
# Process chunks
|
178 |
wavs = []
|
179 |
for i, chunk in enumerate(chunks, 1):
|
180 |
logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...")
|
181 |
wav = self._process_chunk(chunk)
|
182 |
wavs.append(wav)
|
183 |
+
|
184 |
# Concatenate results
|
185 |
final_wav = np.concatenate(wavs)
|
186 |
return final_wav
|
187 |
+
|
188 |
except Exception as e:
|
189 |
logger.error(f"Error in generate_speech: {str(e)}")
|
190 |
raise
|
191 |
|
192 |
+
|
193 |
# Initialize FastAPI app
|
194 |
app = FastAPI(title="Goggins TTS API")
|
195 |
|
|
|
205 |
# Initialize service
|
206 |
service = None
|
207 |
|
208 |
+
|
209 |
@app.on_event("startup")
|
210 |
async def startup_event():
|
211 |
global service
|
|
|
215 |
logger.error(f"Failed to initialize service: {str(e)}")
|
216 |
raise
|
217 |
|
218 |
+
|
219 |
@app.post("/generate")
|
220 |
async def generate_speech(request: TTSRequest):
|
221 |
"""Generate speech from text with detailed timing"""
|
222 |
try:
|
223 |
total_start = time.time()
|
224 |
logger.info(f"\nReceived request for text: {request.text[:50]}...")
|
225 |
+
|
226 |
# Model processing time
|
227 |
model_start = time.time()
|
228 |
wav = service.generate_speech(request.text)
|
229 |
model_time = time.time() - model_start
|
230 |
+
|
231 |
# Audio conversion time
|
232 |
conversion_start = time.time()
|
233 |
buffer = io.BytesIO()
|
234 |
np.save(buffer, wav.astype(np.float32))
|
235 |
audio_base64 = base64.b64encode(buffer.getvalue()).decode()
|
236 |
conversion_time = time.time() - conversion_start
|
237 |
+
|
238 |
# Total processing time
|
239 |
total_time = time.time() - total_start
|
240 |
+
|
241 |
timing_info = {
|
242 |
"total_processing_time": round(total_time, 2),
|
243 |
"model_processing_time": round(model_time, 2),
|
244 |
+
"audio_conversion_time": round(conversion_time, 2),
|
245 |
}
|
246 |
+
|
247 |
logger.info(f"Timing breakdown: {timing_info}")
|
248 |
|
249 |
+
# Add the missing return statement
|
250 |
+
return {"status": "success", "audio": audio_base64, "timing": timing_info}
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
logger.error(f"Error in generate_speech endpoint: {str(e)}")
|
254 |
+
raise HTTPException(status_code=500, detail=str(e))
|
255 |
+
|
256 |
+
|
257 |
@app.get("/health")
|
258 |
async def health_check():
|
259 |
"""Health check endpoint"""
|