07Codex07 commited on
Commit
30f67dd
·
1 Parent(s): 463a7b5

Initial PrepGraph backend

Browse files
Files changed (5) hide show
  1. chatbot_graph.py +210 -0
  2. chatbot_retriever.py +417 -0
  3. main_api.py +309 -0
  4. memory_store.py +110 -0
  5. requirements.txt +10 -0
chatbot_graph.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatbot_graph.py
2
+ import os
3
+ from dotenv import load_dotenv
4
+ import gradio as gr
5
+ import logging
6
+ from typing import List
7
+
8
+ load_dotenv()
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+
13
+ # LLM client (Groq wrapper)
14
+ try:
15
+ from langchain_groq import ChatGroq
16
+ except Exception:
17
+ ChatGroq = None
18
+ logger.warning("langchain_groq.ChatGroq not importable. Ensure langchain-groq is installed in requirements.")
19
+
20
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
21
+
22
+ from chatbot_retriever import retrieve_node_from_rows
23
+ from memory_store import init_db, save_message, get_last_messages, build_gradio_history
24
+
25
+ # initialize DB early
26
+ init_db()
27
+
28
+ # Instantiate Groq LLM (will require GROQ_API_KEY in env)
29
+ GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
30
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
31
+ GROQ_TEMP = float(os.getenv("GROQ_TEMP", "0.2"))
32
+
33
+ if ChatGroq:
34
+ llm = ChatGroq(model=GROQ_MODEL, api_key=GROQ_API_KEY, temperature=GROQ_TEMP)
35
+ else:
36
+ llm = None
37
+
38
+
39
+ def _extract_answer_from_response(response):
40
+ # robust extraction similar to your previous helper - simplified
41
+ try:
42
+ if hasattr(response, "content"):
43
+ c = response.content
44
+ if isinstance(c, str) and c.strip():
45
+ return c.strip()
46
+ if isinstance(c, (list, tuple)):
47
+ parts = [str(x) for x in c if x is not None]
48
+ if parts:
49
+ return "".join(parts).strip()
50
+ if isinstance(c, dict):
51
+ for key in ("answer", "text", "content", "output_text", "generated_text"):
52
+ v = c.get(key)
53
+ if v:
54
+ if isinstance(v, (list, tuple)):
55
+ return "".join([str(x) for x in v]).strip()
56
+ return str(v).strip()
57
+ if isinstance(response, dict):
58
+ for key in ("answer", "text", "content"):
59
+ v = response.get(key)
60
+ if v:
61
+ return str(v)
62
+ choices = response.get("choices") or response.get("outputs")
63
+ if isinstance(choices, (list, tuple)) and choices:
64
+ first = choices[0]
65
+ if isinstance(first, dict):
66
+ msg = first.get("message") or first.get("text") or first.get("content")
67
+ if msg:
68
+ if isinstance(msg, (list, tuple)):
69
+ return "".join([str(x) for x in msg])
70
+ return str(msg)
71
+ if hasattr(response, "generations"):
72
+ gens = getattr(response, "generations")
73
+ if gens:
74
+ for outer in gens:
75
+ for g in outer:
76
+ if hasattr(g, "text") and g.text:
77
+ return str(g.text)
78
+ if hasattr(g, "message") and getattr(g.message, "content", None):
79
+ return str(g.message.content)
80
+ s = str(response)
81
+ if s and s.strip():
82
+ return s.strip()
83
+ except Exception:
84
+ logger.exception("Failed extracting answer")
85
+ return None
86
+
87
+
88
+ SYSTEM_PROMPT = (
89
+ "You are PrepGraph — an accurate, concise AI tutor specialized in academic and technical content.\n"
90
+ "Rules:\n"
91
+ "1) Always prioritize answering the CURRENT user question directly and clearly.\n"
92
+ "2) Refer to provided CONTEXT (delimited below) if relevant. Cite which doc (filename) or say 'from provided context' when applicable.\n"
93
+ "3) If the current query is unclear, use ONLY the immediate previous user question to infer intent — not older ones.\n"
94
+ "4) Provide step-by-step explanations when appropriate, using short, structured points.\n"
95
+ "5) Include ASCII diagrams or flowcharts if they help understanding (e.g., for protocols, layers, architectures, etc.).\n"
96
+ "6) If the context is insufficient or ambiguous, clearly say 'I’m unsure' and specify what extra information is needed.\n"
97
+ "7) Avoid repetition, speculation, and hallucination — answer precisely what is asked.\n\n"
98
+ "CONTEXT:\n"
99
+ )
100
+
101
+ # ---- helper: call the LLM with a list of messages (SystemMessage + HumanMessage...) ----
102
+ def call_llm(messages: List):
103
+ if not llm:
104
+ raise RuntimeError("LLM client (ChatGroq) not configured or import failed. Set up langchain_groq and GROQ_API_KEY.")
105
+ # many wrappers accept the langchain message objects; keep using llm.invoke
106
+ response = llm.invoke(messages)
107
+ return response
108
+
109
+
110
+ # ---- Gradio UI functions ----
111
+ def load_history(user_id: str):
112
+ uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
113
+ try:
114
+ hist = build_gradio_history(uid)
115
+ logger.info("Loaded %d messages for user %s", len(hist), uid)
116
+ return hist
117
+ except Exception:
118
+ logger.exception("Failed to load history for %s", uid)
119
+ return []
120
+
121
+
122
+ def chat_interface(user_input: str, chat_state: List[dict], user_id: str):
123
+ """
124
+ Receives user_input (string), chat_state (list of {'role':..., 'content':...}),
125
+ user_id (string). Returns: (clear_input_str, new_chat_state)
126
+ """
127
+ uid = (user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
128
+ history = chat_state or []
129
+
130
+ # Save user's message immediately
131
+ try:
132
+ save_message(uid, "user", user_input)
133
+ except Exception:
134
+ logger.exception("Failed to persist user message")
135
+
136
+ # Build rows to pass to retriever: get last messages from DB (ensures persistence)
137
+ rows = get_last_messages(uid, limit=200) # chronological order
138
+
139
+ # Retrieve context using hybrid retriever (uses last 3 user messages internally)
140
+ try:
141
+ retrieved = retrieve_node_from_rows(rows)
142
+ context = retrieved.get("context")
143
+ except Exception:
144
+ logger.exception("Retriever failed")
145
+ context = None
146
+
147
+ # Build prompt: SystemMessage + last 3 user messages (HumanMessage)
148
+ prompt_msgs = []
149
+ system_content = SYSTEM_PROMPT + (context or "No context found.")
150
+ prompt_msgs.append(SystemMessage(content=system_content))
151
+
152
+ # collect last 3 user messages (from rows)
153
+ last_users = [r[1] for r in rows if r[0] == "user"][-3:]
154
+ if not last_users:
155
+ # fallback to current input if DB empty
156
+ last_users = [user_input]
157
+ # append each of the last user messages as HumanMessage (preserves order)
158
+ for u in last_users:
159
+ prompt_msgs.append(HumanMessage(content=u))
160
+
161
+ # send to LLM
162
+ try:
163
+ raw = call_llm(prompt_msgs)
164
+ answer = _extract_answer_from_response(raw) or ""
165
+ except Exception as e:
166
+ logger.exception("LLM call failed")
167
+ answer = f"Sorry — I couldn't process that right now ({e})."
168
+
169
+ # persist assistant reply
170
+ try:
171
+ save_message(uid, "assistant", answer)
172
+ except Exception:
173
+ logger.exception("Failed to persist assistant message")
174
+
175
+ # update gradio chat state: append current user and assistant
176
+ history = history or load_history(uid) # in case front-end was empty, rehydrate
177
+ history.append({"role": "user", "content": user_input})
178
+ history.append({"role": "assistant", "content": answer})
179
+
180
+ # return: clear the input box (""), updated history for gr.Chatbot(type="messages")
181
+ return "", history
182
+
183
+
184
+ # ---- Minimal / attractive Gradio UI ----
185
+ with gr.Blocks(css=".gradio-container {max-width:900px; margin:0 auto;}") as demo:
186
+ gr.Markdown("# 🤖 PrepGraph — RAG Tutor")
187
+ with gr.Row():
188
+ user_id_input = gr.Textbox(label="User ID (will be used to persist your memory)", value=os.getenv("DEFAULT_USER", "vinayak"))
189
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
190
+
191
+ with gr.Row():
192
+ msg = gr.Textbox(placeholder="Ask anything about your course material...", show_label=False)
193
+ send = gr.Button("Send")
194
+
195
+ with gr.Row():
196
+ clear_ui = gr.Button("Clear Chat")
197
+
198
+ # Load history at page load (and when user_id changes)
199
+ demo.load(load_history, [user_id_input], [chatbot])
200
+ user_id_input.change(load_history, [user_id_input], [chatbot])
201
+
202
+ # Bind send
203
+ msg.submit(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])
204
+ send.click(chat_interface, [msg, chatbot, user_id_input], [msg, chatbot])
205
+
206
+ # just clears the UI, not the DB
207
+ clear_ui.click(lambda: [], None, chatbot)
208
+
209
+ if __name__ == "__main__":
210
+ demo.launch()
chatbot_retriever.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatbot_retriever.py
2
+ """
3
+ Hybrid retriever:
4
+ - loads PDFs & PPTX (robust imports)
5
+ - chunks via RecursiveCharacterTextSplitter
6
+ - BM25 (rank_bm25) + FAISS (IVF when possible) using SentenceTransformers
7
+ - returns a combined context string limited by MAX_CONTEXT_CHARS
8
+ """
9
+
10
+ import os
11
+ import re
12
+ import pickle
13
+ import logging
14
+ import shutil
15
+ import random
16
+ from typing import List, Optional, Dict, Any
17
+
18
+ import numpy as np
19
+ import faiss
20
+
21
+ from rank_bm25 import BM25Okapi
22
+ from langchain_community.document_loaders import UnstructuredFileLoader
23
+
24
+
25
+ # Document loaders: try langchain first, then community loader
26
+ try:
27
+ from langchain.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader
28
+ except Exception:
29
+ # fallback to community package (older installations)
30
+ try:
31
+ from langchain_community.document_loaders import PyPDFLoader, UnstructuredPowerPointLoader
32
+ from langchain_community.document_loaders.powerpoint import UnstructuredPowerPointLoader
33
+ except Exception:
34
+ raise ImportError("Please install langchain + langchain-community (or upgrade).")
35
+
36
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
37
+ from sentence_transformers import SentenceTransformer
38
+
39
+ # ---------- Config ----------
40
+ DATA_DIR = os.getenv("DATA_DIR", "data")
41
+ CACHE_DIR = os.getenv("CACHE_DIR", ".ragg_cache")
42
+ CHUNKS_CACHE = os.path.join(CACHE_DIR, "chunks.pkl")
43
+ BM25_CACHE = os.path.join(CACHE_DIR, "bm25.pkl")
44
+
45
+ FAISS_DIR = os.getenv("FAISS_DIR", "faiss_index")
46
+ FAISS_INDEX_PATH = os.path.join(FAISS_DIR, "index.faiss")
47
+ FAISS_META_PATH = os.path.join(FAISS_DIR, "meta.pkl")
48
+
49
+ os.makedirs(CACHE_DIR, exist_ok=True)
50
+ os.makedirs(FAISS_DIR, exist_ok=True)
51
+
52
+ CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", 400))
53
+ CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", 80))
54
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2")
55
+
56
+ TOP_K_DOCS = int(os.getenv("TOP_K_DOCS", 3))
57
+ MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", 4000))
58
+
59
+ # FAISS params
60
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", 256))
61
+ FAISS_NLIST = int(os.getenv("FAISS_NLIST", 100))
62
+ FAISS_TRAIN_SIZE = int(os.getenv("FAISS_TRAIN_SIZE", 2000))
63
+ FAISS_NPROBE = int(os.getenv("FAISS_NPROBE", 10))
64
+ SEARCH_EXPANSION = int(os.getenv("FAISS_SEARCH_EXPANSION", 5))
65
+
66
+ logger = logging.getLogger(__name__)
67
+ logger.setLevel(logging.INFO)
68
+
69
+
70
+ def detect_subject(fname: str) -> Optional[str]:
71
+ # light heuristic to guess subject code from filename
72
+ t = (fname or "").lower()
73
+ if "network" in t or "cn" in t:
74
+ return "cn"
75
+ if "distributed" in t or "dos" in t:
76
+ return "dos"
77
+ if "software" in t or "se" in t:
78
+ return "se"
79
+ return None
80
+
81
+
82
+ def extract_year(s: str) -> Optional[str]:
83
+ m = re.search(r"\b(20\d{2})\b", s)
84
+ return m.group(1) if m else None
85
+
86
+
87
+ # ---------- Embeddings wrapper (SentenceTransformers) ----------
88
+ class Embeddings:
89
+ def __init__(self, model_name=EMBED_MODEL):
90
+ self.model_name = model_name
91
+ self.model = SentenceTransformer(model_name)
92
+
93
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
94
+ vecs = self.model.encode(texts, show_progress_bar=False, convert_to_numpy=True)
95
+ return [v.astype("float32") for v in vecs]
96
+
97
+ def embed_query(self, text: str) -> List[float]:
98
+ v = self.model.encode([text], show_progress_bar=False, convert_to_numpy=True)[0]
99
+ return v.astype("float32")
100
+
101
+
102
+ # ---------- Load documents ----------
103
+ def load_all_docs(base_dir: str = DATA_DIR) -> List:
104
+ docs = []
105
+ if not os.path.isdir(base_dir):
106
+ logger.warning("Data dir does not exist: %s", base_dir)
107
+ return docs
108
+
109
+ def load_file(path: str, filename: str, category: str):
110
+ try:
111
+ fname = filename.lower()
112
+ if fname.endswith(".pdf"):
113
+ loader = PyPDFLoader(path)
114
+ elif fname.endswith(".pptx"):
115
+ loader = UnstructuredPowerPointLoader(path)
116
+ else:
117
+ return []
118
+ file_docs = loader.load()
119
+ subject = detect_subject(fname)
120
+ year = extract_year(fname)
121
+ for d in file_docs:
122
+ d.metadata["subject"] = subject
123
+ d.metadata["filename"] = filename
124
+ d.metadata["category"] = category
125
+ if year:
126
+ d.metadata["year"] = year
127
+ return file_docs
128
+ except Exception:
129
+ logger.exception("Failed to load %s", filename)
130
+ return []
131
+
132
+ # root files
133
+ for file in os.listdir(base_dir):
134
+ path = os.path.join(base_dir, file)
135
+ if os.path.isfile(path) and (file.lower().endswith(".pdf") or file.lower().endswith(".pptx")):
136
+ docs.extend(load_file(path, file, "syllabus"))
137
+
138
+ # optional pyqs directory
139
+ pyqs_dir = os.path.join(base_dir, "pyqs")
140
+ if os.path.isdir(pyqs_dir):
141
+ for file in os.listdir(pyqs_dir):
142
+ path = os.path.join(pyqs_dir, file)
143
+ if os.path.isfile(path) and file.lower().endswith(".pdf"):
144
+ docs.extend(load_file(path, file, "pyq"))
145
+
146
+ logger.info("Loaded %d raw document pages", len(docs))
147
+ return docs
148
+
149
+
150
+ # ---------- Build / load FAISS + BM25 ----------
151
+ def build_or_load_indexes(force_reindex: bool = False):
152
+ if os.getenv("FORCE_REINDEX", "0").lower() in ("1", "true", "yes"):
153
+ force_reindex = True
154
+
155
+ docs = load_all_docs(DATA_DIR)
156
+ if not docs:
157
+ logger.warning("No documents found. Returning empty indexes.")
158
+ return [], None, [], [], None
159
+
160
+ # chunking
161
+ if os.path.exists(CHUNKS_CACHE) and not force_reindex:
162
+ with open(CHUNKS_CACHE, "rb") as f:
163
+ chunks = pickle.load(f)
164
+ logger.info("Loaded %d chunks from cache.", len(chunks))
165
+ else:
166
+ splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
167
+ chunks = splitter.split_documents(docs)
168
+ with open(CHUNKS_CACHE, "wb") as f:
169
+ pickle.dump(chunks, f)
170
+ logger.info("Created and cached %d chunks.", len(chunks))
171
+
172
+ corpus_texts = [c.page_content for c in chunks]
173
+
174
+ # BM25
175
+ if os.path.exists(BM25_CACHE) and not force_reindex:
176
+ try:
177
+ with open(BM25_CACHE, "rb") as f:
178
+ bm25_data = pickle.load(f)
179
+ bm25 = bm25_data.get("bm25")
180
+ tokenized = bm25_data.get("tokenized", [])
181
+ logger.info("Loaded BM25 from cache (n=%d)", len(corpus_texts))
182
+ except Exception:
183
+ logger.exception("Failed to load BM25 cache — rebuilding")
184
+ tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts]
185
+ bm25 = BM25Okapi(tokenized)
186
+ with open(BM25_CACHE, "wb") as f:
187
+ pickle.dump({"bm25": bm25, "tokenized": tokenized}, f)
188
+ else:
189
+ tokenized = [re.findall(r"\w+", t.lower()) for t in corpus_texts]
190
+ bm25 = BM25Okapi(tokenized)
191
+ try:
192
+ with open(BM25_CACHE, "wb") as f:
193
+ pickle.dump({"bm25": bm25, "tokenized": tokenized}, f)
194
+ except Exception:
195
+ logger.warning("Could not write BM25 cache")
196
+
197
+ # Embeddings
198
+ embeddings = Embeddings()
199
+
200
+ metadatas = [c.metadata for c in chunks]
201
+
202
+ # load existing faiss index
203
+ if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(FAISS_META_PATH) and not force_reindex:
204
+ try:
205
+ index = faiss.read_index(FAISS_INDEX_PATH)
206
+ with open(FAISS_META_PATH, "rb") as f:
207
+ meta = pickle.load(f)
208
+ texts = meta.get("texts", corpus_texts)
209
+ try:
210
+ index.nprobe = FAISS_NPROBE
211
+ except Exception:
212
+ pass
213
+ logger.info("Loaded FAISS index from disk (%s), entries=%d", FAISS_INDEX_PATH, len(texts))
214
+ return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": texts, "metadatas": metadatas, "embeddings": embeddings}
215
+ except Exception:
216
+ logger.exception("Failed to load FAISS index; rebuilding")
217
+
218
+ # force reindex cleanup
219
+ if force_reindex:
220
+ try:
221
+ shutil.rmtree(FAISS_DIR, ignore_errors=True)
222
+ os.makedirs(FAISS_DIR, exist_ok=True)
223
+ except Exception:
224
+ pass
225
+
226
+ # Build FAISS (memory-aware, batch)
227
+ logger.info("Building FAISS index (nlist=%d). This may take a while...", FAISS_NLIST)
228
+ total = len(corpus_texts)
229
+ sample_size = min(total, FAISS_TRAIN_SIZE)
230
+ sample_indices = random.sample(range(total), sample_size) if sample_size < total else list(range(total))
231
+
232
+ sample_embs = []
233
+ for i in range(0, len(sample_indices), BATCH_SIZE):
234
+ batch_idx = sample_indices[i:i + BATCH_SIZE]
235
+ batch_texts = [corpus_texts[j] for j in batch_idx]
236
+ try:
237
+ batch_vecs = embeddings.embed_documents(batch_texts)
238
+ except Exception:
239
+ batch_vecs = [embeddings.embed_query(t) for t in batch_texts]
240
+ sample_embs.extend(batch_vecs)
241
+
242
+ sample_np = np.array(sample_embs, dtype="float32")
243
+ if sample_np.ndim == 1:
244
+ sample_np = sample_np.reshape(1, -1)
245
+ d = sample_np.shape[1]
246
+ n_train_samples = sample_np.shape[0]
247
+
248
+ use_ivf = True
249
+ if n_train_samples < FAISS_NLIST:
250
+ logger.warning("Not enough training samples (%d) for FAISS_NLIST=%d — using Flat index", n_train_samples, FAISS_NLIST)
251
+ use_ivf = False
252
+
253
+ try:
254
+ if use_ivf:
255
+ index_desc = f"IVF{FAISS_NLIST},Flat"
256
+ index = faiss.index_factory(d, index_desc, faiss.METRIC_L2)
257
+ if not index.is_trained:
258
+ try:
259
+ index.train(sample_np)
260
+ logger.info("Trained IVF on %d samples", n_train_samples)
261
+ except Exception:
262
+ logger.exception("IVF training failed — falling back to Flat")
263
+ index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
264
+ else:
265
+ index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
266
+ except Exception:
267
+ logger.exception("Failed to create FAISS index — using Flat")
268
+ index = faiss.index_factory(d, "Flat", faiss.METRIC_L2)
269
+
270
+ # add vectors in batches
271
+ added = 0
272
+ for i in range(0, total, BATCH_SIZE):
273
+ batch_texts = corpus_texts[i:i + BATCH_SIZE]
274
+ try:
275
+ batch_vecs = embeddings.embed_documents(batch_texts)
276
+ except Exception:
277
+ batch_vecs = [embeddings.embed_query(t) for t in batch_texts]
278
+ batch_np = np.array(batch_vecs, dtype="float32")
279
+ if batch_np.ndim == 1:
280
+ batch_np = batch_np.reshape(1, -1)
281
+ index.add(batch_np)
282
+ added += batch_np.shape[0]
283
+ logger.info("FAISS: added %d / %d vectors", added, total)
284
+
285
+ try:
286
+ index.nprobe = FAISS_NPROBE
287
+ except Exception:
288
+ pass
289
+
290
+ try:
291
+ faiss.write_index(index, FAISS_INDEX_PATH)
292
+ with open(FAISS_META_PATH, "wb") as f:
293
+ pickle.dump({"texts": corpus_texts}, f)
294
+ logger.info("FAISS index saved to %s (entries=%d)", FAISS_INDEX_PATH, total)
295
+ except Exception:
296
+ logger.exception("Failed to persist FAISS index on disk")
297
+
298
+ return chunks, bm25, tokenized, corpus_texts, {"index": index, "texts": corpus_texts, "metadatas": metadatas, "embeddings": embeddings}
299
+
300
+
301
+ # ---------- Hybrid retrieve ----------
302
+ def _ensure_index_built():
303
+ if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built:
304
+ hybrid_retrieve._chunks, hybrid_retrieve._bm25, hybrid_retrieve._tokenized, hybrid_retrieve._corpus, hybrid_retrieve._faiss = build_or_load_indexes()
305
+ hybrid_retrieve._index_built = True
306
+
307
+
308
+ def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] = None):
309
+ faiss_data = hybrid_retrieve._faiss
310
+ if not faiss_data:
311
+ return []
312
+
313
+ index = faiss_data.get("index")
314
+ texts = faiss_data.get("texts", [])
315
+ metadatas = faiss_data.get("metadatas", [{}] * len(texts))
316
+ embeddings = faiss_data.get("embeddings")
317
+
318
+ try:
319
+ q_vec = embeddings.embed_query(query)
320
+ except Exception:
321
+ q_vec = embeddings.embed_documents([query])[0]
322
+
323
+ q_np = np.array(q_vec, dtype="float32").reshape(1, -1)
324
+ search_k = max(top_k * SEARCH_EXPANSION, top_k)
325
+ try:
326
+ distances, indices = index.search(q_np, int(search_k))
327
+ except Exception:
328
+ distances, indices = index.search(q_np, int(top_k))
329
+
330
+ results = []
331
+ for dist, idx in zip(distances[0], indices[0]):
332
+ if idx < 0 or idx >= len(texts):
333
+ continue
334
+ meta = metadatas[idx]
335
+ if subject and meta.get("subject") != subject:
336
+ continue
337
+ score_like = float(-dist)
338
+ results.append((score_like, meta, texts[idx]))
339
+ if len(results) >= top_k:
340
+ break
341
+
342
+ return results
343
+
344
+
345
+ def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_K_DOCS, max_chars: int = MAX_CONTEXT_CHARS) -> Dict[str, Any]:
346
+ if not query:
347
+ return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
348
+
349
+ _ensure_index_built()
350
+
351
+ chunks = hybrid_retrieve._chunks
352
+ bm25 = hybrid_retrieve._bm25
353
+
354
+ # BM25
355
+ results_bm25 = []
356
+ try:
357
+ if bm25:
358
+ q_tokens = re.findall(r"\w+", query.lower())
359
+ scores = bm25.get_scores(q_tokens)
360
+ ranked_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
361
+ for i in ranked_idx:
362
+ results_bm25.append((float(scores[i]), chunks[i].metadata, chunks[i].page_content))
363
+ except Exception:
364
+ logger.exception("BM25 search failed")
365
+
366
+ # FAISS
367
+ results_faiss = []
368
+ try:
369
+ results_faiss = _faiss_search(query, top_k=top_k, subject=subject)
370
+ except Exception:
371
+ logger.exception("FAISS search failed")
372
+
373
+ # Merge and dedupe by text
374
+ merged_texts = []
375
+ merged_meta = []
376
+ for score, meta, text in results_bm25:
377
+ if text and text.strip() and text not in merged_texts:
378
+ merged_texts.append(text)
379
+ merged_meta.append({"source": meta.get("filename"), "subject": meta.get("subject"), "score": score})
380
+ for score, meta, text in results_faiss:
381
+ if text and text.strip() and text not in merged_texts:
382
+ merged_texts.append(text)
383
+ merged_meta.append({"source": meta.get("filename") if isinstance(meta, dict) else None, "subject": meta.get("subject") if isinstance(meta, dict) else None, "score": score})
384
+
385
+ # compose context parts with headers
386
+ context_parts = []
387
+ for i, t in enumerate(merged_texts):
388
+ header = f"\n\n===== DOC {i+1} =====\n"
389
+ context_parts.append(header + t)
390
+ context = "\n".join(context_parts).strip()
391
+ if not context:
392
+ return {"context": None, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
393
+
394
+ if len(context) > max_chars:
395
+ context = context[:max_chars].rstrip() + "..."
396
+
397
+ return {"context": context, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
398
+
399
+
400
+ # ---------- retrieve_node (for reuse) ----------
401
+ def _last_n_user_messages(rows: List[tuple], n: int = 3) -> List[str]:
402
+ """Return only the latest user message for retrieval context."""
403
+ users = [r[1] for r in rows if r[0] == "user"]
404
+ return users[-n:] # only keep the last one
405
+
406
+ def retrieve_node_from_rows(rows: List[tuple], top_k: int = TOP_K_DOCS) -> Dict[str, Any]:
407
+ last_users = _last_n_user_messages(rows, n=3)
408
+ current_query = " ".join(last_users).strip() if last_users else ""
409
+ if not current_query:
410
+ return {"context": None, "direct": False}
411
+ detected = None
412
+ try:
413
+ detected = detect_subject(current_query)
414
+ except Exception:
415
+ detected = None
416
+ result = hybrid_retrieve(current_query, subject=detected, top_k=top_k, max_chars=MAX_CONTEXT_CHARS)
417
+ return {"context": result.get("context"), "direct": False}
main_api.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main_api.py
2
+ import os
3
+ import logging
4
+ import traceback
5
+ from typing import Optional, List, Dict, Any
6
+ import tiktoken
7
+
8
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form
9
+ from fastapi.responses import JSONResponse, FileResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from pydantic import BaseModel
12
+ import uvicorn
13
+
14
+ # import your existing modules (assumed in same directory)
15
+ from memory_store import init_db, save_message, get_last_messages, clear_user_memory, build_gradio_history # :contentReference[oaicite:4]{index=4}
16
+ from chatbot_retriever import build_or_load_indexes, hybrid_retrieve, retrieve_node_from_rows, load_all_docs # :contentReference[oaicite:5]{index=5}
17
+ from chatbot_graph import SYSTEM_PROMPT, call_llm, _extract_answer_from_response # :contentReference[oaicite:6]{index=6}
18
+
19
+ # ----------------- CORS SETUP -----------------
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+
22
+ app = FastAPI(title="RAG Chat Backend", version="1.0")
23
+
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=[
29
+ "http://localhost:5173",
30
+ "http://127.0.0.1:5173",
31
+ ],
32
+ allow_credentials=True,
33
+ allow_methods=["*"], # ✅ lowercase 'allow_'
34
+ allow_headers=["*"], # ✅ lowercase 'allow_'
35
+ )
36
+ # ------------------------------------------------
37
+
38
+ from dotenv import load_dotenv
39
+ load_dotenv()
40
+
41
+ logger = logging.getLogger("rag_api")
42
+ logging.basicConfig(level=logging.INFO)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # initialize DB now
46
+ init_db()
47
+
48
+ # Global in-memory flag/object to check indexes loaded (populated by build_or_load_indexes)
49
+ INDEXES = {"built": False, "info": None}
50
+
51
+
52
+ # ---------- Pydantic models ----------
53
+ class ChatRequest(BaseModel):
54
+ user_id: Optional[str] = None
55
+ message: str
56
+
57
+
58
+ class ChatResponse(BaseModel):
59
+ user_id: str
60
+ message: str
61
+ assistant: str
62
+ history: List[Dict[str, str]]
63
+
64
+
65
+ class RetrieveResponse(BaseModel):
66
+ query: str
67
+ context: Optional[str]
68
+ meta: List[Dict[str, Any]]
69
+
70
+
71
+ # ---------- helpers ----------
72
+ def ensure_indexes(force_reindex: bool = False):
73
+ """
74
+ Build or load indexes synchronously. This wraps build_or_load_indexes from chatbot_retriever.
75
+ """
76
+ if INDEXES["built"] and not force_reindex:
77
+ return INDEXES["info"]
78
+ try:
79
+ chunks, bm25, tokenized, corpus_texts, faiss_data = build_or_load_indexes(force_reindex=force_reindex)
80
+ INDEXES["built"] = True
81
+ INDEXES["info"] = {"chunks_len": len(chunks) if chunks else 0, "corpus_len": len(corpus_texts) if corpus_texts else 0}
82
+ return INDEXES["info"]
83
+ except Exception:
84
+ logger.exception("Index build/load failed")
85
+ raise
86
+
87
+ # ===== Token limiter helper =====
88
+ enc = tiktoken.get_encoding("cl100k_base")
89
+
90
+ def trim_to_token_limit(texts, limit=4000):
91
+ """Join text chunks until token limit is reached."""
92
+ joined = ""
93
+ for t in texts:
94
+ if len(enc.encode(joined + t)) > limit:
95
+ break
96
+ joined += t + "\n"
97
+ return joined
98
+
99
+ def extract_history_for_frontend(user_id: str, limit: int = 500):
100
+ return build_gradio_history(user_id)
101
+
102
+
103
+ # ---------- Routes ----------
104
+ @app.get("/health")
105
+ def health():
106
+ """Basic health check."""
107
+ return {"status": "ok", "indexes_built": INDEXES["built"]}
108
+
109
+
110
+ @app.post("/reindex")
111
+ def reindex(force: Optional[bool] = False):
112
+ """
113
+ Force rebuild of indexes. This calls the same build_or_load_indexes used by your retriever module.
114
+ Use ?force=true to force.
115
+ """
116
+ try:
117
+ info = ensure_indexes(force_reindex=bool(force))
118
+ return {"status": "ok", "info": info}
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=f"Failed to build indexes: {e}")
121
+
122
+
123
+ @app.post("/upload")
124
+ async def upload_file(file: UploadFile = File(...), category: Optional[str] = Form("syllabus")):
125
+ """
126
+ Upload PDF/PPTX into DATA_DIR (same dir used by chatbot_retriever.load_all_docs).
127
+ After upload you may call /reindex to include the file.
128
+ """
129
+ from chatbot_retriever import DATA_DIR # keep using same constant
130
+ os.makedirs(DATA_DIR, exist_ok=True)
131
+ dest_path = os.path.join(DATA_DIR, file.filename)
132
+ try:
133
+ with open(dest_path, "wb") as f:
134
+ content = await file.read()
135
+ f.write(content)
136
+ return {"status": "ok", "filename": file.filename, "saved_to": dest_path}
137
+ except Exception as e:
138
+ logger.exception("upload failed")
139
+ raise HTTPException(status_code=500, detail=str(e))
140
+
141
+
142
+ @app.get("/docs_list")
143
+ def docs_list():
144
+ """List files in DATA_DIR (documents available to retriever)."""
145
+ from chatbot_retriever import DATA_DIR
146
+ if not os.path.isdir(DATA_DIR):
147
+ return {"files": []}
148
+ files = [f for f in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, f))]
149
+ return {"files": files}
150
+
151
+
152
+ @app.get("/retrieve", response_model=RetrieveResponse)
153
+ def retrieve(query: str, subject: Optional[str] = None, top_k: Optional[int] = None):
154
+ """
155
+ Directly call the hybrid retriever for a query. Returns context + meta.
156
+ """
157
+ try:
158
+ # ensure indexes built (but don't force)
159
+ ensure_indexes(force_reindex=False)
160
+ res = hybrid_retrieve(query=query, subject=subject, top_k=(top_k or None))
161
+ return {"query": query, "context": res.get("context"), "meta": res.get("meta", [])}
162
+ except Exception as e:
163
+ logger.exception("retrieve failed")
164
+ raise HTTPException(status_code=500, detail=str(e))
165
+
166
+
167
+ @app.get("/history/{user_id}")
168
+ def get_history(user_id: str, limit: Optional[int] = 500):
169
+ """Return persisted history for a user (in same format your frontend expects)."""
170
+ try:
171
+ hist = extract_history_for_frontend(user_id)
172
+ if limit:
173
+ hist = hist[-int(limit):]
174
+ return {"user_id": user_id, "history": hist}
175
+ except Exception as e:
176
+ logger.exception("history fetch failed")
177
+ raise HTTPException(status_code=500, detail=str(e))
178
+
179
+
180
+ @app.post("/memory/clear")
181
+ def clear_memory(user_id: str):
182
+ """Clear stored memory for user."""
183
+ try:
184
+ deleted = clear_user_memory(user_id)
185
+ return {"status": "ok", "deleted_rows": deleted}
186
+ except Exception as e:
187
+ logger.exception("clear failed")
188
+ raise HTTPException(status_code=500, detail=str(e))
189
+
190
+
191
+ @app.post("/chat", response_model=ChatResponse)
192
+ def chat(req: ChatRequest):
193
+ """
194
+ Main chat endpoint.
195
+ - saves user message
196
+ - fetches last messages from sqlite memory
197
+ - runs retriever to get context
198
+ - builds the system prompt + last 3 user messages
199
+ - calls the LLM via call_llm (same wrapper imported from chatbot_graph)
200
+ - saves assistant reply and returns it + updated history
201
+ """
202
+ uid = (req.user_id or os.getenv("DEFAULT_USER", "vinayak")).strip() or "vinayak"
203
+ if not req.message:
204
+ raise HTTPException(status_code=400, detail="message is required")
205
+
206
+ try:
207
+ # 1) persist user message
208
+ save_message(uid, "user", req.message)
209
+
210
+ # 2) get rows (chronological order) for retriever
211
+ rows = get_last_messages(uid, limit=200)
212
+
213
+ # 3) ensure indexes exist (non-force)
214
+ try:
215
+ ensure_indexes(force_reindex=False)
216
+ except Exception:
217
+ logger.warning("Indexes not built or failed. retriever may return no context.")
218
+
219
+ # 4) run retrieve_node_from_rows to get context (keeps same logic as your retriever glue)
220
+ try:
221
+ retrieved = retrieve_node_from_rows(rows)
222
+ context = retrieved.get("context")
223
+ except Exception:
224
+ logger.exception("retriever call failed")
225
+ context = None
226
+
227
+ # 5) build system prompt content
228
+ # ===== Combine retrieval context + last 2 user turns =====
229
+ MAX_TOKENS_CONTEXT = 3000
230
+ NUM_RECENT_TURNS = 2 # last 2 user + assistant pairs
231
+
232
+ # Get last few messages (both user + assistant)
233
+ recent_pairs = rows[-(NUM_RECENT_TURNS * 2):]
234
+ recent_chat = "\n".join([f"{r[0].upper()}: {r[1]}" for r in recent_pairs])
235
+
236
+ # Trim context to token-safe limit
237
+ context_texts = context.split("\n\n") if context else []
238
+ trimmed_context = trim_to_token_limit(context_texts, limit=MAX_TOKENS_CONTEXT)
239
+
240
+ # Final system prompt
241
+ system_content = SYSTEM_PROMPT
242
+ if trimmed_context:
243
+ system_content += "\n\n===== RETRIEVED CONTEXT =====\n" + trimmed_context
244
+
245
+ # Always include recent conversation (to maintain chat flow)
246
+ system_content += "\n\n===== RECENT CHAT =====\n" + recent_chat
247
+
248
+ # build prompt messages as list of simple dicts (call_llm expects same message format as in chatbot_graph)
249
+ # chatbot_graph.call_llm expects langchain messages (SystemMessage/HumanMessage) — we built that in original file.
250
+ # create messages as minimal objects that call_llm can accept (we rely on original call_llm).
251
+ from langchain_core.messages import SystemMessage, HumanMessage # re-use same message classes
252
+ prompt_msgs = [SystemMessage(content=system_content)]
253
+
254
+ # collect last 3 user messages
255
+ last_users = [r[1] for r in rows if r[0] == "user"][-1:]
256
+ if not last_users:
257
+ last_users = [req.message]
258
+ for u in last_users:
259
+ prompt_msgs.append(HumanMessage(content=u))
260
+
261
+ # 6) call LLM
262
+ try:
263
+ raw = call_llm(prompt_msgs)
264
+ answer = _extract_answer_from_response(raw) or ""
265
+ except Exception as e:
266
+ logger.exception("LLM call failed")
267
+ # If LLM client not configured (ChatGroq missing or no API KEY), return helpful message
268
+ detail = str(e)
269
+ answer = f"LLM call failed: {detail}"
270
+
271
+ # 7) persist assistant reply
272
+ try:
273
+ save_message(uid, "assistant", answer)
274
+ except Exception:
275
+ logger.exception("Failed to persist assistant message")
276
+
277
+ # 8) build history to return
278
+ history = extract_history_for_frontend(uid)
279
+ return {
280
+ "user_id": uid,
281
+ "message": req.message,
282
+ "assistant": answer,
283
+ "history": history,
284
+ }
285
+ except HTTPException:
286
+ raise
287
+ except Exception as e:
288
+ logger.exception("chat failed: %s", e)
289
+ raise HTTPException(status_code=500, detail=str(e))
290
+
291
+
292
+ # Mount static files for frontend
293
+ FRONTEND_DIR = os.path.join(os.path.dirname(__file__), "frontend", "dist")
294
+ if os.path.exists(FRONTEND_DIR):
295
+ app.mount("/assets", StaticFiles(directory=os.path.join(FRONTEND_DIR, "assets")), name="assets")
296
+
297
+ @app.get("/{full_path:path}")
298
+ async def serve_frontend(full_path: str):
299
+ """Serve the React frontend for all non-API routes"""
300
+ if full_path and not full_path.startswith("api"):
301
+ file_path = os.path.join(FRONTEND_DIR, full_path)
302
+ if os.path.exists(file_path) and os.path.isfile(file_path):
303
+ return FileResponse(file_path)
304
+ return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
305
+
306
+
307
+ # Run with: uvicorn main_api:app --reload --host 127.0.0.1 --port 8000
308
+ if __name__ == "__main__":
309
+ uvicorn.run("main_api:app", host="127.0.0.1", port=8000, reload=True)
memory_store.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # memory_store.py
2
+ import sqlite3
3
+ import os
4
+ import logging
5
+ from typing import List, Tuple
6
+
7
+ DB_PATH = os.getenv("MEMORY_DB", "chat_memory.db")
8
+ MAX_MESSAGES_PER_USER = int(os.getenv("MAX_MESSAGES_PER_USER", 500))
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+
13
+
14
+ def _get_conn():
15
+ # check_same_thread=False so Gradio threads can use the DB concurrently
16
+ return sqlite3.connect(DB_PATH, timeout=10, check_same_thread=False)
17
+
18
+
19
+ def init_db():
20
+ conn = _get_conn()
21
+ try:
22
+ with conn:
23
+ conn.execute(
24
+ """
25
+ CREATE TABLE IF NOT EXISTS memory (
26
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
27
+ user_id TEXT,
28
+ role TEXT,
29
+ message TEXT,
30
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
31
+ )
32
+ """
33
+ )
34
+ finally:
35
+ conn.close()
36
+
37
+
38
+ def save_message(user_id: str, role: str, message: str) -> None:
39
+ if not user_id:
40
+ raise ValueError("user_id is required")
41
+ conn = _get_conn()
42
+ try:
43
+ with conn:
44
+ conn.execute(
45
+ "INSERT INTO memory (user_id, role, message) VALUES (?, ?, ?)",
46
+ (user_id, role, message),
47
+ )
48
+ # prune if too many
49
+ if MAX_MESSAGES_PER_USER and MAX_MESSAGES_PER_USER > 0:
50
+ cur = conn.execute(
51
+ "SELECT id FROM memory WHERE user_id = ? ORDER BY id DESC",
52
+ (user_id,),
53
+ )
54
+ rows = cur.fetchall()
55
+ if len(rows) > MAX_MESSAGES_PER_USER:
56
+ ids_to_delete = [r[0] for r in rows[MAX_MESSAGES_PER_USER:]]
57
+ conn.executemany("DELETE FROM memory WHERE id = ?", [(i,) for i in ids_to_delete])
58
+ except Exception:
59
+ logger.exception("Failed to save message for user %s", user_id)
60
+ raise
61
+ finally:
62
+ conn.close()
63
+
64
+
65
+ def get_last_messages(user_id: str, limit: int = 200) -> List[Tuple[str, str, str]]:
66
+ """
67
+ Return last `limit` messages in chronological order as (role, message, created_at)
68
+ """
69
+ conn = _get_conn()
70
+ try:
71
+ cur = conn.cursor()
72
+ cur.execute(
73
+ """
74
+ SELECT role, message, created_at FROM memory
75
+ WHERE user_id = ?
76
+ ORDER BY id DESC
77
+ LIMIT ?
78
+ """,
79
+ (user_id, limit),
80
+ )
81
+ rows = cur.fetchall()
82
+ return list(reversed(rows))
83
+ except Exception:
84
+ logger.exception("Failed to fetch messages for user %s", user_id)
85
+ return []
86
+ finally:
87
+ conn.close()
88
+
89
+
90
+ def clear_user_memory(user_id: str) -> int:
91
+ """Delete memory for user. Returns deleted rowcount."""
92
+ conn = _get_conn()
93
+ try:
94
+ with conn:
95
+ cur = conn.execute("DELETE FROM memory WHERE user_id = ?", (user_id,))
96
+ return cur.rowcount
97
+ except Exception:
98
+ logger.exception("Failed to clear memory for user %s", user_id)
99
+ raise
100
+ finally:
101
+ conn.close()
102
+
103
+
104
+ def build_gradio_history(user_id: str) -> List[dict]:
105
+ """
106
+ Return history formatted for gr.Chatbot with type='messages':
107
+ A chronological list of dicts: {'role':'user'|'assistant','content': '...'}
108
+ """
109
+ rows = get_last_messages(user_id, limit=500)
110
+ return [{"role": r[0], "content": r[1]} for r in rows]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-groq
4
+ sentence-transformers
5
+ faiss-cpu
6
+ pypdf
7
+ unstructured
8
+ python-dotenv
9
+ gradio
10
+ sqlite3-binary