rairo commited on
Commit
4160d8e
Β·
verified Β·
1 Parent(s): 2a26eee

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +458 -0
main.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, logging, warnings, time, certifi, pymysql, requests
2
+ from contextlib import contextmanager
3
+ from datetime import date
4
+ from flask import Flask, request, jsonify
5
+ from flask_cors import CORS
6
+
7
+ # ---- Optional Google GenAI (Gemini) ----
8
+ from google import genai
9
+ from google.genai import types
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+ # ───────────────────────────────────────────────────────────────────────────────
14
+ # CONFIG
15
+ # ───────────────────────────────────────────────────────────────────────────────
16
+ DB_NAME = os.getenv("TIDB_DB", "test")
17
+ TIDB_HOST = os.getenv("TIDB_HOST", "")
18
+ TIDB_PORT = int(os.getenv("TIDB_PORT", "4000"))
19
+ TIDB_USER = os.getenv("TIDB_USER", "")
20
+ TIDB_PASS = os.getenv("TIDB_PASS", "")
21
+
22
+ VEC_DIM = int(os.getenv("VEC_DIM", "1536"))
23
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
24
+ USE_GPU = os.getenv("USE_GPU", "0") == "1" # Spaces are usually CPU; works either way
25
+
26
+ # Policy windows (server is single source of truth for the client)
27
+ POLICY_WINDOWS = [
28
+ {
29
+ "code": "NAZI_ERA",
30
+ "label": "Washington Conference Principles (1933–1945)",
31
+ "from": "1933-01-01",
32
+ "to": "1945-12-31",
33
+ "ref": "https://www.state.gov/washington-conference-principles-on-nazi-confiscated-art"
34
+ },
35
+ {
36
+ "code": "UNESCO_1970",
37
+ "label": "UNESCO 1970 Convention",
38
+ "from": "1970-11-14",
39
+ "to": None,
40
+ "ref": "https://www.unesco.org/en/legal-affairs/convention-means-prohibiting-and-preventing-illicit-import-export-and-transfer-ownership-cultural"
41
+ }
42
+ ]
43
+
44
+ # ───────────────────────────────────────────────────────────────────────────────
45
+ # APP + LOGGING
46
+ # ───────────────────────────────────────────────────────────────────────────────
47
+ logging.basicConfig(level=logging.INFO)
48
+ log = logging.getLogger("provenance-api")
49
+
50
+ app = Flask(__name__)
51
+ CORS(app)
52
+
53
+ # ───────────────────────────────────────────────────────────────────────────────
54
+ # DB CONNECTION (autocommit + TLS + auto-reconnect)
55
+ # ───────────────────────────────────────────────────────────────────────────────
56
+ _CONN = None
57
+
58
+ def _connect():
59
+ """(Re)connect to TiDB with TLS + autocommit; DictCursor for JSON friendliness."""
60
+ global _CONN
61
+ try:
62
+ if _CONN:
63
+ _CONN.close()
64
+ except Exception:
65
+ pass
66
+ _CONN = pymysql.connect(
67
+ host=TIDB_HOST,
68
+ port=TIDB_PORT,
69
+ user=TIDB_USER,
70
+ password=TIDB_PASS,
71
+ database=DB_NAME,
72
+ ssl={"ca": certifi.where()},
73
+ ssl_verify_cert=True,
74
+ ssl_verify_identity=True,
75
+ autocommit=True, # ensure object rows are durable before child rows (FKs)
76
+ charset="utf8mb4",
77
+ cursorclass=pymysql.cursors.DictCursor,
78
+ )
79
+
80
+ def _ensure_conn():
81
+ global _CONN
82
+ if _CONN is None:
83
+ _connect()
84
+ else:
85
+ try:
86
+ _CONN.ping(reconnect=True)
87
+ except Exception:
88
+ _connect()
89
+ return _CONN
90
+
91
+ @contextmanager
92
+ def cursor():
93
+ """DictCursor with auto-ping; use in each route."""
94
+ conn = _ensure_conn()
95
+ with conn.cursor() as cur:
96
+ yield cur
97
+
98
+ # ───────────────────────────────────────────────────────────────────────────────
99
+ # EMBEDDINGS (lazy-load; same model as ingest; pad to 1536)
100
+ # ───────────────────────────────────────────────────────────────────────────────
101
+ _MODEL = None
102
+ _DEVICE_INFO = "cpu"
103
+
104
+ def _pad(vec, dim=VEC_DIM):
105
+ return vec[:dim] + [0.0] * max(0, dim - len(vec))
106
+
107
+ def _load_model():
108
+ global _MODEL, _DEVICE_INFO
109
+ if _MODEL is not None:
110
+ return _MODEL
111
+ if USE_GPU:
112
+ try:
113
+ import torch
114
+ if torch.cuda.is_available():
115
+ _DEVICE_INFO = "cuda"
116
+ except Exception:
117
+ _DEVICE_INFO = "cpu"
118
+ from sentence_transformers import SentenceTransformer
119
+ _MODEL = SentenceTransformer(EMBED_MODEL, device=_DEVICE_INFO)
120
+ log.info(f"Loaded embedding model on '{_DEVICE_INFO}': {EMBED_MODEL}")
121
+ return _MODEL
122
+
123
+ def embed_text_to_vec1536(text: str):
124
+ model = _load_model()
125
+ vec = model.encode(
126
+ [text], batch_size=1, show_progress_bar=False, convert_to_numpy=True
127
+ )[0].tolist()
128
+ return _pad(vec, VEC_DIM)
129
+
130
+ # ───────────────────────────────────────────────────────────────────────────────
131
+ # GEMINI (explanations / descriptions)
132
+ # ───────────────────────────────────────────────────────────────────────────────
133
+ GEMINI_KEY = os.environ.get("Gemini")
134
+ _gclient = None
135
+
136
+ def _gemini():
137
+ global _gclient
138
+ if _gclient is not None:
139
+ return _gclient
140
+ if not GEMINI_KEY:
141
+ return None
142
+ try:
143
+ _gclient = genai.Client(api_key=GEMINI_KEY)
144
+ log.info("Gemini client initialized.")
145
+ return _gclient
146
+ except Exception as e:
147
+ log.warning(f"Gemini init failed: {e}")
148
+ return None
149
+
150
+ EXPLAIN_MODEL = "gemini-2.5-flash"
151
+
152
+ def gemini_explain(prompt: str, sys: str = None, model: str = EXPLAIN_MODEL) -> str:
153
+ g = _gemini()
154
+ if g is None:
155
+ # Graceful fallback so the API still works without a key
156
+ return "(Gemini not configured) " + prompt[:180]
157
+ # chat-style to mirror your original pattern
158
+ chat = g.chats.create(model=model)
159
+ # Add a light system preamble for style/constraints
160
+ if sys:
161
+ chat.send_message(f"[SYSTEM]\n{sys}")
162
+ resp = chat.send_message(prompt)
163
+ return getattr(resp, "text", "").strip() or ""
164
+
165
+ # ───────────────────────────────────────────────────────────────────────────────
166
+ # UTIL: Build graph & timeline from events (+ risk overlays)
167
+ # ───────────────────────────────────────────────────────────────────────────────
168
+ def _policy_hits_for_date(d: str):
169
+ """Return policy codes a given ISO date falls into."""
170
+ if not d:
171
+ return []
172
+ hits = []
173
+ for w in POLICY_WINDOWS:
174
+ start_ok = (d >= w["from"]) if w["from"] else True
175
+ end_ok = (d <= w["to"]) if w["to"] else True
176
+ if start_ok and end_ok:
177
+ hits.append(w["code"])
178
+ return hits
179
+
180
+ def build_graph_from_events(obj_row, events):
181
+ """Cytoscape.js-style graph: nodes+edges."""
182
+ nodes = []
183
+ edges = []
184
+
185
+ # center object node
186
+ onode = {
187
+ "id": f"obj:{obj_row['object_id']}",
188
+ "label": f"{obj_row.get('title') or 'Untitled'} ({obj_row.get('source')})",
189
+ "type": "object"
190
+ }
191
+ nodes_map = {onode["id"]: onode}
192
+
193
+ def add_node(kind, label):
194
+ if not label:
195
+ return None
196
+ nid = f"{kind}:{label}"
197
+ if nid not in nodes_map:
198
+ nodes_map[nid] = {"id": nid, "label": label, "type": kind}
199
+ return nid
200
+
201
+ for ev in events:
202
+ actor = ev.get("actor")
203
+ place = ev.get("place")
204
+ etype = ev.get("event_type") or "UNKNOWN"
205
+ d_iso = (ev.get("date_from") or "")[:10] if ev.get("date_from") else None
206
+
207
+ actor_id = add_node("actor", actor) if actor else None
208
+ place_id = add_node("place", place) if place else None
209
+
210
+ # Edge semantics: actor -> object; place is context (not endpoint)
211
+ if actor_id:
212
+ edges.append({
213
+ "source": actor_id,
214
+ "target": onode["id"],
215
+ "label": etype,
216
+ "date": d_iso,
217
+ "weight": 1.0, # client may recompute with risk overlays
218
+ "source_ref": ev.get("source_ref"),
219
+ "policy": _policy_hits_for_date(d_iso)
220
+ })
221
+
222
+ # Optional: object -> place (to visualize locations)
223
+ if place_id and place:
224
+ edges.append({
225
+ "source": onode["id"],
226
+ "target": place_id,
227
+ "label": "LOCATED",
228
+ "date": d_iso,
229
+ "weight": 0.5,
230
+ "source_ref": ev.get("source_ref"),
231
+ "policy": _policy_hits_for_date(d_iso)
232
+ })
233
+
234
+ return {"nodes": list(nodes_map.values()), "edges": edges}
235
+
236
+ def build_timeline_from_events_and_sentences(events, sentences):
237
+ """Simple list items for any timeline widget."""
238
+ items = []
239
+ s_by_seq = {s["seq"]: s["sentence"] for s in sentences}
240
+ for ev in events:
241
+ start = ev.get("date_from")
242
+ end = ev.get("date_to")
243
+ title = ev.get("event_type") or "Event"
244
+ txt = None
245
+ # Try to pull the nearest sentence by seq if present
246
+ # (ingest stored seq starting at 0)
247
+ for k in (0, 1, 2, 3):
248
+ if k in s_by_seq:
249
+ txt = s_by_seq[k]; break
250
+ items.append({
251
+ "title": title,
252
+ "start_date": start,
253
+ "end_date": end,
254
+ "text": txt or "",
255
+ "source_ref": ev.get("source_ref")
256
+ })
257
+ return items
258
+
259
+ # ───────────────────────────────────────────────────────────────────────────────
260
+ # ROUTES
261
+ # ───────────────────────────────────────────────────────────────────────────────
262
+
263
+ @app.get("/")
264
+ def root():
265
+ return jsonify({"ok": True, "service": "provenance-radar-api", "device": _DEVICE_INFO})
266
+
267
+ @app.get("/api/health")
268
+ def health():
269
+ try:
270
+ with cursor() as cur:
271
+ cur.execute("SELECT COUNT(*) AS c FROM objects"); objects = cur.fetchone()["c"]
272
+ cur.execute("SELECT COUNT(*) AS c FROM provenance_sentences"); sentences = cur.fetchone()["c"]
273
+ cur.execute("SELECT COUNT(*) AS c FROM risk_signals"); risks = cur.fetchone()["c"]
274
+ return jsonify({"ok": True, "device": _DEVICE_INFO, "counts": {
275
+ "objects": objects, "sentences": sentences, "risk_signals": risks}})
276
+ except Exception as e:
277
+ log.exception("health failed")
278
+ return jsonify({"ok": False, "error": str(e)}), 500
279
+
280
+ @app.get("/api/policy/windows")
281
+ def policy_windows():
282
+ return jsonify({"ok": True, "windows": POLICY_WINDOWS})
283
+
284
+ @app.get("/api/leads")
285
+ def leads():
286
+ limit = max(1, min(int(request.args.get("limit", 50)), 200))
287
+ min_score = float(request.args.get("min_score", 0))
288
+ source = request.args.get("source")
289
+ sql = (
290
+ "SELECT object_id, source, title, creator, risk_score, top_signals "
291
+ "FROM flagged_leads WHERE risk_score >= %s "
292
+ )
293
+ args = [min_score]
294
+ if source:
295
+ sql += " AND source = %s "
296
+ args.append(source)
297
+ sql += " LIMIT %s"
298
+ args.append(limit)
299
+ with cursor() as cur:
300
+ cur.execute(sql, args)
301
+ rows = cur.fetchall()
302
+ return jsonify({"ok": True, "data": rows})
303
+
304
+ @app.get("/api/object/<int:object_id>")
305
+ def object_detail(object_id: int):
306
+ with cursor() as cur:
307
+ cur.execute("SELECT * FROM objects WHERE object_id=%s", (object_id,))
308
+ obj = cur.fetchone()
309
+ if not obj:
310
+ return jsonify({"ok": False, "error": "not_found"}), 404
311
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
312
+ sents = cur.fetchall()
313
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, method, source_ref
314
+ FROM provenance_events WHERE object_id=%s
315
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
316
+ events = cur.fetchall()
317
+ cur.execute("SELECT code, detail, weight FROM risk_signals WHERE object_id=%s ORDER BY weight DESC", (object_id,))
318
+ risks = cur.fetchall()
319
+ return jsonify({"ok": True, "object": obj, "sentences": sents, "events": events, "risks": risks})
320
+
321
+ @app.get("/api/graph/<int:object_id>")
322
+ def graph(object_id: int):
323
+ with cursor() as cur:
324
+ cur.execute("SELECT object_id, source, title FROM objects WHERE object_id=%s", (object_id,))
325
+ obj = cur.fetchone()
326
+ if not obj:
327
+ return jsonify({"ok": False, "error": "not_found"}), 404
328
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref
329
+ FROM provenance_events WHERE object_id=%s
330
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
331
+ events = cur.fetchall()
332
+ return jsonify({"ok": True, **build_graph_from_events(obj, events)})
333
+
334
+ @app.get("/api/timeline/<int:object_id>")
335
+ def timeline(object_id: int):
336
+ with cursor() as cur:
337
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
338
+ sents = cur.fetchall()
339
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref
340
+ FROM provenance_events WHERE object_id=%s
341
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
342
+ events = cur.fetchall()
343
+ items = build_timeline_from_events_and_sentences(events, sents)
344
+ return jsonify({"ok": True, "items": items})
345
+
346
+ @app.get("/api/keyword")
347
+ def keyword_search():
348
+ q = (request.args.get("q") or "").strip()
349
+ limit = max(1, min(int(request.args.get("limit", 50)), 200))
350
+ if not q:
351
+ return jsonify({"ok": False, "error": "q required"}), 400
352
+ like = "%" + q.replace("%","").replace("_","") + "%"
353
+ with cursor() as cur:
354
+ cur.execute(
355
+ """SELECT ps.object_id, ps.seq, ps.sentence, o.source, o.title, o.creator
356
+ FROM provenance_sentences ps
357
+ JOIN objects o ON o.object_id = ps.object_id
358
+ WHERE ps.sentence LIKE %s
359
+ LIMIT %s""", (like, limit)
360
+ )
361
+ rows = cur.fetchall()
362
+ return jsonify({"ok": True, "query": q, "data": rows})
363
+
364
+ @app.post("/api/similar")
365
+ def similar_search():
366
+ payload = request.get_json(force=True) or {}
367
+ text = (payload.get("text") or "").strip()
368
+ limit = max(1, min(int(payload.get("limit", 20)), 100))
369
+ if not text:
370
+ return jsonify({"ok": False, "error": "text required"}), 400
371
+ vec = embed_text_to_vec1536(text)
372
+ vec_json = json.dumps(vec)
373
+ sql = (
374
+ "SELECT ps.object_id, ps.seq, ps.sentence, o.source, o.title, o.creator, "
375
+ f"VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance "
376
+ "FROM provenance_sentences ps "
377
+ "JOIN objects o ON o.object_id = ps.object_id "
378
+ "ORDER BY distance ASC "
379
+ "LIMIT %s"
380
+ )
381
+ with cursor() as cur:
382
+ cur.execute(sql, (vec_json, limit))
383
+ rows = cur.fetchall()
384
+ return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows})
385
+
386
+ @app.get("/api/vocab")
387
+ def vocab():
388
+ field = (request.args.get("field") or "").strip().lower()
389
+ limit = max(1, min(int(request.args.get("limit", 100)), 500))
390
+ if field not in {"actor", "place", "source", "culture"}:
391
+ return jsonify({"ok": False, "error": "field must be one of actor|place|source|culture"}), 400
392
+ if field in {"actor", "place"}:
393
+ sql = f"SELECT {field} AS v, COUNT(*) AS n FROM provenance_events WHERE {field} IS NOT NULL AND {field}<>'' GROUP BY {field} ORDER BY n DESC LIMIT %s"
394
+ elif field == "source":
395
+ sql = "SELECT source AS v, COUNT(*) AS n FROM objects GROUP BY source ORDER BY n DESC LIMIT %s"
396
+ else: # culture
397
+ sql = "SELECT culture AS v, COUNT(*) AS n FROM objects WHERE culture IS NOT NULL AND culture<>'' GROUP BY culture ORDER BY n DESC LIMIT %s"
398
+ with cursor() as cur:
399
+ cur.execute(sql, (limit,))
400
+ rows = cur.fetchall()
401
+ return jsonify({"ok": True, "field": field, "data": rows})
402
+
403
+ # ── Gemini-powered explanations ────────────────────────────────────────────────
404
+
405
+ @app.get("/api/explain/object/<int:object_id>")
406
+ def explain_object(object_id: int):
407
+ """Generate a concise, policy-aware research note for an object."""
408
+ with cursor() as cur:
409
+ cur.execute("SELECT object_id, source, title, creator, date_display, risk_score FROM objects WHERE object_id=%s", (object_id,))
410
+ obj = cur.fetchone()
411
+ if not obj:
412
+ return jsonify({"ok": False, "error": "not_found"}), 404
413
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
414
+ sents = cur.fetchall()
415
+ cur.execute("SELECT event_type, date_from, date_to, place, actor, source_ref FROM provenance_events WHERE object_id=%s ORDER BY COALESCE(date_from,'0001-01-01')", (object_id,))
416
+ events = cur.fetchall()
417
+
418
+ # Build a compact prompt (few sentences) to keep latency low
419
+ bullets = []
420
+ for s in sents[:8]: # keep prompt small
421
+ bullets.append(f"- {s['sentence']}")
422
+ evsumm = []
423
+ for e in events[:8]:
424
+ evsumm.append(f"{e.get('event_type')} @ {e.get('place') or 'β€”'} on {e.get('date_from') or 'β€”'} (actor: {e.get('actor') or 'β€”'})")
425
+
426
+ sys = ("You are assisting provenance researchers. Write a neutral, concise brief (120–180 words) that:\n"
427
+ "1) summarizes the chain of custody in plain language; 2) clearly marks any timeline gaps; "
428
+ "3) calls out potential red flags (e.g., confiscated/looted, sales during 1933–45, exports post-1970) "
429
+ "without making legal conclusions; 4) ends with a short 'Next leads' list (max 3).")
430
+ prompt = (
431
+ f"Object: {obj.get('title') or 'Untitled'} β€” {obj.get('creator') or ''} (source {obj['source']}). "
432
+ f"Display date: {obj.get('date_display') or 'n/a'}. Current risk_score={obj.get('risk_score', 0)}.\n\n"
433
+ f"Provenance sentences:\n" + "\n".join(bullets) + "\n\n"
434
+ f"Structured events (first 8):\n- " + "\n- ".join(evsumm) + "\n\n"
435
+ f"Policy windows to consider: Nazi era 1933–1945; UNESCO 1970 onwards."
436
+ )
437
+ text = gemini_explain(prompt, sys=sys)
438
+ return jsonify({"ok": True, "model": EXPLAIN_MODEL, "note": text})
439
+
440
+ @app.post("/api/explain/text")
441
+ def explain_text():
442
+ """Explain a specific provenance sentence or user query with policy context."""
443
+ payload = request.get_json(force=True) or {}
444
+ sentence = (payload.get("text") or "").strip()
445
+ if not sentence:
446
+ return jsonify({"ok": False, "error": "text required"}), 400
447
+ sys = ("Explain this text as a provenance note for curators. "
448
+ "Be precise and cautious; highlight possible red flags tied to 1933–1945 and post-1970 export rules.")
449
+ prompt = f"Explain and contextualize this provenance fragment:\n\nβ€œ{sentence}”."
450
+ text = gemini_explain(prompt, sys=sys)
451
+ return jsonify({"ok": True, "model": EXPLAIN_MODEL, "explanation": text})
452
+
453
+ # ───────────────────────────────────────────────────────────────────────────────
454
+ # MAIN (Spaces expects 7860)
455
+ # ───────────────────────────────────────────────────────────────────────────────
456
+ if __name__ == "__main__":
457
+ port = int(os.environ.get("PORT", "7860"))
458
+ app.run(host="0.0.0.0", port=port, debug=False)