import streamlit as st from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration from typing import List from langchain_community.llms import HuggingFaceEndpoint from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain from langchain_community.vectorstores import FAISS from langchain.prompts import PromptTemplate import os from dotenv import load_dotenv import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry import whisper import numpy as np import av import time # Added import time import queue # Load environment variables load_dotenv() # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "audio_buffer" not in st.session_state: st.session_state.audio_buffer = queue.Queue() if 'recording' not in st.session_state: st.session_state.recording = False if 'webrtc_ctx' not in st.session_state: st.session_state.webrtc_ctx = None # Prompt template PROMPT_TEMPLATE = """ [INST] You are a professional therapist who speaks Moroccan Arabic (Darija). Act as a compassionate therapist and provide empathetic responses using therapeutic techniques. Always respond in Darija unless specifically asked otherwise. Previous conversation: {chat_history} User message: {question} Context: {context} [/INST] """ # Setup retry strategy retry_strategy = Retry( total=3, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] ) session = requests.Session() session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) # Initialize models whisper_model = whisper.load_model("base") llm = HuggingFaceEndpoint( endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", task="text-generation", temperature=0.7, do_sample=True, return_full_text=False, max_new_tokens=2048, top_p=0.9, repetition_penalty=1.2, model_kwargs={ "return_text": True, "stop": [""] }, huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"), client=session ) # Setup memory and conversation chain memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) embeddings = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-large-en" ) vectorstore = FAISS.from_texts( ["Initial therapeutic context"], embeddings ) qa_prompt = PromptTemplate( template=PROMPT_TEMPLATE, input_variables=["context", "chat_history", "question"] ) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=vectorstore.as_retriever(), memory=memory, combine_docs_chain_kwargs={"prompt": qa_prompt}, return_source_documents=False, # Changed to False chain_type="stuff" ) def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame: audio = frame.to_ndarray().flatten() st.session_state.audio_buffer.put(audio) return frame def get_ai_response(user_input: str) -> str: max_retries = 3 for attempt in range(max_retries): try: if not user_input or len(user_input.strip()) == 0: return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك." if len(user_input) > 512: user_input = user_input[:512] # Update response handling response = conversation_chain({"question": user_input}) if not response: if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return "عذراً، كاين مشكل. حاول مرة أخرى." return response['answer'] except requests.exceptions.HTTPError as e: if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر." except Exception as e: st.error(f"Error: {str(e)}") if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return "عذراً، كاين شي مشكل. حاول مرة أخرى." def process_message(user_input: str) -> None: st.session_state.messages.append({"role": "user", "content": user_input}) with st.spinner("جاري التفكير..."): ai_response = get_ai_response(user_input) if ai_response: st.session_state.messages.append({"role": "assistant", "content": ai_response}) def main(): st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠") st.title("Darija AI Therapist 🧠") st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك") col1, col2 = st.columns([9, 1]) with col1: user_input = st.text_input("اكتب رسالتك هنا:", key="text_input") with col2: if st.session_state.recording: mic_icon = "🛑" else: mic_icon = "🎤" if st.button(mic_icon): st.session_state.recording = not st.session_state.recording if st.session_state.recording: st.session_state.audio_buffer = queue.Queue() st.session_state.webrtc_ctx = webrtc_streamer( key="speech-to-text", mode=WebRtcMode.SENDONLY, audio_receiver_size=256, rtc_configuration=RTCConfiguration( {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} ), media_stream_constraints={"video": False, "audio": True}, async_processing=True, audio_frame_callback=audio_frame_callback, ) else: st.info("🔄 Processing audio...") audio_frames = [] while not st.session_state.audio_buffer.empty(): audio_frames.append(st.session_state.audio_buffer.get()) if audio_frames: audio_data = np.concatenate(audio_frames, axis=0).flatten() # Convert to 16-bit integers audio_data_int16 = (audio_data * 32767).astype(np.int16) # Use Whisper to transcribe result = whisper_model.transcribe(audio_data_int16, fp16=False) text = result.get("text", "") if text: process_message(text) else: st.warning("ما فهمتش الصوت. حاول مرة أخرى.") st.session_state.audio_buffer = queue.Queue() else: st.warning("ما تسجلش الصوت. حاول مرة أخرى.") if st.session_state.webrtc_ctx: st.session_state.webrtc_ctx.stop() st.session_state.webrtc_ctx = None if st.session_state.recording: st.info("🎙️ Recording...") else: st.empty() if user_input: process_message(user_input) # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) if __name__ == "__main__": main()