import os import threading import logging import uuid import shutil import json import tempfile import glob from flask import Flask, request as flask_request, make_response import dash from dash import dcc, html, Input, Output, State, callback_context, no_update import dash_bootstrap_components as dbc import openai import base64 import datetime from werkzeug.utils import secure_filename import numpy as np import io import PyPDF2 import docx import openpyxl logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(threadName)s %(message)s") logger = logging.getLogger("AskTricare") app_flask = Flask(__name__) SESSION_DATA = {} SESSION_LOCKS = {} SESSION_DIR_BASE = os.path.join(tempfile.gettempdir(), "asktricare_sessions") os.makedirs(SESSION_DIR_BASE, exist_ok=True) openai.api_key = os.environ.get("OPENAI_API_KEY") EMBEDDING_INDEX = {} EMBEDDING_TEXTS = {} EMBEDDING_MODEL = "text-embedding-ada-002" def get_session_id(): sid = flask_request.cookies.get("asktricare_session_id") if not sid: sid = str(uuid.uuid4()) return sid def get_session_dir(session_id): d = os.path.join(SESSION_DIR_BASE, session_id) os.makedirs(d, exist_ok=True) return d def get_session_lock(session_id): if session_id not in SESSION_LOCKS: SESSION_LOCKS[session_id] = threading.Lock() return SESSION_LOCKS[session_id] def get_session_state(session_id): if session_id not in SESSION_DATA: SESSION_DATA[session_id] = { "messages": [], "uploads": [], "created": datetime.datetime.utcnow().isoformat(), "streaming": False, "stream_buffer": "" } return SESSION_DATA[session_id] def save_session_state(session_id): state = get_session_state(session_id) d = get_session_dir(session_id) with open(os.path.join(d, "state.json"), "w") as f: json.dump(state, f) def load_session_state(session_id): d = get_session_dir(session_id) path = os.path.join(d, "state.json") if os.path.exists(path): with open(path, "r") as f: SESSION_DATA[session_id] = json.load(f) def load_system_prompt(): prompt_path = os.path.join(os.getcwd(), "system_prompt.txt") try: with open(prompt_path, "r", encoding="utf-8") as f: return f.read().strip() except Exception as e: logger.error(f"Failed to load system prompt: {e}") return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so." def embed_docs_folder(): global EMBEDDING_INDEX, EMBEDDING_TEXTS docs_folder = os.path.join(os.getcwd(), "docs") if not os.path.isdir(docs_folder): logger.warning(f"Docs folder '{docs_folder}' does not exist. Skipping embedding.") return doc_files = [] for ext in ("*.txt", "*.md", "*.pdf"): doc_files.extend(glob.glob(os.path.join(docs_folder, ext))) for doc_path in doc_files: fname = os.path.basename(doc_path) if fname in EMBEDDING_INDEX: continue try: with open(doc_path, "r", encoding="utf-8", errors="ignore") as f: text = f.read() if not text.strip(): continue chunk = text[:4000] response = openai.Embedding.create( input=[chunk], model=EMBEDDING_MODEL ) embedding = response['data'][0]['embedding'] EMBEDDING_INDEX[fname] = embedding EMBEDDING_TEXTS[fname] = chunk logger.info(f"Embedded doc: {fname}") except Exception as e: logger.error(f"Embedding failed for {fname}: {e}") embed_docs_folder() def embed_user_doc(session_id, filename, text): session_dir = get_session_dir(session_id) if not text.strip(): return try: chunk = text[:4000] response = openai.Embedding.create( input=[chunk], model=EMBEDDING_MODEL ) embedding = response['data'][0]['embedding'] user_embeds_path = os.path.join(session_dir, "user_embeds.json") if os.path.exists(user_embeds_path): with open(user_embeds_path, "r") as f: user_embeds = json.load(f) else: user_embeds = {"embeddings": [], "texts": [], "filenames": []} user_embeds["embeddings"].append(embedding) user_embeds["texts"].append(chunk) user_embeds["filenames"].append(filename) with open(user_embeds_path, "w") as f: json.dump(user_embeds, f) logger.info(f"Session {session_id}: Embedded user doc {filename}") except Exception as e: logger.error(f"Session {session_id}: Failed to embed user doc {filename}: {e}") def get_user_embeddings(session_id): session_dir = get_session_dir(session_id) user_embeds_path = os.path.join(session_dir, "user_embeds.json") if os.path.exists(user_embeds_path): with open(user_embeds_path, "r") as f: d = json.load(f) embeds = np.array(d.get("embeddings", [])) texts = d.get("texts", []) filenames = d.get("filenames", []) return embeds, texts, filenames return np.array([]), [], [] def semantic_search(query, embed_matrix, texts, filenames, top_k=2): if len(embed_matrix) == 0: return [] try: q_embed = openai.Embedding.create(input=[query], model=EMBEDDING_MODEL)["data"][0]["embedding"] q_embed = np.array(q_embed) embed_matrix = np.array(embed_matrix) scores = np.dot(embed_matrix, q_embed) / (np.linalg.norm(embed_matrix, axis=1) * np.linalg.norm(q_embed) + 1e-8) idx = np.argsort(scores)[::-1][:top_k] results = [] for i in idx: results.append({"filename": filenames[i], "text": texts[i], "score": float(scores[i])}) return results except Exception as e: logger.error(f"Semantic search error: {e}") return [] app = dash.Dash( __name__, server=app_flask, suppress_callback_exceptions=True, external_stylesheets=[dbc.themes.BOOTSTRAP, "/assets/custom.css"], update_title="Ask Tricare" ) def chat_message_card(msg, is_user): align = "flex-end" if is_user else "flex-start" color = "primary" if is_user else "secondary" avatar = "🧑" if is_user else "🤖" return html.Div( dbc.Card( dbc.CardBody([ html.Div([ html.Span(avatar, style={"fontSize": "2rem"}), html.Span(msg, style={"whiteSpace": "pre-wrap", "marginLeft": "0.75rem", "overflowWrap": "break-word", "wordBreak": "break-word"}) ], style={"display": "flex", "alignItems": "center"}) ]), className=f"mb-2 ms-3 me-3", color=color, inverse=is_user, style={"maxWidth": "80%"} ), style={"display": "flex", "justifyContent": align, "width": "100%"} ) def uploaded_file_card(filename, is_img): ext = os.path.splitext(filename)[1].lower() icon = "🖼️" if is_img else "📄" return dbc.Card( dbc.CardBody([ html.Span(icon, style={"fontSize": "2rem", "marginRight": "0.5rem"}), html.Span(filename) ]), className="mb-2", color="tertiary" ) def disclaimer_card(): return dbc.Card( dbc.CardBody([ html.H5("Disclaimer", className="card-title"), html.P("This information is not private. Do not send PII or PHI. For official guidance visit the Tricare website.", style={"fontSize": "0.95rem"}) ]), className="mb-2" ) def left_navbar_static(): return html.Div([ html.H3("Ask Tricare", className="mb-3 mt-3", style={"fontWeight": "bold"}), disclaimer_card(), dcc.Upload( id="file-upload", children=dbc.Button("Upload Document/Image", color="secondary", className="mb-2", style={"width": "100%"}), multiple=True, style={"width": "100%"} ), html.Div(id="upload-list"), html.Hr() ], style={"padding": "1rem", "backgroundColor": "#f8f9fa", "height": "100vh", "overflowY": "auto"}) def chat_box_card(): return dbc.Card( dbc.CardBody([ html.Div( id="chat-window-container", children=[ html.Div(id="chat-window", style={"width": "100%"}) ], style={ "height": "70vh", "overflowY": "auto", "overflowX": "hidden", "backgroundColor": "#fff", "padding": "0.5rem", "borderRadius": "0.5rem" } ) ]), className="mt-3", style={ "height": "72vh", "overflowY": "hidden", "overflowX": "hidden" } ) def user_input_card(): return dbc.Card( dbc.CardBody([ html.Div([ dcc.Textarea( id="user-input", placeholder="Type your question...", style={"width": "100%", "height": "60px", "resize": "vertical", "wordWrap": "break-word"}, wrap="soft", maxLength=1000, n_blur=0, ), dcc.Store(id="enter-triggered", data=False), html.Div([ dbc.Button("Send", id="send-btn", color="primary", className="mt-2 me-2", style={"minWidth": "100px"}), ], style={"float": "right", "display": "flex", "gap": "0.5rem"}), dcc.Store(id="user-input-store", data="", storage_type="session"), html.Button(id='hidden-send', style={'display': 'none'}) ], style={"marginTop": "1rem"}), html.Div(id="error-message", style={"color": "#bb2124", "marginTop": "0.5rem"}), dcc.Store(id="should-clear-input", data=False) ]) ) def right_main_static(): return html.Div([ chat_box_card(), user_input_card(), dcc.Loading(id="loading", type="default", fullscreen=False, style={"position": "absolute", "top": "5%", "left": "50%"}), dcc.Interval(id="stream-interval", interval=400, n_intervals=0, disabled=True, max_intervals=1000), dcc.Store(id="client-question", data="") ], style={"padding": "1rem", "backgroundColor": "#fff", "height": "100vh", "overflowY": "auto"}) app.layout = html.Div([ dcc.Store(id="session-id", storage_type="local"), dcc.Location(id="url"), html.Div([ html.Div(left_navbar_static(), id='left-navbar', style={"width": "30vw", "height": "100vh", "position": "fixed", "left": 0, "top": 0, "zIndex": 2, "overflowY": "auto"}), html.Div(right_main_static(), id='right-main', style={"marginLeft": "30vw", "width": "70vw", "overflowY": "auto"}) ], style={"display": "flex"}), dcc.Store(id="clear-input", data=False), dcc.Store(id="scroll-bottom", data=0), dcc.Store(id="enter-pressed", data=False) ]) app.clientside_callback( """ function(n, value) { var ta = document.getElementById('user-input'); if (!ta) return window.dash_clientside.no_update; if (!window._asktricare_enter_handler) { ta.addEventListener('keydown', function(e) { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); var btn = document.getElementById('hidden-send'); if (btn) btn.click(); } }); window._asktricare_enter_handler = true; } return window.dash_clientside.no_update; } """, Output('enter-pressed', 'data'), Input('user-input', 'n_blur'), State('user-input', 'value') ) # Clientside callback to scroll chat window to bottom when scroll-bottom is incremented app.clientside_callback( """ function(scrollIndex) { var chatContainer = document.getElementById('chat-window-container'); if (chatContainer) { chatContainer.scrollTop = chatContainer.scrollHeight; } return null; } """, Output('clear-input', 'data'), # dummy output Input('scroll-bottom', 'data') ) def _is_supported_doc(filename): ext = os.path.splitext(filename)[1].lower() return ext in [".txt", ".pdf", ".md", ".docx", ".xlsx"] def _extract_text_from_upload(filepath, ext): try: if ext in [".txt", ".md"]: with open(filepath, "r", encoding="utf-8", errors="ignore") as f: text = f.read() return text elif ext == ".pdf": try: text = "" with open(filepath, "rb") as f: reader = PyPDF2.PdfReader(f) for page in reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" return text except Exception as e: logger.error(f"Error reading PDF {filepath}: {e}") return "" elif ext == ".docx": try: doc = docx.Document(filepath) paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] return "\n".join(paragraphs) except Exception as e: logger.error(f"Error reading DOCX {filepath}: {e}") return "" elif ext == ".xlsx": try: wb = openpyxl.load_workbook(filepath, read_only=True, data_only=True) text_rows = [] for ws in wb.worksheets: for row in ws.iter_rows(values_only=True): row_strs = [str(cell) for cell in row if cell is not None] if any(row_strs): text_rows.append("\t".join(row_strs)) return "\n".join(text_rows) except Exception as e: logger.error(f"Error reading XLSX {filepath}: {e}") return "" else: return "" except Exception as e: logger.error(f"Error extracting text from {filepath}: {e}") return "" @app.callback( Output("session-id", "data"), Input("url", "href"), prevent_initial_call=False ) def assign_session_id(_): sid = get_session_id() d = get_session_dir(sid) load_session_state(sid) logger.info(f"Assigned session id: {sid}") return sid @app.callback( Output("upload-list", "children"), Output("chat-window", "children"), Output("error-message", "children"), Output("stream-interval", "disabled"), Output("stream-interval", "n_intervals"), Output("user-input", "value"), Output("scroll-bottom", "data"), Input("session-id", "data"), Input("send-btn", "n_clicks"), Input("file-upload", "contents"), Input("stream-interval", "n_intervals"), Input('hidden-send', 'n_clicks'), State("file-upload", "filename"), State("user-input", "value"), State("scroll-bottom", "data"), prevent_initial_call=False ) def main_callback(session_id, send_clicks, file_contents, stream_n, hidden_send_clicks, file_names, user_input, scroll_bottom): trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else "" session_id = session_id or get_session_id() session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) error = "" start_streaming = False uploads = state.get("uploads", []) file_was_uploaded_and_sent = False file_upload_message = None doc_texts_to_send = [] if trigger == "file-upload" and file_contents and file_names: uploads = [] file_upload_messages = [] if not isinstance(file_contents, list): file_contents = [file_contents] file_names = [file_names] for c, n in zip(file_contents, file_names): header, data = c.split(',', 1) ext = os.path.splitext(n)[1].lower() is_img = ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] fname = secure_filename(f"{datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{n}") session_dir = get_session_dir(session_id) fp = os.path.join(session_dir, fname) with open(fp, "wb") as f: f.write(base64.b64decode(data)) uploads.append({"name": fname, "is_img": is_img, "path": fp}) if _is_supported_doc(n) and not is_img: text = _extract_text_from_upload(fp, ext) if text.strip(): embed_user_doc(session_id, fname, text) logger.info(f"Session {session_id}: Uploaded doc '{n}' embedded for user vector store") preview = text[:1000] file_upload_messages.append({ "role": "user", "content": f"[Document uploaded: {n}]\n{preview if preview.strip() else '[No text extracted]'}" }) doc_texts_to_send.append(text.strip()) else: file_upload_messages.append({ "role": "user", "content": f"[Document uploaded: {n}]\n[No text extracted]" }) elif is_img: file_upload_messages.append({ "role": "user", "content": f"[Image uploaded: {n}]" }) else: file_upload_messages.append({ "role": "user", "content": f"[File uploaded: {n}]" }) state["uploads"].extend(uploads) for msg in file_upload_messages: state["messages"].append(msg) save_session_state(session_id) logger.info(f"Session {session_id}: Uploaded files {[u['name'] for u in uploads]}") if doc_texts_to_send: doc_question = "\n\n".join(doc_texts_to_send) state["messages"].append({"role": "user", "content": doc_question}) state["streaming"] = True state["stream_buffer"] = "" save_session_state(session_id) def run_stream_for_doc(session_id, messages, doc_question): try: system_prompt = load_system_prompt() rag_chunks = [] try: global_embeds = [] global_texts = [] global_fnames = [] for fname, emb in EMBEDDING_INDEX.items(): global_embeds.append(emb) global_texts.append(EMBEDDING_TEXTS[fname]) global_fnames.append(fname) global_rag = semantic_search(doc_question, global_embeds, global_texts, global_fnames, top_k=2) if global_rag: for r in global_rag: rag_chunks.append(f"Global doc [{r['filename']}]:\n{r['text'][:1000]}") user_embeds, user_texts, user_fnames = get_user_embeddings(session_id) user_rag = semantic_search(doc_question, user_embeds, user_texts, user_fnames, top_k=2) if user_rag: for r in user_rag: rag_chunks.append(f"User upload [{r['filename']}]:\n{r['text'][:1000]}") except Exception as e: logger.error(f"Session {session_id}: RAG error (doc upload): {e}") context_block = "" if rag_chunks: context_block = "The following sources may help answer the question:\n\n" + "\n\n".join(rag_chunks) + "\n\n" msg_list = [{"role": "system", "content": system_prompt}] if context_block: msg_list.append({"role": "system", "content": context_block}) for m in messages: msg_list.append({"role": m["role"], "content": m["content"]}) response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=msg_list, max_tokens=700, temperature=0.2, stream=True, ) reply = "" for chunk in response: delta = chunk["choices"][0]["delta"] content = delta.get("content", "") if content: reply += content session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["stream_buffer"] = reply save_session_state(session_id) session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["messages"].append({"role": "assistant", "content": reply}) state["stream_buffer"] = "" state["streaming"] = False save_session_state(session_id) logger.info(f"Session {session_id}: Assistant responded to doc upload") except Exception as e: session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["streaming"] = False state["stream_buffer"] = "" save_session_state(session_id) logger.error(f"Session {session_id}: Streaming error for doc upload: {e}") threading.Thread(target=run_stream_for_doc, args=(session_id, list(state["messages"]), doc_question), daemon=True).start() start_streaming = True chat_history = state.get("messages", []) uploads = state.get("uploads", []) upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in uploads] chat_cards = [] for msg in chat_history: chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) return upload_cards, chat_cards, error, (not state.get("streaming", False)), 0, no_update, scroll_bottom+1 send_triggered = False if trigger == "send-btn" or trigger == "hidden-send": send_triggered = True if send_triggered and user_input and user_input.strip(): question = user_input.strip() state["messages"].append({"role": "user", "content": question}) state["streaming"] = True state["stream_buffer"] = "" save_session_state(session_id) def run_stream(session_id, messages, question): try: system_prompt = load_system_prompt() rag_chunks = [] try: global_embeds = [] global_texts = [] global_fnames = [] for fname, emb in EMBEDDING_INDEX.items(): global_embeds.append(emb) global_texts.append(EMBEDDING_TEXTS[fname]) global_fnames.append(fname) global_rag = semantic_search(question, global_embeds, global_texts, global_fnames, top_k=2) if global_rag: for r in global_rag: rag_chunks.append(f"Global doc [{r['filename']}]:\n{r['text'][:1000]}") user_embeds, user_texts, user_fnames = get_user_embeddings(session_id) user_rag = semantic_search(question, user_embeds, user_texts, user_fnames, top_k=2) if user_rag: for r in user_rag: rag_chunks.append(f"User upload [{r['filename']}]:\n{r['text'][:1000]}") except Exception as e: logger.error(f"Session {session_id}: RAG error: {e}") context_block = "" if rag_chunks: context_block = "The following sources may help answer the question:\n\n" + "\n\n".join(rag_chunks) + "\n\n" msg_list = [{"role": "system", "content": system_prompt}] if context_block: msg_list.append({"role": "system", "content": context_block}) for m in messages: msg_list.append({"role": m["role"], "content": m["content"]}) response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=msg_list, max_tokens=700, temperature=0.2, stream=True, ) reply = "" for chunk in response: delta = chunk["choices"][0]["delta"] content = delta.get("content", "") if content: reply += content session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["stream_buffer"] = reply save_session_state(session_id) session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["messages"].append({"role": "assistant", "content": reply}) state["stream_buffer"] = "" state["streaming"] = False save_session_state(session_id) logger.info(f"Session {session_id}: User: {question} | Assistant: {reply}") except Exception as e: session_lock = get_session_lock(session_id) with session_lock: load_session_state(session_id) state = get_session_state(session_id) state["streaming"] = False state["stream_buffer"] = "" save_session_state(session_id) logger.error(f"Session {session_id}: Streaming error: {e}") threading.Thread(target=run_stream, args=(session_id, list(state["messages"]), question), daemon=True).start() start_streaming = True if trigger == "stream-interval": chat_history = state.get("messages", []) chat_cards = [] for msg in chat_history: chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) if state.get("streaming", False): if state.get("stream_buffer", ""): chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False)) upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])] return ( upload_cards, chat_cards, "", False, stream_n+1, no_update, scroll_bottom+1 ) else: chat_cards = [] for msg in state.get("messages", []): chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])] return ( upload_cards, chat_cards, "", True, 0, no_update, scroll_bottom+1 ) chat_history = state.get("messages", []) uploads = state.get("uploads", []) upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in uploads] chat_cards = [] for msg in chat_history: chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) if trigger == "send-btn" or trigger == "hidden-send": return upload_cards, chat_cards, error, (not state.get("streaming", False)), 0, "", scroll_bottom+1 elif trigger == "file-upload": return upload_cards, chat_cards, error, (not state.get("streaming", False)), 0, no_update, scroll_bottom+1 else: return upload_cards, chat_cards, error, (not state.get("streaming", False)), 0, no_update, scroll_bottom @app_flask.after_request def set_session_cookie(resp): sid = flask_request.cookies.get("asktricare_session_id") if not sid: sid = str(uuid.uuid4()) resp.set_cookie("asktricare_session_id", sid, max_age=60*60*24*7, path="/") return resp def cleanup_sessions(max_age_hours=48): now = datetime.datetime.utcnow() for sid in os.listdir(SESSION_DIR_BASE): d = os.path.join(SESSION_DIR_BASE, sid) try: state_path = os.path.join(d, "state.json") if os.path.exists(state_path): with open(state_path, "r") as f: st = json.load(f) created = st.get("created") if created and (now - datetime.datetime.fromisoformat(created)).total_seconds() > max_age_hours * 3600: shutil.rmtree(d) logger.info(f"Cleaned up session {sid}") except Exception as e: logger.error(f"Cleanup error for {sid}: {e}") try: import torch if torch.cuda.is_available(): torch.set_default_tensor_type(torch.cuda.FloatTensor) logger.info("CUDA GPU detected and configured.") except Exception as e: logger.warning(f"CUDA config failed: {e}") if __name__ == '__main__': print("Starting the Dash application...") app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) print("Dash application has finished running.")