Update app.py
Browse files
app.py
CHANGED
|
@@ -1,28 +1,16 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
JusticeAI Backend — app.py
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
- /
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
- persists the user message and the reply into engine_user.user_memory and prunes to the last 10 messages per user
|
| 15 |
-
- blocks storing toxic messages using the moderator pipeline (if available)
|
| 16 |
-
- All endpoints included: /chat, /response, /add, /add-bulk, /leaderboard, /reembed, /model-status,
|
| 17 |
-
/health, /metrics_stream, /metrics_recent, /verify-admin, /cleardatabase, / (frontend).
|
| 18 |
-
- Ollama integration: uses HTTP (if ollama serve) or CLI (ollama run) to infer topic semantically if possible.
|
| 19 |
-
- Optional models: SentenceTransformer for embeddings and transformers (Helsinki) for translation; code runs without them using fallbacks.
|
| 20 |
-
|
| 21 |
-
Deployment notes:
|
| 22 |
-
- Set DATABASE_URL and KNOWLEDGEDATABASE_URL environment variables.
|
| 23 |
-
- Optionally install dependencies for better features:
|
| 24 |
-
pip install sentence-transformers transformers torch langdetect emoji hf-cli
|
| 25 |
-
- To enable Ollama model auto-pull at startup set OLLAMA_AUTO_PULL=1 and ensure ollama CLI exists.
|
| 26 |
"""
|
| 27 |
|
| 28 |
from sqlalchemy.pool import NullPool
|
|
@@ -36,31 +24,45 @@ import subprocess
|
|
| 36 |
import shutil
|
| 37 |
import logging
|
| 38 |
import random
|
|
|
|
|
|
|
|
|
|
| 39 |
from datetime import datetime, timezone
|
| 40 |
from collections import deque
|
| 41 |
from typing import Optional, Dict, Any, List
|
| 42 |
|
| 43 |
-
from fastapi import FastAPI, Request, Body, Query, Header
|
| 44 |
-
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
|
| 45 |
from sqlalchemy import create_engine, text as sql_text
|
| 46 |
|
| 47 |
# external helpers
|
| 48 |
import requests
|
| 49 |
|
| 50 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
try:
|
| 52 |
from sentence_transformers import SentenceTransformer
|
| 53 |
except Exception:
|
| 54 |
SentenceTransformer = None
|
| 55 |
|
| 56 |
try:
|
| 57 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,
|
| 58 |
except Exception:
|
| 59 |
AutoTokenizer = None
|
| 60 |
AutoModelForSeq2SeqLM = None
|
| 61 |
-
AutoModelForCausalLM = None
|
| 62 |
hf_pipeline = None
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Optional local modules
|
| 65 |
try:
|
| 66 |
import language as language_module # type: ignore
|
|
@@ -83,7 +85,7 @@ try:
|
|
| 83 |
except Exception:
|
| 84 |
detect_lang = None
|
| 85 |
|
| 86 |
-
# Moderator pipeline (
|
| 87 |
moderator = None
|
| 88 |
try:
|
| 89 |
if hf_pipeline is not None:
|
|
@@ -91,7 +93,7 @@ try:
|
|
| 91 |
except Exception:
|
| 92 |
moderator = None
|
| 93 |
|
| 94 |
-
# Config
|
| 95 |
ADMIN_KEY = os.environ.get("ADMIN_KEY")
|
| 96 |
DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///justice_user.db")
|
| 97 |
KNOWLEDGEDATABASE_URL = os.environ.get("KNOWLEDGEDATABASE_URL", DATABASE_URL)
|
|
@@ -104,15 +106,23 @@ OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "llama3")
|
|
| 104 |
OLLAMA_HTTP_URL = os.environ.get("OLLAMA_HTTP_URL", "http://localhost:11434")
|
| 105 |
OLLAMA_AUTO_PULL = os.environ.get("OLLAMA_AUTO_PULL", "0") in ("1", "true", "yes")
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Logging
|
| 108 |
logging.basicConfig(level=logging.INFO)
|
| 109 |
logger = logging.getLogger("justicebrain")
|
| 110 |
|
| 111 |
-
#
|
| 112 |
last_heartbeat = {"time": datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(), "ok": True}
|
| 113 |
app_start_time = time.time()
|
| 114 |
|
| 115 |
-
# Engines (
|
| 116 |
engine_user = create_engine(
|
| 117 |
DATABASE_URL,
|
| 118 |
poolclass=NullPool,
|
|
@@ -128,7 +138,7 @@ app = FastAPI(title="Justice Brain — Backend")
|
|
| 128 |
|
| 129 |
# --- Database schema setup ---
|
| 130 |
def ensure_tables():
|
| 131 |
-
# knowledge table
|
| 132 |
dialect_k = engine_knowledge.dialect.name
|
| 133 |
with engine_knowledge.begin() as conn:
|
| 134 |
if dialect_k == "sqlite":
|
|
@@ -165,7 +175,7 @@ def ensure_tables():
|
|
| 165 |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 166 |
);
|
| 167 |
"""))
|
| 168 |
-
# user memory table
|
| 169 |
dialect_u = engine_user.dialect.name
|
| 170 |
with engine_user.begin() as conn:
|
| 171 |
if dialect_u == "sqlite":
|
|
@@ -209,7 +219,6 @@ def ensure_tables():
|
|
| 209 |
|
| 210 |
ensure_tables()
|
| 211 |
|
| 212 |
-
# add columns if missing (best-effort; uses engine_user but applied generically)
|
| 213 |
def ensure_column_exists(table: str, column: str, col_def_sql: str):
|
| 214 |
dialect = engine_user.dialect.name
|
| 215 |
try:
|
|
@@ -292,7 +301,7 @@ def emoji_sentiment_score(emojis: List[str]) -> float:
|
|
| 292 |
score += 0.1
|
| 293 |
return max(-1.0, min(1.0, score / max(1, len(emojis))))
|
| 294 |
|
| 295 |
-
# Language detection & translation
|
| 296 |
_translation_model_cache: Dict[str, Any] = {}
|
| 297 |
|
| 298 |
def detect_language_safe(text: str) -> str:
|
|
@@ -352,7 +361,6 @@ def translate_text(text: str, src: str, tgt: str) -> str:
|
|
| 352 |
return out
|
| 353 |
except Exception:
|
| 354 |
pass
|
| 355 |
-
# Helsinki fallback
|
| 356 |
src_code = (src or "und").split("-")[0].lower()
|
| 357 |
tgt_code = (tgt or "und").split("-")[0].lower()
|
| 358 |
if not re.fullmatch(r"[a-z]{2,3}", src_code) or not re.fullmatch(r"[a-z]{2,3}", tgt_code):
|
|
@@ -391,7 +399,7 @@ def translate_from_english(text: str, tgt_lang: str) -> str:
|
|
| 391 |
return text
|
| 392 |
return translate_text(text, "en", tgt)
|
| 393 |
|
| 394 |
-
#
|
| 395 |
embed_model = None
|
| 396 |
def try_load_embed():
|
| 397 |
global embed_model
|
|
@@ -414,71 +422,13 @@ def embed_to_bytes(text: str) -> Optional[bytes]:
|
|
| 414 |
except Exception:
|
| 415 |
return None
|
| 416 |
|
| 417 |
-
#
|
| 418 |
-
def
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
return True
|
| 423 |
-
return any(g in s_low for g in generic)
|
| 424 |
-
|
| 425 |
-
def generate_creative_reply(candidates: List[str]) -> str:
|
| 426 |
-
all_sent = []
|
| 427 |
-
seen = set()
|
| 428 |
-
for c in candidates:
|
| 429 |
-
for s in re.split(r'(?<=[.?!])\s+', c):
|
| 430 |
-
st = s.strip()
|
| 431 |
-
if not st or st in seen or is_boilerplate_candidate(st):
|
| 432 |
-
continue
|
| 433 |
-
seen.add(st)
|
| 434 |
-
all_sent.append(st)
|
| 435 |
-
if not all_sent:
|
| 436 |
-
return "I don't have enough context yet — can you give more details?"
|
| 437 |
-
return "\n".join(all_sent[:5])
|
| 438 |
-
|
| 439 |
-
# Duplicate detection within topic
|
| 440 |
-
def knowledge_text_exists_in_topic(text: str, topic: str, threshold: float = 0.92) -> bool:
|
| 441 |
-
t = (text or "").strip()
|
| 442 |
-
if not t:
|
| 443 |
-
return False
|
| 444 |
-
try:
|
| 445 |
-
with engine_knowledge.begin() as conn:
|
| 446 |
-
rows = conn.execute(sql_text("SELECT id, text FROM knowledge WHERE topic = :topic LIMIT 200"), {"topic": topic}).fetchall()
|
| 447 |
-
for r in rows:
|
| 448 |
-
existing = (r[1] or "").strip()
|
| 449 |
-
if existing.lower() == t.lower():
|
| 450 |
-
return True
|
| 451 |
-
if embed_model is not None and rows:
|
| 452 |
-
texts = [r[1] or "" for r in rows]
|
| 453 |
-
embs = embed_model.encode(texts, convert_to_tensor=True)
|
| 454 |
-
q_emb = embed_model.encode([t], convert_to_tensor=True)[0]
|
| 455 |
-
import torch
|
| 456 |
-
sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), embs)
|
| 457 |
-
if float(torch.max(sims).item()) >= threshold:
|
| 458 |
-
return True
|
| 459 |
-
except Exception:
|
| 460 |
-
pass
|
| 461 |
-
return False
|
| 462 |
-
|
| 463 |
-
# Topic inference fallback (embeddings/keywords)
|
| 464 |
-
def infer_topic_from_message(msg: str, known_topics: List[str]) -> str:
|
| 465 |
-
msg_low = (msg or "").lower()
|
| 466 |
-
for topic in known_topics or []:
|
| 467 |
-
if topic and topic.lower() in msg_low:
|
| 468 |
-
return topic
|
| 469 |
-
if embed_model is not None and known_topics:
|
| 470 |
-
try:
|
| 471 |
-
import torch
|
| 472 |
-
topic_embs = embed_model.encode(known_topics, convert_to_tensor=True)
|
| 473 |
-
q_emb = embed_model.encode([msg], convert_to_tensor=True)[0]
|
| 474 |
-
sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), topic_embs)
|
| 475 |
-
best_idx = int(torch.argmax(sims).item())
|
| 476 |
-
return known_topics[best_idx]
|
| 477 |
-
except Exception:
|
| 478 |
-
pass
|
| 479 |
-
return "general"
|
| 480 |
|
| 481 |
-
# Ollama helpers
|
| 482 |
def ollama_cli_available() -> bool:
|
| 483 |
return shutil.which("ollama") is not None
|
| 484 |
|
|
@@ -489,19 +439,18 @@ def ollama_http_available() -> bool:
|
|
| 489 |
except Exception:
|
| 490 |
return False
|
| 491 |
|
| 492 |
-
def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int =
|
| 493 |
try:
|
| 494 |
url = f"{OLLAMA_HTTP_URL}/api/generate"
|
| 495 |
payload = {"model": model, "prompt": prompt, "max_tokens": 256}
|
| 496 |
headers = {"Content-Type": "application/json"}
|
| 497 |
-
r = requests.post(url, json=payload, headers=headers, timeout=timeout_s)
|
| 498 |
if r.status_code == 200:
|
| 499 |
try:
|
| 500 |
obj = r.json()
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
if key
|
| 504 |
-
return obj[key] if isinstance(obj[key], str) else json.dumps(obj[key])
|
| 505 |
return r.text
|
| 506 |
except Exception:
|
| 507 |
return r.text
|
|
@@ -512,11 +461,11 @@ def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 10
|
|
| 512 |
logger.debug(f"ollama HTTP call failed: {e}")
|
| 513 |
return None
|
| 514 |
|
| 515 |
-
def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int =
|
| 516 |
if not ollama_cli_available():
|
| 517 |
return None
|
| 518 |
try:
|
| 519 |
-
proc = subprocess.run(["ollama", "run", model, "--prompt", prompt], capture_output=True, text=True, timeout=timeout_s)
|
| 520 |
if proc.returncode == 0:
|
| 521 |
return proc.stdout.strip() or proc.stderr.strip()
|
| 522 |
else:
|
|
@@ -526,7 +475,7 @@ def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = 15)
|
|
| 526 |
logger.debug(f"ollama CLI call exception: {e}")
|
| 527 |
return None
|
| 528 |
|
| 529 |
-
def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MODEL, timeout_s: int =
|
| 530 |
if not msg or not topics:
|
| 531 |
return None
|
| 532 |
topics_escaped = [t.replace('"','\\"') for t in topics]
|
|
@@ -580,12 +529,188 @@ def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MOD
|
|
| 580 |
pass
|
| 581 |
return None
|
| 582 |
|
| 583 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
recent_request_times = deque()
|
| 585 |
recent_learning_timestamps = deque()
|
| 586 |
response_time_ema: Optional[float] = None
|
| 587 |
EMA_ALPHA = 0.2
|
| 588 |
-
knowledge_version = 0
|
| 589 |
|
| 590 |
def record_request(duration_s: float):
|
| 591 |
global response_time_ema
|
|
@@ -604,14 +729,20 @@ def record_learn_event():
|
|
| 604 |
while recent_learning_timestamps and recent_learning_timestamps[0] < ts - 3600:
|
| 605 |
recent_learning_timestamps.popleft()
|
| 606 |
|
| 607 |
-
# Startup
|
| 608 |
@app.on_event("startup")
|
| 609 |
async def startup_event():
|
| 610 |
-
logger.info("[JusticeAI] startup
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
if OLLAMA_AUTO_PULL and ollama_cli_available():
|
| 616 |
try:
|
| 617 |
subprocess.run(["ollama", "pull", OLLAMA_MODEL], timeout=300)
|
|
@@ -620,8 +751,7 @@ async def startup_event():
|
|
| 620 |
logger.debug(f"[startup] ollama pull failed: {e}")
|
| 621 |
logger.info("[JusticeAI] startup complete")
|
| 622 |
|
| 623 |
-
# ---
|
| 624 |
-
|
| 625 |
@app.post("/add")
|
| 626 |
async def add_knowledge(data: dict = Body(...)):
|
| 627 |
if not isinstance(data, dict):
|
|
@@ -642,7 +772,10 @@ async def add_knowledge(data: dict = Body(...)):
|
|
| 642 |
return JSONResponse(status_code=400, content={"error": "translation failed"})
|
| 643 |
emb_bytes = None
|
| 644 |
if embed_model is not None:
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
| 646 |
try:
|
| 647 |
with engine_knowledge.begin() as conn:
|
| 648 |
if emb_bytes:
|
|
@@ -655,13 +788,8 @@ async def add_knowledge(data: dict = Body(...)):
|
|
| 655 |
"INSERT INTO knowledge (text, reply, language, category, topic, confidence, meta) "
|
| 656 |
"VALUES (:t, :r, :lang, 'manual', :topic, :conf, :meta)"
|
| 657 |
), {"t": text_data, "r": reply, "lang": detected, "topic": topic, "conf": 0.9, "meta": json.dumps({"manual": True})})
|
| 658 |
-
global knowledge_version
|
| 659 |
-
knowledge_version += 1
|
| 660 |
record_learn_event()
|
| 661 |
-
|
| 662 |
-
if not emb_bytes:
|
| 663 |
-
res["note"] = "stored without embedding"
|
| 664 |
-
return res
|
| 665 |
except Exception as e:
|
| 666 |
logger.exception("add failed")
|
| 667 |
return JSONResponse(status_code=500, content={"error": "failed to store knowledge", "details": str(e)})
|
|
@@ -684,32 +812,132 @@ async def add_bulk(data: List[dict] = Body(...)):
|
|
| 684 |
detected = detect_language_safe(text_data) or "und"
|
| 685 |
if detected not in ("en", "eng", "und"):
|
| 686 |
errors.append({"index": i, "error": "non-english; skip"}); continue
|
| 687 |
-
emb_bytes =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
with engine_knowledge.begin() as conn:
|
| 689 |
if emb_bytes:
|
| 690 |
conn.execute(sql_text(
|
| 691 |
-
"INSERT INTO knowledge (text, reply, language, embedding, category, topic) "
|
| 692 |
-
"VALUES (:t, :r, :lang, :e, 'manual', :topic)"
|
| 693 |
), {"t": text_data, "r": reply, "lang": "en", "e": emb_bytes, "topic": topic})
|
| 694 |
else:
|
| 695 |
conn.execute(sql_text(
|
| 696 |
-
"INSERT INTO knowledge (text, reply, language, category, topic) "
|
| 697 |
-
"VALUES (:t, :r, :lang, 'manual', :topic)"
|
| 698 |
), {"t": text_data, "r": reply, "lang": "en", "topic": topic})
|
| 699 |
added += 1
|
| 700 |
except Exception as e:
|
| 701 |
logger.exception("add-bulk item error")
|
| 702 |
errors.append({"index": i, "error": str(e)})
|
| 703 |
if added:
|
| 704 |
-
global knowledge_version
|
| 705 |
-
knowledge_version += 1
|
| 706 |
record_learn_event()
|
| 707 |
return {"added": added, "errors": errors}
|
| 708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
@app.post("/chat")
|
| 710 |
async def chat(request: Request, data: dict = Body(...)):
|
| 711 |
t0 = time.time()
|
| 712 |
-
# Accept
|
| 713 |
if isinstance(data, dict):
|
| 714 |
raw_msg = str(data.get("message", "") or data.get("text", "") or "").strip()
|
| 715 |
else:
|
|
@@ -727,7 +955,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 727 |
detected_lang = detect_language_safe(raw_msg)
|
| 728 |
reply_lang = detected_lang if detected_lang and detected_lang != "und" else "en"
|
| 729 |
|
| 730 |
-
# Translate incoming to English for retrieval
|
| 731 |
en_msg = raw_msg
|
| 732 |
if detected_lang not in ("en", "eng", "", "und"):
|
| 733 |
try:
|
|
@@ -735,7 +963,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 735 |
except Exception:
|
| 736 |
en_msg = raw_msg
|
| 737 |
|
| 738 |
-
#
|
| 739 |
topic = "general"
|
| 740 |
try:
|
| 741 |
if not topic_hint:
|
|
@@ -758,7 +986,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 758 |
except Exception:
|
| 759 |
topic = topic_hint or "general"
|
| 760 |
|
| 761 |
-
# Moderation
|
| 762 |
flags = {}
|
| 763 |
try:
|
| 764 |
if moderator is not None:
|
|
@@ -771,43 +999,45 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 771 |
except Exception:
|
| 772 |
pass
|
| 773 |
|
| 774 |
-
#
|
| 775 |
-
# We'll store into engine_user.user_memory only (personal).
|
| 776 |
-
|
| 777 |
-
# Load knowledge entries for this topic only
|
| 778 |
try:
|
| 779 |
with engine_knowledge.begin() as conn:
|
| 780 |
-
rows = conn.execute(sql_text(
|
| 781 |
-
"SELECT id, text, reply, language, embedding FROM knowledge WHERE topic = :topic ORDER BY created_at DESC"
|
| 782 |
-
), {"topic": topic}).fetchall()
|
| 783 |
except Exception as e:
|
| 784 |
record_request(time.time() - t0)
|
| 785 |
return JSONResponse(status_code=500, content={"error": "failed to read knowledge", "details": str(e)})
|
| 786 |
|
| 787 |
knowledge_rows = [{"id": r[0], "text": r[1] or "", "reply": r[2] or "", "lang": r[3] or "und", "embedding": r[4]} for r in rows]
|
| 788 |
|
| 789 |
-
# Retrieval (embedding-first
|
| 790 |
matches: List[str] = []
|
| 791 |
confidence = 0.0
|
| 792 |
try:
|
| 793 |
if embed_model is not None and knowledge_rows:
|
| 794 |
texts = [kr["text"] for kr in knowledge_rows]
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
else:
|
| 812 |
cand = []
|
| 813 |
for kr in knowledge_rows:
|
|
@@ -822,7 +1052,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 822 |
logger.warning(f"[retrieval] error: {e}")
|
| 823 |
matches = []
|
| 824 |
|
| 825 |
-
# Compose reply from topic
|
| 826 |
if matches and confidence >= 0.6:
|
| 827 |
reply_en = matches[0]
|
| 828 |
elif matches:
|
|
@@ -835,7 +1065,6 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 835 |
except Exception:
|
| 836 |
pass
|
| 837 |
reply_final = base
|
| 838 |
-
# Persist user memory (even when no confident match), skipping toxic
|
| 839 |
try:
|
| 840 |
if not flags.get('toxic', False):
|
| 841 |
with engine_user.begin() as conn:
|
|
@@ -844,20 +1073,18 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 844 |
"VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
|
| 845 |
), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
|
| 846 |
"mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
|
| 847 |
-
# prune to last 10 per user
|
| 848 |
conn.execute(sql_text(
|
| 849 |
-
"DELETE FROM user_memory WHERE id NOT IN ("
|
| 850 |
-
"SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
|
| 851 |
), {"uid": user_id})
|
| 852 |
except Exception as e:
|
| 853 |
logger.debug(f"user_memory store error: {e}")
|
| 854 |
record_request(time.time() - t0)
|
| 855 |
-
return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": "", "confidence": round(confidence,
|
| 856 |
|
| 857 |
-
# Postprocess
|
| 858 |
reply_en = dedupe_sentences(reply_en)
|
| 859 |
|
| 860 |
-
#
|
| 861 |
reply_final = reply_en
|
| 862 |
lang_code = (reply_lang or "und").split("-")[0].lower()
|
| 863 |
if lang_code not in ("en", "eng", "und", ""):
|
|
@@ -868,7 +1095,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 868 |
logger.warning(f"[translation] failed to translate reply_en -> {lang_code}: {exc}")
|
| 869 |
reply_final = reply_en
|
| 870 |
|
| 871 |
-
# Mood & emoji
|
| 872 |
emoji = ""
|
| 873 |
try:
|
| 874 |
mood = detect_mood(raw_msg + " " + reply_final)
|
|
@@ -883,7 +1110,7 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 883 |
except Exception:
|
| 884 |
emoji = ""
|
| 885 |
|
| 886 |
-
# Persist user memory
|
| 887 |
try:
|
| 888 |
if not flags.get('toxic', False):
|
| 889 |
with engine_user.begin() as conn:
|
|
@@ -892,10 +1119,8 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 892 |
"VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
|
| 893 |
), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
|
| 894 |
"mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
|
| 895 |
-
# prune to last 10 per user
|
| 896 |
conn.execute(sql_text(
|
| 897 |
-
"DELETE FROM user_memory WHERE id NOT IN ("
|
| 898 |
-
"SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
|
| 899 |
), {"uid": user_id})
|
| 900 |
except Exception as e:
|
| 901 |
logger.debug(f"user_memory persist error: {e}")
|
|
@@ -906,126 +1131,12 @@ async def chat(request: Request, data: dict = Body(...)):
|
|
| 906 |
if include_steps:
|
| 907 |
reply_final = f"{reply_final}\n\n[Debug: topic={topic} confidence={round(confidence,2)}]"
|
| 908 |
|
| 909 |
-
return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": emoji, "confidence": round(confidence,
|
| 910 |
|
| 911 |
@app.post("/response")
|
| 912 |
async def response_wrapper(request: Request, data: dict = Body(...)):
|
| 913 |
return await chat(request, data)
|
| 914 |
|
| 915 |
-
@app.get("/leaderboard")
|
| 916 |
-
async def leaderboard(topic: str = Query("general")):
|
| 917 |
-
t = str(topic or "general").strip() or "general"
|
| 918 |
-
try:
|
| 919 |
-
with engine_knowledge.begin() as conn:
|
| 920 |
-
rows = conn.execute(sql_text("""
|
| 921 |
-
SELECT id, text, reply, language, category, confidence, created_at
|
| 922 |
-
FROM knowledge
|
| 923 |
-
WHERE topic = :topic
|
| 924 |
-
ORDER BY confidence DESC, created_at DESC
|
| 925 |
-
LIMIT 20
|
| 926 |
-
"""), {"topic": t}).fetchall()
|
| 927 |
-
out = []
|
| 928 |
-
for r in rows:
|
| 929 |
-
text_en = r[1] or ""
|
| 930 |
-
lang = r[3] or "und"
|
| 931 |
-
display_text = text_en
|
| 932 |
-
if lang and lang not in ("en", "eng", "", "und"):
|
| 933 |
-
try:
|
| 934 |
-
display_text = translate_to_english(text_en, lang)
|
| 935 |
-
except Exception:
|
| 936 |
-
display_text = text_en
|
| 937 |
-
created_at = r[6]
|
| 938 |
-
out.append({
|
| 939 |
-
"id": r[0],
|
| 940 |
-
"text": display_text,
|
| 941 |
-
"reply": r[2],
|
| 942 |
-
"language": lang,
|
| 943 |
-
"category": r[4],
|
| 944 |
-
"confidence": round(r[5] or 0.0, 2),
|
| 945 |
-
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else str(created_at)
|
| 946 |
-
})
|
| 947 |
-
return {"topic": t, "top_20": out}
|
| 948 |
-
except Exception as e:
|
| 949 |
-
logger.exception("leaderboard failed")
|
| 950 |
-
return JSONResponse(status_code=500, content={"error": "failed to fetch leaderboard", "details": str(e)})
|
| 951 |
-
|
| 952 |
-
@app.post("/reembed")
|
| 953 |
-
async def reembed_all(data: dict = Body(...), x_admin_key: str = Header(None, alias="X-Admin-Key")):
|
| 954 |
-
if ADMIN_KEY is None:
|
| 955 |
-
return JSONResponse(status_code=403, content={"error": "Server not configured for admin operations."})
|
| 956 |
-
if x_admin_key != ADMIN_KEY:
|
| 957 |
-
return JSONResponse(status_code=403, content={"error": "Invalid admin key."})
|
| 958 |
-
if embed_model is None:
|
| 959 |
-
return JSONResponse(status_code=503, content={"error": "Embedding model not ready."})
|
| 960 |
-
confirm = str(data.get("confirm", "") or "").strip()
|
| 961 |
-
if confirm != "REEMBED":
|
| 962 |
-
return JSONResponse(status_code=400, content={"error": "confirm token required."})
|
| 963 |
-
batch_size = int(data.get("batch_size", 100))
|
| 964 |
-
try:
|
| 965 |
-
with engine_knowledge.begin() as conn:
|
| 966 |
-
rows = conn.execute(sql_text("SELECT id, text FROM knowledge ORDER BY id")).fetchall()
|
| 967 |
-
ids_texts = [(r[0], r[1]) for r in rows]
|
| 968 |
-
total = len(ids_texts)
|
| 969 |
-
updated = 0
|
| 970 |
-
for i in range(0, total, batch_size):
|
| 971 |
-
batch = ids_texts[i:i+batch_size]
|
| 972 |
-
texts = [t for _, t in batch]
|
| 973 |
-
embs = embed_model.encode(texts, convert_to_tensor=True)
|
| 974 |
-
for j, (kid, _) in enumerate(batch):
|
| 975 |
-
emb_bytes = embs[j].cpu().numpy().tobytes()
|
| 976 |
-
with engine_knowledge.begin() as conn:
|
| 977 |
-
conn.execute(sql_text("UPDATE knowledge SET embedding = :e, updated_at = CURRENT_TIMESTAMP WHERE id = :id"), {"e": emb_bytes, "id": kid})
|
| 978 |
-
updated += 1
|
| 979 |
-
return {"status": "✅ Re-embed complete", "total_rows": total, "updated": updated}
|
| 980 |
-
except Exception as e:
|
| 981 |
-
logger.exception("reembed failed")
|
| 982 |
-
return JSONResponse(status_code=500, content={"error": "reembed failed", "details": str(e)})
|
| 983 |
-
|
| 984 |
-
@app.get("/model-status")
|
| 985 |
-
async def model_status():
|
| 986 |
-
return {
|
| 987 |
-
"embed_loaded": embed_model is not None,
|
| 988 |
-
"ollama_cli": ollama_cli_available(),
|
| 989 |
-
"ollama_http": ollama_http_available(),
|
| 990 |
-
"moderator": moderator is not None,
|
| 991 |
-
"language_module": LANGUAGE_MODULE_AVAILABLE
|
| 992 |
-
}
|
| 993 |
-
|
| 994 |
-
@app.get("/health")
|
| 995 |
-
async def health():
|
| 996 |
-
try:
|
| 997 |
-
with engine_knowledge.connect() as c:
|
| 998 |
-
k = c.execute(sql_text("SELECT COUNT(*) FROM knowledge")).scalar() or 0
|
| 999 |
-
except Exception:
|
| 1000 |
-
k = -1
|
| 1001 |
-
try:
|
| 1002 |
-
with engine_user.connect() as c:
|
| 1003 |
-
u = c.execute(sql_text("SELECT COUNT(*) FROM user_memory")).scalar() or 0
|
| 1004 |
-
except Exception:
|
| 1005 |
-
u = -1
|
| 1006 |
-
return {"ok": True, "knowledge_count": int(k), "user_memory_count": int(u), "uptime_s": round(time.time() - app_start_time, 2), "heartbeat": last_heartbeat}
|
| 1007 |
-
|
| 1008 |
-
async def metrics_producer():
|
| 1009 |
-
while True:
|
| 1010 |
-
try:
|
| 1011 |
-
import psutil
|
| 1012 |
-
cpu = psutil.cpu_percent(interval=None)
|
| 1013 |
-
mem = psutil.virtual_memory()
|
| 1014 |
-
mem_percent = mem.percent
|
| 1015 |
-
except Exception:
|
| 1016 |
-
cpu = 0.0; mem_percent = 0.0
|
| 1017 |
-
payload = {"time": datetime.utcnow().isoformat(), "cpu_percent": cpu, "memory_percent": mem_percent}
|
| 1018 |
-
yield f"data: {json.dumps(payload)}\n\n"
|
| 1019 |
-
await asyncio.sleep(1.0)
|
| 1020 |
-
|
| 1021 |
-
@app.get("/metrics_stream")
|
| 1022 |
-
async def metrics_stream():
|
| 1023 |
-
return StreamingResponse(metrics_producer(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
|
| 1024 |
-
|
| 1025 |
-
@app.get("/metrics_recent")
|
| 1026 |
-
async def metrics_recent(limit: int = Query(100, ge=1, le=600)):
|
| 1027 |
-
return {"count": 0, "metrics": []}
|
| 1028 |
-
|
| 1029 |
@app.post("/verify-admin")
|
| 1030 |
async def verify_admin(x_admin_key: str = Header(None, alias="X-Admin-Key")):
|
| 1031 |
if ADMIN_KEY is None:
|
|
@@ -1085,33 +1196,19 @@ async def frontend_dashboard():
|
|
| 1085 |
html = html.replace("%%STARTUP_TIME%%", str(startup_time_local))
|
| 1086 |
return HTMLResponse(html)
|
| 1087 |
|
| 1088 |
-
#
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
negative = ["sad", "bad", "problem", "angry", "hate", "fail", "no", "error", "issue"]
|
| 1093 |
-
if any(w in lower for w in positive):
|
| 1094 |
-
return "positive"
|
| 1095 |
-
if any(w in lower for w in negative):
|
| 1096 |
-
return "negative"
|
| 1097 |
-
return "neutral"
|
| 1098 |
-
|
| 1099 |
-
def should_append_emoji(user_text: str, reply_text: str, mood: str, flags: Dict) -> str:
|
| 1100 |
-
if flags.get("toxic"):
|
| 1101 |
-
return ""
|
| 1102 |
-
if EMOJIS_AVAILABLE:
|
| 1103 |
try:
|
| 1104 |
-
|
| 1105 |
-
return get_emoji(cat, 0.6)
|
| 1106 |
except Exception:
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
except Exception:
|
| 1114 |
-
pass
|
| 1115 |
app_start_time = time.time()
|
| 1116 |
import uvicorn
|
| 1117 |
port = int(os.environ.get("PORT", 7860))
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
JusticeAI Backend — merged app.py
|
| 4 |
+
|
| 5 |
+
This file:
|
| 6 |
+
- Consolidates the JusticeAI backend (knowledge DB, user DB, /chat and other endpoints)
|
| 7 |
+
- Integrates Ollama topic inference (HTTP/CLI optional)
|
| 8 |
+
- Integrates optional embeddings (SentenceTransformer) and optional Helsinki translation models
|
| 9 |
+
- Adds a TTS /speak endpoint (voice cloning) using TTS.api (Coqui TTS) with optimizations for speed
|
| 10 |
+
- Keeps strict separation: user chat stored only in DATABASE_URL.user_memory and never used to mutate the global knowledge DB
|
| 11 |
+
- Prunes user_memory to the last 10 messages per user
|
| 12 |
+
- Attempts to minimize TTS latency by preloading, using GPU if available, using inference_mode / autocast,
|
| 13 |
+
and caching identical speaker samples by file hash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
from sqlalchemy.pool import NullPool
|
|
|
|
| 24 |
import shutil
|
| 25 |
import logging
|
| 26 |
import random
|
| 27 |
+
import tempfile
|
| 28 |
+
import uuid
|
| 29 |
+
import asyncio
|
| 30 |
from datetime import datetime, timezone
|
| 31 |
from collections import deque
|
| 32 |
from typing import Optional, Dict, Any, List
|
| 33 |
|
| 34 |
+
from fastapi import FastAPI, Request, Body, Query, Header, BackgroundTasks, File, UploadFile, Form, HTTPException, status
|
| 35 |
+
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse, FileResponse
|
| 36 |
from sqlalchemy import create_engine, text as sql_text
|
| 37 |
|
| 38 |
# external helpers
|
| 39 |
import requests
|
| 40 |
|
| 41 |
+
# ML libs (optional)
|
| 42 |
+
try:
|
| 43 |
+
import torch
|
| 44 |
+
except Exception:
|
| 45 |
+
torch = None
|
| 46 |
+
|
| 47 |
try:
|
| 48 |
from sentence_transformers import SentenceTransformer
|
| 49 |
except Exception:
|
| 50 |
SentenceTransformer = None
|
| 51 |
|
| 52 |
try:
|
| 53 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline
|
| 54 |
except Exception:
|
| 55 |
AutoTokenizer = None
|
| 56 |
AutoModelForSeq2SeqLM = None
|
|
|
|
| 57 |
hf_pipeline = None
|
| 58 |
|
| 59 |
+
# Optional TTS library (Coqui TTS)
|
| 60 |
+
try:
|
| 61 |
+
from TTS.api import TTS
|
| 62 |
+
TTS_AVAILABLE = True
|
| 63 |
+
except Exception:
|
| 64 |
+
TTS_AVAILABLE = False
|
| 65 |
+
|
| 66 |
# Optional local modules
|
| 67 |
try:
|
| 68 |
import language as language_module # type: ignore
|
|
|
|
| 85 |
except Exception:
|
| 86 |
detect_lang = None
|
| 87 |
|
| 88 |
+
# Moderator pipeline (optional)
|
| 89 |
moderator = None
|
| 90 |
try:
|
| 91 |
if hf_pipeline is not None:
|
|
|
|
| 93 |
except Exception:
|
| 94 |
moderator = None
|
| 95 |
|
| 96 |
+
# Config (env)
|
| 97 |
ADMIN_KEY = os.environ.get("ADMIN_KEY")
|
| 98 |
DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///justice_user.db")
|
| 99 |
KNOWLEDGEDATABASE_URL = os.environ.get("KNOWLEDGEDATABASE_URL", DATABASE_URL)
|
|
|
|
| 106 |
OLLAMA_HTTP_URL = os.environ.get("OLLAMA_HTTP_URL", "http://localhost:11434")
|
| 107 |
OLLAMA_AUTO_PULL = os.environ.get("OLLAMA_AUTO_PULL", "0") in ("1", "true", "yes")
|
| 108 |
|
| 109 |
+
# TTS configuration and speed options
|
| 110 |
+
TTS_MODEL_NAME = os.environ.get("TTS_MODEL_NAME", "tts_models/multilingual/multi-dataset/xtts_v2")
|
| 111 |
+
TTS_DEVICE = os.environ.get("TTS_DEVICE", "cuda" if (torch is not None and torch.cuda.is_available()) else "cpu")
|
| 112 |
+
TTS_USE_HALF = os.environ.get("TTS_USE_HALF", "1") in ("1", "true", "yes")
|
| 113 |
+
|
| 114 |
+
# Non-TTS operation timeout (for blocking calls we choose to limit)
|
| 115 |
+
MODEL_TIMEOUT = float(os.environ.get("MODEL_TIMEOUT", "10"))
|
| 116 |
+
|
| 117 |
# Logging
|
| 118 |
logging.basicConfig(level=logging.INFO)
|
| 119 |
logger = logging.getLogger("justicebrain")
|
| 120 |
|
| 121 |
+
# Heartbeat & startup
|
| 122 |
last_heartbeat = {"time": datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(), "ok": True}
|
| 123 |
app_start_time = time.time()
|
| 124 |
|
| 125 |
+
# Engines (separate DBs)
|
| 126 |
engine_user = create_engine(
|
| 127 |
DATABASE_URL,
|
| 128 |
poolclass=NullPool,
|
|
|
|
| 138 |
|
| 139 |
# --- Database schema setup ---
|
| 140 |
def ensure_tables():
|
| 141 |
+
# knowledge table
|
| 142 |
dialect_k = engine_knowledge.dialect.name
|
| 143 |
with engine_knowledge.begin() as conn:
|
| 144 |
if dialect_k == "sqlite":
|
|
|
|
| 175 |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 176 |
);
|
| 177 |
"""))
|
| 178 |
+
# user memory table
|
| 179 |
dialect_u = engine_user.dialect.name
|
| 180 |
with engine_user.begin() as conn:
|
| 181 |
if dialect_u == "sqlite":
|
|
|
|
| 219 |
|
| 220 |
ensure_tables()
|
| 221 |
|
|
|
|
| 222 |
def ensure_column_exists(table: str, column: str, col_def_sql: str):
|
| 223 |
dialect = engine_user.dialect.name
|
| 224 |
try:
|
|
|
|
| 301 |
score += 0.1
|
| 302 |
return max(-1.0, min(1.0, score / max(1, len(emojis))))
|
| 303 |
|
| 304 |
+
# --- Language detection & translation ---
|
| 305 |
_translation_model_cache: Dict[str, Any] = {}
|
| 306 |
|
| 307 |
def detect_language_safe(text: str) -> str:
|
|
|
|
| 361 |
return out
|
| 362 |
except Exception:
|
| 363 |
pass
|
|
|
|
| 364 |
src_code = (src or "und").split("-")[0].lower()
|
| 365 |
tgt_code = (tgt or "und").split("-")[0].lower()
|
| 366 |
if not re.fullmatch(r"[a-z]{2,3}", src_code) or not re.fullmatch(r"[a-z]{2,3}", tgt_code):
|
|
|
|
| 399 |
return text
|
| 400 |
return translate_text(text, "en", tgt)
|
| 401 |
|
| 402 |
+
# --- Embeddings utilities ---
|
| 403 |
embed_model = None
|
| 404 |
def try_load_embed():
|
| 405 |
global embed_model
|
|
|
|
| 422 |
except Exception:
|
| 423 |
return None
|
| 424 |
|
| 425 |
+
# --- Helpers for running blocking code with a timeout (for non-TTS operations) ---
|
| 426 |
+
async def run_blocking_with_timeout(func, *args, timeout: float = MODEL_TIMEOUT):
|
| 427 |
+
loop = asyncio.get_running_loop()
|
| 428 |
+
fut = loop.run_in_executor(None, lambda: func(*args))
|
| 429 |
+
return await asyncio.wait_for(fut, timeout=timeout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
+
# --- Ollama helpers (HTTP & CLI) ---
|
| 432 |
def ollama_cli_available() -> bool:
|
| 433 |
return shutil.which("ollama") is not None
|
| 434 |
|
|
|
|
| 439 |
except Exception:
|
| 440 |
return False
|
| 441 |
|
| 442 |
+
def call_ollama_http(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
|
| 443 |
try:
|
| 444 |
url = f"{OLLAMA_HTTP_URL}/api/generate"
|
| 445 |
payload = {"model": model, "prompt": prompt, "max_tokens": 256}
|
| 446 |
headers = {"Content-Type": "application/json"}
|
| 447 |
+
r = requests.post(url, json=payload, headers=headers, timeout=min(timeout_s, MODEL_TIMEOUT))
|
| 448 |
if r.status_code == 200:
|
| 449 |
try:
|
| 450 |
obj = r.json()
|
| 451 |
+
for key in ("output", "text", "result", "generations"):
|
| 452 |
+
if key in obj:
|
| 453 |
+
return obj[key] if isinstance(obj[key], str) else json.dumps(obj[key])
|
|
|
|
| 454 |
return r.text
|
| 455 |
except Exception:
|
| 456 |
return r.text
|
|
|
|
| 461 |
logger.debug(f"ollama HTTP call failed: {e}")
|
| 462 |
return None
|
| 463 |
|
| 464 |
+
def call_ollama_cli(prompt: str, model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
|
| 465 |
if not ollama_cli_available():
|
| 466 |
return None
|
| 467 |
try:
|
| 468 |
+
proc = subprocess.run(["ollama", "run", model, "--prompt", prompt], capture_output=True, text=True, timeout=min(timeout_s, MODEL_TIMEOUT))
|
| 469 |
if proc.returncode == 0:
|
| 470 |
return proc.stdout.strip() or proc.stderr.strip()
|
| 471 |
else:
|
|
|
|
| 475 |
logger.debug(f"ollama CLI call exception: {e}")
|
| 476 |
return None
|
| 477 |
|
| 478 |
+
def infer_topic_with_ollama(msg: str, topics: List[str], model: str = OLLAMA_MODEL, timeout_s: int = MODEL_TIMEOUT) -> Optional[str]:
|
| 479 |
if not msg or not topics:
|
| 480 |
return None
|
| 481 |
topics_escaped = [t.replace('"','\\"') for t in topics]
|
|
|
|
| 529 |
pass
|
| 530 |
return None
|
| 531 |
|
| 532 |
+
# --- Boilerplate detection & reply synthesis helpers ---
|
| 533 |
+
def is_boilerplate_candidate(s: str) -> bool:
|
| 534 |
+
s_low = (s or "").strip().lower()
|
| 535 |
+
generic = ["i don't know", "not sure", "maybe", "perhaps", "justiceai is a unified intelligence dashboard"]
|
| 536 |
+
if len(s_low) < 8:
|
| 537 |
+
return True
|
| 538 |
+
return any(g in s_low for g in generic)
|
| 539 |
+
|
| 540 |
+
def generate_creative_reply(candidates: List[str]) -> str:
|
| 541 |
+
all_sent = []
|
| 542 |
+
seen = set()
|
| 543 |
+
for c in candidates:
|
| 544 |
+
for s in re.split(r'(?<=[.?!])\s+', c):
|
| 545 |
+
st = s.strip()
|
| 546 |
+
if not st or st in seen or is_boilerplate_candidate(st):
|
| 547 |
+
continue
|
| 548 |
+
seen.add(st)
|
| 549 |
+
all_sent.append(st)
|
| 550 |
+
if not all_sent:
|
| 551 |
+
return "I don't have enough context yet — can you give more details?"
|
| 552 |
+
return "\n".join(all_sent[:5])
|
| 553 |
+
|
| 554 |
+
# --- TTS: optimized loader, caching speaker files ---
|
| 555 |
+
_tts_model = None
|
| 556 |
+
_tts_lock = threading.Lock()
|
| 557 |
+
_speaker_hash_cache: Dict[str, str] = {}
|
| 558 |
+
_tts_loaded_event = threading.Event()
|
| 559 |
+
|
| 560 |
+
def compute_file_sha256(path: str) -> str:
|
| 561 |
+
h = hashlib.sha256()
|
| 562 |
+
with open(path, "rb") as f:
|
| 563 |
+
while True:
|
| 564 |
+
b = f.read(8192)
|
| 565 |
+
if not b:
|
| 566 |
+
break
|
| 567 |
+
h.update(b)
|
| 568 |
+
return h.hexdigest()
|
| 569 |
+
|
| 570 |
+
def get_tts_model_blocking():
|
| 571 |
+
global _tts_model
|
| 572 |
+
if not TTS_AVAILABLE:
|
| 573 |
+
raise RuntimeError("TTS.api not available on server")
|
| 574 |
+
with _tts_lock:
|
| 575 |
+
if _tts_model is None:
|
| 576 |
+
model_name = os.environ.get("TTS_MODEL_NAME", TTS_MODEL_NAME)
|
| 577 |
+
device = os.environ.get("TTS_DEVICE", TTS_DEVICE)
|
| 578 |
+
logger.info(f"[TTS] Loading model {model_name} on device {device}")
|
| 579 |
+
_tts_model = TTS(model_name)
|
| 580 |
+
try:
|
| 581 |
+
if device and torch is not None:
|
| 582 |
+
if device.startswith("cuda") and torch.cuda.is_available():
|
| 583 |
+
try:
|
| 584 |
+
_tts_model.to(device)
|
| 585 |
+
except Exception:
|
| 586 |
+
pass
|
| 587 |
+
try:
|
| 588 |
+
torch.backends.cudnn.benchmark = True
|
| 589 |
+
except Exception:
|
| 590 |
+
pass
|
| 591 |
+
if TTS_USE_HALF:
|
| 592 |
+
try:
|
| 593 |
+
if hasattr(_tts_model, "model") and hasattr(_tts_model.model, "half"):
|
| 594 |
+
_tts_model.model.half()
|
| 595 |
+
except Exception:
|
| 596 |
+
pass
|
| 597 |
+
try:
|
| 598 |
+
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
|
| 599 |
+
except Exception:
|
| 600 |
+
pass
|
| 601 |
+
else:
|
| 602 |
+
try:
|
| 603 |
+
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
|
| 604 |
+
except Exception:
|
| 605 |
+
pass
|
| 606 |
+
except Exception as e:
|
| 607 |
+
logger.debug(f"[TTS] model device tuning warning: {e}")
|
| 608 |
+
logger.info("[TTS] model loaded")
|
| 609 |
+
_tts_loaded_event.set()
|
| 610 |
+
return _tts_model
|
| 611 |
+
|
| 612 |
+
def _save_upload_file_tmp(upload_file: UploadFile) -> str:
|
| 613 |
+
suffix = os.path.splitext(upload_file.filename)[1] or ".wav"
|
| 614 |
+
fd, tmp_path = tempfile.mkstemp(suffix=suffix, prefix="tts_speaker_")
|
| 615 |
+
os.close(fd)
|
| 616 |
+
with open(tmp_path, "wb") as f:
|
| 617 |
+
content = upload_file.file.read()
|
| 618 |
+
f.write(content)
|
| 619 |
+
return tmp_path
|
| 620 |
+
|
| 621 |
+
# Preload TTS in background at process start
|
| 622 |
+
if TTS_AVAILABLE:
|
| 623 |
+
threading.Thread(target=lambda: (get_tts_model_blocking()), daemon=True).start()
|
| 624 |
+
|
| 625 |
+
@app.post("/speak")
|
| 626 |
+
async def speak(
|
| 627 |
+
background_tasks: BackgroundTasks,
|
| 628 |
+
text: str = Form(...),
|
| 629 |
+
voice_wav: Optional[UploadFile] = File(None),
|
| 630 |
+
language: Optional[str] = Form(None),
|
| 631 |
+
):
|
| 632 |
+
"""
|
| 633 |
+
Generate speech for `text`. Optionally use an uploaded `voice_wav` (WAV) file as speaker sample.
|
| 634 |
+
This endpoint aims for speed by using a preloaded model and GPU/half precision if configured.
|
| 635 |
+
"""
|
| 636 |
+
if not text or not text.strip():
|
| 637 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Field 'text' is required")
|
| 638 |
+
if not TTS_AVAILABLE:
|
| 639 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="TTS engine not available")
|
| 640 |
+
|
| 641 |
+
speaker_path = None
|
| 642 |
+
speaker_hash = None
|
| 643 |
+
if voice_wav is not None:
|
| 644 |
+
try:
|
| 645 |
+
speaker_path = _save_upload_file_tmp(voice_wav)
|
| 646 |
+
speaker_hash = compute_file_sha256(speaker_path)
|
| 647 |
+
cached = _speaker_hash_cache.get(speaker_hash)
|
| 648 |
+
if cached and os.path.exists(cached):
|
| 649 |
+
try:
|
| 650 |
+
os.remove(speaker_path)
|
| 651 |
+
except Exception:
|
| 652 |
+
pass
|
| 653 |
+
speaker_path = cached
|
| 654 |
+
else:
|
| 655 |
+
_speaker_hash_cache[speaker_hash] = speaker_path
|
| 656 |
+
except Exception as e:
|
| 657 |
+
logger.exception("Failed to save uploaded voice sample")
|
| 658 |
+
raise HTTPException(status_code=500, detail="Failed to process uploaded voice sample")
|
| 659 |
+
|
| 660 |
+
out_fd, out_path = tempfile.mkstemp(suffix=".wav", prefix="tts_out_")
|
| 661 |
+
os.close(out_fd)
|
| 662 |
+
background_tasks.add_task(lambda p: os.path.exists(p) and os.remove(p), out_path)
|
| 663 |
+
|
| 664 |
+
try:
|
| 665 |
+
tts = get_tts_model_blocking()
|
| 666 |
+
except Exception as e:
|
| 667 |
+
logger.exception("[TTS] model load failed")
|
| 668 |
+
try:
|
| 669 |
+
if os.path.exists(out_path):
|
| 670 |
+
os.remove(out_path)
|
| 671 |
+
except Exception:
|
| 672 |
+
pass
|
| 673 |
+
raise HTTPException(status_code=500, detail="Failed to load TTS model")
|
| 674 |
+
|
| 675 |
+
kwargs = {}
|
| 676 |
+
if speaker_path:
|
| 677 |
+
kwargs["speaker_wav"] = speaker_path
|
| 678 |
+
if language:
|
| 679 |
+
kwargs["language"] = language
|
| 680 |
+
|
| 681 |
+
try:
|
| 682 |
+
if torch is not None and torch.cuda.is_available() and TTS_USE_HALF:
|
| 683 |
+
try:
|
| 684 |
+
with torch.inference_mode():
|
| 685 |
+
with torch.cuda.amp.autocast():
|
| 686 |
+
tts.tts_to_file(text=text, file_path=out_path, **kwargs)
|
| 687 |
+
except Exception as e:
|
| 688 |
+
logger.debug(f"[TTS] autocast path failed: {e}, falling back")
|
| 689 |
+
with torch.inference_mode():
|
| 690 |
+
tts.tts_to_file(text=text, file_path=out_path, **kwargs)
|
| 691 |
+
else:
|
| 692 |
+
if torch is not None:
|
| 693 |
+
with torch.inference_mode():
|
| 694 |
+
tts.tts_to_file(text=text, file_path=out_path, **kwargs)
|
| 695 |
+
else:
|
| 696 |
+
tts.tts_to_file(text=text, file_path=out_path, **kwargs)
|
| 697 |
+
except Exception as e:
|
| 698 |
+
logger.exception("[TTS] synthesis failed")
|
| 699 |
+
try:
|
| 700 |
+
if os.path.exists(out_path):
|
| 701 |
+
os.remove(out_path)
|
| 702 |
+
except Exception:
|
| 703 |
+
pass
|
| 704 |
+
raise HTTPException(status_code=500, detail="TTS synthesis failed")
|
| 705 |
+
|
| 706 |
+
filename = f"speech-{uuid.uuid4().hex}.wav"
|
| 707 |
+
return FileResponse(path=out_path, filename=filename, media_type="audio/wav", background=background_tasks)
|
| 708 |
+
|
| 709 |
+
# --- Metrics & caches ---
|
| 710 |
recent_request_times = deque()
|
| 711 |
recent_learning_timestamps = deque()
|
| 712 |
response_time_ema: Optional[float] = None
|
| 713 |
EMA_ALPHA = 0.2
|
|
|
|
| 714 |
|
| 715 |
def record_request(duration_s: float):
|
| 716 |
global response_time_ema
|
|
|
|
| 729 |
while recent_learning_timestamps and recent_learning_timestamps[0] < ts - 3600:
|
| 730 |
recent_learning_timestamps.popleft()
|
| 731 |
|
| 732 |
+
# --- Startup event: warm up optional components ---
|
| 733 |
@app.on_event("startup")
|
| 734 |
async def startup_event():
|
| 735 |
+
logger.info("[JusticeAI] startup event beginning")
|
| 736 |
+
# Try to warmup embedding model quickly in background
|
| 737 |
+
if SentenceTransformer is not None:
|
| 738 |
+
def _warm_embed():
|
| 739 |
+
try:
|
| 740 |
+
try_load_embed()
|
| 741 |
+
logger.info("[startup] embed model warmup complete")
|
| 742 |
+
except Exception as e:
|
| 743 |
+
logger.debug(f"[startup] embed warmup issue: {e}")
|
| 744 |
+
threading.Thread(target=_warm_embed, daemon=True).start()
|
| 745 |
+
# Optionally attempt ollama pull (best-effort)
|
| 746 |
if OLLAMA_AUTO_PULL and ollama_cli_available():
|
| 747 |
try:
|
| 748 |
subprocess.run(["ollama", "pull", OLLAMA_MODEL], timeout=300)
|
|
|
|
| 751 |
logger.debug(f"[startup] ollama pull failed: {e}")
|
| 752 |
logger.info("[JusticeAI] startup complete")
|
| 753 |
|
| 754 |
+
# --- Knowledge management endpoints ---
|
|
|
|
| 755 |
@app.post("/add")
|
| 756 |
async def add_knowledge(data: dict = Body(...)):
|
| 757 |
if not isinstance(data, dict):
|
|
|
|
| 772 |
return JSONResponse(status_code=400, content={"error": "translation failed"})
|
| 773 |
emb_bytes = None
|
| 774 |
if embed_model is not None:
|
| 775 |
+
try:
|
| 776 |
+
emb_bytes = await run_blocking_with_timeout(lambda: embed_to_bytes(text_data), timeout=MODEL_TIMEOUT)
|
| 777 |
+
except Exception:
|
| 778 |
+
emb_bytes = None
|
| 779 |
try:
|
| 780 |
with engine_knowledge.begin() as conn:
|
| 781 |
if emb_bytes:
|
|
|
|
| 788 |
"INSERT INTO knowledge (text, reply, language, category, topic, confidence, meta) "
|
| 789 |
"VALUES (:t, :r, :lang, 'manual', :topic, :conf, :meta)"
|
| 790 |
), {"t": text_data, "r": reply, "lang": detected, "topic": topic, "conf": 0.9, "meta": json.dumps({"manual": True})})
|
|
|
|
|
|
|
| 791 |
record_learn_event()
|
| 792 |
+
return {"status": "✅ Knowledge added", "text": text_data, "topic": topic, "language": detected}
|
|
|
|
|
|
|
|
|
|
| 793 |
except Exception as e:
|
| 794 |
logger.exception("add failed")
|
| 795 |
return JSONResponse(status_code=500, content={"error": "failed to store knowledge", "details": str(e)})
|
|
|
|
| 812 |
detected = detect_language_safe(text_data) or "und"
|
| 813 |
if detected not in ("en", "eng", "und"):
|
| 814 |
errors.append({"index": i, "error": "non-english; skip"}); continue
|
| 815 |
+
emb_bytes = None
|
| 816 |
+
if embed_model is not None:
|
| 817 |
+
try:
|
| 818 |
+
emb_bytes = await run_blocking_with_timeout(lambda: embed_to_bytes(text_data), timeout=MODEL_TIMEOUT)
|
| 819 |
+
except Exception:
|
| 820 |
+
emb_bytes = None
|
| 821 |
with engine_knowledge.begin() as conn:
|
| 822 |
if emb_bytes:
|
| 823 |
conn.execute(sql_text(
|
| 824 |
+
"INSERT INTO knowledge (text, reply, language, embedding, category, topic) VALUES (:t, :r, :lang, :e, 'manual', :topic)"
|
|
|
|
| 825 |
), {"t": text_data, "r": reply, "lang": "en", "e": emb_bytes, "topic": topic})
|
| 826 |
else:
|
| 827 |
conn.execute(sql_text(
|
| 828 |
+
"INSERT INTO knowledge (text, reply, language, category, topic) VALUES (:t, :r, :lang, 'manual', :topic)"
|
|
|
|
| 829 |
), {"t": text_data, "r": reply, "lang": "en", "topic": topic})
|
| 830 |
added += 1
|
| 831 |
except Exception as e:
|
| 832 |
logger.exception("add-bulk item error")
|
| 833 |
errors.append({"index": i, "error": str(e)})
|
| 834 |
if added:
|
|
|
|
|
|
|
| 835 |
record_learn_event()
|
| 836 |
return {"added": added, "errors": errors}
|
| 837 |
|
| 838 |
+
@app.get("/leaderboard")
|
| 839 |
+
async def leaderboard(topic: str = Query("general")):
|
| 840 |
+
t = str(topic or "general").strip() or "general"
|
| 841 |
+
try:
|
| 842 |
+
with engine_knowledge.begin() as conn:
|
| 843 |
+
rows = conn.execute(sql_text("""
|
| 844 |
+
SELECT id, text, reply, language, category, confidence, created_at
|
| 845 |
+
FROM knowledge
|
| 846 |
+
WHERE topic = :topic
|
| 847 |
+
ORDER BY confidence DESC, created_at DESC
|
| 848 |
+
LIMIT 20
|
| 849 |
+
"""), {"topic": t}).fetchall()
|
| 850 |
+
out = []
|
| 851 |
+
for r in rows:
|
| 852 |
+
text_en = r[1] or ""
|
| 853 |
+
lang = r[3] or "und"
|
| 854 |
+
display_text = text_en
|
| 855 |
+
if lang and lang not in ("en", "eng", "", "und"):
|
| 856 |
+
try:
|
| 857 |
+
display_text = translate_to_english(text_en, lang)
|
| 858 |
+
except Exception:
|
| 859 |
+
display_text = text_en
|
| 860 |
+
created_at = r[6]
|
| 861 |
+
out.append({
|
| 862 |
+
"id": r[0],
|
| 863 |
+
"text": display_text,
|
| 864 |
+
"reply": r[2],
|
| 865 |
+
"language": lang,
|
| 866 |
+
"category": r[4],
|
| 867 |
+
"confidence": round(r[5] or 0.0, 2),
|
| 868 |
+
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else str(created_at)
|
| 869 |
+
})
|
| 870 |
+
return {"topic": t, "top_20": out}
|
| 871 |
+
except Exception as e:
|
| 872 |
+
logger.exception("leaderboard failed")
|
| 873 |
+
return JSONResponse(status_code=500, content={"error": "failed to fetch leaderboard", "details": str(e)})
|
| 874 |
+
|
| 875 |
+
@app.post("/reembed")
|
| 876 |
+
async def reembed_all(data: dict = Body(...), x_admin_key: str = Header(None, alias="X-Admin-Key")):
|
| 877 |
+
if ADMIN_KEY is None:
|
| 878 |
+
return JSONResponse(status_code=403, content={"error": "Server not configured for admin operations."})
|
| 879 |
+
if x_admin_key != ADMIN_KEY:
|
| 880 |
+
return JSONResponse(status_code=403, content={"error": "Invalid admin key."})
|
| 881 |
+
if embed_model is None:
|
| 882 |
+
return JSONResponse(status_code=503, content={"error": "Embedding model not ready."})
|
| 883 |
+
confirm = str(data.get("confirm", "") or "").strip()
|
| 884 |
+
if confirm != "REEMBED":
|
| 885 |
+
return JSONResponse(status_code=400, content={"error": "confirm token required."})
|
| 886 |
+
batch_size = int(data.get("batch_size", 100))
|
| 887 |
+
try:
|
| 888 |
+
with engine_knowledge.begin() as conn:
|
| 889 |
+
rows = conn.execute(sql_text("SELECT id, text FROM knowledge ORDER BY id")).fetchall()
|
| 890 |
+
ids_texts = [(r[0], r[1]) for r in rows]
|
| 891 |
+
total = len(ids_texts)
|
| 892 |
+
updated = 0
|
| 893 |
+
for i in range(0, total, batch_size):
|
| 894 |
+
batch = ids_texts[i:i+batch_size]
|
| 895 |
+
texts = [t for _, t in batch]
|
| 896 |
+
try:
|
| 897 |
+
embs = await run_blocking_with_timeout(lambda: embed_model.encode(texts, convert_to_tensor=True), timeout=MODEL_TIMEOUT)
|
| 898 |
+
except Exception:
|
| 899 |
+
embs = None
|
| 900 |
+
if embs is None:
|
| 901 |
+
continue
|
| 902 |
+
for j, (kid, _) in enumerate(batch):
|
| 903 |
+
emb_bytes = embs[j].cpu().numpy().tobytes()
|
| 904 |
+
with engine_knowledge.begin() as conn:
|
| 905 |
+
conn.execute(sql_text("UPDATE knowledge SET embedding = :e, updated_at = CURRENT_TIMESTAMP WHERE id = :id"), {"e": emb_bytes, "id": kid})
|
| 906 |
+
updated += 1
|
| 907 |
+
return {"status": "✅ Re-embed complete", "total_rows": total, "updated": updated}
|
| 908 |
+
except Exception as e:
|
| 909 |
+
logger.exception("reembed failed")
|
| 910 |
+
return JSONResponse(status_code=500, content={"error": "reembed failed", "details": str(e)})
|
| 911 |
+
|
| 912 |
+
@app.get("/model-status")
|
| 913 |
+
async def model_status():
|
| 914 |
+
return {
|
| 915 |
+
"embed_loaded": embed_model is not None,
|
| 916 |
+
"ollama_cli": ollama_cli_available(),
|
| 917 |
+
"ollama_http": ollama_http_available(),
|
| 918 |
+
"moderator": moderator is not None,
|
| 919 |
+
"language_module": LANGUAGE_MODULE_AVAILABLE,
|
| 920 |
+
"tts_available": TTS_AVAILABLE
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
@app.get("/health")
|
| 924 |
+
async def health():
|
| 925 |
+
try:
|
| 926 |
+
with engine_knowledge.connect() as c:
|
| 927 |
+
k = c.execute(sql_text("SELECT COUNT(*) FROM knowledge")).scalar() or 0
|
| 928 |
+
except Exception:
|
| 929 |
+
k = -1
|
| 930 |
+
try:
|
| 931 |
+
with engine_user.connect() as c:
|
| 932 |
+
u = c.execute(sql_text("SELECT COUNT(*) FROM user_memory")).scalar() or 0
|
| 933 |
+
except Exception:
|
| 934 |
+
u = -1
|
| 935 |
+
return {"ok": True, "knowledge_count": int(k), "user_memory_count": int(u), "uptime_s": round(time.time() - app_start_time, 2), "heartbeat": last_heartbeat}
|
| 936 |
+
|
| 937 |
@app.post("/chat")
|
| 938 |
async def chat(request: Request, data: dict = Body(...)):
|
| 939 |
t0 = time.time()
|
| 940 |
+
# Accept "message" or "text"
|
| 941 |
if isinstance(data, dict):
|
| 942 |
raw_msg = str(data.get("message", "") or data.get("text", "") or "").strip()
|
| 943 |
else:
|
|
|
|
| 955 |
detected_lang = detect_language_safe(raw_msg)
|
| 956 |
reply_lang = detected_lang if detected_lang and detected_lang != "und" else "en"
|
| 957 |
|
| 958 |
+
# Translate incoming to English for retrieval if needed
|
| 959 |
en_msg = raw_msg
|
| 960 |
if detected_lang not in ("en", "eng", "", "und"):
|
| 961 |
try:
|
|
|
|
| 963 |
except Exception:
|
| 964 |
en_msg = raw_msg
|
| 965 |
|
| 966 |
+
# Determine topic: Ollama first, then embedding, then keyword
|
| 967 |
topic = "general"
|
| 968 |
try:
|
| 969 |
if not topic_hint:
|
|
|
|
| 986 |
except Exception:
|
| 987 |
topic = topic_hint or "general"
|
| 988 |
|
| 989 |
+
# Moderation
|
| 990 |
flags = {}
|
| 991 |
try:
|
| 992 |
if moderator is not None:
|
|
|
|
| 999 |
except Exception:
|
| 1000 |
pass
|
| 1001 |
|
| 1002 |
+
# Load topic-scoped knowledge
|
|
|
|
|
|
|
|
|
|
| 1003 |
try:
|
| 1004 |
with engine_knowledge.begin() as conn:
|
| 1005 |
+
rows = conn.execute(sql_text("SELECT id, text, reply, language, embedding FROM knowledge WHERE topic = :topic ORDER BY created_at DESC"), {"topic": topic}).fetchall()
|
|
|
|
|
|
|
| 1006 |
except Exception as e:
|
| 1007 |
record_request(time.time() - t0)
|
| 1008 |
return JSONResponse(status_code=500, content={"error": "failed to read knowledge", "details": str(e)})
|
| 1009 |
|
| 1010 |
knowledge_rows = [{"id": r[0], "text": r[1] or "", "reply": r[2] or "", "lang": r[3] or "und", "embedding": r[4]} for r in rows]
|
| 1011 |
|
| 1012 |
+
# Retrieval (embedding-first)
|
| 1013 |
matches: List[str] = []
|
| 1014 |
confidence = 0.0
|
| 1015 |
try:
|
| 1016 |
if embed_model is not None and knowledge_rows:
|
| 1017 |
texts = [kr["text"] for kr in knowledge_rows]
|
| 1018 |
+
try:
|
| 1019 |
+
embs = await run_blocking_with_timeout(lambda: embed_model.encode(texts, convert_to_tensor=True), timeout=MODEL_TIMEOUT)
|
| 1020 |
+
q_emb = await run_blocking_with_timeout(lambda: embed_model.encode([en_msg], convert_to_tensor=True)[0], timeout=MODEL_TIMEOUT)
|
| 1021 |
+
import torch as _torch
|
| 1022 |
+
scores = _torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), embs)
|
| 1023 |
+
cand = []
|
| 1024 |
+
for i in range(scores.shape[0]):
|
| 1025 |
+
s = float(scores[i])
|
| 1026 |
+
kr = knowledge_rows[i]
|
| 1027 |
+
candidate_text = (kr["reply"] or kr["text"]).strip()
|
| 1028 |
+
if is_boilerplate_candidate(candidate_text):
|
| 1029 |
+
continue
|
| 1030 |
+
if s >= 0.30:
|
| 1031 |
+
cand.append({"text": candidate_text, "lang": kr["lang"], "score": s})
|
| 1032 |
+
cand = sorted(cand, key=lambda x: -x["score"])
|
| 1033 |
+
matches = [c["text"] for c in cand]
|
| 1034 |
+
confidence = cand[0]["score"] if cand else 0.0
|
| 1035 |
+
except asyncio.TimeoutError:
|
| 1036 |
+
logger.warning("[retrieval] embedding encode timed out")
|
| 1037 |
+
matches = []
|
| 1038 |
+
except Exception as e:
|
| 1039 |
+
logger.warning(f"[retrieval] embedding error: {e}")
|
| 1040 |
+
matches = []
|
| 1041 |
else:
|
| 1042 |
cand = []
|
| 1043 |
for kr in knowledge_rows:
|
|
|
|
| 1052 |
logger.warning(f"[retrieval] error: {e}")
|
| 1053 |
matches = []
|
| 1054 |
|
| 1055 |
+
# Compose reply strictly from topic matches
|
| 1056 |
if matches and confidence >= 0.6:
|
| 1057 |
reply_en = matches[0]
|
| 1058 |
elif matches:
|
|
|
|
| 1065 |
except Exception:
|
| 1066 |
pass
|
| 1067 |
reply_final = base
|
|
|
|
| 1068 |
try:
|
| 1069 |
if not flags.get('toxic', False):
|
| 1070 |
with engine_user.begin() as conn:
|
|
|
|
| 1073 |
"VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
|
| 1074 |
), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
|
| 1075 |
"mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
|
|
|
|
| 1076 |
conn.execute(sql_text(
|
| 1077 |
+
"DELETE FROM user_memory WHERE id NOT IN (SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
|
|
|
|
| 1078 |
), {"uid": user_id})
|
| 1079 |
except Exception as e:
|
| 1080 |
logger.debug(f"user_memory store error: {e}")
|
| 1081 |
record_request(time.time() - t0)
|
| 1082 |
+
return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": "", "confidence": round(confidence,2), "flags": flags}
|
| 1083 |
|
| 1084 |
+
# Postprocess reply_en
|
| 1085 |
reply_en = dedupe_sentences(reply_en)
|
| 1086 |
|
| 1087 |
+
# Translate to user's language if needed
|
| 1088 |
reply_final = reply_en
|
| 1089 |
lang_code = (reply_lang or "und").split("-")[0].lower()
|
| 1090 |
if lang_code not in ("en", "eng", "und", ""):
|
|
|
|
| 1095 |
logger.warning(f"[translation] failed to translate reply_en -> {lang_code}: {exc}")
|
| 1096 |
reply_final = reply_en
|
| 1097 |
|
| 1098 |
+
# Mood & emoji
|
| 1099 |
emoji = ""
|
| 1100 |
try:
|
| 1101 |
mood = detect_mood(raw_msg + " " + reply_final)
|
|
|
|
| 1110 |
except Exception:
|
| 1111 |
emoji = ""
|
| 1112 |
|
| 1113 |
+
# Persist user memory (only in user DB) and prune to last 10
|
| 1114 |
try:
|
| 1115 |
if not flags.get('toxic', False):
|
| 1116 |
with engine_user.begin() as conn:
|
|
|
|
| 1119 |
"VALUES (:uid, :uname, :ip, :text, :reply, :lang, :mood, :conf, :topic, :source)"
|
| 1120 |
), {"uid": user_id, "uname": username, "ip": user_ip, "text": raw_msg, "reply": reply_final, "lang": detected_lang,
|
| 1121 |
"mood": detect_mood(raw_msg + " " + reply_final), "conf": float(confidence), "topic": topic, "source": "chat"})
|
|
|
|
| 1122 |
conn.execute(sql_text(
|
| 1123 |
+
"DELETE FROM user_memory WHERE id NOT IN (SELECT id FROM user_memory WHERE user_id = :uid ORDER BY created_at DESC LIMIT 10) AND user_id = :uid"
|
|
|
|
| 1124 |
), {"uid": user_id})
|
| 1125 |
except Exception as e:
|
| 1126 |
logger.debug(f"user_memory persist error: {e}")
|
|
|
|
| 1131 |
if include_steps:
|
| 1132 |
reply_final = f"{reply_final}\n\n[Debug: topic={topic} confidence={round(confidence,2)}]"
|
| 1133 |
|
| 1134 |
+
return {"reply": reply_final, "topic": topic, "language": reply_lang, "emoji": emoji, "confidence": round(confidence,2), "flags": flags}
|
| 1135 |
|
| 1136 |
@app.post("/response")
|
| 1137 |
async def response_wrapper(request: Request, data: dict = Body(...)):
|
| 1138 |
return await chat(request, data)
|
| 1139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1140 |
@app.post("/verify-admin")
|
| 1141 |
async def verify_admin(x_admin_key: str = Header(None, alias="X-Admin-Key")):
|
| 1142 |
if ADMIN_KEY is None:
|
|
|
|
| 1196 |
html = html.replace("%%STARTUP_TIME%%", str(startup_time_local))
|
| 1197 |
return HTMLResponse(html)
|
| 1198 |
|
| 1199 |
+
# --- Start app ---
|
| 1200 |
+
if __name__ == "__main__":
|
| 1201 |
+
# preload embed and TTS in background
|
| 1202 |
+
if TTS_AVAILABLE:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1203 |
try:
|
| 1204 |
+
threading.Thread(target=lambda: get_tts_model_blocking(), daemon=True).start()
|
|
|
|
| 1205 |
except Exception:
|
| 1206 |
+
pass
|
| 1207 |
+
if SentenceTransformer is not None:
|
| 1208 |
+
try:
|
| 1209 |
+
threading.Thread(target=try_load_embed, daemon=True).start()
|
| 1210 |
+
except Exception:
|
| 1211 |
+
pass
|
|
|
|
|
|
|
| 1212 |
app_start_time = time.time()
|
| 1213 |
import uvicorn
|
| 1214 |
port = int(os.environ.get("PORT", 7860))
|