Keeby-smilyai commited on
Commit
73ec700
·
verified ·
1 Parent(s): b16b776

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +201 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,203 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import (
4
+ pipeline,
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ )
8
+ import numpy as np
9
+ import time
10
+
11
+ # --- 0. Streamlit App Configuration ---
12
+ # Set the page configuration for a cleaner look
13
+ st.set_page_config(
14
+ page_title="Nova Voice Chat (Streamlit)",
15
+ layout="centered",
16
+ initial_sidebar_state="collapsed"
17
+ )
18
+
19
+ # A custom component is needed for microphone recording in Streamlit
20
+ try:
21
+ from st_audiorec import st_audiorec
22
+ except ImportError:
23
+ st.error("Please install the st-audiorec component: `pip install st-audiorec`")
24
+ st.stop()
25
+
26
+
27
+ # --- 1. Global Model Loading (Cached) ---
28
+ @st.cache_resource
29
+ def load_models():
30
+ """Loads all models and pipes, cached globally."""
31
+ with st.spinner("Loading AI models..."):
32
+ print("Loading models...")
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ # 1. Speech-to-Text (STT)
36
+ stt_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device=device)
37
+
38
+ # 2. Large Language Model (LLM)
39
+ model_name = "Qwen/Qwen2-0.5B-Instruct"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+ if tokenizer.pad_token is None:
42
+ tokenizer.pad_token = tokenizer.eos_token
43
+
44
+ llm_model = AutoModelForCausalLM.from_pretrained(
45
+ model_name,
46
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
47
+ device_map="auto"
48
+ )
49
+
50
+ # 3. Text-to-Speech (TTS)
51
+ tts_pipe = pipeline("text-to-speech", model="facebook/mms-tts-eng", device=device)
52
+ print("Models loaded.")
53
+ st.success("Models loaded!")
54
+
55
+ return stt_pipe, tokenizer, llm_model, tts_pipe
56
+
57
+ # Load the models once
58
+ try:
59
+ STT_PIPE, TOKENIZER, LLM_MODEL, TTS_PIPE = load_models()
60
+ except Exception as e:
61
+ st.error(f"Failed to load models. Please check your hardware and dependencies. Error: {e}")
62
+ st.stop()
63
+
64
+
65
+ # --- 2. State Initialization and Functions ---
66
+
67
+ def get_initial_chat_history():
68
+ """Returns the initial chat history structure."""
69
+ return [
70
+ {
71
+ "role": "system",
72
+ "content": "You are Nova, an AI assistant. You are friendly and helpful. Respond naturally, as if in a conversation."
73
+ }
74
+ ]
75
+
76
+ # Initialize session state for chat history
77
+ if 'chat_history' not in st.session_state:
78
+ st.session_state.chat_history = get_initial_chat_history()
79
+
80
+ # Placeholder for the status text
81
+ if 'status_text' not in st.session_state:
82
+ st.session_state.status_text = "I'm listening..."
83
+
84
+ # Placeholder for the audio playback element
85
+ if 'audio_to_play' not in st.session_state:
86
+ st.session_state.audio_to_play = None
87
+
88
+
89
+ def process_audio_file(wav_audio_data):
90
+ """
91
+ Handles the entire voice interaction flow: STT -> LLM -> TTS.
92
+ This function is called when a recording is finished.
93
+ """
94
+ if wav_audio_data is None:
95
+ st.session_state.status_text = "Didn't catch that. Try again."
96
+ st.session_state.audio_to_play = None
97
+ st.rerun() # Rerun to update status
98
+ return
99
+
100
+ st.session_state.status_text = "Thinking..."
101
+ st.rerun()
102
+
103
+ try:
104
+ # Save the audio data to a temporary file for the STT pipe
105
+ # The Gradio version received a file path, st_audiorec gives raw bytes.
106
+ import tempfile
107
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
108
+ tmp_file.write(wav_audio_data)
109
+ audio_filepath = tmp_file.name
110
+
111
+ # 1. Speech-to-Text (STT)
112
+ result = STT_PIPE(audio_filepath)
113
+ transcript = result["text"].strip() if result and result["text"] else ""
114
+
115
+ if not transcript:
116
+ st.session_state.status_text = "I couldn't hear anything clearly."
117
+ st.session_state.audio_to_play = None
118
+ st.rerun()
119
+ return
120
+
121
+ # 2. LLM Inference
122
+ st.session_state.chat_history.append({"role": "user", "content": transcript})
123
+
124
+ # Manage context length (keep system prompt + last 9 exchanges)
125
+ if len(st.session_state.chat_history) > 10:
126
+ st.session_state.chat_history = [st.session_state.chat_history[0]] + st.session_state.chat_history[-9:]
127
+
128
+ text = TOKENIZER.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
129
+ model_inputs = TOKENIZER([text], return_tensors="pt").to(LLM_MODEL.device)
130
+
131
+ with torch.no_grad():
132
+ generated_ids = LLM_MODEL.generate(**model_inputs, max_new_tokens=256, pad_token_id=TOKENIZER.eos_token_id)
133
+
134
+ response_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
135
+ response = TOKENIZER.decode(response_ids, skip_special_tokens=True).strip()
136
+
137
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
138
+
139
+ # 3. Text-to-Speech (TTS)
140
+ st.session_state.status_text = "Responding..."
141
+ st.rerun()
142
+
143
+ speech = TTS_PIPE(response)
144
+
145
+ # Save audio data and trigger playback
146
+ audio_data = (speech["sampling_rate"], speech["audio"].astype(np.float32))
147
+
148
+ st.session_state.audio_to_play = audio_data
149
+ st.session_state.status_text = "I'm listening..."
150
+ st.rerun()
151
+
152
+ except Exception as e:
153
+ print(f"Error in process_audio: {e}")
154
+ st.session_state.status_text = "An error occurred."
155
+ st.session_state.audio_to_play = None
156
+ st.rerun()
157
+
158
+
159
+ # --- 3. Streamlit Interface ---
160
+
161
+ st.title("Nova Voice Chat 🎤")
162
+ st.markdown("---")
163
+
164
+ # The custom microphone recording component
165
+ wav_audio_data = st_audiorec()
166
+
167
+ # When a recording is completed, the component returns the audio data,
168
+ # which triggers the processing function.
169
+ if wav_audio_data is not None:
170
+ process_audio_file(wav_audio_data)
171
+
172
+
173
+ # Status Text Display
174
+ st.markdown(
175
+ f'<h2 style="text-align: center; color: #1A73E8;">{st.session_state.status_text}</h2>',
176
+ unsafe_allow_html=True
177
+ )
178
+ st.markdown(
179
+ '<p style="text-align: center; color: #8C8C8C; font-size: 14px;">Qwen2-0.5B-Instruct</p>',
180
+ unsafe_allow_html=True
181
+ )
182
+
183
+ # Audio Playback
184
+ # Display the audio player only when there's new audio to play
185
+ if st.session_state.audio_to_play is not None:
186
+ sampling_rate, audio_array = st.session_state.audio_to_play
187
+ st.audio(audio_array, sample_rate=sampling_rate)
188
+
189
+ # Reset the audio state after playback starts (or immediately, as Streamlit reruns)
190
+ st.session_state.audio_to_play = None
191
+
192
+ # Optional: Display chat history in the sidebar
193
+ with st.sidebar:
194
+ st.subheader("Chat History")
195
+ # Display the conversation (excluding the system prompt)
196
+ for message in st.session_state.chat_history[1:]:
197
+ with st.chat_message(message["role"]):
198
+ st.write(message["content"])
199
 
200
+ if st.button("Reset Chat"):
201
+ st.session_state.chat_history = get_initial_chat_history()
202
+ st.session_state.status_text = "Chat reset. I'm listening..."
203
+ st.rerun()