|
import numpy as np |
|
import torch |
|
import librosa |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration, pipeline |
|
import soundfile as sf |
|
import os |
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
|
|
class InferenceRecipe: |
|
def __init__(self, model_path='./models', device='cuda'): |
|
self.device = device |
|
self.asr_processor = None |
|
self.asr_model = None |
|
self.chat_tokenizer = None |
|
self.chat_model = None |
|
self.tts_model = None |
|
self.tts_sample_rate = 22050 |
|
self.model_path = model_path |
|
self.initialize_models() |
|
|
|
def initialize_models(self): |
|
"""Initialize models from local cache""" |
|
|
|
asr_path = os.path.join(self.model_path, 'asr') |
|
logger.info(f"Loading ASR model from {asr_path}") |
|
|
|
self.asr_processor = WhisperProcessor.from_pretrained(asr_path, local_files_only=True) |
|
self.asr_model = WhisperForConditionalGeneration.from_pretrained(asr_path, local_files_only=True) |
|
self.asr_model = self.asr_model.to(self.device) |
|
|
|
|
|
self.asr_model.generation_config.no_timestamps_token_id = self.asr_processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") |
|
self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(language="english", task="transcribe") |
|
|
|
|
|
dialogpt_path = os.path.join(self.model_path, "llm") |
|
logger.info(f"Loading Chat model from {dialogpt_path}") |
|
self.chat_tokenizer = AutoTokenizer.from_pretrained(dialogpt_path) |
|
self.chat_model = AutoModelForCausalLM.from_pretrained(dialogpt_path) |
|
self.chat_model = self.chat_model.to(self.device) |
|
|
|
|
|
logger.info(f"Loading TTS model from {self.model_path}") |
|
self.tts_model = pipeline( |
|
"text-to-speech", |
|
model=os.path.join(self.model_path, "tts"), |
|
device=self.device, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
def inference(self, audio_array, sample_rate): |
|
"""Updated inference pipeline""" |
|
logger.info(f"Running inference with audio shape: {audio_array.shape}") |
|
if len(audio_array.shape) == 2: |
|
audio_array = audio_array.squeeze() |
|
|
|
|
|
logger.info(f"Running ASR with audio shape: {audio_array.shape}") |
|
|
|
|
|
input_features = self.asr_processor( |
|
audio_array, |
|
sampling_rate=sample_rate, |
|
return_tensors="pt" |
|
).input_features.to(self.device) |
|
|
|
|
|
generated_ids = self.asr_model.generate(input_features) |
|
text = self.asr_processor.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
|
|
logger.info(f"Running Chat with text: {text}") |
|
input_ids = self.chat_tokenizer.encode(text + self.chat_tokenizer.eos_token, return_tensors="pt") |
|
attention_mask = torch.ones_like(input_ids) |
|
chat_output = self.chat_model.generate( |
|
input_ids.to(self.device), |
|
attention_mask=attention_mask.to(self.device), |
|
max_length=1000, |
|
pad_token_id=self.chat_tokenizer.eos_token_id |
|
) |
|
reply = self.chat_tokenizer.decode(chat_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
|
|
logger.info(f"Running TTS with text: {reply}") |
|
tts_output = self.tts_model(reply) |
|
audio_array = tts_output['audio'] |
|
|
|
|
|
logger.info(f"Ensuring audio is in correct format") |
|
audio_array = audio_array.astype(np.float32) |
|
audio_array = np.clip(audio_array, -1.0, 1.0) |
|
|
|
|
|
if sample_rate != self.tts_sample_rate: |
|
logger.info(f"Resampling audio to match input rate") |
|
from scipy import signal |
|
samples = len(audio_array) |
|
new_samples = int(samples * sample_rate / self.tts_sample_rate) |
|
audio_array = signal.resample(audio_array, new_samples) |
|
|
|
|
|
logger.info(f"Ensuring audio is 1D") |
|
if len(audio_array.shape) > 1: |
|
audio_array = audio_array.squeeze() |
|
|
|
return {"audio": audio_array, "text": reply} |
|
|
|
if __name__ == "__main__": |
|
recipe = InferenceRecipe(model_path="./models") |
|
|
|
sr = 16000 |
|
duration = 3 |
|
audio = np.zeros(int(sr * duration)) |
|
response = recipe.inference(audio, sr) |
|
|
|
print(f"Audio shape: {response['audio'].shape}, Range: [{response['audio'].min()}, {response['audio'].max()}]") |
|
print(f"Generated text: {response['text']}") |
|
|
|
|
|
sf.write( |
|
"response.wav", |
|
response['audio'], |
|
sr, |
|
format='WAV', |
|
subtype='FLOAT' |
|
) |