Alshargi commited on
Commit
6e10eef
·
verified ·
1 Parent(s): 5a7ff52

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import time
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import faiss
11
+ from flask import Flask, request, jsonify
12
+ from flask_cors import CORS
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+
16
+ # =========================
17
+ # Config
18
+ # =========================
19
+ INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_ar.faiss")
20
+ META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
21
+ MODEL_NAME = os.getenv("HADITH_MODEL_NAME", "intfloat/multilingual-e5-base")
22
+
23
+ DEFAULT_TOP_K = 10
24
+ MAX_TOP_K = 50
25
+
26
+ # If you want a smaller response payload
27
+ DEFAULT_INCLUDE_TEXT = True
28
+
29
+
30
+ # =========================
31
+ # Arabic normalization
32
+ # =========================
33
+ _AR_DIACRITICS = re.compile(r"""
34
+ [\u0610-\u061A]
35
+ | [\u064B-\u065F]
36
+ | [\u0670]
37
+ | [\u06D6-\u06ED]
38
+ """, re.VERBOSE)
39
+
40
+ def normalize_ar(text: str) -> str:
41
+ """Remove tashkeel + normalize common Arabic letter variants."""
42
+ if text is None:
43
+ return ""
44
+ text = str(text)
45
+ text = _AR_DIACRITICS.sub("", text)
46
+ text = text.replace("ـ", "")
47
+ text = re.sub(r"[إأآٱ]", "ا", text)
48
+ text = text.replace("ى", "ي")
49
+ text = text.replace("ؤ", "و")
50
+ text = text.replace("ئ", "ي")
51
+ text = re.sub(r"\s+", " ", text).strip()
52
+ return text
53
+
54
+
55
+ # =========================
56
+ # Load model + index + meta (once)
57
+ # =========================
58
+ if not os.path.exists(INDEX_PATH):
59
+ raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
60
+
61
+ if not os.path.exists(META_PATH):
62
+ raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
63
+
64
+ model = SentenceTransformer(MODEL_NAME)
65
+ index = faiss.read_index(INDEX_PATH)
66
+ meta = pd.read_parquet(META_PATH)
67
+
68
+ required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
69
+ missing = required_cols - set(meta.columns)
70
+ if missing:
71
+ raise ValueError(f"Meta is missing required columns: {missing}")
72
+
73
+ if "arabic_clean" not in meta.columns:
74
+ meta["arabic_clean"] = ""
75
+
76
+ # Normalize column types to avoid NaN surprises
77
+ for col in ["arabic", "english", "arabic_clean", "collection"]:
78
+ if col in meta.columns:
79
+ meta[col] = meta[col].fillna("").astype(str)
80
+
81
+
82
+ def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
83
+ q = str(query or "").strip()
84
+ if not q:
85
+ return meta.iloc[0:0].copy()
86
+
87
+ top_k = max(1, min(int(top_k), MAX_TOP_K))
88
+
89
+ q_norm = normalize_ar(q)
90
+ q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
91
+ scores, idx = index.search(q_emb, top_k)
92
+
93
+ res = meta.iloc[idx[0]].copy()
94
+ res["score"] = scores[0].astype(float)
95
+ res = res.sort_values("score", ascending=False)
96
+
97
+ # Ensure no empty Arabic (avoid useless results)
98
+ res["arabic"] = res["arabic"].fillna("").astype(str)
99
+ res = res[res["arabic"].str.strip() != ""]
100
+
101
+ return res
102
+
103
+
104
+ def row_to_json(row: pd.Series, include_text: bool = True) -> Dict[str, Any]:
105
+ arabic = str(row.get("arabic", "") or "")
106
+ arabic_clean = str(row.get("arabic_clean", "") or "").strip()
107
+ if not arabic_clean:
108
+ arabic_clean = normalize_ar(arabic)
109
+
110
+ base = {
111
+ "score": float(row.get("score", 0.0)),
112
+ "hadithID": int(row.get("hadithID")),
113
+ "collection": str(row.get("collection", "")),
114
+ "hadith_number": int(row.get("hadith_number")),
115
+ }
116
+
117
+ if include_text:
118
+ base.update({
119
+ "arabic": arabic,
120
+ "arabic_clean": arabic_clean,
121
+ "english": str(row.get("english", "") or ""),
122
+ })
123
+
124
+ return base
125
+
126
+
127
+ # =========================
128
+ # Flask API app
129
+ # =========================
130
+ app = Flask(__name__)
131
+ CORS(app, resources={r"/*": {"origins": "*"}}) # allow calls from other hosts
132
+
133
+
134
+ @app.get("/health")
135
+ def health():
136
+ return jsonify({
137
+ "ok": True,
138
+ "rows": int(len(meta)),
139
+ "index_ntotal": int(getattr(index, "ntotal", -1)),
140
+ "model": MODEL_NAME
141
+ })
142
+
143
+
144
+ @app.post("/search")
145
+ def search():
146
+ """
147
+ JSON body:
148
+ {
149
+ "q": "الزرق و سبيل الرزق",
150
+ "k": 10,
151
+ "include_text": true
152
+ }
153
+ """
154
+ payload = request.get_json(silent=True) or {}
155
+ q = (payload.get("q") or "").strip()
156
+ k = payload.get("k", DEFAULT_TOP_K)
157
+ include_text = payload.get("include_text", DEFAULT_INCLUDE_TEXT)
158
+
159
+ # Validate
160
+ if not q:
161
+ return jsonify({"ok": False, "error": "Missing 'q'"}), 400
162
+ try:
163
+ k = int(k)
164
+ except Exception:
165
+ k = DEFAULT_TOP_K
166
+ k = max(1, min(k, MAX_TOP_K))
167
+
168
+ t0 = time.time()
169
+ res_df = semantic_search(q, top_k=k)
170
+ took_ms = int((time.time() - t0) * 1000)
171
+
172
+ results = [row_to_json(r, include_text=bool(include_text)) for _, r in res_df.iterrows()]
173
+
174
+ return jsonify({
175
+ "ok": True,
176
+ "query": q,
177
+ "query_norm": normalize_ar(q),
178
+ "k": k,
179
+ "took_ms": took_ms,
180
+ "results_count": len(results),
181
+ "results": results
182
+ })
183
+
184
+
185
+ @app.get("/search")
186
+ def search_get():
187
+ """
188
+ GET /search?q=...&k=10&include_text=1
189
+ Useful for quick testing in browser.
190
+ """
191
+ q = (request.args.get("q") or "").strip()
192
+ k = request.args.get("k", str(DEFAULT_TOP_K))
193
+ include_text = request.args.get("include_text", "1")
194
+
195
+ if not q:
196
+ return jsonify({"ok": False, "error": "Missing 'q'"}), 400
197
+
198
+ try:
199
+ k_int = int(k)
200
+ except Exception:
201
+ k_int = DEFAULT_TOP_K
202
+ k_int = max(1, min(k_int, MAX_TOP_K))
203
+
204
+ include_text_bool = include_text not in ("0", "false", "False", "")
205
+
206
+ t0 = time.time()
207
+ res_df = semantic_search(q, top_k=k_int)
208
+ took_ms = int((time.time() - t0) * 1000)
209
+
210
+ results = [row_to_json(r, include_text=include_text_bool) for _, r in res_df.iterrows()]
211
+
212
+ return jsonify({
213
+ "ok": True,
214
+ "query": q,
215
+ "query_norm": normalize_ar(q),
216
+ "k": k_int,
217
+ "took_ms": took_ms,
218
+ "results_count": len(results),
219
+ "results": results
220
+ })
221
+
222
+
223
+ if __name__ == "__main__":
224
+ # For local debug only. On HF Spaces, gunicorn/uvicorn handles it.
225
+ app.run(host="0.0.0.0", port=7860, debug=False)