Keeby-smilyai commited on
Commit
a9da23a
·
verified ·
1 Parent(s): 41c0043

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +61 -54
src/streamlit_app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  # --- ENVIRONMENT SETUP ---
4
  # Set cache and config directories to writable paths
5
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
6
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
7
  os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit' # Fix permission error
8
 
9
  import streamlit as st
@@ -30,12 +29,13 @@ st.set_page_config(
30
  # --- CSS STYLING ---
31
  st.markdown("""
32
  <style>
33
- /* Hide Streamlit header, footer, and menu */
34
  #MainMenu {visibility: hidden;}
35
  footer {visibility: hidden;}
36
  header {visibility: hidden;}
37
- .st-emotion-cache-1yfpk7 {display: none;}
38
- .st-emotion-cache-1a2p8c {display: none;}
 
39
  /* Full-screen body */
40
  body {
41
  background-color: #f0f2f6;
@@ -137,59 +137,63 @@ def get_session_state():
137
 
138
  session_state = get_session_state()
139
 
140
- # --- MODEL LOADING ---
141
- def load_models():
142
- if not st.session_state.models_loaded:
143
- # Hide spinner for seamless experience (CSS hides it anyway)
144
- with st.spinner(text=""):
145
- st.session_state.models = {}
146
- device = 0 if torch.cuda.is_available() else -1
147
- dtype = torch.float16 if device == 0 else torch.float32
148
-
149
- # Speech-to-Text: Load Whisper manually to ignore mismatched sizes
150
- whisper_model = WhisperForConditionalGeneration.from_pretrained(
151
- "openai/whisper-small",
152
- device_map="auto" if device == 0 else None,
153
- dtype=dtype,
154
- ignore_mismatched_sizes=True,
155
- low_cpu_mem_usage=True
156
- )
157
- whisper_tokenizer = AutoTokenizer.from_pretrained("openai/whisper-small")
158
- whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
159
- st.session_state.models['stt'] = pipeline(
160
- "automatic-speech-recognition",
161
- model=whisper_model,
162
- tokenizer=whisper_tokenizer,
163
- feature_extractor=whisper_feature_extractor,
164
- device=device
165
- )
166
 
167
- # Large Language Model: Load Qwen manually
168
- qwen_model = AutoModelForCausalLM.from_pretrained(
169
- "Qwen/Qwen2-0.5B-Instruct",
170
- device_map="auto" if device == 0 else None,
171
- dtype=dtype,
172
- ignore_mismatched_sizes=True,
173
- low_cpu_mem_usage=True,
174
- trust_remote_code=True
175
- )
176
- qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True)
177
- st.session_state.models['llm'] = pipeline(
178
- "text-generation",
179
- model=qwen_model,
180
- tokenizer=qwen_tokenizer,
181
- device=device
182
- )
 
 
 
 
 
183
 
184
- # Text-to-Speech: Load TTS manually if needed, but usually no mismatch
185
- st.session_state.models['tts'] = pipeline(
186
- "text-to-speech", model="nineninesix/kani-tts-450m-0.1-pt", device=device
187
- )
188
- st.session_state.models_loaded = True
189
-
190
- return st.session_state.models['stt'], st.session_state.models['llm'], st.session_state.models['tts']
191
 
192
- stt_pipe, llm_pipe, tts_pipe = load_models()
 
 
 
193
 
194
  # --- CORE FUNCTIONS ---
195
  def transcribe_audio(audio_bytes):
@@ -246,6 +250,9 @@ def update_ui_status(status):
246
  circle_class = status
247
  circle_placeholder.markdown(f'<div class="voice-circle {circle_class}"></div>', unsafe_allow_html=True)
248
 
 
 
 
249
  # --- APP LOGIC: Auto mode ---
250
  # Always listening - no chat history shown
251
 
 
3
  # --- ENVIRONMENT SETUP ---
4
  # Set cache and config directories to writable paths
5
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
 
6
  os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit' # Fix permission error
7
 
8
  import streamlit as st
 
29
  # --- CSS STYLING ---
30
  st.markdown("""
31
  <style>
32
+ /* Hide Streamlit header, footer, and menu - more robust selectors */
33
  #MainMenu {visibility: hidden;}
34
  footer {visibility: hidden;}
35
  header {visibility: hidden;}
36
+ /* Target common Streamlit emotion cache classes (update if version changes) */
37
+ [data-testid="stHeader"] {display: none;}
38
+ [data-testid="stFooter"] {display: none;}
39
  /* Full-screen body */
40
  body {
41
  background-color: #f0f2f6;
 
137
 
138
  session_state = get_session_state()
139
 
140
+ # --- MODEL LOADING WITH CACHE ---
141
+ @st.cache_resource
142
+ def load_stt_model():
143
+ device = 0 if torch.cuda.is_available() else -1
144
+ dtype = torch.float16 if device == 0 else torch.float32
145
+
146
+ # Speech-to-Text: Load Whisper manually to ignore mismatched sizes
147
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(
148
+ "openai/whisper-small",
149
+ device_map="auto" if device == 0 else None,
150
+ dtype=dtype,
151
+ ignore_mismatched_sizes=True,
152
+ low_cpu_mem_usage=True
153
+ )
154
+ whisper_tokenizer = AutoTokenizer.from_pretrained("openai/whisper-small")
155
+ whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
156
+ return pipeline(
157
+ "automatic-speech-recognition",
158
+ model=whisper_model,
159
+ tokenizer=whisper_tokenizer,
160
+ feature_extractor=whisper_feature_extractor,
161
+ device=device
162
+ )
 
 
 
163
 
164
+ @st.cache_resource
165
+ def load_llm_model():
166
+ device = 0 if torch.cuda.is_available() else -1
167
+ dtype = torch.float16 if device == 0 else torch.float32
168
+
169
+ # Large Language Model: Load Qwen manually
170
+ qwen_model = AutoModelForCausalLM.from_pretrained(
171
+ "Qwen/Qwen2-0.5B-Instruct",
172
+ device_map="auto" if device == 0 else None,
173
+ dtype=dtype,
174
+ ignore_mismatched_sizes=True,
175
+ low_cpu_mem_usage=True,
176
+ trust_remote_code=True
177
+ )
178
+ qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True)
179
+ return pipeline(
180
+ "text-generation",
181
+ model=qwen_model,
182
+ tokenizer=qwen_tokenizer,
183
+ device=device
184
+ )
185
 
186
+ @st.cache_resource
187
+ def load_tts_model():
188
+ device = 0 if torch.cuda.is_available() else -1
189
+ return pipeline(
190
+ "text-to-speech", model="nineninesix/kani-tts-450m-0.1-pt", device=device
191
+ )
 
192
 
193
+ # Load models using cache
194
+ stt_pipe = load_stt_model()
195
+ llm_pipe = load_llm_model()
196
+ tts_pipe = load_tts_model()
197
 
198
  # --- CORE FUNCTIONS ---
199
  def transcribe_audio(audio_bytes):
 
250
  circle_class = status
251
  circle_placeholder.markdown(f'<div class="voice-circle {circle_class}"></div>', unsafe_allow_html=True)
252
 
253
+ # Debug: Temporary message to confirm rendering (remove after testing)
254
+ st.write("App loaded - if you see this, basic rendering works. Circle should appear below.")
255
+
256
  # --- APP LOGIC: Auto mode ---
257
  # Always listening - no chat history shown
258