Spaces:
Running
Running
import streamlit as st | |
import time | |
import asyncio | |
import nest_asyncio | |
import traceback | |
from typing import List, Dict, Any | |
import re # for extracting citation IDs | |
# --- Configuration and Service Initialization --- | |
try: | |
print("App: Loading config...") | |
import config | |
print("App: Loading utils...") | |
from utils import clean_source_text | |
print("App: Loading services...") | |
from services.retriever import init_retriever, get_retriever_status | |
from services.openai_service import init_openai_client, get_openai_status | |
print("App: Loading RAG processor...") | |
from rag_processor import execute_validate_generate_pipeline, PIPELINE_VALIDATE_GENERATE_GPT4O | |
print("App: Imports successful.") | |
except ImportError as e: | |
st.error(f"Fatal Error: Module import failed. {e}", icon="🚨") | |
traceback.print_exc() | |
st.stop() | |
except Exception as e: | |
st.error(f"Fatal Error during initial setup: {e}", icon="🚨") | |
traceback.print_exc() | |
st.stop() | |
nest_asyncio.apply() | |
# --- Initialize Required Services --- | |
print("App: Initializing services...") | |
try: | |
retriever_ready_init, retriever_msg_init = init_retriever() | |
openai_ready_init, openai_msg_init = init_openai_client() | |
print("App: Service initialization calls complete.") | |
except Exception as init_err: | |
st.error(f"Error during service initialization: {init_err}", icon="🔥") | |
traceback.print_exc() | |
# --- Streamlit Page Configuration and Styling --- | |
st.set_page_config(page_title="Divrey Yoel AI Chat (GPT-4o Gen)", layout="wide") | |
st.markdown("""<style> /* ... Keep existing styles ... */ </style>""", unsafe_allow_html=True) | |
st.markdown("<h1 class='rtl-text'> דברות קודש - חיפוש ועיון</h1>", unsafe_allow_html=True) | |
st.markdown("<p class='rtl-text'>מבוסס על ספרי דברי יואל מסאטמאר זצוק'ל זי'ע - אחזור מידע חכם (RAG)</p>", unsafe_allow_html=True) | |
st.markdown("<p class='rtl-text' style='font-size: 0.9em; color: #555;'>תהליך: אחזור -> אימות (GPT-4o) -> יצירה (GPT-4o)</p>", unsafe_allow_html=True) | |
# --- UI Helper Functions --- | |
def display_sidebar() -> Dict[str, Any]: | |
st.sidebar.markdown("<h3 class='rtl-text'>מצב המערכת</h3>", unsafe_allow_html=True) | |
retriever_ready, _ = get_retriever_status() | |
openai_ready, _ = get_openai_status() | |
st.sidebar.markdown( | |
f"<p class='rtl-text'><strong>מאחזר (Pinecone):</strong> {'✅' if retriever_ready else '❌'}</p>", | |
unsafe_allow_html=True | |
) | |
if not retriever_ready: | |
st.sidebar.error("מאחזר אינו זמין.", icon="🛑") | |
st.stop() | |
st.sidebar.markdown("<hr>", unsafe_allow_html=True) | |
st.sidebar.markdown( | |
f"<p class='rtl-text'><strong>OpenAI ({config.OPENAI_VALIDATION_MODEL} / {config.OPENAI_GENERATION_MODEL}):</strong> {'✅' if openai_ready else '❌'}</p>", | |
unsafe_allow_html=True | |
) | |
if not openai_ready: | |
st.sidebar.error("OpenAI אינו זמין.", icon="⚠️") | |
st.sidebar.markdown("<hr>", unsafe_allow_html=True) | |
st.sidebar.markdown("<h3 class='rtl-text'>הגדרות חיפוש</h3>", unsafe_allow_html=True) | |
n_retrieve = st.sidebar.slider("מספר פסקאות לאחזור", 1, 300, config.DEFAULT_N_RETRIEVE) | |
max_validate = min(n_retrieve, 100) | |
n_validate = st.sidebar.slider( | |
"פסקאות לאימות (GPT-4o)", | |
1, | |
max_validate, | |
min(config.DEFAULT_N_VALIDATE, max_validate), | |
disabled=not openai_ready | |
) | |
st.sidebar.info("התשובות מבוססות רק על המקורות שאומתו.", icon="ℹ️") | |
return {"n_retrieve": n_retrieve, "n_validate": n_validate, "services_ready": (retriever_ready and openai_ready)} | |
def display_chat_message(message: Dict[str, Any]): | |
role = message.get("role", "assistant") | |
with st.chat_message(role): | |
st.markdown(message.get('content', ''), unsafe_allow_html=True) | |
if role == "assistant" and message.get("final_docs"): | |
docs = message["final_docs"] | |
exp_title = f"<span class='rtl-text'>הצג {len(docs)} קטעי מקור שנשלחו למחולל (GPT-4o)</span>" | |
with st.expander(exp_title, expanded=False): | |
st.markdown("<div dir='rtl' class='expander-content'>", unsafe_allow_html=True) | |
for i, doc in enumerate(docs, start=1): | |
if not isinstance(doc, dict): | |
continue | |
source = doc.get('source_name', '') or 'מקור לא ידוע' | |
text = clean_source_text(doc.get('hebrew_text', '')) | |
st.markdown( | |
f"<div class='source-info rtl-text'><strong>מקור {i}:</strong> {source}</div>", | |
unsafe_allow_html=True | |
) | |
st.markdown(f"<div class='hebrew-text'>{text}</div>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
def display_status_updates(status_log: List[str]): | |
if status_log: | |
with st.expander("<span class='rtl-text'>הצג פרטי עיבוד</span>", expanded=False): | |
for u in status_log: | |
st.markdown( | |
f"<code class='status-update rtl-text'>- {u}</code>", | |
unsafe_allow_html=True | |
) | |
# --- Main Application Logic --- | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
rag_params = display_sidebar() | |
# Render history | |
for msg in st.session_state.messages: | |
display_chat_message(msg) | |
if prompt := st.chat_input("שאל שאלה בענייני חסידות...", disabled=not rag_params["services_ready"]): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
display_chat_message(st.session_state.messages[-1]) | |
with st.chat_message("assistant"): | |
msg_placeholder = st.empty() | |
status_container = st.status("מעבד בקשה...", expanded=True) | |
chunks: List[str] = [] | |
try: | |
def status_cb(m): | |
status_container.update(label=f"<span class='rtl-text'>{m}</span>") | |
def stream_cb(c): | |
chunks.append(c) | |
msg_placeholder.markdown( | |
f"<div dir='rtl' class='rtl-text'>{''.join(chunks)}▌</div>", | |
unsafe_allow_html=True | |
) | |
loop = asyncio.get_event_loop() | |
final_rag = loop.run_until_complete( | |
execute_validate_generate_pipeline( | |
history=st.session_state.messages, | |
params=rag_params, | |
status_callback=status_cb, | |
stream_callback=stream_cb | |
) | |
) | |
if isinstance(final_rag, dict): | |
raw = final_rag.get("final_response", "") | |
err = final_rag.get("error") | |
log = final_rag.get("status_log", []) | |
docs = final_rag.get("generator_input_documents", []) | |
pipeline = final_rag.get("pipeline_used", PIPELINE_VALIDATE_GENERATE_GPT4O) | |
# wrap in RTL div if needed | |
final = raw | |
if not (err and final.strip().startswith("<div")) and not final.strip().startswith(( | |
'<div', '<p', '<ul', '<ol', '<strong' | |
)): | |
final = f"<div dir='rtl' class='rtl-text'>{final or 'לא התקבלה תשובה מהמחולל.'}</div>" | |
msg_placeholder.markdown(final, unsafe_allow_html=True) | |
# --- Show only cited paragraphs --- | |
cited_ids = set(re.findall(r'\(מקור\s*([0-9]+)\)', raw)) | |
if cited_ids: | |
enumerated_docs = list(enumerate(docs, start=1)) | |
docs_to_show = [(idx, doc) for idx, doc in enumerated_docs if str(idx) in cited_ids] | |
else: | |
docs_to_show = list(enumerate(docs, start=1)) | |
if docs_to_show: | |
label = f"<span class='rtl-text'>הצג {len(docs_to_show)} קטעי מקור שהוזכרו בתשובה</span>" | |
with st.expander(label, expanded=False): | |
st.markdown("<div dir='rtl' class='expander-content'>", unsafe_allow_html=True) | |
for idx, doc in docs_to_show: | |
source = doc.get('source_name', '') or 'מקור לא ידוע' | |
text = clean_source_text(doc.get('hebrew_text', '')) | |
st.markdown( | |
f"<div class='source-info rtl-text'><strong>מקור {idx}:</strong> {source}</div>", | |
unsafe_allow_html=True | |
) | |
st.markdown(f"<div class='hebrew-text'>{text}</div>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# --- end filter display --- | |
# store assistant message | |
assistant_data = { | |
"role": "assistant", | |
"content": final, | |
"final_docs": docs, | |
"pipeline_used": pipeline, | |
"status_log": log, | |
"error": err | |
} | |
st.session_state.messages.append(assistant_data) | |
display_status_updates(log) | |
if err: | |
status_container.update(label="שגיאה בעיבוד!", state="error", expanded=False) | |
else: | |
status_container.update(label="העיבוד הושלם!", state="complete", expanded=False) | |
else: | |
msg_placeholder.markdown( | |
"<div dir='rtl' class='rtl-text'><strong>שגיאה בלתי צפויה בתקשורת.</strong></div>", | |
unsafe_allow_html=True | |
) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": "שגיאה בלתי צפויה בתקשורת.", | |
"final_docs": [], | |
"pipeline_used": "Error", | |
"status_log": ["Unexpected result"], | |
"error": "Unexpected" | |
}) | |
status_container.update(label="שגיאה בלתי צפויה!", state="error", expanded=False) | |
except Exception as e: | |
traceback.print_exc() | |
err_html = (f"<div dir='rtl' class='rtl-text'><strong>שגיאה קריטית!</strong><br>נסה לרענן." | |
f"<details><summary>פרטים</summary><pre>{traceback.format_exc()}</pre></details></div>") | |
msg_placeholder.error(err_html, icon="🔥") | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": err_html, | |
"final_docs": [], | |
"pipeline_used": "Critical Error", | |
"status_log": [f"Critical: {type(e).__name__}"], | |
"error": str(e) | |
}) | |
status_container.update(label=str(e), state="error", expanded=False) | |