import numpy as np |
import torch |
import librosa |
from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration, pipeline |
import soundfile as sf |
import os |
import logging |
logger = logging.getLogger(__name__) |
class InferenceRecipe: |
def __init__(self, model_path='./models', device='cuda'): |
self.device = device |
self.asr_processor = None |
self.asr_model = None |
self.chat_tokenizer = None |
self.chat_model = None |
self.tts_model = None |
self.tts_sample_rate = 22050 |
self.model_path = model_path |
self.initialize_models() |
def initialize_models(self): |
"""Initialize models from local cache""" |
asr_path = os.path.join(self.model_path, 'asr') |
logger.info(f"Loading ASR model from {asr_path}") |
self.asr_processor = WhisperProcessor.from_pretrained(asr_path, local_files_only=True) |
self.asr_model = WhisperForConditionalGeneration.from_pretrained(asr_path, local_files_only=True) |
self.asr_model = self.asr_model.to(self.device) |
self.asr_model.generation_config.no_timestamps_token_id = self.asr_processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") |
self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(language="english", task="transcribe") |
dialogpt_path = os.path.join(self.model_path, "llm") |
logger.info(f"Loading Chat model from {dialogpt_path}") |
self.chat_tokenizer = AutoTokenizer.from_pretrained(dialogpt_path) |
self.chat_model = AutoModelForCausalLM.from_pretrained(dialogpt_path) |
self.chat_model = self.chat_model.to(self.device) |
logger.info(f"Loading TTS model from {self.model_path}") |
self.tts_model = pipeline( |
"text-to-speech", |
model=os.path.join(self.model_path, "tts"), |
device=self.device, |
torch_dtype=torch.float32 |
) |
def inference(self, audio_array, sample_rate): |
"""Updated inference pipeline""" |
logger.info(f"Running inference with audio shape: {audio_array.shape}") |
if len(audio_array.shape) == 2: |
audio_array = audio_array.squeeze() |
logger.info(f"Running ASR with audio shape: {audio_array.shape}") |
input_features = self.asr_processor( |
audio_array, |
sampling_rate=sample_rate, |
return_tensors="pt" |
).input_features.to(self.device) |
generated_ids = self.asr_model.generate(input_features) |
text = self.asr_processor.batch_decode( |
generated_ids, |
skip_special_tokens=True |
)[0] |
logger.info(f"Running Chat with text: {text}") |
input_ids = self.chat_tokenizer.encode(text + self.chat_tokenizer.eos_token, return_tensors="pt") |
attention_mask = torch.ones_like(input_ids) |
chat_output = self.chat_model.generate( |
input_ids.to(self.device), |
attention_mask=attention_mask.to(self.device), |
max_length=1000, |
pad_token_id=self.chat_tokenizer.eos_token_id |
) |
reply = self.chat_tokenizer.decode(chat_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
logger.info(f"Running TTS with text: {reply}") |
tts_output = self.tts_model(reply) |
audio_array = tts_output['audio'] |
logger.info(f"Ensuring audio is in correct format") |
audio_array = audio_array.astype(np.float32) |
audio_array = np.clip(audio_array, -1.0, 1.0) |
if sample_rate != self.tts_sample_rate: |
logger.info(f"Resampling audio to match input rate") |
from scipy import signal |
samples = len(audio_array) |
new_samples = int(samples * sample_rate / self.tts_sample_rate) |
audio_array = signal.resample(audio_array, new_samples) |
logger.info(f"Ensuring audio is 1D") |
if len(audio_array.shape) > 1: |
audio_array = audio_array.squeeze() |
return {"audio": audio_array, "text": reply} |
if __name__ == "__main__": |
recipe = InferenceRecipe(model_path="./models") |
sr = 16000 |
duration = 3 |
audio = np.zeros(int(sr * duration)) |
response = recipe.inference(audio, sr) |
print(f"Audio shape: {response['audio'].shape}, Range: [{response['audio'].min()}, {response['audio'].max()}]") |
print(f"Generated text: {response['text']}") |
sf.write( |
"response.wav", |
response['audio'], |
sr, |
format='WAV', |
subtype='FLOAT' |
) |