shangeth commited on
Commit
af42712
1 Parent(s): 8972619

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torchaudio
3
+ import io
4
+ import matplotlib.pyplot as plt
5
+ import time # Import the time module
6
+
7
+ from audio_recorder_streamlit import audio_recorder
8
+ from trainer import SpeechLLMLightning
9
+ import re
10
+ import json
11
+
12
+ import whisper
13
+ import re
14
+ from transformers import AutoProcessor
15
+
16
+ # Function to load the model and tokenizer
17
+ def plot_mel_spectrogram(mel_spec):
18
+ plt.figure(figsize=(10, 4))
19
+ plt.imshow(mel_spec.squeeze().cpu().numpy(), aspect='auto', origin='lower')
20
+ plt.colorbar(format='%+2.0f dB')
21
+ plt.title('Mel Spectrogram')
22
+ plt.tight_layout()
23
+ st.pyplot(plt)
24
+
25
+ def get_or_load_model():
26
+ if 'model' not in st.session_state or 'tokenizer' not in st.session_state or 'processor' not in st.session_state:
27
+ ckpt_path = "checkpoints/pretrained_checkpoint.ckpt"
28
+ model = SpeechLLMLightning.load_from_checkpoint(ckpt_path)
29
+ tokenizer = model.llm_tokenizer
30
+ model.eval()
31
+ model.freeze()
32
+ model.to('cuda')
33
+ st.session_state.model = model
34
+ st.session_state.tokenizer = tokenizer
35
+
36
+ st.session_state.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
37
+ return st.session_state.model, st.session_state.tokenizer, st.session_state.processor
38
+
39
+ def extract_dictionary(input_string):
40
+ json_str_match = re.search(r'\{.*\}', input_string)
41
+ if not json_str_match:
42
+ print(input_string)
43
+ return "No valid JSON found."
44
+
45
+ json_str = json_str_match.group(0)
46
+
47
+ json_str = re.sub(r'(?<=\{|\,)\s*([^\"{}\[\]\s]+)\s*:', r'"\1":', json_str) # Fix unquoted keys
48
+ json_str = re.sub(r',\s*([\}\]])', r'\1', json_str) # Remove trailing commas
49
+
50
+ try:
51
+ data_dict = json.loads(json_str)
52
+ return data_dict
53
+ except json.JSONDecodeError as e:
54
+ return f"Error parsing JSON: {str(e)}"
55
+
56
+ pre_speech_prompt = '''Instruction:
57
+ Give me the following information about the speech [Transcript, Gender, Age, Emotion, Accent]
58
+
59
+ Input:
60
+ <speech>'''
61
+
62
+ post_speech_prompt = f'''</speech>
63
+
64
+ Output:'''
65
+
66
+ # Function to generate a response from the model
67
+ def generate_response(mel, pre_speech_prompt, post_speech_prompt, model, tokenizer):
68
+ output_prompt = '\n<s>'
69
+
70
+ pre_tokenized_ids = tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
71
+ post_tokenized_ids = tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
72
+ output_tokenized_ids = tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
73
+
74
+ combined_embeds, atts, label_ids = model.encode(mel.cuda(), pre_tokenized_ids.cuda(), post_tokenized_ids.cuda(), output_tokenized_ids.cuda())
75
+
76
+ start_time = time.time() # Record start time
77
+ out = model.llm_model.generate(
78
+ inputs_embeds=combined_embeds,
79
+ max_new_tokens=2000,
80
+ ).cpu().tolist()[0]
81
+ end_time = time.time() # Record end time
82
+
83
+ latency = (end_time - start_time) * 1000 # Calculate latency in milliseconds
84
+
85
+ output_text = tokenizer.decode(out, skip_special_tokens=True)
86
+ return output_text, latency
87
+
88
+ def extract_prediction_values(self, input_string):
89
+ json_str_match = re.search(r'<s>\s*\{.*?\}\s*</s>', input_string)
90
+ try:
91
+ json_str = json_str_match.group(0)
92
+ except:
93
+ json_str = '{}'
94
+ return self.extract_dictionary(json_str)
95
+
96
+ # Load model and tokenizer once and store them in session_state
97
+ model, tokenizer, processor = get_or_load_model()
98
+
99
+ # Streamlit UI components
100
+ st.title("Multi-Modal Speech LLM")
101
+ st.write("Record an audio file to get its transcription and other metadata.")
102
+
103
+ pre_prompt = st.text_area("Pre Speech Prompt:", value=pre_speech_prompt, height=150)
104
+ post_prompt = st.text_area("Post Speech Prompt:", value=post_speech_prompt, height=100)
105
+
106
+ # Audio recording
107
+ audio_data = audio_recorder(sample_rate=16000)
108
+
109
+ # Transcription process
110
+ if audio_data is not None:
111
+ with st.spinner('Transcribing...'):
112
+ try:
113
+ # Load audio data into a tensor
114
+ audio_buffer = io.BytesIO(audio_data)
115
+ st.audio(audio_data, format='audio/wav', start_time=0)
116
+ wav_tensor, sample_rate = torchaudio.load(audio_buffer)
117
+ wav_tensor = wav_tensor.to('cuda')
118
+ audio = wav_tensor.mean(0)
119
+ mel = whisper.log_mel_spectrogram(audio)
120
+ plot_mel_spectrogram(mel)
121
+
122
+ audio = processor(audio.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
123
+
124
+ # Process audio to get transcription
125
+ prediction, latency = generate_response(audio.cuda(), pre_prompt, post_prompt, model, tokenizer)
126
+ pred_dict = extract_dictionary(prediction)
127
+
128
+ user_utterance = '<user>' + pred_dict['Transcript']
129
+
130
+ # Display the transcription and latency
131
+ st.success('Transcription Complete')
132
+ st.text_area("LLM Output:", value=pred_dict, height=200, max_chars=None)
133
+ st.write(f"Latency in CPU: {latency:.2f} ms")
134
+
135
+ except Exception as e:
136
+ st.error(f"An error occurred during transcription: {e}")