Polarisailabs commited on
Commit
01ae44e
·
verified ·
1 Parent(s): 48df24b

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. IPC.pdf +3 -0
  3. app.py +413 -0
  4. docs/.DS_Store +0 -0
  5. docs/IPC.pdf +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/IPC.pdf filter=lfs diff=lfs merge=lfs -text
37
+ IPC.pdf filter=lfs diff=lfs merge=lfs -text
IPC.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038c736730c09d5b72b1642ab8056607ca546c0b87631811da1a30accd08f81d
3
+ size 1529218
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import pathlib
5
+ import shutil
6
+ from typing import List, Tuple, Dict
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import faiss
11
+ from sentence_transformers import SentenceTransformer
12
+ from pypdf import PdfReader
13
+ import fitz # PyMuPDF
14
+ from collections import defaultdict
15
+ from openai import OpenAI
16
+
17
+ # =========================
18
+ # LLM Endpoint
19
+ # =========================
20
+ API_KEY = os.environ.get("API_KEY")
21
+ if not API_KEY:
22
+ raise RuntimeError("Missing API_KEY (set it in Hugging Face: Settings → Variables and secrets).")
23
+
24
+ client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY)
25
+
26
+ # Model configuration
27
+ # The model was hardcoded to "deepseek/deepseek-r1:free" as requested.
28
+ # The previous default was "Deepseek".
29
+ SINGLE_MODEL_NAME = "deepseek/deepseek-r1:free"
30
+
31
+ GEN_TEMPERATURE = 0.2
32
+ GEN_TOP_P = 0.95
33
+ GEN_MAX_TOKENS = 1024
34
+ EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
35
+
36
+ def choose_store_dir() -> Tuple[str, bool]:
37
+ data_root = "/data"
38
+ if os.path.isdir(data_root) and os.access(data_root, os.W_OK):
39
+ d = os.path.join(data_root, "rag_store")
40
+ try:
41
+ os.makedirs(d, exist_ok=True)
42
+ testf = os.path.join(d, ".write_test")
43
+ with open(testf, "w", encoding="utf-8") as f:
44
+ f.write("ok")
45
+ os.remove(testf)
46
+ return d, True
47
+ except Exception:
48
+ pass
49
+ d = os.path.join(os.getcwd(), "store")
50
+ os.makedirs(d, exist_ok=True)
51
+ return d, False
52
+
53
+ STORE_DIR, IS_PERSISTENT = choose_store_dir()
54
+ META_PATH = os.path.join(STORE_DIR, "meta.json")
55
+ INDEX_PATH = os.path.join(STORE_DIR, "faiss.index")
56
+ LEGACY_STORE_DIR = os.path.join(os.getcwd(), "store")
57
+
58
+ def migrate_legacy_if_any():
59
+ try:
60
+ if IS_PERSISTENT:
61
+ legacy_meta = os.path.join(LEGACY_STORE_DIR, "meta.json")
62
+ legacy_index = os.path.join(LEGACY_STORE_DIR, "faiss.index")
63
+ if (not os.path.exists(META_PATH) or not os.path.exists(INDEX_PATH)) \
64
+ and os.path.isdir(LEGACY_STORE_DIR) \
65
+ and os.path.exists(legacy_meta) and os.path.exists(legacy_index):
66
+ shutil.copyfile(legacy_meta, META_PATH)
67
+ shutil.copyfile(legacy_index, INDEX_PATH)
68
+ except Exception:
69
+ pass
70
+
71
+ migrate_legacy_if_any()
72
+
73
+ _emb_model = None
74
+ _index: faiss.Index = None
75
+ _meta: Dict[str, Dict] = {}
76
+
77
+ DEFAULT_TOP_K = 6
78
+ DEFAULT_POOL_K = 40
79
+ DEFAULT_PER_SOURCE_CAP = 2
80
+ DEFAULT_STRATEGY = "mmr"
81
+ DEFAULT_MMR_LAMBDA = 0.5
82
+
83
+ def get_emb_model():
84
+ global _emb_model
85
+ if _emb_model is None:
86
+ _emb_model = SentenceTransformer(EMB_MODEL_NAME)
87
+ return _emb_model
88
+
89
+ def _ensure_index(dim: int):
90
+ global _index
91
+ if _index is None:
92
+ _index = faiss.IndexFlatIP(dim)
93
+
94
+ def _persist():
95
+ faiss.write_index(_index, INDEX_PATH)
96
+ with open(META_PATH, "w", encoding="utf-8") as f:
97
+ json.dump(_meta, f, ensure_ascii=False)
98
+
99
+ def _load_if_any():
100
+ global _index, _meta
101
+ if os.path.exists(INDEX_PATH) and os.path.exists(META_PATH):
102
+ _index = faiss.read_index(INDEX_PATH)
103
+ with open(META_PATH, "r", encoding="utf-8") as f:
104
+ _meta = json.load(f)
105
+
106
+ def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]:
107
+ text = text.replace("\u0000", "")
108
+ res, i, n = [], 0, len(text)
109
+ while i < n:
110
+ j = min(i + chunk_size, n)
111
+ seg = text[i:j].strip()
112
+ if seg:
113
+ res.append(seg)
114
+ i = max(0, j - overlap)
115
+ if j >= n:
116
+ break
117
+ return res
118
+
119
+ def _read_bytes(file) -> bytes:
120
+ if isinstance(file, dict):
121
+ p = file.get("path") or file.get("name")
122
+ if p and os.path.exists(p):
123
+ with open(p, "rb") as f:
124
+ return f.read()
125
+ if "data" in file and isinstance(file["data"], (bytes, bytearray)):
126
+ return bytes(file["data"])
127
+ if isinstance(file, (str, pathlib.Path)):
128
+ with open(file, "rb") as f:
129
+ return f.read()
130
+ if hasattr(file, "read"):
131
+ try:
132
+ if hasattr(file, "seek"):
133
+ try:
134
+ file.seek(0)
135
+ except Exception:
136
+ pass
137
+ return file.read()
138
+ finally:
139
+ try:
140
+ file.close()
141
+ except Exception:
142
+ pass
143
+ raise ValueError("Unsupported file type from gr.File")
144
+
145
+ def _decode_best_effort(raw: bytes) -> str:
146
+ for enc in ["utf-8", "cp932", "shift_jis", "cp950", "big5", "gb18030", "latin-1"]:
147
+ try:
148
+ return raw.decode(enc)
149
+ except Exception:
150
+ continue
151
+ return raw.decode("utf-8", errors="ignore")
152
+
153
+ def _read_pdf(file_bytes: bytes) -> str:
154
+ try:
155
+ with fitz.open(stream=file_bytes, filetype="pdf") as doc:
156
+ if doc.is_encrypted:
157
+ try:
158
+ doc.authenticate("")
159
+ except Exception:
160
+ pass
161
+ texts = [(page.get_text("text") or "") for page in doc]
162
+ txt = "\n".join(texts)
163
+ if txt.strip():
164
+ return txt
165
+ except Exception:
166
+ pass
167
+ try:
168
+ reader = PdfReader(io.BytesIO(file_bytes))
169
+ pages = []
170
+ for p in reader.pages:
171
+ try:
172
+ pages.append(p.extract_text() or "")
173
+ except Exception:
174
+ pages.append("")
175
+ return "\n".join(pages)
176
+ except Exception:
177
+ return ""
178
+
179
+ def _read_any(file) -> str:
180
+ if isinstance(file, dict):
181
+ name = (file.get("orig_name") or file.get("name") or file.get("path") or "upload").lower()
182
+ else:
183
+ name = getattr(file, "name", None) or (str(file) if isinstance(file, (str, pathlib.Path)) else "upload")
184
+ name = name.lower()
185
+ raw = _read_bytes(file)
186
+ if name.endswith(".pdf"):
187
+ return _read_pdf(raw).replace("\u0000", "")
188
+ return _decode_best_effort(raw).replace("\u0000", "")
189
+
190
+ DOCS_DIR = os.path.join(os.getcwd(), "docs")
191
+
192
+ def get_docs_files() -> List[str]:
193
+ if not os.path.isdir(DOCS_DIR):
194
+ return []
195
+ files = []
196
+ for fname in os.listdir(DOCS_DIR):
197
+ if fname.lower().endswith((".pdf", ".txt")):
198
+ files.append(os.path.join(DOCS_DIR, fname))
199
+ return files
200
+
201
+ def build_corpus_from_docs():
202
+ global _index, _meta
203
+ files = get_docs_files()
204
+ if not files:
205
+ return "No files found in docs folder."
206
+ emb_model = get_emb_model()
207
+ chunks, sources, failed = [], [], []
208
+ _index = None
209
+ _meta = {}
210
+ for f in files:
211
+ fname = os.path.basename(f)
212
+ try:
213
+ text = _read_any(f) or ""
214
+ parts = _chunk_text(text)
215
+ if not parts:
216
+ failed.append(fname)
217
+ continue
218
+ chunks.extend(parts)
219
+ sources.extend([fname] * len(parts))
220
+ except Exception:
221
+ failed.append(fname)
222
+ if not chunks:
223
+ return "No text extracted from docs."
224
+ passages = [f"passage: {c}" for c in chunks]
225
+ vec = emb_model.encode(passages, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
226
+ _ensure_index(vec.shape[1])
227
+ _index.add(vec)
228
+ for i, (src, c) in enumerate(zip(sources, chunks)):
229
+ _meta[str(i)] = {"source": src, "text": c}
230
+ _persist()
231
+ msg = f"Indexed {len(chunks)} chunks from {len(files)} files."
232
+ if failed:
233
+ msg += f" Failed files: {', '.join(failed)}"
234
+ return msg
235
+
236
+ def _encode_query_vec(query: str) -> np.ndarray:
237
+ return get_emb_model().encode([f"query: {query}"], convert_to_numpy=True, normalize_embeddings=True)
238
+
239
+ def retrieve_candidates(qvec: np.ndarray, pool_k: int = 40) -> List[Tuple[str, float]]:
240
+ if _index is None or _index.ntotal == 0:
241
+ return []
242
+ pool_k = min(pool_k, _index.ntotal)
243
+ D, I = _index.search(qvec, pool_k)
244
+ return [(str(idx), float(score)) for idx, score in zip(I[0], D[0]) if idx != -1]
245
+
246
+ def select_diverse_by_source(cands: List[Tuple[str, float]], top_k: int = 6, per_source_cap: int = 2) -> List[Tuple[str, float]]:
247
+ if not cands:
248
+ return []
249
+ by_src: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
250
+ for cid, s in cands:
251
+ m = _meta.get(cid)
252
+ if not m:
253
+ continue
254
+ by_src[m["source"]].append((cid, s))
255
+ for src in by_src:
256
+ by_src[src] = by_src[src][:per_source_cap]
257
+ picked, src_items, ptrs = [], [(s, it) for s, it in by_src.items()], {s: 0 for s in by_src}
258
+ while len(picked) < top_k:
259
+ advanced = False
260
+ for src, items in src_items:
261
+ i = ptrs[src]
262
+ if i < len(items):
263
+ picked.append(items[i])
264
+ ptrs[src] = i + 1
265
+ advanced = True
266
+ if len(picked) >= top_k:
267
+ break
268
+ if not advanced:
269
+ break
270
+ if len(picked) < top_k:
271
+ seen = {cid for cid, _ in picked}
272
+ for cid, s in cands:
273
+ if cid not in seen:
274
+ picked.append((cid, s))
275
+ seen.add(cid)
276
+ if len(picked) >= top_k:
277
+ break
278
+ return picked[:top_k]
279
+
280
+ def _encode_chunks_text(cids: List[str]) -> np.ndarray:
281
+ texts = [f"passage: {(_meta.get(cid) or {}).get('text','')}" for cid in cids]
282
+ return get_emb_model().encode(texts, convert_to_numpy=True, normalize_embeddings=True)
283
+
284
+ def select_diverse_mmr(cands: List[Tuple[str, float]], qvec: np.ndarray, top_k: int = 6, mmr_lambda: float = 0.5) -> List[Tuple[str, float]]:
285
+ if not cands:
286
+ return []
287
+ cids = [cid for cid, _ in cands]
288
+ cvecs = _encode_chunks_text(cids)
289
+ sim_to_q = (cvecs @ qvec.T).reshape(-1)
290
+ selected, remaining = [], set(range(len(cids)))
291
+ while len(selected) < min(top_k, len(cids)):
292
+ if not selected:
293
+ i = int(np.argmax(sim_to_q))
294
+ selected.append(i)
295
+ remaining.remove(i)
296
+ continue
297
+ S = cvecs[selected]
298
+ sim_to_S = (cvecs[list(remaining)] @ S.T)
299
+ max_sim_to_S = sim_to_S.max(axis=1) if sim_to_S.size > 0 else np.zeros((len(remaining),), dtype=np.float32)
300
+ sim_q_rem = sim_to_q[list(remaining)]
301
+ mmr_scores = mmr_lambda * sim_q_rem - (1.0 - mmr_lambda) * max_sim_to_S
302
+ j_rel = int(np.argmax(mmr_scores))
303
+ j = list(remaining)[j_rel]
304
+ selected.append(j)
305
+ remaining.remove(j)
306
+ return [(cids[i], float(sim_to_q[i])) for i in selected][:top_k]
307
+
308
+ def retrieve_diverse(query: str,
309
+ top_k: int = 6,
310
+ pool_k: int = 40,
311
+ per_source_cap: int = 2,
312
+ strategy: str = "mmr",
313
+ mmr_lambda: float = 0.5) -> List[Tuple[str, float]]:
314
+ qvec = _encode_query_vec(query)
315
+ cands = retrieve_candidates(qvec, pool_k=pool_k)
316
+ if strategy == "mmr":
317
+ return select_diverse_mmr(cands, qvec, top_k=top_k, mmr_lambda=mmr_lambda)
318
+ return select_diverse_by_source(cands, top_k=top_k, per_source_cap=per_source_cap)
319
+
320
+ def _format_ctx(hits: List[Tuple[str, float]]) -> str:
321
+ if not hits:
322
+ return ""
323
+ lines = []
324
+ for cid, _ in hits:
325
+ m = _meta.get(cid)
326
+ if not m:
327
+ continue
328
+ source_clean = m.get("source", "")
329
+ text_clean = (m.get("text", "") or "").replace("\n", " ")
330
+ lines.append(f"[{cid}] ({source_clean}) " + text_clean)
331
+ return "\n".join(lines[:10])
332
+
333
+ def chat_fn(message, history):
334
+ model_name = SINGLE_MODEL_NAME
335
+ if _index is None or _index.ntotal == 0:
336
+ status = build_corpus_from_docs()
337
+ if not (_index and _index.ntotal > 0):
338
+ yield f"**Index Status:** {status}\n\nPlease ensure you have a 'docs' folder with PDF/TXT files and try again."
339
+ return
340
+
341
+ hits = retrieve_diverse(
342
+ message,
343
+ top_k=6,
344
+ pool_k=40,
345
+ per_source_cap=2,
346
+ strategy="mmr",
347
+ mmr_lambda=0.5,
348
+ )
349
+
350
+ ctx = _format_ctx(hits) if hits else "(Current index is empty or no matching chunks found)"
351
+
352
+ sys_blocks = ["You are a research assistant who has an excellent factual understanding of the legal policies, regulations, and compliance of enterprises, governments, and global organizations. You are a research assistant who reads Legal papers and provides factual answers to queries. If you do not know the answer, you should convey that to the user instead of hallucinating. Answers must be based on retrieved content with evidence and source numbers cited. If retrieval is insufficient, please clearly explain the shortcomings. When answering, please cite the numbers, e.g., [3]"]
353
+ messages = [{"role": "system", "content": "\n\n".join(sys_blocks)}]
354
+ for u, a in history:
355
+ messages.append({"role": "user", "content": u})
356
+ messages.append({"role": "assistant", "content": a})
357
+ messages.append({"role": "user", "content": message})
358
+
359
+ try:
360
+ response = client.chat.completions.create(
361
+ model=model_name,
362
+ messages=messages,
363
+ temperature=GEN_TEMPERATURE,
364
+ top_p=GEN_TOP_P,
365
+ max_tokens=GEN_MAX_TOKENS,
366
+ stream=True,
367
+ )
368
+
369
+ partial_message = ""
370
+ for chunk in response:
371
+ if hasattr(chunk.choices[0], "delta") and chunk.choices[0].delta.content is not None:
372
+ partial_message += chunk.choices[0].delta.content
373
+ yield partial_message
374
+ elif hasattr(chunk.choices[0], "message") and chunk.choices[0].message.content is not None:
375
+ partial_message += chunk.choices[0].message.content
376
+ yield partial_message
377
+ except Exception as e:
378
+ yield f"[Exception] {repr(e)}"
379
+
380
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as legalprodigy:
381
+ gr.Markdown("")
382
+ with gr.Row():
383
+ query_box = gr.Textbox(
384
+ placeholder="Try: Explain Arbiration Process",
385
+ scale=5
386
+ )
387
+ send_btn = gr.Button("Send", scale=1)
388
+ with gr.Row():
389
+ chatbot = gr.Chatbot(label="LegalProdigy")
390
+ state = gr.State([])
391
+
392
+ def chat_wrapper(user_message, history):
393
+ history = history or []
394
+ gen = chat_fn(user_message, history)
395
+ result = ""
396
+ for chunk in gen:
397
+ result = chunk
398
+ history.append((user_message, result))
399
+ return history, history
400
+
401
+ send_btn.click(
402
+ chat_wrapper,
403
+ inputs=[query_box, state],
404
+ outputs=[chatbot, state]
405
+ )
406
+
407
+ try:
408
+ _load_if_any()
409
+ except Exception:
410
+ pass
411
+
412
+ if __name__ == "__main__":
413
+ legalprodigy.launch()
docs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
docs/IPC.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038c736730c09d5b72b1642ab8056607ca546c0b87631811da1a30accd08f81d
3
+ size 1529218