Spaces:
Sleeping
Sleeping
| from fastrtc import ( | |
| ReplyOnPause, AdditionalOutputs, Stream, | |
| audio_to_bytes, aggregate_bytes_to_16bit | |
| ) | |
| import gradio as gr | |
| import time | |
| import numpy as np | |
| import torch | |
| import os | |
| import tempfile | |
| from transformers import ( | |
| AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForCausalLM | |
| ) | |
| from gtts import gTTS | |
| from scipy.io import wavfile | |
| # Check if CUDA is available, otherwise use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Step 1: Audio transcription with Whisper | |
| def load_asr_model(): | |
| model_id = "openai/whisper-small" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| max_new_tokens=128, | |
| chunk_length_s=30, | |
| batch_size=16, | |
| return_timestamps=False, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| # Step 2: Text generation with a smaller LLM | |
| def load_llm_model(): | |
| model_id = "facebook/opt-1.3b" | |
| # Load tokenizer with special attention to the padding token | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Print initial configuration | |
| print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}") | |
| # For OPT models specifically - configure tokenizer before loading model | |
| if tokenizer.pad_token is None: | |
| # Use a completely different token as pad token - must be done before model loading | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| # Ensure pad token is really different from EOS token | |
| assert tokenizer.pad_token_id != tokenizer.eos_token_id, "Pad token still same as EOS token!" | |
| print(f"Added special PAD token with ID {tokenizer.pad_token_id} (different from EOS: {tokenizer.eos_token_id})") | |
| # Load model with the knowledge that tokenizer may have been modified | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Resize embeddings to match tokenizer | |
| model.resize_token_embeddings(len(tokenizer)) | |
| # CRITICAL: Make sure model config knows about the pad token | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| # OPT models need this explicit configuration | |
| if hasattr(model.config, "word_embed_proj_dim"): | |
| model.config._remove_wrong_keys = False | |
| # Move model to device | |
| model.to(device) | |
| print(f"Final token setup - Pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})") | |
| print(f"Model config pad_token_id: {model.config.pad_token_id}") | |
| return model, tokenizer | |
| # Step 3: Text-to-Speech with gTTS (Google Text-to-Speech) | |
| def gtts_text_to_speech(text): | |
| """Convert text to speech using gTTS and ensure proper WAV format.""" | |
| # Import numpy and wavfile at the function level to ensure they're available in all code paths | |
| import numpy as np | |
| from scipy.io import wavfile | |
| # Create absolute paths for temporary files | |
| temp_dir = tempfile.gettempdir() | |
| mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3") | |
| wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav") | |
| try: | |
| # Make sure text is not empty | |
| if not text or text.isspace(): | |
| text = "I don't have a response for that." | |
| # Create gTTS object and save to MP3 | |
| tts = gTTS(text=text, lang='en', slow=False) | |
| tts.save(mp3_filename) | |
| print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}") | |
| # Try multiple methods to convert MP3 to WAV | |
| wav_created = False | |
| # Method 1: Try ffmpeg (most reliable) | |
| try: | |
| import subprocess | |
| cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename] | |
| print(f"Running ffmpeg command: {' '.join(cmd)}") | |
| result = subprocess.run( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| check=True | |
| ) | |
| if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: | |
| print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}") | |
| wav_created = True | |
| else: | |
| print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}") | |
| except Exception as e: | |
| print(f"ffmpeg conversion failed: {str(e)}") | |
| # Method 2: Try pydub if ffmpeg failed | |
| if not wav_created: | |
| try: | |
| from pydub import AudioSegment | |
| print("Converting MP3 to WAV using pydub...") | |
| sound = AudioSegment.from_mp3(mp3_filename) | |
| sound = sound.set_frame_rate(24000).set_channels(1) | |
| sound.export(wav_filename, format="wav") | |
| if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: | |
| print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}") | |
| wav_created = True | |
| else: | |
| print(f"pydub ran but WAV file is missing or too small") | |
| except Exception as e: | |
| print(f"pydub conversion failed: {str(e)}") | |
| # Method 3: Direct WAV creation | |
| if not wav_created: | |
| try: | |
| print("Generating synthetic speech directly...") | |
| # Generate a simple speech-like tone pattern | |
| sample_rate = 24000 | |
| duration = len(text) * 0.075 # Approx timing | |
| t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) | |
| # Create a speech-like tone with some variation | |
| frequencies = [220, 440, 330, 550] | |
| audio = np.zeros_like(t) | |
| for i, freq in enumerate(frequencies): | |
| audio += 0.2 * np.sin(2 * np.pi * freq * t + i) | |
| # Add some envelope | |
| envelope = np.ones_like(t) | |
| attack = int(0.01 * sample_rate) | |
| release = int(0.1 * sample_rate) | |
| envelope[:attack] = np.linspace(0, 1, attack) | |
| envelope[-release:] = np.linspace(1, 0, release) | |
| audio = audio * envelope | |
| # Normalize and convert to int16 | |
| audio = audio / np.max(np.abs(audio)) | |
| audio = (audio * 32767).astype(np.int16) | |
| # Save as WAV | |
| wavfile.write(wav_filename, sample_rate, audio) | |
| if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100: | |
| print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}") | |
| wav_created = True | |
| except Exception as e: | |
| print(f"Direct WAV creation failed: {str(e)}") | |
| # Read the WAV file if it was created | |
| if wav_created: | |
| try: | |
| # Add a small delay to ensure the file is fully written | |
| time.sleep(0.1) | |
| # Read WAV file with scipy | |
| print(f"Reading WAV file: {wav_filename}") | |
| sample_rate, audio_data = wavfile.read(wav_filename) | |
| # Convert to expected format | |
| audio_data = audio_data.reshape(1, -1).astype(np.int16) | |
| print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}") | |
| return (sample_rate, audio_data) | |
| except Exception as e: | |
| print(f"Error reading WAV file: {str(e)}") | |
| # If all else fails, generate a simple tone | |
| print("All methods failed. Falling back to synthetic audio tone") | |
| sample_rate = 24000 | |
| duration_sec = max(1, len(text) * 0.1) | |
| tone_length = int(sample_rate * duration_sec) | |
| audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate) | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| audio_data = audio_data.reshape(1, -1) | |
| return (sample_rate, audio_data) | |
| except Exception as e: | |
| print(f"Unexpected error in text-to-speech: {str(e)}") | |
| # Generate a simple tone as last resort | |
| sample_rate = 24000 | |
| audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate) | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| audio_data = audio_data.reshape(1, -1) | |
| return (sample_rate, audio_data) | |
| finally: | |
| # Clean up temporary files | |
| for filename in [mp3_filename, wav_filename]: | |
| try: | |
| if os.path.exists(filename): | |
| os.remove(filename) | |
| except Exception as e: | |
| print(f"Failed to remove temporary file {filename}: {str(e)}") | |
| # Initialize models | |
| print("Loading ASR model...") | |
| asr_pipeline = load_asr_model() | |
| print("Loading LLM model...") | |
| llm_model, llm_tokenizer = load_llm_model() | |
| # Chat history management | |
| chat_history = [] | |
| def generate_response(prompt): | |
| # If chat history is empty, add a system message | |
| if not chat_history: | |
| chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."}) | |
| # Add user message to history | |
| chat_history.append({"role": "user", "content": prompt}) | |
| # Build full prompt from chat history | |
| full_prompt = "" | |
| for message in chat_history: | |
| if message["role"] == "system": | |
| full_prompt += f"System: {message['content']}\n" | |
| elif message["role"] == "user": | |
| full_prompt += f"User: {message['content']}\n" | |
| elif message["role"] == "assistant": | |
| full_prompt += f"Assistant: {message['content']}\n" | |
| full_prompt += "Assistant: " | |
| # Use encode_plus which offers more control | |
| encoded_input = llm_tokenizer.encode_plus( | |
| full_prompt, | |
| return_tensors="pt", | |
| padding=False, # Don't pad here - we'll handle it manually | |
| add_special_tokens=True, | |
| return_attention_mask=True | |
| ) | |
| # Extract and move tensors to device | |
| input_ids = encoded_input["input_ids"].to(device) | |
| # Create attention mask explicitly - all 1s for a non-padded sequence | |
| attention_mask = torch.ones_like(input_ids).to(device) | |
| # Print for debugging | |
| print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}") | |
| # Generate with very explicit parameters for OPT models | |
| with torch.no_grad(): | |
| try: | |
| output = llm_model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, # Explicitly pass attention mask | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=llm_tokenizer.pad_token_id, # Explicitly set pad token ID | |
| eos_token_id=llm_tokenizer.eos_token_id, # Explicitly set EOS token ID | |
| use_cache=True, | |
| no_repeat_ngram_size=3, | |
| # Add these parameters specifically for OPT | |
| forced_bos_token_id=None, | |
| forced_eos_token_id=None, | |
| num_beams=1 # Simple greedy decoding with temperature | |
| ) | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| # Fallback with simpler parameters | |
| output = llm_model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| # Decode only the generated part (not the input) | |
| response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
| response_text = response_text.split("Assistant: ")[-1].strip() | |
| # Add assistant response to history | |
| chat_history.append({"role": "assistant", "content": response_text}) | |
| # Keep history manageable | |
| if len(chat_history) > 10: | |
| # Keep system message and last 9 exchanges | |
| chat_history.pop(1) | |
| return response_text | |
| def response(audio: tuple[int, np.ndarray]): | |
| # Step 1: Convert audio to float32 before passing to ASR | |
| sample_rate, audio_data = audio | |
| # Convert int16 audio to float32 | |
| audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0] | |
| # Speech-to-Text with correct data type | |
| transcript = asr_pipeline({ | |
| "sampling_rate": sample_rate, | |
| "raw": audio_float32 | |
| }) | |
| prompt = transcript["text"] | |
| print(f"Transcribed: {prompt}") | |
| # Step 2: Generate text response | |
| response_text = generate_response(prompt) | |
| print(f"Response: {response_text}") | |
| # Step 3: Text-to-Speech using gTTS | |
| sample_rate, audio_array = gtts_text_to_speech(response_text) | |
| # Convert to expected format and yield chunks | |
| chunk_size = int(sample_rate * 0.2) # 200ms chunks | |
| for i in range(0, audio_array.shape[1], chunk_size): | |
| chunk = audio_array[:, i:i+chunk_size] | |
| if chunk.size > 0: # Ensure we don't yield empty chunks | |
| yield (sample_rate, chunk) | |
| stream = Stream( | |
| modality="audio", | |
| mode="send-receive", | |
| handler=ReplyOnPause(response), | |
| ) | |
| # For testing without WebRTC | |
| def demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Local Voice Chatbot") | |
| audio_input = gr.Audio(sources=["microphone"], type="numpy") | |
| audio_output = gr.Audio() | |
| def process_audio(audio): | |
| if audio is None: | |
| return None | |
| sample_rate, audio_array = audio | |
| # Convert to float32 for ASR | |
| audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0 | |
| transcript = asr_pipeline({ | |
| "sampling_rate": sample_rate, | |
| "raw": audio_float32 | |
| }) | |
| prompt = transcript["text"] | |
| print(f"Transcribed: {prompt}") | |
| response_text = generate_response(prompt) | |
| print(f"Response: {response_text}") | |
| sample_rate, audio_array = gtts_text_to_speech(response_text) | |
| return (sample_rate, audio_array[0]) | |
| audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output]) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC") | |
| args = parser.parse_args() | |
| # hugging face issues | |
| demo() | |
| # if args.demo: | |
| # demo() | |
| # else: | |
| # # For running with FastRTC | |
| # # You would need to add your FastRTC server code here | |
| # pass |