shangeth's picture
Create app.py
af42712 verified
raw
history blame
No virus
5.25 kB
import streamlit as st
import torchaudio
import io
import matplotlib.pyplot as plt
import time # Import the time module
from audio_recorder_streamlit import audio_recorder
from trainer import SpeechLLMLightning
import re
import json
import whisper
import re
from transformers import AutoProcessor
# Function to load the model and tokenizer
def plot_mel_spectrogram(mel_spec):
plt.figure(figsize=(10, 4))
plt.imshow(mel_spec.squeeze().cpu().numpy(), aspect='auto', origin='lower')
plt.colorbar(format='%+2.0f dB')
plt.title('Mel Spectrogram')
plt.tight_layout()
st.pyplot(plt)
def get_or_load_model():
if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
tokenizer = model.llm_tokenizer
model.eval()
model.freeze()
model.to('cuda')
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
return st.session_state.model, st.session_state.tokenizer, st.session_state.processor
def extract_dictionary(input_string):
json_str_match = re.search(r'\{.*\}', input_string)
if not json_str_match:
print(input_string)
return "No valid JSON found."
json_str = json_str_match.group(0)
json_str = re.sub(r'(?<=\{|\,)\s*([^\"{}\[\]\s]+)\s*:', r'"\1":', json_str) # Fix unquoted keys
json_str = re.sub(r',\s*([\}\]])', r'\1', json_str) # Remove trailing commas
try:
data_dict = json.loads(json_str)
return data_dict
except json.JSONDecodeError as e:
return f"Error parsing JSON: {str(e)}"
pre_speech_prompt = '''Instruction:
Give me the following information about the speech [Transcript, Gender, Age, Emotion, Accent]
Input:
<speech>'''
post_speech_prompt = f'''</speech>
Output:'''
# Function to generate a response from the model
def generate_response(mel, pre_speech_prompt, post_speech_prompt, model, tokenizer):
output_prompt = '\n<s>'
pre_tokenized_ids = tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
post_tokenized_ids = tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
output_tokenized_ids = tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
combined_embeds, atts, label_ids = model.encode(mel.cuda(), pre_tokenized_ids.cuda(), post_tokenized_ids.cuda(), output_tokenized_ids.cuda())
start_time = time.time() # Record start time
out = model.llm_model.generate(
inputs_embeds=combined_embeds,
max_new_tokens=2000,
).cpu().tolist()[0]
end_time = time.time() # Record end time
latency = (end_time - start_time) * 1000 # Calculate latency in milliseconds
output_text = tokenizer.decode(out, skip_special_tokens=True)
return output_text, latency
def extract_prediction_values(self, input_string):
json_str_match = re.search(r'<s>\s*\{.*?\}\s*</s>', input_string)
try:
json_str = json_str_match.group(0)
except:
json_str = '{}'
return self.extract_dictionary(json_str)
# Load model and tokenizer once and store them in session_state
model, tokenizer, processor = get_or_load_model()
# Streamlit UI components
st.title("Multi-Modal Speech LLM")
st.write("Record an audio file to get its transcription and other metadata.")
pre_prompt = st.text_area("Pre Speech Prompt:", value=pre_speech_prompt, height=150)
post_prompt = st.text_area("Post Speech Prompt:", value=post_speech_prompt, height=100)
# Audio recording
audio_data = audio_recorder(sample_rate=16000)
# Transcription process
if audio_data is not None:
with st.spinner('Transcribing...'):
try:
# Load audio data into a tensor
audio_buffer = io.BytesIO(audio_data)
st.audio(audio_data, format='audio/wav', start_time=0)
wav_tensor, sample_rate = torchaudio.load(audio_buffer)
wav_tensor = wav_tensor.to('cuda')
audio = wav_tensor.mean(0)
mel = whisper.log_mel_spectrogram(audio)
plot_mel_spectrogram(mel)
audio = processor(audio.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
# Process audio to get transcription
prediction, latency = generate_response(audio.cuda(), pre_prompt, post_prompt, model, tokenizer)
pred_dict = extract_dictionary(prediction)
user_utterance = '<user>' + pred_dict['Transcript']
# Display the transcription and latency
st.success('Transcription Complete')
st.text_area("LLM Output:", value=pred_dict, height=200, max_chars=None)
st.write(f"Latency in CPU: {latency:.2f} ms")
except Exception as e:
st.error(f"An error occurred during transcription: {e}")