chaty / app.py
hashhac
try
8e6480a
from fastrtc import (
ReplyOnPause, AdditionalOutputs, Stream,
audio_to_bytes, aggregate_bytes_to_16bit
)
import gradio as gr
import time
import numpy as np
import torch
import os
import tempfile
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
from gtts import gTTS
from scipy.io import wavfile
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Step 1: Audio transcription with Whisper
def load_asr_model():
model_id = "openai/whisper-small"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=False,
torch_dtype=torch_dtype,
device=device,
)
# Step 2: Text generation with a smaller LLM
def load_llm_model():
model_id = "facebook/opt-1.3b"
# Load tokenizer with special attention to the padding token
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Print initial configuration
print(f"Initial pad token ID: {tokenizer.pad_token_id}, EOS token ID: {tokenizer.eos_token_id}")
# For OPT models specifically - configure tokenizer before loading model
if tokenizer.pad_token is None:
# Use a completely different token as pad token - must be done before model loading
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Ensure pad token is really different from EOS token
assert tokenizer.pad_token_id != tokenizer.eos_token_id, "Pad token still same as EOS token!"
print(f"Added special PAD token with ID {tokenizer.pad_token_id} (different from EOS: {tokenizer.eos_token_id})")
# Load model with the knowledge that tokenizer may have been modified
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True
)
# Resize embeddings to match tokenizer
model.resize_token_embeddings(len(tokenizer))
# CRITICAL: Make sure model config knows about the pad token
model.config.pad_token_id = tokenizer.pad_token_id
# OPT models need this explicit configuration
if hasattr(model.config, "word_embed_proj_dim"):
model.config._remove_wrong_keys = False
# Move model to device
model.to(device)
print(f"Final token setup - Pad token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
print(f"Model config pad_token_id: {model.config.pad_token_id}")
return model, tokenizer
# Step 3: Text-to-Speech with gTTS (Google Text-to-Speech)
def gtts_text_to_speech(text):
"""Convert text to speech using gTTS and ensure proper WAV format."""
# Import numpy and wavfile at the function level to ensure they're available in all code paths
import numpy as np
from scipy.io import wavfile
# Create absolute paths for temporary files
temp_dir = tempfile.gettempdir()
mp3_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.mp3")
wav_filename = os.path.join(temp_dir, f"tts_temp_{os.getpid()}_{time.time()}.wav")
try:
# Make sure text is not empty
if not text or text.isspace():
text = "I don't have a response for that."
# Create gTTS object and save to MP3
tts = gTTS(text=text, lang='en', slow=False)
tts.save(mp3_filename)
print(f"MP3 file created: {mp3_filename}, size: {os.path.getsize(mp3_filename)}")
# Try multiple methods to convert MP3 to WAV
wav_created = False
# Method 1: Try ffmpeg (most reliable)
try:
import subprocess
cmd = ['ffmpeg', '-y', '-i', mp3_filename, '-acodec', 'pcm_s16le', '-ar', '24000', '-ac', '1', wav_filename]
print(f"Running ffmpeg command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True
)
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created with ffmpeg: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
else:
print(f"ffmpeg ran but WAV file is missing or too small: {wav_filename}")
except Exception as e:
print(f"ffmpeg conversion failed: {str(e)}")
# Method 2: Try pydub if ffmpeg failed
if not wav_created:
try:
from pydub import AudioSegment
print("Converting MP3 to WAV using pydub...")
sound = AudioSegment.from_mp3(mp3_filename)
sound = sound.set_frame_rate(24000).set_channels(1)
sound.export(wav_filename, format="wav")
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created with pydub: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
else:
print(f"pydub ran but WAV file is missing or too small")
except Exception as e:
print(f"pydub conversion failed: {str(e)}")
# Method 3: Direct WAV creation
if not wav_created:
try:
print("Generating synthetic speech directly...")
# Generate a simple speech-like tone pattern
sample_rate = 24000
duration = len(text) * 0.075 # Approx timing
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
# Create a speech-like tone with some variation
frequencies = [220, 440, 330, 550]
audio = np.zeros_like(t)
for i, freq in enumerate(frequencies):
audio += 0.2 * np.sin(2 * np.pi * freq * t + i)
# Add some envelope
envelope = np.ones_like(t)
attack = int(0.01 * sample_rate)
release = int(0.1 * sample_rate)
envelope[:attack] = np.linspace(0, 1, attack)
envelope[-release:] = np.linspace(1, 0, release)
audio = audio * envelope
# Normalize and convert to int16
audio = audio / np.max(np.abs(audio))
audio = (audio * 32767).astype(np.int16)
# Save as WAV
wavfile.write(wav_filename, sample_rate, audio)
if os.path.exists(wav_filename) and os.path.getsize(wav_filename) > 100:
print(f"WAV file successfully created directly: {wav_filename}, size: {os.path.getsize(wav_filename)}")
wav_created = True
except Exception as e:
print(f"Direct WAV creation failed: {str(e)}")
# Read the WAV file if it was created
if wav_created:
try:
# Add a small delay to ensure the file is fully written
time.sleep(0.1)
# Read WAV file with scipy
print(f"Reading WAV file: {wav_filename}")
sample_rate, audio_data = wavfile.read(wav_filename)
# Convert to expected format
audio_data = audio_data.reshape(1, -1).astype(np.int16)
print(f"WAV file read successfully, shape: {audio_data.shape}, sample rate: {sample_rate}")
return (sample_rate, audio_data)
except Exception as e:
print(f"Error reading WAV file: {str(e)}")
# If all else fails, generate a simple tone
print("All methods failed. Falling back to synthetic audio tone")
sample_rate = 24000
duration_sec = max(1, len(text) * 0.1)
tone_length = int(sample_rate * duration_sec)
audio_data = np.sin(2 * np.pi * np.arange(tone_length) * 440 / sample_rate)
audio_data = (audio_data * 32767).astype(np.int16)
audio_data = audio_data.reshape(1, -1)
return (sample_rate, audio_data)
except Exception as e:
print(f"Unexpected error in text-to-speech: {str(e)}")
# Generate a simple tone as last resort
sample_rate = 24000
audio_data = np.sin(2 * np.pi * np.arange(sample_rate) * 440 / sample_rate)
audio_data = (audio_data * 32767).astype(np.int16)
audio_data = audio_data.reshape(1, -1)
return (sample_rate, audio_data)
finally:
# Clean up temporary files
for filename in [mp3_filename, wav_filename]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Failed to remove temporary file {filename}: {str(e)}")
# Initialize models
print("Loading ASR model...")
asr_pipeline = load_asr_model()
print("Loading LLM model...")
llm_model, llm_tokenizer = load_llm_model()
# Chat history management
chat_history = []
def generate_response(prompt):
# If chat history is empty, add a system message
if not chat_history:
chat_history.append({"role": "system", "content": "You are a helpful, friendly AI assistant. Keep your responses concise and conversational."})
# Add user message to history
chat_history.append({"role": "user", "content": prompt})
# Build full prompt from chat history
full_prompt = ""
for message in chat_history:
if message["role"] == "system":
full_prompt += f"System: {message['content']}\n"
elif message["role"] == "user":
full_prompt += f"User: {message['content']}\n"
elif message["role"] == "assistant":
full_prompt += f"Assistant: {message['content']}\n"
full_prompt += "Assistant: "
# Use encode_plus which offers more control
encoded_input = llm_tokenizer.encode_plus(
full_prompt,
return_tensors="pt",
padding=False, # Don't pad here - we'll handle it manually
add_special_tokens=True,
return_attention_mask=True
)
# Extract and move tensors to device
input_ids = encoded_input["input_ids"].to(device)
# Create attention mask explicitly - all 1s for a non-padded sequence
attention_mask = torch.ones_like(input_ids).to(device)
# Print for debugging
print(f"Input shape: {input_ids.shape}, Attention mask shape: {attention_mask.shape}")
# Generate with very explicit parameters for OPT models
with torch.no_grad():
try:
output = llm_model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Explicitly pass attention mask
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=llm_tokenizer.pad_token_id, # Explicitly set pad token ID
eos_token_id=llm_tokenizer.eos_token_id, # Explicitly set EOS token ID
use_cache=True,
no_repeat_ngram_size=3,
# Add these parameters specifically for OPT
forced_bos_token_id=None,
forced_eos_token_id=None,
num_beams=1 # Simple greedy decoding with temperature
)
except Exception as e:
print(f"Error during generation: {e}")
# Fallback with simpler parameters
output = llm_model.generate(
input_ids=input_ids,
max_new_tokens=128,
do_sample=True,
temperature=0.7
)
# Decode only the generated part (not the input)
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True)
response_text = response_text.split("Assistant: ")[-1].strip()
# Add assistant response to history
chat_history.append({"role": "assistant", "content": response_text})
# Keep history manageable
if len(chat_history) > 10:
# Keep system message and last 9 exchanges
chat_history.pop(1)
return response_text
def response(audio: tuple[int, np.ndarray]):
# Step 1: Convert audio to float32 before passing to ASR
sample_rate, audio_data = audio
# Convert int16 audio to float32
audio_float32 = audio_data.flatten().astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0]
# Speech-to-Text with correct data type
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
# Step 2: Generate text response
response_text = generate_response(prompt)
print(f"Response: {response_text}")
# Step 3: Text-to-Speech using gTTS
sample_rate, audio_array = gtts_text_to_speech(response_text)
# Convert to expected format and yield chunks
chunk_size = int(sample_rate * 0.2) # 200ms chunks
for i in range(0, audio_array.shape[1], chunk_size):
chunk = audio_array[:, i:i+chunk_size]
if chunk.size > 0: # Ensure we don't yield empty chunks
yield (sample_rate, chunk)
stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(response),
)
# For testing without WebRTC
def demo():
with gr.Blocks() as demo:
gr.Markdown("# Local Voice Chatbot")
audio_input = gr.Audio(sources=["microphone"], type="numpy")
audio_output = gr.Audio()
def process_audio(audio):
if audio is None:
return None
sample_rate, audio_array = audio
# Convert to float32 for ASR
audio_float32 = audio_array.flatten().astype(np.float32) / 32768.0
transcript = asr_pipeline({
"sampling_rate": sample_rate,
"raw": audio_float32
})
prompt = transcript["text"]
print(f"Transcribed: {prompt}")
response_text = generate_response(prompt)
print(f"Response: {response_text}")
sample_rate, audio_array = gtts_text_to_speech(response_text)
return (sample_rate, audio_array[0])
audio_input.change(process_audio, inputs=[audio_input], outputs=[audio_output])
demo.launch()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--demo", action="store_true", help="Run Gradio demo instead of WebRTC")
args = parser.parse_args()
# hugging face issues
demo()
# if args.demo:
# demo()
# else:
# # For running with FastRTC
# # You would need to add your FastRTC server code here
# pass