mms-tts-ory / handler.py
akhilbattula's picture
Create handler.py
f876b9c verified
raw
history blame
9.12 kB
import torch
import numpy as np
import io
import base64
import subprocess
import tempfile
import os
from typing import Dict, Any
from transformers import VitsModel, AutoTokenizer
import scipy.io.wavfile as wavfile
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for facebook/mms-tts-asm model
"""
# Load the model and tokenizer
self.model = VitsModel.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Set model to evaluation mode
self.model.eval()
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def wav_to_mp3_ffmpeg(self, wav_data: bytes) -> bytes:
"""
Convert WAV data to MP3 using ffmpeg directly
"""
try:
# Create temporary files
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
temp_wav.write(wav_data)
temp_wav_path = temp_wav.name
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_mp3:
temp_mp3_path = temp_mp3.name
# Use ffmpeg to convert WAV to MP3
cmd = [
'ffmpeg', '-y', # -y to overwrite output file
'-i', temp_wav_path, # input file
'-codec:a', 'libmp3lame', # MP3 encoder
'-b:a', '128k', # bitrate
'-ar', '16000', # sample rate
temp_mp3_path # output file
]
# Run ffmpeg
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"FFmpeg error: {result.stderr}")
# Read MP3 data
with open(temp_mp3_path, 'rb') as f:
mp3_data = f.read()
# Clean up temporary files
os.unlink(temp_wav_path)
os.unlink(temp_mp3_path)
return mp3_data
except Exception as e:
# Clean up on error
try:
if 'temp_wav_path' in locals():
os.unlink(temp_wav_path)
if 'temp_mp3_path' in locals():
os.unlink(temp_mp3_path)
except:
pass
raise Exception(f"Error converting to MP3: {str(e)}")
def wav_to_mp3_manual(self, wav_data: bytes) -> bytes:
"""
Alternative: Create a simple MP3-like format manually
Note: This creates a basic audio format, not true MP3
"""
# This is a simplified approach - not recommended for production
# Just wrapping WAV data with minimal MP3-like headers
# For true MP3, ffmpeg or similar encoder is needed
# Simple ID3v2 header for MP3
id3_header = b'ID3\x03\x00\x00\x00\x00\x00\x00'
# Basic MP3 frame header (simplified)
mp3_frame_header = b'\xff\xfb\x90\x00'
# Combine headers with audio data
# Note: This is NOT a proper MP3 file, just a wrapper
return id3_header + mp3_frame_header + wav_data
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the request
Args:
data (Dict): The input data containing text to convert to speech
Expected format: {"inputs": "text to convert to speech"}
Returns:
Dict: Contains the audio file as base64 encoded MP3
"""
try:
# Extract input text
inputs = data.get("inputs", "")
if not inputs:
return {"error": "No input text provided"}
# Additional parameters (optional)
parameters = data.get("parameters", {})
conversion_method = parameters.get("conversion_method", "ffmpeg") # "ffmpeg" or "manual"
# Process the text with tokenizer
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.device)
# Generate speech
with torch.no_grad():
output = self.model(input_ids)
waveform = output.waveform.squeeze().cpu().numpy()
# Convert to audio file
sample_rate = 16000
# Normalize audio to prevent clipping
if np.max(np.abs(waveform)) > 0:
waveform = waveform / np.max(np.abs(waveform)) * 0.95
# Convert to 16-bit PCM
waveform_int16 = (waveform * 32767).astype(np.int16)
# Create WAV file in memory
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, sample_rate, waveform_int16)
wav_data = wav_buffer.getvalue()
# Convert to MP3
if conversion_method == "ffmpeg":
try:
mp3_data = self.wav_to_mp3_ffmpeg(wav_data)
except Exception as e:
# Fallback to manual method if ffmpeg fails
print(f"FFmpeg conversion failed: {e}, falling back to manual method")
mp3_data = self.wav_to_mp3_manual(wav_data)
else:
mp3_data = self.wav_to_mp3_manual(wav_data)
# Convert to base64 for JSON response
audio_base64 = base64.b64encode(mp3_data).decode('utf-8')
return {
"audio": audio_base64,
"sampling_rate": sample_rate,
"format": "mp3",
"text": inputs,
"conversion_method": conversion_method,
"content_type": "audio/mpeg"
}
except Exception as e:
return {"error": f"Error processing request: {str(e)}"}
# Pure Python MP3 encoder alternative (more complex but no external dependencies)
class SimpleLAMEEncoder:
"""
A very basic MP3-like encoder using pure Python
Note: This is a simplified implementation for demonstration
For production use, proper MP3 encoding libraries are recommended
"""
@staticmethod
def encode_wav_to_mp3_like(wav_data: bytes, sample_rate: int = 16000) -> bytes:
"""
Create a basic MP3-like file structure
This is a simplified approach and may not be compatible with all players
"""
# Read WAV header to get audio data
wav_io = io.BytesIO(wav_data)
# Skip WAV header (44 bytes)
wav_io.seek(44)
audio_data = wav_io.read()
# Create basic MP3 file structure
# ID3v2 header
id3v2_header = bytearray([
0x49, 0x44, 0x33, # "ID3"
0x03, 0x00, # Version 2.3
0x00, # Flags
0x00, 0x00, 0x00, 0x00 # Size (will be updated)
])
# Basic MP3 frame header for 16kHz, 128kbps
mp3_frame_header = bytearray([
0xFF, 0xFB, # Sync word and audio version
0x90, 0x00 # Bitrate and sample rate info
])
# Combine to create MP3-like structure
result = bytes(id3v2_header) + bytes(mp3_frame_header) + audio_data
return result
# # Example usage and testing
# if __name__ == "__main__":
# # Test the handler locally
# handler = EndpointHandler("facebook/mms-tts-asm")
# # Test input with ffmpeg conversion
# test_data = {
# "inputs": "Hello, this is a test of the text to speech system.",
# "parameters": {"conversion_method": "ffmpeg"}
# }
# result = handler(test_data)
# print("Handler result keys:", result.keys())
# if "audio" in result:
# print("MP3 audio generated successfully!")
# print(f"Sampling rate: {result['sampling_rate']}")
# print(f"Format: {result['format']}")
# print(f"Conversion method: {result.get('conversion_method', 'unknown')}")
# print(f"Audio data length: {len(result['audio'])} characters (base64)")
# # Save the MP3 file for testing
# with open("test_output.mp3", "wb") as f:
# f.write(base64.b64decode(result['audio']))
# print("Test MP3 saved as 'test_output.mp3'")
# else:
# print("Error:", result.get("error", "Unknown error"))
# # Test with manual conversion method
# print("\n--- Testing manual conversion ---")
# test_data["parameters"]["conversion_method"] = "manual"
# result_manual = handler(test_data)
# if "audio" in result_manual:
# print("Manual conversion successful!")
# with open("test_output_manual.mp3", "wb") as f:
# f.write(base64.b64decode(result_manual['audio']))
# print("Manual MP3 saved as 'test_output_manual.mp3'")