Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torchaudio | |
import io | |
import matplotlib.pyplot as plt | |
import time # Import the time module | |
from audio_recorder_streamlit import audio_recorder | |
from trainer import SpeechLLMLightning | |
import re | |
import json | |
import whisper | |
import re | |
from transformers import AutoProcessor | |
# Function to load the model and tokenizer | |
def plot_mel_spectrogram(mel_spec): | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(mel_spec.squeeze().cpu().numpy(), aspect='auto', origin='lower') | |
plt.colorbar(format='%+2.0f dB') | |
plt.title('Mel Spectrogram') | |
plt.tight_layout() | |
st.pyplot(plt) | |
def get_or_load_model(): | |
if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state: | |
ckpt_path = "checkpoints/pretrained_checkpoint.ckpt" | |
model = SpeechLLMLightning.load_from_checkpoint(ckpt_path) | |
tokenizer = model.llm_tokenizer | |
model.eval() | |
model.freeze() | |
model.to('cuda') | |
st.session_state.model = model | |
st.session_state.tokenizer = tokenizer | |
st.session_state.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") | |
return st.session_state.model, st.session_state.tokenizer, st.session_state.processor | |
def extract_dictionary(input_string): | |
json_str_match = re.search(r'\{.*\}', input_string) | |
if not json_str_match: | |
print(input_string) | |
return "No valid JSON found." | |
json_str = json_str_match.group(0) | |
json_str = re.sub(r'(?<=\{|\,)\s*([^\"{}\[\]\s]+)\s*:', r'"\1":', json_str) # Fix unquoted keys | |
json_str = re.sub(r',\s*([\}\]])', r'\1', json_str) # Remove trailing commas | |
try: | |
data_dict = json.loads(json_str) | |
return data_dict | |
except json.JSONDecodeError as e: | |
return f"Error parsing JSON: {str(e)}" | |
pre_speech_prompt = '''Instruction: | |
Give me the following information about the speech [Transcript, Gender, Age, Emotion, Accent] | |
Input: | |
<speech>''' | |
post_speech_prompt = f'''</speech> | |
Output:''' | |
# Function to generate a response from the model | |
def generate_response(mel, pre_speech_prompt, post_speech_prompt, model, tokenizer): | |
output_prompt = '\n<s>' | |
pre_tokenized_ids = tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] | |
post_tokenized_ids = tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] | |
output_tokenized_ids = tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"] | |
combined_embeds, atts, label_ids = model.encode(mel.cuda(), pre_tokenized_ids.cuda(), post_tokenized_ids.cuda(), output_tokenized_ids.cuda()) | |
start_time = time.time() # Record start time | |
out = model.llm_model.generate( | |
inputs_embeds=combined_embeds, | |
max_new_tokens=2000, | |
).cpu().tolist()[0] | |
end_time = time.time() # Record end time | |
latency = (end_time - start_time) * 1000 # Calculate latency in milliseconds | |
output_text = tokenizer.decode(out, skip_special_tokens=True) | |
return output_text, latency | |
def extract_prediction_values(self, input_string): | |
json_str_match = re.search(r'<s>\s*\{.*?\}\s*</s>', input_string) | |
try: | |
json_str = json_str_match.group(0) | |
except: | |
json_str = '{}' | |
return self.extract_dictionary(json_str) | |
# Load model and tokenizer once and store them in session_state | |
model, tokenizer, processor = get_or_load_model() | |
# Streamlit UI components | |
st.title("Multi-Modal Speech LLM") | |
st.write("Record an audio file to get its transcription and other metadata.") | |
pre_prompt = st.text_area("Pre Speech Prompt:", value=pre_speech_prompt, height=150) | |
post_prompt = st.text_area("Post Speech Prompt:", value=post_speech_prompt, height=100) | |
# Audio recording | |
audio_data = audio_recorder(sample_rate=16000) | |
# Transcription process | |
if audio_data is not None: | |
with st.spinner('Transcribing...'): | |
try: | |
# Load audio data into a tensor | |
audio_buffer = io.BytesIO(audio_data) | |
st.audio(audio_data, format='audio/wav', start_time=0) | |
wav_tensor, sample_rate = torchaudio.load(audio_buffer) | |
wav_tensor = wav_tensor.to('cuda') | |
audio = wav_tensor.mean(0) | |
mel = whisper.log_mel_spectrogram(audio) | |
plot_mel_spectrogram(mel) | |
audio = processor(audio.squeeze(), return_tensors="pt", sampling_rate=16000).input_values | |
# Process audio to get transcription | |
prediction, latency = generate_response(audio.cuda(), pre_prompt, post_prompt, model, tokenizer) | |
pred_dict = extract_dictionary(prediction) | |
user_utterance = '<user>' + pred_dict['Transcript'] | |
# Display the transcription and latency | |
st.success('Transcription Complete') | |
st.text_area("LLM Output:", value=pred_dict, height=200, max_chars=None) | |
st.write(f"Latency in CPU: {latency:.2f} ms") | |
except Exception as e: | |
st.error(f"An error occurred during transcription: {e}") | |