cascadedS2S / inference.py
tezuesh's picture
Upload folder using huggingface_hub
45f3333 verified
import numpy as np
import torch
import librosa
from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration, pipeline
import soundfile as sf
import os
import logging
logger = logging.getLogger(__name__)
class InferenceRecipe:
def __init__(self, model_path='./models', device='cuda'):
self.device = device
self.asr_processor = None
self.asr_model = None
self.chat_tokenizer = None
self.chat_model = None
self.tts_model = None
self.tts_sample_rate = 22050 # TTS output sample rate
self.model_path = model_path
self.initialize_models()
def initialize_models(self):
"""Initialize models from local cache"""
# ASR: OpenAI Whisper
asr_path = os.path.join(self.model_path, 'asr')
logger.info(f"Loading ASR model from {asr_path}")
self.asr_processor = WhisperProcessor.from_pretrained(asr_path, local_files_only=True)
self.asr_model = WhisperForConditionalGeneration.from_pretrained(asr_path, local_files_only=True)
self.asr_model = self.asr_model.to(self.device)
# Configure Whisper for timestamps
self.asr_model.generation_config.no_timestamps_token_id = self.asr_processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>")
self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(language="english", task="transcribe")
# Chat: DialoGPT
dialogpt_path = os.path.join(self.model_path, "llm")
logger.info(f"Loading Chat model from {dialogpt_path}")
self.chat_tokenizer = AutoTokenizer.from_pretrained(dialogpt_path)
self.chat_model = AutoModelForCausalLM.from_pretrained(dialogpt_path)
self.chat_model = self.chat_model.to(self.device)
# TTS: Facebook MMS
logger.info(f"Loading TTS model from {self.model_path}")
self.tts_model = pipeline(
"text-to-speech",
model=os.path.join(self.model_path, "tts"),
device=self.device,
torch_dtype=torch.float32
)
def inference(self, audio_array, sample_rate):
"""Updated inference pipeline"""
logger.info(f"Running inference with audio shape: {audio_array.shape}")
if len(audio_array.shape) == 2:
audio_array = audio_array.squeeze()
# Speech-to-Text using Whisper
logger.info(f"Running ASR with audio shape: {audio_array.shape}")
# Process audio input
input_features = self.asr_processor(
audio_array,
sampling_rate=sample_rate,
return_tensors="pt"
).input_features.to(self.device)
# Generate transcription
generated_ids = self.asr_model.generate(input_features)
text = self.asr_processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
# Generate response with proper attention mask
logger.info(f"Running Chat with text: {text}")
input_ids = self.chat_tokenizer.encode(text + self.chat_tokenizer.eos_token, return_tensors="pt")
attention_mask = torch.ones_like(input_ids)
chat_output = self.chat_model.generate(
input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
max_length=1000,
pad_token_id=self.chat_tokenizer.eos_token_id
)
reply = self.chat_tokenizer.decode(chat_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
# Text-to-Speech using HF Pipeline
logger.info(f"Running TTS with text: {reply}")
tts_output = self.tts_model(reply)
audio_array = tts_output['audio']
# Ensure audio is in correct format
logger.info(f"Ensuring audio is in correct format")
audio_array = audio_array.astype(np.float32)
audio_array = np.clip(audio_array, -1.0, 1.0)
# Resample to match input rate
if sample_rate != self.tts_sample_rate:
logger.info(f"Resampling audio to match input rate")
from scipy import signal
samples = len(audio_array)
new_samples = int(samples * sample_rate / self.tts_sample_rate)
audio_array = signal.resample(audio_array, new_samples)
# Ensure the audio is 1D
logger.info(f"Ensuring audio is 1D")
if len(audio_array.shape) > 1:
audio_array = audio_array.squeeze()
return {"audio": audio_array, "text": reply}
if __name__ == "__main__":
recipe = InferenceRecipe(model_path="./models") # Specify your cache directory here
# Test with realistic input (silent audio)
sr = 16000
duration = 3
audio = np.zeros(int(sr * duration)) # Silent input
response = recipe.inference(audio, sr)
print(f"Audio shape: {response['audio'].shape}, Range: [{response['audio'].min()}, {response['audio'].max()}]")
print(f"Generated text: {response['text']}")
# Save with explicit format
sf.write(
"response.wav",
response['audio'],
sr,
format='WAV',
subtype='FLOAT'
)