Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import numpy as np | |
import io | |
import soundfile as sf | |
import base64 | |
import logging | |
import torch | |
import librosa | |
from pathlib import Path | |
import magic # For MIME type detection | |
from pydub import AudioSegment | |
import traceback | |
from logging.handlers import RotatingFileHandler | |
import os | |
import boto3 | |
from botocore.exceptions import NoCredentialsError | |
import time | |
# Import functions from other modules | |
from asr import transcribe, ASR_LANGUAGES | |
from tts import synthesize, TTS_LANGUAGES | |
from lid import identify | |
from asr import ASR_SAMPLING_RATE | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Add a file handler | |
file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) | |
file_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") | |
# S3 Configuration | |
S3_BUCKET = "afri" | |
S3_REGION = "eu-west-3" | |
S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID") | |
S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
# Initialize S3 client | |
s3_client = boto3.client( | |
's3', | |
aws_access_key_id=S3_ACCESS_KEY_ID, | |
aws_secret_access_key=S3_SECRET_ACCESS_KEY, | |
region_name=S3_REGION | |
) | |
# Define request models | |
class AudioRequest(BaseModel): | |
audio: str # Base64 encoded audio or video data | |
language: str | |
class TTSRequest(BaseModel): | |
text: str | |
language: str | |
speed: float | |
def detect_mime_type(input_bytes): | |
mime = magic.Magic(mime=True) | |
return mime.from_buffer(input_bytes) | |
def extract_audio(input_bytes): | |
mime_type = detect_mime_type(input_bytes) | |
if mime_type.startswith('audio/'): | |
return sf.read(io.BytesIO(input_bytes)) | |
elif mime_type.startswith('video/webm'): | |
audio = AudioSegment.from_file(io.BytesIO(input_bytes), format="webm") | |
audio_array = np.array(audio.get_array_of_samples()) | |
sample_rate = audio.frame_rate | |
return audio_array, sample_rate | |
else: | |
raise ValueError(f"Unsupported MIME type: {mime_type}") | |
async def transcribe_audio(request: AudioRequest): | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio(input_bytes) | |
# Convert to mono if stereo | |
if len(audio_array.shape) > 1: | |
audio_array = audio_array.mean(axis=1) | |
# Ensure audio_array is float32 | |
audio_array = audio_array.astype(np.float32) | |
# Resample if necessary | |
if sample_rate != ASR_SAMPLING_RATE: | |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) | |
result = transcribe(audio_array, request.language) | |
return JSONResponse(content={"transcription": result}) | |
except Exception as e: | |
logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred during transcription", "details": error_details} | |
) | |
async def synthesize_speech(request: TTSRequest): | |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") | |
try: | |
# Extract the ISO code from the full language name | |
lang_code = request.language.split()[0].strip() | |
# Input validation | |
if not request.text: | |
raise ValueError("Text cannot be empty") | |
if lang_code not in TTS_LANGUAGES: | |
raise ValueError(f"Unsupported language: {request.language}") | |
if not 0.5 <= request.speed <= 2.0: | |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") | |
logger.info(f"Calling synthesize function with lang_code: {lang_code}") | |
result, filtered_text = synthesize(request.text, request.language, request.speed) | |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") | |
if result is None: | |
logger.error("Synthesize function returned None") | |
raise ValueError("Synthesis failed to produce audio") | |
sample_rate, audio = result | |
logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") | |
logger.info("Converting audio to numpy array") | |
audio = np.array(audio, dtype=np.float32) | |
logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") | |
logger.info("Normalizing audio") | |
max_value = np.max(np.abs(audio)) | |
if max_value == 0: | |
logger.warning("Audio array is all zeros") | |
raise ValueError("Generated audio is silent (all zeros)") | |
audio = audio / max_value | |
logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") | |
logger.info("Converting to int16") | |
audio = (audio * 32767).astype(np.int16) | |
logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") | |
logger.info("Writing audio to buffer") | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, sample_rate, format='wav') | |
buffer.seek(0) | |
logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") | |
# Generate a unique filename | |
filename = f"synthesized_audio_{int(time.time())}.wav" | |
# Upload to S3 without ACL | |
try: | |
s3_client.upload_fileobj( | |
buffer, | |
S3_BUCKET, | |
filename, | |
ExtraArgs={'ContentType': 'audio/wav'} | |
) | |
logger.info(f"File uploaded successfully to S3: {filename}") | |
# Generate the public URL | |
url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{s3_file}" | |
logger.info(f"Public URL generated: {url}") | |
return JSONResponse(content={"audio_url": url}) | |
except NoCredentialsError: | |
logger.error("AWS credentials not available or invalid") | |
raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials") | |
except Exception as e: | |
logger.error(f"Failed to upload to S3: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}") | |
except ValueError as ve: | |
logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) | |
return JSONResponse( | |
status_code=400, | |
content={"message": "Invalid input", "details": str(ve)} | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"type": type(e).__name__, | |
"traceback": traceback.format_exc() | |
} | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details} | |
) | |
finally: | |
logger.info("Synthesize request completed") | |
async def identify_language(request: AudioRequest): | |
try: | |
input_bytes = base64.b64decode(request.audio) | |
audio_array, sample_rate = extract_audio(input_bytes) | |
result = identify(audio_array) | |
return JSONResponse(content={"language_identification": result}) | |
except Exception as e: | |
logger.error(f"Error in identify_language: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred during language identification", "details": error_details} | |
) | |
async def get_asr_languages(): | |
try: | |
return JSONResponse(content=ASR_LANGUAGES) | |
except Exception as e: | |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred while fetching ASR languages", "details": error_details} | |
) | |
async def get_tts_languages(): | |
try: | |
return JSONResponse(content=TTS_LANGUAGES) | |
except Exception as e: | |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) | |
error_details = { | |
"error": str(e), | |
"traceback": traceback.format_exc() | |
} | |
return JSONResponse( | |
status_code=500, | |
content={"message": "An error occurred while fetching TTS languages", "details": error_details} | |
) |