Rajan Sharma commited on
Commit
5543737
·
verified ·
1 Parent(s): 023cf3a

Update session_rag.py

Browse files
Files changed (1) hide show
  1. session_rag.py +36 -68
session_rag.py CHANGED
@@ -1,49 +1,29 @@
1
- """
2
- Session-level RAG with graceful FAISS fallback.
3
-
4
- - If FAISS is installed, uses a FAISS L2 index over normalized embeddings.
5
- - If FAISS is missing, falls back to pure NumPy cosine similarity.
6
- - Designed to work with extract_text_from_files(...) outputs:
7
- * list[str]
8
- * list[dict] with keys like "text" or "content"
9
- """
10
 
 
11
  from __future__ import annotations
12
-
13
- import logging
14
- import hashlib
15
- from typing import Iterable, List, Optional, Tuple
16
-
17
  import numpy as np
18
  from sentence_transformers import SentenceTransformer
19
 
20
- # ----- Optional FAISS -----
21
  try:
22
  import faiss # type: ignore
23
  _HAS_FAISS = True
24
  except Exception:
25
- logging.warning(
26
- "FAISS not installed — session RAG will use a NumPy cosine-similarity fallback. "
27
- "Install faiss-cpu or faiss-gpu for faster retrieval."
28
- )
29
  faiss = None # type: ignore
30
  _HAS_FAISS = False
31
 
32
-
33
  def _normalize_rows(x: np.ndarray) -> np.ndarray:
34
- """L2 normalize row vectors; avoids division by zero."""
35
  norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
36
  return x / norms
37
 
38
-
39
  def _hash_text(s: str) -> str:
40
  return hashlib.sha256(s.encode("utf-8")).hexdigest()
41
 
42
-
43
  def _coerce_texts(items: Iterable) -> List[str]:
44
- """Accept str or dict items, pull text safely, drop empties, dedupe by hash."""
45
  out: List[str] = []
46
- seen: set = set()
47
  for it in items or []:
48
  if isinstance(it, str):
49
  txt = it.strip()
@@ -60,45 +40,40 @@ def _coerce_texts(items: Iterable) -> List[str]:
60
  out.append(txt)
61
  return out
62
 
63
-
64
  def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
65
- """Lightweight char-based chunking to improve recall on long docs."""
66
  if len(text) <= max_chars:
67
  return [text]
68
  chunks = []
69
  i = 0
70
  while i < len(text):
71
- chunk = text[i : i + max_chars]
72
- chunks.append(chunk)
73
  i += max_chars - overlap
74
  return chunks
75
 
76
-
77
  class SessionRAG:
78
  """
79
- Ephemeral per-session retriever.
80
-
81
- Methods:
82
- - add_docs(items): add strings or dicts({"text"/"content": ...})
83
- - retrieve(query, k=5): returns list[str] of top-k chunks
84
- - clear(): drop index & memory
 
 
85
  """
86
-
87
  def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
88
  self.model = SentenceTransformer(model_name)
89
  self.texts: List[str] = []
90
- self.embeddings: Optional[np.ndarray] = None # shape: (N, D)
91
- self.index = None # FAISS index if available
92
  self.dim: Optional[int] = None
 
93
 
94
- # ---------- Private helpers ----------
95
  def _fit_faiss(self) -> None:
96
  if not _HAS_FAISS or self.embeddings is None:
97
  return
98
- # Use inner product on normalized vectors (cosine similarity)
99
  emb = _normalize_rows(self.embeddings.astype("float32"))
100
  self.dim = emb.shape[1]
101
- # Build IP index
102
  self.index = faiss.IndexFlatIP(self.dim)
103
  self.index.add(emb)
104
 
@@ -107,33 +82,21 @@ class SessionRAG:
107
  self.embeddings = None
108
  self.index = None
109
  return
110
- # Compute embeddings
111
  embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
112
  self.embeddings = np.asarray(embs, dtype="float32")
113
- # Build FAISS if available
114
  if _HAS_FAISS:
115
  self._fit_faiss()
116
  else:
117
  self.index = None
118
 
119
- # ---------- Public API ----------
120
  def add_docs(self, items: Iterable) -> int:
121
- """
122
- Add a batch of texts or dicts with 'text'/'content'.
123
- Applies basic chunking and deduplication.
124
- Returns the number of chunks added.
125
- """
126
  raw_texts = _coerce_texts(items)
127
  if not raw_texts:
128
  return 0
129
-
130
- # Chunk each long text into manageable pieces
131
  chunks: List[str] = []
132
  for t in raw_texts:
133
  chunks.extend(_simple_chunk(t))
134
-
135
- # Deduplicate vs existing memory
136
- existing_hashes = { _hash_text(t) for t in self.texts }
137
  added = 0
138
  for c in chunks:
139
  h = _hash_text(c)
@@ -142,40 +105,45 @@ class SessionRAG:
142
  self.texts.append(c)
143
  existing_hashes.add(h)
144
  added += 1
145
-
146
- # Recompute embeddings/index
147
  if added > 0:
148
  self._ensure_embeddings()
149
-
150
  return added
151
 
 
 
 
 
 
 
 
 
152
  def retrieve(self, query: str, k: int = 5) -> List[str]:
153
- """Return up to k most similar chunks for the query."""
154
  if not query or not self.texts:
155
  return []
156
-
157
- # Encode query, normalize
158
  q_emb = self.model.encode([query], show_progress_bar=False)
159
  q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
160
-
161
  if self.embeddings is None:
162
  return []
163
-
164
- # FAISS path (inner product on normalized vectors)
165
  if _HAS_FAISS and self.index is not None:
166
  D, I = self.index.search(q, min(k, len(self.texts)))
167
  idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
168
  return [self.texts[i] for i in idxs]
169
-
170
- # NumPy fallback: cosine similarity via dot product on normalized vectors
171
  docs = _normalize_rows(self.embeddings)
172
- sims = (q @ docs.T)[0] # shape: (N,)
173
  top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
174
  return [self.texts[i] for i in top_idx]
175
 
 
 
 
 
 
 
 
 
176
  def clear(self) -> None:
177
- """Drop all in-memory data for this session."""
178
  self.texts = []
179
  self.embeddings = None
180
  self.index = None
181
  self.dim = None
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ # session_rag.py
3
  from __future__ import annotations
4
+ import logging, hashlib
5
+ from typing import Iterable, List, Optional, Dict, Any
 
 
 
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
8
 
 
9
  try:
10
  import faiss # type: ignore
11
  _HAS_FAISS = True
12
  except Exception:
13
+ logging.warning("FAISS not installed — using NumPy cosine fallback.")
 
 
 
14
  faiss = None # type: ignore
15
  _HAS_FAISS = False
16
 
 
17
  def _normalize_rows(x: np.ndarray) -> np.ndarray:
 
18
  norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-10
19
  return x / norms
20
 
 
21
  def _hash_text(s: str) -> str:
22
  return hashlib.sha256(s.encode("utf-8")).hexdigest()
23
 
 
24
  def _coerce_texts(items: Iterable) -> List[str]:
 
25
  out: List[str] = []
26
+ seen = set()
27
  for it in items or []:
28
  if isinstance(it, str):
29
  txt = it.strip()
 
40
  out.append(txt)
41
  return out
42
 
 
43
  def _simple_chunk(text: str, max_chars: int = 1200, overlap: int = 150) -> List[str]:
 
44
  if len(text) <= max_chars:
45
  return [text]
46
  chunks = []
47
  i = 0
48
  while i < len(text):
49
+ chunks.append(text[i : i + max_chars])
 
50
  i += max_chars - overlap
51
  return chunks
52
 
 
53
  class SessionRAG:
54
  """
55
+ Ephemeral per-session retriever with artifact registry.
56
+
57
+ Public:
58
+ - add_docs(items)
59
+ - register_artifacts(arts)
60
+ - retrieve(query, k=5)
61
+ - get_latest_csv_columns()
62
+ - clear()
63
  """
 
64
  def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
65
  self.model = SentenceTransformer(model_name)
66
  self.texts: List[str] = []
67
+ self.embeddings: Optional[np.ndarray] = None
68
+ self.index = None
69
  self.dim: Optional[int] = None
70
+ self.artifacts: List[Dict[str, Any]] = [] # keeps structured info per upload
71
 
 
72
  def _fit_faiss(self) -> None:
73
  if not _HAS_FAISS or self.embeddings is None:
74
  return
 
75
  emb = _normalize_rows(self.embeddings.astype("float32"))
76
  self.dim = emb.shape[1]
 
77
  self.index = faiss.IndexFlatIP(self.dim)
78
  self.index.add(emb)
79
 
 
82
  self.embeddings = None
83
  self.index = None
84
  return
 
85
  embs = self.model.encode(self.texts, batch_size=64, show_progress_bar=False)
86
  self.embeddings = np.asarray(embs, dtype="float32")
 
87
  if _HAS_FAISS:
88
  self._fit_faiss()
89
  else:
90
  self.index = None
91
 
 
92
  def add_docs(self, items: Iterable) -> int:
 
 
 
 
 
93
  raw_texts = _coerce_texts(items)
94
  if not raw_texts:
95
  return 0
 
 
96
  chunks: List[str] = []
97
  for t in raw_texts:
98
  chunks.extend(_simple_chunk(t))
99
+ existing_hashes = {_hash_text(t) for t in self.texts}
 
 
100
  added = 0
101
  for c in chunks:
102
  h = _hash_text(c)
 
105
  self.texts.append(c)
106
  existing_hashes.add(h)
107
  added += 1
 
 
108
  if added > 0:
109
  self._ensure_embeddings()
 
110
  return added
111
 
112
+ def register_artifacts(self, arts: Iterable[Dict[str, Any]]) -> int:
113
+ count = 0
114
+ for a in (arts or []):
115
+ if isinstance(a, dict):
116
+ self.artifacts.append(a)
117
+ count += 1
118
+ return count
119
+
120
  def retrieve(self, query: str, k: int = 5) -> List[str]:
 
121
  if not query or not self.texts:
122
  return []
 
 
123
  q_emb = self.model.encode([query], show_progress_bar=False)
124
  q = _normalize_rows(np.asarray(q_emb, dtype="float32"))
 
125
  if self.embeddings is None:
126
  return []
 
 
127
  if _HAS_FAISS and self.index is not None:
128
  D, I = self.index.search(q, min(k, len(self.texts)))
129
  idxs = [i for i in I[0] if 0 <= i < len(self.texts)]
130
  return [self.texts[i] for i in idxs]
 
 
131
  docs = _normalize_rows(self.embeddings)
132
+ sims = (q @ docs.T)[0]
133
  top_idx = np.argsort(-sims)[: min(k, len(self.texts))]
134
  return [self.texts[i] for i in top_idx]
135
 
136
+ # ---------- helpers for structured Qs ----------
137
+ def get_latest_csv_columns(self) -> List[str]:
138
+ # scan artifacts in reverse insertion order
139
+ for a in reversed(self.artifacts):
140
+ if a.get("kind") == "csv" and a.get("columns"):
141
+ return list(map(str, a["columns"]))
142
+ return []
143
+
144
  def clear(self) -> None:
 
145
  self.texts = []
146
  self.embeddings = None
147
  self.index = None
148
  self.dim = None
149
+ self.artifacts = []